diff --git a/README.md b/README.md index de494024742b1c66f24a3a9ae3547e1d8e6dfed3..3882afe676eb021d13f83259991caa16dc277cc4 100644 --- a/README.md +++ b/README.md @@ -1,103 +1,123 @@ -# BOMLIP-CSP - -An open-source Python framework that integrates machine learning interatomic -potentials (MLIPs) with a tailored batched optimization strategy, enabling rapid, -unbiased structure prediction across the full density range - -## Perform the complete CSP process - -```sh -git clone https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP.git --recursive && cd BOMLIP-CSP - -conda create -n BOMLIP_CSP python=3.10 -y && conda activate BOMLIP_CSP -cd BOMLIP-CSP -top_dir=$(pwd) -cd $top_dir/mace-bench -./reproduce/init_mace.sh && source util/env.sh -sudo ./util/mps_start.sh - -cd $top_dir -./csp.sh - -sudo ./util/mps_clean.sh -``` -## Reproduce mace batch opt speedup. - -```sh -#!/bin/bash - -git clone https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP.git --recursive && cd BOMLIP-CSP -conda create -n BOMLIP_CSP python=3.10 -y && conda activate BOMLIP_CSP -cd BOMLIP-CSP/mace-bench - -# initialize mace env. -./reproduce/init_mace.sh && source util/env.sh -sudo ./util/mps_start.sh -cd reproduce - -# run baseline sub-test -./subtest_baseline.sh - -# run baseline mixed test -cd perf_v2_base -./run_mace.sh - -# run BOMLIP_CSP sub-test -cd ../ -./subtest.sh - -# run BOMLIP_CSP mixed test -cd perf_v2_batch -./opt.sh - -# clean mps -./util/mps_clean.sh - -``` - -## If you want to configure the 7net environment. - -```sh -#!/bin/bash -conda create -n 7net-cueq python=3.10 -y && conda activate 7net-cueq -./reproduce/init_7net.sh && source util/env.sh - -# Use a fixed batch size for structural optimization -python ../../scripts/mace_opt_batch.py --target_folder "../../data/perf_v2" \ - --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers 4 \ - --batch_size 2 --max_steps 3000 --filter1 UnitCellFilter \ - --filter2 UnitCellFilter --optimizer1 BFGSFusedLS --optimizer2 BFGS \ - --num_threads 2 --cueq true --use_ordered_files true --model sevennet -``` - -## License - -This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. - -### Third-party Dependencies - -This project includes dependencies with various licenses: -- **MACE**: MIT License (compatible) -- **FairChem**: MIT License (compatible) -- **SevenNet**: GPL v3 License (Note: GPL is a copyleft license) - -### License Compatibility Notice - -**Important**: This project can run completely without relying on SevenNet. -This project includes SevenNet as an optional dependency, which is licensed under GPL v3. -If you use SevenNet functionality, you should be aware of the GPL licensing requirements. -For commercial use or to avoid GPL restrictions, consider using only the MACE calculator -functionality. - -## Citation - -If you use this code in your research, please cite: - -```bibtex -@software{BOMLIP_CSP, - author = {Chengxi Zhao, Zhaojia Ma, Dingrui Fan}, - title = {BOMLIP_CSP: Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction}, - year = {2025}, - url = {https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP} -} +# BOMLIP-CSP + +An open-source Python framework that integrates machine learning interatomic +potentials (MLIPs) with a tailored batched optimization strategy, enabling rapid, +unbiased structure prediction across the full density range + +## Perform a complete CSP process + +```sh +git clone https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP.git --recursive && cd BOMLIP-CSP + +conda create -n BOMLIP_CSP python=3.10 -y && conda activate BOMLIP_CSP +cd BOMLIP-CSP/mace-bench +./reproduce/init_mace.sh && source util/env.sh +sudo ./util/mps_start.sh + +cd .. +./csp.sh + +sudo ./util/mps_clean.sh +``` + +## Perform conformer search / structure generation / structure optimization separately + +In csp.sh, the argument --mode controls the jobs to do. +Use conformer_only to perform conformer search task only. +```sh +python "${TOP_DIR}/main.py" --path ${TAR_DIR} --smiles "OC(=O)c1cc(O)c(O)c(O)c1.O" \ + --molecule_num_in_cell 1,1 --space_group_list 13,14 --add_name KONTIQ --max_workers 16\ + --num_generation 100 --generate_conformers 20 --use_conformers 4 --mode conformer_only > generate.log 2>&1 +``` +Or use structure_only to perform structure generation only. +```sh +python "${TOP_DIR}/main.py" --path ${TAR_DIR} --smiles "OC(=O)c1cc(O)c(O)c(O)c1.O" \ + --molecule_num_in_cell 1,1 --space_group_list 13,14 --add_name KONTIQ --max_workers 16\ + --num_generation 100 --generate_conformers 20 --use_conformers 4 --mode structure_only > generate.log 2>&1 +``` +Structure optimization is done by a seperate command +```sh +python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" ... +``` +Change this command into a comment if you don't want to do that. + +## Reproduce mace batch opt speedup. + +```sh +#!/bin/bash + +git clone https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP.git --recursive && cd BOMLIP-CSP +conda create -n BOMLIP_CSP python=3.10 -y && conda activate BOMLIP_CSP +cd BOMLIP-CSP/mace-bench + +# initialize mace env. +./reproduce/init_mace.sh && source util/env.sh +sudo ./util/mps_start.sh +cd reproduce + +# run baseline sub-test +./subtest_baseline.sh + +# run baseline mixed test +cd perf_v2_base +./run_mace.sh + +# run BOMLIP_CSP sub-test +cd ../ +./subtest.sh + +# run BOMLIP_CSP mixed test +cd perf_v2_batch +./opt.sh + +# clean mps +./util/mps_clean.sh + +``` + +## If you want to configure the 7net environment. + +```sh +#!/bin/bash +conda create -n 7net-cueq python=3.10 -y && conda activate 7net-cueq +./reproduce/init_7net.sh && source util/env.sh + +# Use a fixed batch size for structural optimization +python ../../scripts/mace_opt_batch.py --target_folder "../../data/perf_v2" \ + --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers 4 \ + --batch_size 2 --max_steps 3000 --filter1 UnitCellFilter \ + --filter2 UnitCellFilter --optimizer1 BFGSFusedLS --optimizer2 BFGS \ + --num_threads 2 --cueq true --use_ordered_files true --model sevennet +``` + +## License + +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. + +### Third-party Dependencies + +This project includes dependencies with various licenses: +- **MACE**: MIT License (compatible) +- **FairChem**: MIT License (compatible) +- **SevenNet**: GPL v3 License (Note: GPL is a copyleft license) + +### License Compatibility Notice + +**Important**: This project can run completely without relying on SevenNet. +This project includes SevenNet as an optional dependency, which is licensed under GPL v3. +If you use SevenNet functionality, you should be aware of the GPL licensing requirements. +For commercial use or to avoid GPL restrictions, consider using only the MACE calculator +functionality. + +## Citation + +If you use this code in your research, please cite: + +```bibtex +@software{BOMLIP_CSP, + author = {Chengxi Zhao, Zhaojia Ma, Dingrui Fan}, + title = {BOMLIP_CSP: Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction}, + year = {2025}, + url = {https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP} +} ``` \ No newline at end of file diff --git a/basic_function/CSP_generator_normal.py b/basic_function/CSP_generator_normal.py index 146cc70b3e684443b6b21cc8ccddd015955d3a3b..3c70e189405701bab11ee0ecb0da8e5c95590adc 100644 --- a/basic_function/CSP_generator_normal.py +++ b/basic_function/CSP_generator_normal.py @@ -1,615 +1,615 @@ -""" -This module provides the CrystalGenerator class for crystal structure prediction (CSP). - -It uses a Sobol sequence-based random search to generate candidate crystal -structures for a given set of molecules and space group, followed by a crude -packing minimization. -""" - -# Standard library imports -import itertools -from typing import List, Tuple, Optional, Any - -# Third-party imports -import numpy as np -from scipy.spatial import cKDTree -from scipy.stats import qmc - -# Local application/library specific imports -from basic_function import chemical_knowledge -from basic_function import operation -from basic_function import data_classes - -# Module-level constants for better readability and maintenance -_VDW_CLASH_FACTOR = 0.9 # Scaling factor for van der Waals radii in collision checks -_SUPERCELL_RANGE = np.arange(-2, 3) # Range for generating supercell translations - - -class CrystalGenerator: - """ - Generates candidate crystal structures for Crystal Structure Prediction (CSP). - - The generator takes a list of unique molecules and a space group, then searches - the conformational space of cell parameters and molecular orientations to - produce tightly packed, sterically plausible crystal structures. - """ - - def __init__(self, - molecules: list[data_classes.Molecule], - space_group: int = 1, - angles: tuple[float, float] = (45.0, 135.0)): - """ - Initializes the CrystalGenerator. - - Args: - molecules: A list of molecule objects (from data_classes) that will form - the asymmetric unit. - space_group: The international space group number (e.g., 1 for P1). - angles: A tuple (min, max) defining the range for sampling cell angles in degrees. - """ - if not (0 < space_group <= 230): - raise ValueError("Space group must be an integer between 1 and 230.") - - self.molecules = molecules - self.space_group_number = space_group - self.angle_sampling_range = angles - - # Derived properties from the space group - self.symmetry_ops = chemical_knowledge.space_group[self.space_group_number][0] - self.point_group = chemical_knowledge.space_group[self.space_group_number][2] - - # Calculate counts and dimensions - self.num_asym_molecules = len(self.molecules) - self.num_total_molecules = len(self.symmetry_ops) * self.num_asym_molecules - self.atomic_counts_per_molecule = self._calculate_atomic_counts() - - # Determine search space dimensionality - self.search_dimensions, self.search_dimension_shape = self._determine_search_dimensions() - - # Pre-calculate molecular and crystal properties - self.max_vdw_radius = self._find_max_vdw_radius() - self.estimated_packed_volume = self._calculate_estimated_packed_volume() - self._orient_molecules() - - # Pre-generate supercell translation vectors, sorted by distance from origin - self.supercell_frac_translations = np.array( - sorted(list(itertools.product(_SUPERCELL_RANGE, repeat=3)), - key=lambda p: p[0]**2 + p[1]**2 + p[2]**2) - ) - - def _calculate_atomic_counts(self) -> list[int]: - """Calculates the number of atoms for each molecule in the asymmetric unit.""" - return [len(mol.atoms) for mol in self.molecules] - - def _orient_molecules(self) -> None: - """ - Orients each molecule to a standardized principal axis frame. - This reduces the rotational search space. For details, see: http://sobereva.com/426 - """ - for i, molecule in enumerate(self.molecules): - if len(molecule.atoms) > 1: - self.molecules[i] = operation.orient_molecule(molecule) - - def _find_max_vdw_radius(self) -> float: - """Finds the maximum van der Waals radius among all atoms in all molecules.""" - vdw_max = 0.0 - for molecule in self.molecules: - elements, _ = molecule.get_ele_and_cart() - for ele in set(elements): - vdw_max = max(vdw_max, chemical_knowledge.element_vdw_radii[ele]) - return vdw_max - - def _determine_search_dimensions(self) -> tuple[int, list[int]]: - """ - Determines the dimensionality of the search space. - - The search space consists of: - - 3 dimensions for cell angles (alpha, beta, gamma) - - 3 dimensions for cell lengths (a, b, c) - - 3 * N dimensions for molecular translations (x, y, z for each of N molecules) - - 3 * N dimensions for molecular rotations (Euler angles for each of N molecules) - - Returns: - A tuple containing the total dimension count and a list detailing the - breakdown of dimensions. - """ - dim_cell_lengths = 3 - dim_cell_angles = 3 - dim_translations = 3 * self.num_asym_molecules - dim_rotations = 3 * self.num_asym_molecules - total_dimension = dim_cell_lengths + dim_cell_angles + dim_translations + dim_rotations - shape = [dim_cell_lengths, dim_cell_angles, dim_translations, dim_rotations] - return total_dimension, shape - - def _calculate_estimated_packed_volume(self) -> float: - """ - Estimates the total volume of all molecules in the unit cell based on their - van der Waals radii. This is used for heuristics during generation. - """ - total_volume = 0.0 - for molecule in self.molecules: - elements, _ = molecule.get_ele_and_cart() - vdws = np.array([chemical_knowledge.element_vdw_radii[x] for x in elements]) - volumes = (4 / 3) * np.pi * vdws**3 - total_volume += np.sum(volumes) - return total_volume * len(self.symmetry_ops) # Multiply by Z - - def _map_random_to_angle(self, value: float) -> float: - """ - Maps a random number from [0, 1] to an angle in the specified range. - - This uses an arcsin distribution to more densely sample angles near the - midpoint of the range, which can be more efficient if orthogonal angles - are more likely. - """ - min_angle, max_angle = self.angle_sampling_range - angle_range = max_angle - min_angle - # A non-linear mapping to bias sampling - a = np.arcsin(2 * value - 1.0) / np.pi - return (0.5 + a) * angle_range + min_angle - - def _get_cell_angles_from_vector(self, vector: np.ndarray) -> tuple[float, float, float]: - """ - Determines the three cell angles based on a 3D random vector, respecting - the constraints of the crystal's point group. - """ - angle_candidates = [self._map_random_to_angle(v) for v in vector] - - if self.point_group == "Triclinic": - return angle_candidates[0], angle_candidates[1], angle_candidates[2] - if self.point_group == "Monoclinic": - return 90.0, angle_candidates[1], 90.0 - if self.point_group in ["Orthorhombic", "Tetragonal", "Cubic"]: - return 90.0, 90.0, 90.0 - if self.point_group == "Hexagonal": - return 90.0, 90.0, 120.0 - if self.point_group == "Trigonal": - # For rhombohedral lattices described in hexagonal axes, angles are fixed. - # This assumes a rhombohedral setting where angles are variable and equal. - return angle_candidates[0], angle_candidates[0], angle_candidates[0] - # Fallback for safety, though should be covered by above cases - return 90.0, 90.0, 90.0 - - - def _get_cell_lengths_from_vector(self, - vector: np.ndarray, - cell_angles: list[float], - rotated_molecules_cart: list[np.ndarray] - ) -> tuple[float, float, float]: - """ - Determines the three cell lengths based on a 3D random vector and molecule size. - - The method first calculates the minimum bounding box for the rotated molecules, - then scales the lengths based on the random vector to explore larger volumes. - """ - # Estimate minimum cell lengths to avoid self-collision within a molecule - min_lengths = np.zeros(3) - conversion_matrix = operation.c2f_matrix([[1, 1, 1], cell_angles]) - for cart_coords in rotated_molecules_cart: - frac_coords = cart_coords @ conversion_matrix - max_vals = np.max(frac_coords, axis=0) - min_vals = np.min(frac_coords, axis=0) - min_lengths = np.maximum(min_lengths, max_vals - min_vals) - - # Add a buffer based on the largest VdW radius - min_lengths += self.max_vdw_radius * 2 - - # Scale the lengths using the random vector to explore the search space - a = min_lengths[0] + vector[0] * (self.num_total_molecules * min_lengths[0]) - b = min_lengths[1] + vector[1] * (self.num_total_molecules * min_lengths[1]) - c = min_lengths[2] + vector[2] * (self.num_total_molecules * min_lengths[2]) - - # Apply constraints based on the point group - if self.point_group in ["Tetragonal", "Hexagonal"]: - return a, a, c - if self.point_group in ["Trigonal", "Cubic"]: - return a, a, a - return a, b, c - - def _check_for_collisions(self, - atom_elements: np.ndarray, - atom_cart_coords: np.ndarray - ) -> bool: - """ - Performs a steric clash test for the generated structure. - - It checks for intermolecular distances that are smaller than the sum of - the van der Waals radii (with a tolerance factor). - - Args: - atom_elements: A numpy array of element symbols for all atoms in the supercell. - atom_cart_coords: A numpy array of Cartesian coordinates for all atoms. - - Returns: - True if a collision is detected, False otherwise. - """ - vdw_radii = np.array([chemical_knowledge.element_vdw_radii[el.item()] for el in atom_elements]) - - start_index = 0 - for i in range(self.num_asym_molecules): - # Define the asymmetric unit molecule to check against its environment - num_atoms_in_mol = self.atomic_counts_per_molecule[i] - end_index = start_index + num_atoms_in_mol - - asym_mol_coords = atom_cart_coords[start_index:end_index] - asym_mol_vdws = vdw_radii[start_index:end_index] - - # The rest of the atoms form the environment - neighbor_coords = atom_cart_coords[end_index:] - neighbor_vdws = vdw_radii[end_index:] - - # A coarse filter using a bounding box around the asymmetric molecule - mol_min = np.min(asym_mol_coords, axis=0) - self.max_vdw_radius * 2 - mol_max = np.max(asym_mol_coords, axis=0) + self.max_vdw_radius * 2 - box_indices = np.all((neighbor_coords > mol_min) & (neighbor_coords < mol_max), axis=1) - - if not np.any(box_indices): - # Move to the next molecule in the asymmetric unit - num_atoms_in_supercell_mol = num_atoms_in_mol * len(self.supercell_frac_translations) * len(self.symmetry_ops) - start_index += num_atoms_in_supercell_mol - continue - - nearby_neighbor_coords = neighbor_coords[box_indices] - nearby_neighbor_vdws = neighbor_vdws[box_indices] - - # Use KD-Trees for efficient nearest-neighbor search - tree_asym = cKDTree(asym_mol_coords, compact_nodes=False, balanced_tree=False) - tree_neighbors = cKDTree(nearby_neighbor_coords, compact_nodes=False, balanced_tree=False) - - # Find all pairs of atoms within the maximum possible interaction distance - possible_contacts = tree_asym.query_ball_tree(tree_neighbors, self.max_vdw_radius * 2) - - for j, neighbor_indices in enumerate(possible_contacts): - if not neighbor_indices: - continue - - # Check precise distances for potential contacts - diff = asym_mol_coords[j] - nearby_neighbor_coords[neighbor_indices] - # einsum is a fast way to compute squared norms row-wise - distances = np.sqrt(np.einsum('ij,ij->i', diff, diff)) - - sum_radii = (asym_mol_vdws[j] + nearby_neighbor_vdws[neighbor_indices]) * _VDW_CLASH_FACTOR - - if np.any(distances < sum_radii): - return True # Collision detected - - # Update start index for the next asymmetric molecule - num_atoms_in_supercell_mol = num_atoms_in_mol * len(self.supercell_frac_translations) * len(self.symmetry_ops) - start_index += num_atoms_in_supercell_mol - - return False # No collisions found - - - def _shrink_cell_dimensions(self, a: float, b: float, c: float, locked_dims: list[bool] - ) -> tuple[float, float, float, list[int]]: - """ - Shrinks the crystal cell along the longest unlocked dimension by 1 Angstrom. - This is a crude optimization step to pack the molecules more tightly. - - Args: - a, b, c: Current cell lengths. - locked_dims: A boolean list [a, b, c] where True means the dimension - cannot be shrunk further. - - Returns: - A tuple of (new_a, new_b, new_c, last_change_indices). - """ - lengths = [val for val, is_locked in zip([a, b, c], locked_dims) if not is_locked] - if not lengths: - return a, b, c, [] # All dimensions are locked - - max_length = max(lengths) - last_change = [] - - # Logic to shrink the largest dimension(s) while respecting point group constraints - if self.point_group in ["Triclinic", "Monoclinic", "Orthorhombic"]: - if a == max_length and not locked_dims[0]: - a -= 1.0 - last_change = [0] - elif b == max_length and not locked_dims[1]: - b -= 1.0 - last_change = [1] - elif c == max_length and not locked_dims[2]: - c -= 1.0 - last_change = [2] - elif self.point_group in ["Tetragonal", "Hexagonal"]: - if (a == max_length or b == max_length) and not locked_dims[0]: - a -= 1.0 - b -= 1.0 - last_change = [0, 1] - elif c == max_length and not locked_dims[2]: - c -= 1.0 - last_change = [2] - elif self.point_group in ["Trigonal", "Cubic"]: - if (a == max_length or b == max_length or c == max_length) and not locked_dims[0]: - a -= 1.0 - b -= 1.0 - c -= 1.0 - last_change = [0, 1, 2] - - return a, b, c, last_change - - def _setup_crystal_from_vector(self, vector: np.ndarray - ) -> tuple[Optional[list], Optional[list[np.ndarray]], Optional[list[Any]]]: - """ - Performs the initial setup of a crystal structure from a random vector. - This includes setting angles, rotating molecules, and setting initial lengths. - This helper is used by both `generate` and `_generate_from_vector`. - """ - # Unpack the Sobol vector into its components for cell parameters and molecules - # Slicing indices based on the defined search space shape - s = self.search_dimension_shape - cell_angle_seed = vector[0:s[1]] - cell_length_seed = vector[s[1]:s[1]+s[0]] - move_part_seed = vector[s[1]+s[0] : s[1]+s[0]+s[2]] - rotate_part_seed = vector[s[1]+s[0]+s[2]:] - - # 1. Set cell angles - alpha, beta, gamma = self._get_cell_angles_from_vector(cell_angle_seed) - cell_angles = [alpha, beta, gamma] - - # Check for valid cell matrix from angles - ca, cb, cg = np.cos(np.deg2rad([alpha, beta, gamma])) - volume_sqrt_term = 1 - ca**2 - cb**2 - cg**2 + 2 * ca * cb * cg - if volume_sqrt_term <= 0: - print("Failed: Invalid angles cannot form a valid parallelepiped.") - return None, None, None - - # 2. Rotate molecules - rotated_molecules_cart = [] - rotated_molecules_ele = [] - rotate_vectors = rotate_part_seed.reshape(-1, 3) - for r_vec, molecule in zip(rotate_vectors, self.molecules): - elements, cart_coords = molecule.get_ele_and_cart() - rotation_matrix = operation.get_rotate_matrix(r_vec) - rotated_cart = cart_coords @ rotation_matrix - rotated_molecules_cart.append(rotated_cart) - rotated_molecules_ele.append(elements) - - # 3. Set initial cell lengths - a, b, c = self._get_cell_lengths_from_vector(cell_length_seed, cell_angles, rotated_molecules_cart) - cell_lengths = [a, b, c] - - crystal_params = [cell_lengths, cell_angles, move_part_seed, rotated_molecules_cart, rotated_molecules_ele] - - return crystal_params, volume_sqrt_term, rotate_part_seed - - def _build_supercell_for_clash_test(self, - cell_params: list, - rotated_molecules_cart: list[np.ndarray], - rotated_molecules_ele: list[list[str]], - move_part_seed: np.ndarray - ) -> tuple[np.ndarray, np.ndarray, list, list]: - """ - Builds a supercell and returns all atomic elements and coordinates for clash testing. - This version correctly handles asymmetric units with multiple, different-sized molecules. - """ - f2c_matrix = operation.f2c_matrix(cell_params) - c2f_matrix = operation.c2f_matrix(cell_params) - supercell_cart_translations = self.supercell_frac_translations @ f2c_matrix - - all_asym_frac_coords = [] - all_asym_elements = [] - - # Use lists to collect 2D blocks of coordinates and elements. This is efficient. - sc_cart_blocks = [] - sc_ele_blocks = [] - - for i, cart_coords in enumerate(rotated_molecules_cart): - # Apply translation vector to this molecule's fractional coordinates - trans_vector = move_part_seed[i * 3:(i + 1) * 3] - frac_coords = cart_coords @ c2f_matrix + trans_vector - - all_asym_frac_coords.append(frac_coords) - all_asym_elements.append(rotated_molecules_ele[i]) - - # Apply symmetry operations - symm_cart_coords = operation.apply_SYMM(frac_coords, self.symmetry_ops) @ f2c_matrix - symm_elements_list = [rotated_molecules_ele[i]] * len(self.symmetry_ops) - - # Center molecules that were moved across periodic boundaries - centroid_frac = np.mean(frac_coords, axis=0) - centroids_all_symm = operation.apply_SYMM(centroid_frac, self.symmetry_ops) - for j, cent in enumerate(centroids_all_symm): - move_to_center = (np.mod(cent, 1) - cent) @ f2c_matrix - symm_cart_coords[j] += move_to_center - - # --- Core Correction Logic --- - # 1. Create the full block of atoms for the current molecule type by applying all - # supercell translations. - mol_block_cart_temp = [] - for translation_vec in supercell_cart_translations: - # Adding the translation vector to all symmetry-equivalent molecules - translated_coords = symm_cart_coords + translation_vec - # Reshape to a flat (N_atoms * N_symm, 3) 2D array and append - mol_block_cart_temp.append(translated_coords.reshape(-1, 3)) - - # 2. Stack all translated blocks for this molecule type into a single 2D array - sc_cart_blocks.append(np.vstack(mol_block_cart_temp)) - - # 3. Handle the corresponding elements, ensuring they are flattened correctly - num_translations = len(self.supercell_frac_translations) - ele_block = np.array(symm_elements_list * num_translations).reshape(-1, 1) - sc_ele_blocks.append(ele_block) - - # After iterating through all molecule types, stack their respective complete blocks - final_sc_cart = np.vstack(sc_cart_blocks) - final_sc_ele = np.vstack(sc_ele_blocks) - - return final_sc_cart, final_sc_ele, all_asym_frac_coords, all_asym_elements - - def _create_final_crystal_object(self, - cell_params: list, - asym_frac_coords: list, - asym_elements: list, - seed: Any - ) -> data_classes.Crystal: - """Creates the final Crystal object from the successful structure.""" - - flat_elements = np.concatenate(asym_elements, axis=0).reshape(-1, 1) - flat_frac_coords = np.concatenate(asym_frac_coords, axis=0).reshape(-1, 3) - - atoms = [] - for ele, frac in zip(flat_elements, flat_frac_coords): - atoms.append(data_classes.Atom(element=ele.item(), frac_xyz=frac)) - - return data_classes.Crystal( - cell_para=cell_params, - atoms=atoms, - comment=str(seed), - system_name=str(seed), - space_group=self.space_group_number, - SYMM=self.symmetry_ops - ) - - def generate(self, - seed: Any = "unknown", - test: bool = False, - densely_pack_method: bool = False, - frame_tolerance: float = 1.5 - ) -> Optional[data_classes.Crystal]: - """ - The main generation method. - - Uses a Sobol sequence to get a random vector, then attempts to build and - pack a crystal structure through an iterative shrinking process. - - Args: - seed: A seed for the Sobol sequence generator. If "unknown", an error is raised. - test: A flag for enabling verbose test-mode output (prints cycle number). - densely_pack_method: If True, applies a heuristic to shrink very large - initial volumes. - frame_tolerance: Tolerance for checking if the final structure is a 2D slab. - - Returns: - A `data_classes.Crystal` object if a valid structure is found, otherwise `None`. - """ - if seed == "unknown": - raise ValueError("A seed must be provided for the Sobol generator.") - - sobol_gen = qmc.Sobol(d=self.search_dimensions, seed=seed) - initial_vector = sobol_gen.random(n=1).flatten() - - setup_result, volume_sqrt_term, _ = self._setup_crystal_from_vector(initial_vector) - if setup_result is None: - return None # Invalid initial angles - - cell_lengths, cell_angles, move_part_seed, rot_carts, rot_eles = setup_result - a, b, c = cell_lengths - alpha, beta, gamma = cell_angles - - # Heuristic to shrink extremely sparse initial structures - if densely_pack_method: - crystal_volume = a * b * c * np.sqrt(volume_sqrt_term) - if crystal_volume > self.estimated_packed_volume * 20: - c = self.estimated_packed_volume * 20 / (a * b * np.sqrt(volume_sqrt_term)) - - locked_dims = [False, False, False] - old_a, old_b, old_c = a, b, c - - for cycle_no in range(1001): - if cycle_no == 1001: - print(f"Stopping: Max optimization cycles reached. Seed: {seed}") - return None - - if a < 0 or b < 0 or c < 0: - print(f"BUG: Negative cell dimension. sg={self.space_group_number}, seed={seed}") - return None - - if test: - print(f"Cycle: {cycle_no}") - - cell_params = [[a, b, c], [alpha, beta, gamma]] - - sc_cart, sc_ele, asym_fracs, asym_eles = self._build_supercell_for_clash_test( - cell_params, rot_carts, rot_eles, move_part_seed - ) - - has_collision = self._check_for_collisions(sc_ele, sc_cart) - - if has_collision: - if cycle_no == 0: - print(f"Failed: Initial structure has collisions. Seed: {seed}") - return None - - # Collision occurred, so revert to last good state and lock the changed dimension - a, b, c = old_a, old_b, old_c - for dim_idx in last_change: - locked_dims[dim_idx] = True - else: - # No collision, this is a valid (though maybe not dense) structure. - # Check if optimization is finished (all dimensions are locked). - if cycle_no > 0 and all(locked_dims): - final_crystal = self._create_final_crystal_object(cell_params, asym_fracs, asym_eles, seed) - - # Final check to filter out 2D slab-like structures - if not operation.detect_is_frame_vdw_new(final_crystal, tolerance=frame_tolerance): - print(f"Failed: Generated structure is a 2D slab. Seed: {seed}") - return None - - print(f"Success: Generated a valid crystal structure. Seed: {seed}") - return final_crystal - - # If no collision and not finished, save current state and shrink further - old_a, old_b, old_c = a, b, c - a, b, c, last_change = self._shrink_cell_dimensions(a, b, c, locked_dims) - - # ============================================================================== - # Test-related functions, kept for compatibility, marked as internal. - # ============================================================================== - - def _generate_from_vector(self, - seed_vector: np.ndarray, - frame_tolerance: float = 1.5 - ) -> Optional[data_classes.Crystal]: - """ - Generates a single crystal structure directly from a vector, without optimization. - This is an internal method intended for testing and analysis. - Original name: generate_by_vector_2. - - Args: - seed_vector: A numpy array of shape (self.search_dimensions,) defining the structure. - frame_tolerance: Tolerance for checking if the final structure is a 2D slab. - - Returns: - A `data_classes.Crystal` object if valid, otherwise `None`. - """ - if not isinstance(seed_vector, np.ndarray): - raise TypeError("seed_vector must be a numpy array.") - - expected_len = self.search_dimensions - if len(seed_vector) != expected_len: - raise ValueError(f"Length of seed_vector must be {expected_len}, got {len(seed_vector)}.") - - setup_result, _, _ = self._setup_crystal_from_vector(seed_vector) - if setup_result is None: - return None # Invalid initial angles - - cell_lengths, cell_angles, move_part_seed, rot_carts, rot_eles = setup_result - cell_params = [cell_lengths, cell_angles] - - sc_cart, sc_ele, asym_fracs, asym_eles = self._build_supercell_for_clash_test( - cell_params, rot_carts, rot_eles, move_part_seed - ) - - if self._check_for_collisions(sc_ele, sc_cart): - print("Failed: Structure from vector has collisions.") - return None - - generated_crystal = self._create_final_crystal_object( - cell_params, asym_fracs, asym_eles, seed="from_vector" - ) - - # Optional: Keep the slab check for consistency - # if not operation.detect_is_frame_vdw_new(generated_crystal, tolerance=frame_tolerance): - # print("Failed: Generated structure is a 2D slab.") - # return None - - return generated_crystal - - def _is_valid_vector(self, seed_vector: np.ndarray) -> bool: - """ - Checks if a given vector produces a valid, collision-free structure. - Internal method for testing. - """ +""" +This module provides the CrystalGenerator class for crystal structure prediction (CSP). + +It uses a Sobol sequence-based random search to generate candidate crystal +structures for a given set of molecules and space group, followed by a crude +packing minimization. +""" + +# Standard library imports +import itertools +from typing import List, Tuple, Optional, Any + +# Third-party imports +import numpy as np +from scipy.spatial import cKDTree +from scipy.stats import qmc + +# Local application/library specific imports +from basic_function import chemical_knowledge +from basic_function import operation +from basic_function import data_classes + +# Module-level constants for better readability and maintenance +_VDW_CLASH_FACTOR = 0.9 # Scaling factor for van der Waals radii in collision checks +_SUPERCELL_RANGE = np.arange(-2, 3) # Range for generating supercell translations + + +class CrystalGenerator: + """ + Generates candidate crystal structures for Crystal Structure Prediction (CSP). + + The generator takes a list of unique molecules and a space group, then searches + the conformational space of cell parameters and molecular orientations to + produce tightly packed, sterically plausible crystal structures. + """ + + def __init__(self, + molecules: list[data_classes.Molecule], + space_group: int = 1, + angles: tuple[float, float] = (45.0, 135.0)): + """ + Initializes the CrystalGenerator. + + Args: + molecules: A list of molecule objects (from data_classes) that will form + the asymmetric unit. + space_group: The international space group number (e.g., 1 for P1). + angles: A tuple (min, max) defining the range for sampling cell angles in degrees. + """ + if not (0 < space_group <= 230): + raise ValueError("Space group must be an integer between 1 and 230.") + + self.molecules = molecules + self.space_group_number = space_group + self.angle_sampling_range = angles + + # Derived properties from the space group + self.symmetry_ops = chemical_knowledge.space_group[self.space_group_number][0] + self.point_group = chemical_knowledge.space_group[self.space_group_number][2] + + # Calculate counts and dimensions + self.num_asym_molecules = len(self.molecules) + self.num_total_molecules = len(self.symmetry_ops) * self.num_asym_molecules + self.atomic_counts_per_molecule = self._calculate_atomic_counts() + + # Determine search space dimensionality + self.search_dimensions, self.search_dimension_shape = self._determine_search_dimensions() + + # Pre-calculate molecular and crystal properties + self.max_vdw_radius = self._find_max_vdw_radius() + self.estimated_packed_volume = self._calculate_estimated_packed_volume() + self._orient_molecules() + + # Pre-generate supercell translation vectors, sorted by distance from origin + self.supercell_frac_translations = np.array( + sorted(list(itertools.product(_SUPERCELL_RANGE, repeat=3)), + key=lambda p: p[0]**2 + p[1]**2 + p[2]**2) + ) + + def _calculate_atomic_counts(self) -> list[int]: + """Calculates the number of atoms for each molecule in the asymmetric unit.""" + return [len(mol.atoms) for mol in self.molecules] + + def _orient_molecules(self) -> None: + """ + Orients each molecule to a standardized principal axis frame. + This reduces the rotational search space. For details, see: http://sobereva.com/426 + """ + for i, molecule in enumerate(self.molecules): + if len(molecule.atoms) > 1: + self.molecules[i] = operation.orient_molecule(molecule) + + def _find_max_vdw_radius(self) -> float: + """Finds the maximum van der Waals radius among all atoms in all molecules.""" + vdw_max = 0.0 + for molecule in self.molecules: + elements, _ = molecule.get_ele_and_cart() + for ele in set(elements): + vdw_max = max(vdw_max, chemical_knowledge.element_vdw_radii[ele]) + return vdw_max + + def _determine_search_dimensions(self) -> tuple[int, list[int]]: + """ + Determines the dimensionality of the search space. + + The search space consists of: + - 3 dimensions for cell angles (alpha, beta, gamma) + - 3 dimensions for cell lengths (a, b, c) + - 3 * N dimensions for molecular translations (x, y, z for each of N molecules) + - 3 * N dimensions for molecular rotations (Euler angles for each of N molecules) + + Returns: + A tuple containing the total dimension count and a list detailing the + breakdown of dimensions. + """ + dim_cell_lengths = 3 + dim_cell_angles = 3 + dim_translations = 3 * self.num_asym_molecules + dim_rotations = 3 * self.num_asym_molecules + total_dimension = dim_cell_lengths + dim_cell_angles + dim_translations + dim_rotations + shape = [dim_cell_lengths, dim_cell_angles, dim_translations, dim_rotations] + return total_dimension, shape + + def _calculate_estimated_packed_volume(self) -> float: + """ + Estimates the total volume of all molecules in the unit cell based on their + van der Waals radii. This is used for heuristics during generation. + """ + total_volume = 0.0 + for molecule in self.molecules: + elements, _ = molecule.get_ele_and_cart() + vdws = np.array([chemical_knowledge.element_vdw_radii[x] for x in elements]) + volumes = (4 / 3) * np.pi * vdws**3 + total_volume += np.sum(volumes) + return total_volume * len(self.symmetry_ops) # Multiply by Z + + def _map_random_to_angle(self, value: float) -> float: + """ + Maps a random number from [0, 1] to an angle in the specified range. + + This uses an arcsin distribution to more densely sample angles near the + midpoint of the range, which can be more efficient if orthogonal angles + are more likely. + """ + min_angle, max_angle = self.angle_sampling_range + angle_range = max_angle - min_angle + # A non-linear mapping to bias sampling + a = np.arcsin(2 * value - 1.0) / np.pi + return (0.5 + a) * angle_range + min_angle + + def _get_cell_angles_from_vector(self, vector: np.ndarray) -> tuple[float, float, float]: + """ + Determines the three cell angles based on a 3D random vector, respecting + the constraints of the crystal's point group. + """ + angle_candidates = [self._map_random_to_angle(v) for v in vector] + + if self.point_group == "Triclinic": + return angle_candidates[0], angle_candidates[1], angle_candidates[2] + if self.point_group == "Monoclinic": + return 90.0, angle_candidates[1], 90.0 + if self.point_group in ["Orthorhombic", "Tetragonal", "Cubic"]: + return 90.0, 90.0, 90.0 + if self.point_group == "Hexagonal": + return 90.0, 90.0, 120.0 + if self.point_group == "Trigonal": + # For rhombohedral lattices described in hexagonal axes, angles are fixed. + # This assumes a rhombohedral setting where angles are variable and equal. + return angle_candidates[0], angle_candidates[0], angle_candidates[0] + # Fallback for safety, though should be covered by above cases + return 90.0, 90.0, 90.0 + + + def _get_cell_lengths_from_vector(self, + vector: np.ndarray, + cell_angles: list[float], + rotated_molecules_cart: list[np.ndarray] + ) -> tuple[float, float, float]: + """ + Determines the three cell lengths based on a 3D random vector and molecule size. + + The method first calculates the minimum bounding box for the rotated molecules, + then scales the lengths based on the random vector to explore larger volumes. + """ + # Estimate minimum cell lengths to avoid self-collision within a molecule + min_lengths = np.zeros(3) + conversion_matrix = operation.c2f_matrix([[1, 1, 1], cell_angles]) + for cart_coords in rotated_molecules_cart: + frac_coords = cart_coords @ conversion_matrix + max_vals = np.max(frac_coords, axis=0) + min_vals = np.min(frac_coords, axis=0) + min_lengths = np.maximum(min_lengths, max_vals - min_vals) + + # Add a buffer based on the largest VdW radius + min_lengths += self.max_vdw_radius * 2 + + # Scale the lengths using the random vector to explore the search space + a = min_lengths[0] + vector[0] * (self.num_total_molecules * min_lengths[0]) + b = min_lengths[1] + vector[1] * (self.num_total_molecules * min_lengths[1]) + c = min_lengths[2] + vector[2] * (self.num_total_molecules * min_lengths[2]) + + # Apply constraints based on the point group + if self.point_group in ["Tetragonal", "Hexagonal"]: + return a, a, c + if self.point_group in ["Trigonal", "Cubic"]: + return a, a, a + return a, b, c + + def _check_for_collisions(self, + atom_elements: np.ndarray, + atom_cart_coords: np.ndarray + ) -> bool: + """ + Performs a steric clash test for the generated structure. + + It checks for intermolecular distances that are smaller than the sum of + the van der Waals radii (with a tolerance factor). + + Args: + atom_elements: A numpy array of element symbols for all atoms in the supercell. + atom_cart_coords: A numpy array of Cartesian coordinates for all atoms. + + Returns: + True if a collision is detected, False otherwise. + """ + vdw_radii = np.array([chemical_knowledge.element_vdw_radii[el.item()] for el in atom_elements]) + + start_index = 0 + for i in range(self.num_asym_molecules): + # Define the asymmetric unit molecule to check against its environment + num_atoms_in_mol = self.atomic_counts_per_molecule[i] + end_index = start_index + num_atoms_in_mol + + asym_mol_coords = atom_cart_coords[start_index:end_index] + asym_mol_vdws = vdw_radii[start_index:end_index] + + # The rest of the atoms form the environment + neighbor_coords = atom_cart_coords[end_index:] + neighbor_vdws = vdw_radii[end_index:] + + # A coarse filter using a bounding box around the asymmetric molecule + mol_min = np.min(asym_mol_coords, axis=0) - self.max_vdw_radius * 2 + mol_max = np.max(asym_mol_coords, axis=0) + self.max_vdw_radius * 2 + box_indices = np.all((neighbor_coords > mol_min) & (neighbor_coords < mol_max), axis=1) + + if not np.any(box_indices): + # Move to the next molecule in the asymmetric unit + num_atoms_in_supercell_mol = num_atoms_in_mol * len(self.supercell_frac_translations) * len(self.symmetry_ops) + start_index += num_atoms_in_supercell_mol + continue + + nearby_neighbor_coords = neighbor_coords[box_indices] + nearby_neighbor_vdws = neighbor_vdws[box_indices] + + # Use KD-Trees for efficient nearest-neighbor search + tree_asym = cKDTree(asym_mol_coords, compact_nodes=False, balanced_tree=False) + tree_neighbors = cKDTree(nearby_neighbor_coords, compact_nodes=False, balanced_tree=False) + + # Find all pairs of atoms within the maximum possible interaction distance + possible_contacts = tree_asym.query_ball_tree(tree_neighbors, self.max_vdw_radius * 2) + + for j, neighbor_indices in enumerate(possible_contacts): + if not neighbor_indices: + continue + + # Check precise distances for potential contacts + diff = asym_mol_coords[j] - nearby_neighbor_coords[neighbor_indices] + # einsum is a fast way to compute squared norms row-wise + distances = np.sqrt(np.einsum('ij,ij->i', diff, diff)) + + sum_radii = (asym_mol_vdws[j] + nearby_neighbor_vdws[neighbor_indices]) * _VDW_CLASH_FACTOR + + if np.any(distances < sum_radii): + return True # Collision detected + + # Update start index for the next asymmetric molecule + num_atoms_in_supercell_mol = num_atoms_in_mol * len(self.supercell_frac_translations) * len(self.symmetry_ops) + start_index += num_atoms_in_supercell_mol + + return False # No collisions found + + + def _shrink_cell_dimensions(self, a: float, b: float, c: float, locked_dims: list[bool] + ) -> tuple[float, float, float, list[int]]: + """ + Shrinks the crystal cell along the longest unlocked dimension by 1 Angstrom. + This is a crude optimization step to pack the molecules more tightly. + + Args: + a, b, c: Current cell lengths. + locked_dims: A boolean list [a, b, c] where True means the dimension + cannot be shrunk further. + + Returns: + A tuple of (new_a, new_b, new_c, last_change_indices). + """ + lengths = [val for val, is_locked in zip([a, b, c], locked_dims) if not is_locked] + if not lengths: + return a, b, c, [] # All dimensions are locked + + max_length = max(lengths) + last_change = [] + + # Logic to shrink the largest dimension(s) while respecting point group constraints + if self.point_group in ["Triclinic", "Monoclinic", "Orthorhombic"]: + if a == max_length and not locked_dims[0]: + a -= 1.0 + last_change = [0] + elif b == max_length and not locked_dims[1]: + b -= 1.0 + last_change = [1] + elif c == max_length and not locked_dims[2]: + c -= 1.0 + last_change = [2] + elif self.point_group in ["Tetragonal", "Hexagonal"]: + if (a == max_length or b == max_length) and not locked_dims[0]: + a -= 1.0 + b -= 1.0 + last_change = [0, 1] + elif c == max_length and not locked_dims[2]: + c -= 1.0 + last_change = [2] + elif self.point_group in ["Trigonal", "Cubic"]: + if (a == max_length or b == max_length or c == max_length) and not locked_dims[0]: + a -= 1.0 + b -= 1.0 + c -= 1.0 + last_change = [0, 1, 2] + + return a, b, c, last_change + + def _setup_crystal_from_vector(self, vector: np.ndarray + ) -> tuple[Optional[list], Optional[list[np.ndarray]], Optional[list[Any]]]: + """ + Performs the initial setup of a crystal structure from a random vector. + This includes setting angles, rotating molecules, and setting initial lengths. + This helper is used by both `generate` and `_generate_from_vector`. + """ + # Unpack the Sobol vector into its components for cell parameters and molecules + # Slicing indices based on the defined search space shape + s = self.search_dimension_shape + cell_angle_seed = vector[0:s[1]] + cell_length_seed = vector[s[1]:s[1]+s[0]] + move_part_seed = vector[s[1]+s[0] : s[1]+s[0]+s[2]] + rotate_part_seed = vector[s[1]+s[0]+s[2]:] + + # 1. Set cell angles + alpha, beta, gamma = self._get_cell_angles_from_vector(cell_angle_seed) + cell_angles = [alpha, beta, gamma] + + # Check for valid cell matrix from angles + ca, cb, cg = np.cos(np.deg2rad([alpha, beta, gamma])) + volume_sqrt_term = 1 - ca**2 - cb**2 - cg**2 + 2 * ca * cb * cg + if volume_sqrt_term <= 0: + print("Failed: Invalid angles cannot form a valid parallelepiped.") + return None, None, None + + # 2. Rotate molecules + rotated_molecules_cart = [] + rotated_molecules_ele = [] + rotate_vectors = rotate_part_seed.reshape(-1, 3) + for r_vec, molecule in zip(rotate_vectors, self.molecules): + elements, cart_coords = molecule.get_ele_and_cart() + rotation_matrix = operation.get_rotate_matrix(r_vec) + rotated_cart = cart_coords @ rotation_matrix + rotated_molecules_cart.append(rotated_cart) + rotated_molecules_ele.append(elements) + + # 3. Set initial cell lengths + a, b, c = self._get_cell_lengths_from_vector(cell_length_seed, cell_angles, rotated_molecules_cart) + cell_lengths = [a, b, c] + + crystal_params = [cell_lengths, cell_angles, move_part_seed, rotated_molecules_cart, rotated_molecules_ele] + + return crystal_params, volume_sqrt_term, rotate_part_seed + + def _build_supercell_for_clash_test(self, + cell_params: list, + rotated_molecules_cart: list[np.ndarray], + rotated_molecules_ele: list[list[str]], + move_part_seed: np.ndarray + ) -> tuple[np.ndarray, np.ndarray, list, list]: + """ + Builds a supercell and returns all atomic elements and coordinates for clash testing. + This version correctly handles asymmetric units with multiple, different-sized molecules. + """ + f2c_matrix = operation.f2c_matrix(cell_params) + c2f_matrix = operation.c2f_matrix(cell_params) + supercell_cart_translations = self.supercell_frac_translations @ f2c_matrix + + all_asym_frac_coords = [] + all_asym_elements = [] + + # Use lists to collect 2D blocks of coordinates and elements. This is efficient. + sc_cart_blocks = [] + sc_ele_blocks = [] + + for i, cart_coords in enumerate(rotated_molecules_cart): + # Apply translation vector to this molecule's fractional coordinates + trans_vector = move_part_seed[i * 3:(i + 1) * 3] + frac_coords = cart_coords @ c2f_matrix + trans_vector + + all_asym_frac_coords.append(frac_coords) + all_asym_elements.append(rotated_molecules_ele[i]) + + # Apply symmetry operations + symm_cart_coords = operation.apply_SYMM(frac_coords, self.symmetry_ops) @ f2c_matrix + symm_elements_list = [rotated_molecules_ele[i]] * len(self.symmetry_ops) + + # Center molecules that were moved across periodic boundaries + centroid_frac = np.mean(frac_coords, axis=0) + centroids_all_symm = operation.apply_SYMM(centroid_frac, self.symmetry_ops) + for j, cent in enumerate(centroids_all_symm): + move_to_center = (np.mod(cent, 1) - cent) @ f2c_matrix + symm_cart_coords[j] += move_to_center + + # --- Core Correction Logic --- + # 1. Create the full block of atoms for the current molecule type by applying all + # supercell translations. + mol_block_cart_temp = [] + for translation_vec in supercell_cart_translations: + # Adding the translation vector to all symmetry-equivalent molecules + translated_coords = symm_cart_coords + translation_vec + # Reshape to a flat (N_atoms * N_symm, 3) 2D array and append + mol_block_cart_temp.append(translated_coords.reshape(-1, 3)) + + # 2. Stack all translated blocks for this molecule type into a single 2D array + sc_cart_blocks.append(np.vstack(mol_block_cart_temp)) + + # 3. Handle the corresponding elements, ensuring they are flattened correctly + num_translations = len(self.supercell_frac_translations) + ele_block = np.array(symm_elements_list * num_translations).reshape(-1, 1) + sc_ele_blocks.append(ele_block) + + # After iterating through all molecule types, stack their respective complete blocks + final_sc_cart = np.vstack(sc_cart_blocks) + final_sc_ele = np.vstack(sc_ele_blocks) + + return final_sc_cart, final_sc_ele, all_asym_frac_coords, all_asym_elements + + def _create_final_crystal_object(self, + cell_params: list, + asym_frac_coords: list, + asym_elements: list, + seed: Any + ) -> data_classes.Crystal: + """Creates the final Crystal object from the successful structure.""" + + flat_elements = np.concatenate(asym_elements, axis=0).reshape(-1, 1) + flat_frac_coords = np.concatenate(asym_frac_coords, axis=0).reshape(-1, 3) + + atoms = [] + for ele, frac in zip(flat_elements, flat_frac_coords): + atoms.append(data_classes.Atom(element=ele.item(), frac_xyz=frac)) + + return data_classes.Crystal( + cell_para=cell_params, + atoms=atoms, + comment=str(seed), + system_name=str(seed), + space_group=self.space_group_number, + SYMM=self.symmetry_ops + ) + + def generate(self, + seed: Any = "unknown", + test: bool = False, + densely_pack_method: bool = False, + frame_tolerance: float = 1.5 + ) -> Optional[data_classes.Crystal]: + """ + The main generation method. + + Uses a Sobol sequence to get a random vector, then attempts to build and + pack a crystal structure through an iterative shrinking process. + + Args: + seed: A seed for the Sobol sequence generator. If "unknown", an error is raised. + test: A flag for enabling verbose test-mode output (prints cycle number). + densely_pack_method: If True, applies a heuristic to shrink very large + initial volumes. + frame_tolerance: Tolerance for checking if the final structure is a 2D slab. + + Returns: + A `data_classes.Crystal` object if a valid structure is found, otherwise `None`. + """ + if seed == "unknown": + raise ValueError("A seed must be provided for the Sobol generator.") + + sobol_gen = qmc.Sobol(d=self.search_dimensions, seed=seed) + initial_vector = sobol_gen.random(n=1).flatten() + + setup_result, volume_sqrt_term, _ = self._setup_crystal_from_vector(initial_vector) + if setup_result is None: + return None # Invalid initial angles + + cell_lengths, cell_angles, move_part_seed, rot_carts, rot_eles = setup_result + a, b, c = cell_lengths + alpha, beta, gamma = cell_angles + + # Heuristic to shrink extremely sparse initial structures + if densely_pack_method: + crystal_volume = a * b * c * np.sqrt(volume_sqrt_term) + if crystal_volume > self.estimated_packed_volume * 20: + c = self.estimated_packed_volume * 20 / (a * b * np.sqrt(volume_sqrt_term)) + + locked_dims = [False, False, False] + old_a, old_b, old_c = a, b, c + + for cycle_no in range(1001): + if cycle_no == 1001: + print(f"Stopping: Max optimization cycles reached. Seed: {seed}") + return None + + if a < 0 or b < 0 or c < 0: + print(f"BUG: Negative cell dimension. sg={self.space_group_number}, seed={seed}") + return None + + if test: + print(f"Cycle: {cycle_no}") + + cell_params = [[a, b, c], [alpha, beta, gamma]] + + sc_cart, sc_ele, asym_fracs, asym_eles = self._build_supercell_for_clash_test( + cell_params, rot_carts, rot_eles, move_part_seed + ) + + has_collision = self._check_for_collisions(sc_ele, sc_cart) + + if has_collision: + if cycle_no == 0: + print(f"Failed: Initial structure has collisions. Seed: {seed}") + return None + + # Collision occurred, so revert to last good state and lock the changed dimension + a, b, c = old_a, old_b, old_c + for dim_idx in last_change: + locked_dims[dim_idx] = True + else: + # No collision, this is a valid (though maybe not dense) structure. + # Check if optimization is finished (all dimensions are locked). + if cycle_no > 0 and all(locked_dims): + final_crystal = self._create_final_crystal_object(cell_params, asym_fracs, asym_eles, seed) + + # Final check to filter out 2D slab-like structures + if not operation.detect_is_frame_vdw_new(final_crystal, tolerance=frame_tolerance): + print(f"Failed: Generated structure is a 2D slab. Seed: {seed}") + return None + + print(f"Success: Generated a valid crystal structure. Seed: {seed}") + return final_crystal + + # If no collision and not finished, save current state and shrink further + old_a, old_b, old_c = a, b, c + a, b, c, last_change = self._shrink_cell_dimensions(a, b, c, locked_dims) + + # ============================================================================== + # Test-related functions, kept for compatibility, marked as internal. + # ============================================================================== + + def _generate_from_vector(self, + seed_vector: np.ndarray, + frame_tolerance: float = 1.5 + ) -> Optional[data_classes.Crystal]: + """ + Generates a single crystal structure directly from a vector, without optimization. + This is an internal method intended for testing and analysis. + Original name: generate_by_vector_2. + + Args: + seed_vector: A numpy array of shape (self.search_dimensions,) defining the structure. + frame_tolerance: Tolerance for checking if the final structure is a 2D slab. + + Returns: + A `data_classes.Crystal` object if valid, otherwise `None`. + """ + if not isinstance(seed_vector, np.ndarray): + raise TypeError("seed_vector must be a numpy array.") + + expected_len = self.search_dimensions + if len(seed_vector) != expected_len: + raise ValueError(f"Length of seed_vector must be {expected_len}, got {len(seed_vector)}.") + + setup_result, _, _ = self._setup_crystal_from_vector(seed_vector) + if setup_result is None: + return None # Invalid initial angles + + cell_lengths, cell_angles, move_part_seed, rot_carts, rot_eles = setup_result + cell_params = [cell_lengths, cell_angles] + + sc_cart, sc_ele, asym_fracs, asym_eles = self._build_supercell_for_clash_test( + cell_params, rot_carts, rot_eles, move_part_seed + ) + + if self._check_for_collisions(sc_ele, sc_cart): + print("Failed: Structure from vector has collisions.") + return None + + generated_crystal = self._create_final_crystal_object( + cell_params, asym_fracs, asym_eles, seed="from_vector" + ) + + # Optional: Keep the slab check for consistency + # if not operation.detect_is_frame_vdw_new(generated_crystal, tolerance=frame_tolerance): + # print("Failed: Generated structure is a 2D slab.") + # return None + + return generated_crystal + + def _is_valid_vector(self, seed_vector: np.ndarray) -> bool: + """ + Checks if a given vector produces a valid, collision-free structure. + Internal method for testing. + """ return self._generate_from_vector(seed_vector) is not None \ No newline at end of file diff --git a/basic_function/__pycache__/CSP_function.cpython-310.pyc b/basic_function/__pycache__/CSP_function.cpython-310.pyc deleted file mode 100644 index 6b80d351bd5c8f8361a22ef4112c988465283110..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/CSP_function.cpython-310.pyc and /dev/null differ diff --git a/basic_function/__pycache__/CSP_function.cpython-311.pyc b/basic_function/__pycache__/CSP_function.cpython-311.pyc deleted file mode 100644 index ba4e2e82f4acb503b75c35068c00cbe94f590075..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/CSP_function.cpython-311.pyc and /dev/null differ diff --git a/basic_function/__pycache__/CSP_function.cpython-313.pyc b/basic_function/__pycache__/CSP_function.cpython-313.pyc deleted file mode 100644 index 3e1cc991a9a0fa43b4921f8ef6b936e1f4cf0f68..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/CSP_function.cpython-313.pyc and /dev/null differ diff --git a/basic_function/__pycache__/CSP_function.cpython-38.pyc b/basic_function/__pycache__/CSP_function.cpython-38.pyc deleted file mode 100644 index 85247bd224342e1d7f33cd98b6a50da40be75bd4..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/CSP_function.cpython-38.pyc and /dev/null differ diff --git a/basic_function/__pycache__/CSP_function.cpython-39.pyc b/basic_function/__pycache__/CSP_function.cpython-39.pyc deleted file mode 100644 index 06c90f6ca6a49748033ebe85bf99c5fe737b0573..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/CSP_function.cpython-39.pyc and /dev/null differ diff --git a/basic_function/__pycache__/CSP_generator_normal.cpython-310.pyc b/basic_function/__pycache__/CSP_generator_normal.cpython-310.pyc deleted file mode 100644 index 1242c103e4a9bd2f1e77e83f2cdeb64272d04a3a..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/CSP_generator_normal.cpython-310.pyc and /dev/null differ diff --git a/basic_function/__pycache__/CSP_generator_normal.cpython-311.pyc b/basic_function/__pycache__/CSP_generator_normal.cpython-311.pyc deleted file mode 100644 index 5f70bc4a231a585fccb4ebe1a15bb1a3306868e7..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/CSP_generator_normal.cpython-311.pyc and /dev/null differ diff --git a/basic_function/__pycache__/CSP_generator_normal.cpython-313.pyc b/basic_function/__pycache__/CSP_generator_normal.cpython-313.pyc deleted file mode 100644 index 90c3d72f1a63b723e914900de2a9ad63daa8df46..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/CSP_generator_normal.cpython-313.pyc and /dev/null differ diff --git a/basic_function/__pycache__/CSP_generator_normal.cpython-38.pyc b/basic_function/__pycache__/CSP_generator_normal.cpython-38.pyc deleted file mode 100644 index 4eac67ff271c8e16a847713e1674a10f17243250..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/CSP_generator_normal.cpython-38.pyc and /dev/null differ diff --git a/basic_function/__pycache__/CSP_generator_normal.cpython-39.pyc b/basic_function/__pycache__/CSP_generator_normal.cpython-39.pyc deleted file mode 100644 index 98575bb94abf75b301951e4fc12eadaa9492c9cc..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/CSP_generator_normal.cpython-39.pyc and /dev/null differ diff --git a/basic_function/__pycache__/chemical_knowledge.cpython-310.pyc b/basic_function/__pycache__/chemical_knowledge.cpython-310.pyc deleted file mode 100644 index d21992527cb67bae5f57c9ce784b9f1740b9f8cb..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/chemical_knowledge.cpython-310.pyc and /dev/null differ diff --git a/basic_function/__pycache__/chemical_knowledge.cpython-311.pyc b/basic_function/__pycache__/chemical_knowledge.cpython-311.pyc deleted file mode 100644 index 963bc37b65997d89f3d17d8ec9db98bf02ef9527..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/chemical_knowledge.cpython-311.pyc and /dev/null differ diff --git a/basic_function/__pycache__/chemical_knowledge.cpython-313.pyc b/basic_function/__pycache__/chemical_knowledge.cpython-313.pyc deleted file mode 100644 index 1b0f8b2b2cfd6d5144d1bb929748e5f409bde610..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/chemical_knowledge.cpython-313.pyc and /dev/null differ diff --git a/basic_function/__pycache__/chemical_knowledge.cpython-38.pyc b/basic_function/__pycache__/chemical_knowledge.cpython-38.pyc deleted file mode 100644 index eee87999297cf343e10f70862b9b3776c278dba9..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/chemical_knowledge.cpython-38.pyc and /dev/null differ diff --git a/basic_function/__pycache__/chemical_knowledge.cpython-39.pyc b/basic_function/__pycache__/chemical_knowledge.cpython-39.pyc deleted file mode 100644 index 513e6fcb056eab5499db4fd6ec3c3525870b8358..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/chemical_knowledge.cpython-39.pyc and /dev/null differ diff --git a/basic_function/__pycache__/conformer_search.cpython-310.pyc b/basic_function/__pycache__/conformer_search.cpython-310.pyc deleted file mode 100644 index 2c51dcd1cd052eaa87d4c9de27bff35e72993366..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/conformer_search.cpython-310.pyc and /dev/null differ diff --git a/basic_function/__pycache__/conformer_search.cpython-311.pyc b/basic_function/__pycache__/conformer_search.cpython-311.pyc deleted file mode 100644 index 8673545a0a78603330b152008bfcab2535679d92..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/conformer_search.cpython-311.pyc and /dev/null differ diff --git a/basic_function/__pycache__/conformer_search.cpython-313.pyc b/basic_function/__pycache__/conformer_search.cpython-313.pyc deleted file mode 100644 index 41e00fec48d2e4ff92b25041e65a438b875182f1..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/conformer_search.cpython-313.pyc and /dev/null differ diff --git a/basic_function/__pycache__/conformer_search.cpython-38.pyc b/basic_function/__pycache__/conformer_search.cpython-38.pyc deleted file mode 100644 index b440bde215ae1aa8e6552fc879beecf847201e42..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/conformer_search.cpython-38.pyc and /dev/null differ diff --git a/basic_function/__pycache__/conformer_search.cpython-39.pyc b/basic_function/__pycache__/conformer_search.cpython-39.pyc deleted file mode 100644 index 05b66e4192783036227ca675516b91f854fd434f..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/conformer_search.cpython-39.pyc and /dev/null differ diff --git a/basic_function/__pycache__/data_classes.cpython-310.pyc b/basic_function/__pycache__/data_classes.cpython-310.pyc deleted file mode 100644 index f4abbefd3107b3109fdce332739469a127cb391b..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/data_classes.cpython-310.pyc and /dev/null differ diff --git a/basic_function/__pycache__/data_classes.cpython-311.pyc b/basic_function/__pycache__/data_classes.cpython-311.pyc deleted file mode 100644 index a8512e072ae1ca290b0d22f1478940cd4d26b6d6..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/data_classes.cpython-311.pyc and /dev/null differ diff --git a/basic_function/__pycache__/data_classes.cpython-313.pyc b/basic_function/__pycache__/data_classes.cpython-313.pyc deleted file mode 100644 index a3013a8d19c26b97a7473bef3feb945e83bce8bb..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/data_classes.cpython-313.pyc and /dev/null differ diff --git a/basic_function/__pycache__/data_classes.cpython-38.pyc b/basic_function/__pycache__/data_classes.cpython-38.pyc deleted file mode 100644 index dd9cc53000a3187c182af2d692f0d4387cff74c3..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/data_classes.cpython-38.pyc and /dev/null differ diff --git a/basic_function/__pycache__/data_classes.cpython-39.pyc b/basic_function/__pycache__/data_classes.cpython-39.pyc deleted file mode 100644 index e64c38e4989c54e6509cddca70560d2b66cffa77..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/data_classes.cpython-39.pyc and /dev/null differ diff --git a/basic_function/__pycache__/descriptor.cpython-39.pyc b/basic_function/__pycache__/descriptor.cpython-39.pyc deleted file mode 100644 index 5747b6525e62dce90131d2b01ae4e2431f6bcb90..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/descriptor.cpython-39.pyc and /dev/null differ diff --git a/basic_function/__pycache__/format_parser.cpython-310.pyc b/basic_function/__pycache__/format_parser.cpython-310.pyc deleted file mode 100644 index 5a52493d629441ce292559aa763c2ddec7530ac8..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/format_parser.cpython-310.pyc and /dev/null differ diff --git a/basic_function/__pycache__/format_parser.cpython-311.pyc b/basic_function/__pycache__/format_parser.cpython-311.pyc deleted file mode 100644 index d2cb79e5634ebefc5ae6c888d0c397951ae20b42..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/format_parser.cpython-311.pyc and /dev/null differ diff --git a/basic_function/__pycache__/format_parser.cpython-313.pyc b/basic_function/__pycache__/format_parser.cpython-313.pyc deleted file mode 100644 index 483d4d73676e3d746e0cd79f14b4ec2c499d209a..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/format_parser.cpython-313.pyc and /dev/null differ diff --git a/basic_function/__pycache__/format_parser.cpython-38.pyc b/basic_function/__pycache__/format_parser.cpython-38.pyc deleted file mode 100644 index f13219069742d4945b050cacebdcf4bba3c3f806..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/format_parser.cpython-38.pyc and /dev/null differ diff --git a/basic_function/__pycache__/format_parser.cpython-39.pyc b/basic_function/__pycache__/format_parser.cpython-39.pyc deleted file mode 100644 index 5dbb0279147a5b1c4a994f8ce42ac3a59cc3aaa5..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/format_parser.cpython-39.pyc and /dev/null differ diff --git a/basic_function/__pycache__/operation.cpython-310.pyc b/basic_function/__pycache__/operation.cpython-310.pyc deleted file mode 100644 index 89df7b9d5215982543a19f24f09b8b7813b65328..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/operation.cpython-310.pyc and /dev/null differ diff --git a/basic_function/__pycache__/operation.cpython-311.pyc b/basic_function/__pycache__/operation.cpython-311.pyc deleted file mode 100644 index 9e94698aedab2d946e7d2261dde52b42e18939fd..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/operation.cpython-311.pyc and /dev/null differ diff --git a/basic_function/__pycache__/operation.cpython-313.pyc b/basic_function/__pycache__/operation.cpython-313.pyc deleted file mode 100644 index 9cc28ad459db11bc76eb9ddea996da626a2fe4dd..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/operation.cpython-313.pyc and /dev/null differ diff --git a/basic_function/__pycache__/operation.cpython-38.pyc b/basic_function/__pycache__/operation.cpython-38.pyc deleted file mode 100644 index 62d500d9fe786633c5ff73d4ecfc0e8e7275946f..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/operation.cpython-38.pyc and /dev/null differ diff --git a/basic_function/__pycache__/operation.cpython-39.pyc b/basic_function/__pycache__/operation.cpython-39.pyc deleted file mode 100644 index 9255c515f9d29632cb43e4761cda921244e6e1c4..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/operation.cpython-39.pyc and /dev/null differ diff --git a/basic_function/__pycache__/operation_new.cpython-310.pyc b/basic_function/__pycache__/operation_new.cpython-310.pyc deleted file mode 100644 index ba8434ce6cdfe9eb37fe62a693c22be7c68969bf..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/operation_new.cpython-310.pyc and /dev/null differ diff --git a/basic_function/__pycache__/operation_new.cpython-313.pyc b/basic_function/__pycache__/operation_new.cpython-313.pyc deleted file mode 100644 index de1a351b415b894fd81f809fe4ea847b1844b9cf..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/operation_new.cpython-313.pyc and /dev/null differ diff --git a/basic_function/__pycache__/operation_new.cpython-38.pyc b/basic_function/__pycache__/operation_new.cpython-38.pyc deleted file mode 100644 index 6b878d7e801c09b61695e984efab052285e58f68..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/operation_new.cpython-38.pyc and /dev/null differ diff --git a/basic_function/__pycache__/operation_new.cpython-39.pyc b/basic_function/__pycache__/operation_new.cpython-39.pyc deleted file mode 100644 index 46e609252f9f8023be8402c94270be444a8ccad8..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/operation_new.cpython-39.pyc and /dev/null differ diff --git a/basic_function/__pycache__/others.cpython-310.pyc b/basic_function/__pycache__/others.cpython-310.pyc deleted file mode 100644 index 2bf4f34a432a4270531ccd0b0eb2b3ce7c09169d..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/others.cpython-310.pyc and /dev/null differ diff --git a/basic_function/__pycache__/others.cpython-311.pyc b/basic_function/__pycache__/others.cpython-311.pyc deleted file mode 100644 index 167d18603fad440640727a0719ce514ffa283430..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/others.cpython-311.pyc and /dev/null differ diff --git a/basic_function/__pycache__/others.cpython-313.pyc b/basic_function/__pycache__/others.cpython-313.pyc deleted file mode 100644 index 4ade1f669eccc13606e596eb17f247fa38c9ce6f..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/others.cpython-313.pyc and /dev/null differ diff --git a/basic_function/__pycache__/others.cpython-38.pyc b/basic_function/__pycache__/others.cpython-38.pyc deleted file mode 100644 index 079ed77b05d464291c9b0f54fe7291ed1b5aad75..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/others.cpython-38.pyc and /dev/null differ diff --git a/basic_function/__pycache__/others.cpython-39.pyc b/basic_function/__pycache__/others.cpython-39.pyc deleted file mode 100644 index a8c6dc5baae0cd5e340db5880ef525d8b572e120..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/others.cpython-39.pyc and /dev/null differ diff --git a/basic_function/__pycache__/packaged_function.cpython-310.pyc b/basic_function/__pycache__/packaged_function.cpython-310.pyc deleted file mode 100644 index 457e7cec1b07deaa9daa78a128184375542910d6..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/packaged_function.cpython-310.pyc and /dev/null differ diff --git a/basic_function/__pycache__/packaged_function.cpython-311.pyc b/basic_function/__pycache__/packaged_function.cpython-311.pyc deleted file mode 100644 index 3872d15b924db393b18526bf5b9e35c12c5eec02..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/packaged_function.cpython-311.pyc and /dev/null differ diff --git a/basic_function/__pycache__/packaged_function.cpython-313.pyc b/basic_function/__pycache__/packaged_function.cpython-313.pyc deleted file mode 100644 index 63de0ed62a394e06a0b64e5bd5132be31e89b0b4..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/packaged_function.cpython-313.pyc and /dev/null differ diff --git a/basic_function/__pycache__/packaged_function.cpython-38.pyc b/basic_function/__pycache__/packaged_function.cpython-38.pyc deleted file mode 100644 index f7b1945997a31188b31731bede672acb172a8a67..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/packaged_function.cpython-38.pyc and /dev/null differ diff --git a/basic_function/__pycache__/packaged_function.cpython-39.pyc b/basic_function/__pycache__/packaged_function.cpython-39.pyc deleted file mode 100644 index 8d0d21fed048b07685f6620bdca6a43daafd21b1..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/packaged_function.cpython-39.pyc and /dev/null differ diff --git a/basic_function/__pycache__/unit_cell_parser.cpython-310.pyc b/basic_function/__pycache__/unit_cell_parser.cpython-310.pyc deleted file mode 100644 index d82f489d384a12c0fd3469c7ff8f2468dffb8cd1..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/unit_cell_parser.cpython-310.pyc and /dev/null differ diff --git a/basic_function/__pycache__/unit_cell_parser.cpython-311.pyc b/basic_function/__pycache__/unit_cell_parser.cpython-311.pyc deleted file mode 100644 index fcff2bcbad0ad0c1a7eac96c8cfbb9ef049a3c52..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/unit_cell_parser.cpython-311.pyc and /dev/null differ diff --git a/basic_function/__pycache__/unit_cell_parser.cpython-313.pyc b/basic_function/__pycache__/unit_cell_parser.cpython-313.pyc deleted file mode 100644 index ba4c6646677fb1a3416be0b56394ecc4ea066a6e..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/unit_cell_parser.cpython-313.pyc and /dev/null differ diff --git a/basic_function/__pycache__/unit_cell_parser.cpython-38.pyc b/basic_function/__pycache__/unit_cell_parser.cpython-38.pyc deleted file mode 100644 index 76cc51130a1b0e061560ffce92bdda0c7ed027d7..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/unit_cell_parser.cpython-38.pyc and /dev/null differ diff --git a/basic_function/__pycache__/unit_cell_parser.cpython-39.pyc b/basic_function/__pycache__/unit_cell_parser.cpython-39.pyc deleted file mode 100644 index e8dd302a4653b024739ec4919001f7df54a4adc9..0000000000000000000000000000000000000000 Binary files a/basic_function/__pycache__/unit_cell_parser.cpython-39.pyc and /dev/null differ diff --git a/basic_function/chemical_knowledge.py b/basic_function/chemical_knowledge.py index 3862146602d6b3fd8a567ce08ab1b78b999b4ea7..3e0d791cc92f3a68542847eb463f845f1f068954 100644 --- a/basic_function/chemical_knowledge.py +++ b/basic_function/chemical_knowledge.py @@ -1,140 +1,140 @@ -element_vdw_radii = { - # First Period - 'H': 1.20, - 'He': 1.40, - # Second Period - 'Li': 1.82, 'Be': 1.53, 'B': 1.92, 'C': 1.70, 'N': 1.55, 'O': 1.52, 'F': 1.35, 'Ne': 1.54, - # Third Period - 'Na': 2.27, 'Mg': 1.73, 'Al': 1.84, 'Si': 2.10, 'P': 1.80, 'S': 1.80, 'Cl': 1.75, 'Ar': 1.88, - # Fourth Period - 'K': 2.75, 'Ca': 2.31, 'Sc': 2.11, 'Ti': 1.87, 'V': 1.79, 'Cr': 1.89, 'Mn': 1.97, 'Fe': 1.94, 'Co': 1.92, - 'Ni': 1.63, 'Cu': 1.40, 'Zn': 1.39, 'Ga': 1.87, 'Ge': 2.11, 'As': 1.85, 'Se': 1.90, 'Br': 1.83, 'Kr': 2.02, - # Fifth Period - 'Rb': None, 'Sr': None, 'Y': None, 'Zr': None, 'Nb': None, 'Mo': None, 'Tc': None, 'Ru': None, 'Rh': None, - 'Pd': None, 'Ag': None, 'Cd': None, 'In': None, 'Sn': None, 'Sb': None, 'Te': 2.06, 'I': 1.98, 'Xe': 2.16} -# comes from https://pubchem.ncbi.nlm.nih.gov/ptable/atomic-radius/ - -element_covalent_radii = { - # First Period - #'H': 0.31, - 'H': 0.31, - 'He': 0.28, - # Second Period - 'Li': 1.28, 'Be': 0.96, 'B': 0.84, 'C': 0.76, 'N': 0.71, 'O': 0.66, 'F': 0.57, 'Ne': 0.58, - # Third Period - 'Na': 1.66, 'Mg': 1.41, 'Al': 1.21, 'Si': 1.11, 'P': 1.07, 'S': 1.05, 'Cl': 1.02, 'Ar': 1.06, - # Fourth Period - 'K': 2.03, 'Ca': 1.76, 'Sc': 1.70, 'Ti': 1.60, 'V': 1.53, 'Cr': 1.39, 'Mn': 1.61, 'Fe': 1.52, 'Co': 1.50, - 'Ni': 1.24, 'Cu': 1.32, 'Zn': 1.22, 'Ga': 1.22, 'Ge': 1.20, 'As': 1.19, 'Se': 1.20, 'Br': 1.20, 'Kr': 1.16, - # Fifth Period - 'Rb': 2.20, 'Sr': 1.95, 'Y': 1.90, 'Zr': 1.75, 'Nb': 1.64, 'Mo': 1.54, 'Tc': 1.47, 'Ru': 1.46, 'Rh': 1.42, - 'Pd': 1.39, 'Ag': 1.45, 'Cd': 1.44, 'In': 1.42, 'Sn': 1.39, 'Sb': 1.39, 'Te': 1.38, 'I': 1.39, 'Xe': 1.40} -# comes from DOI: 10.1039/b801115j - -element_masses = { - # First Period - 'H': 1.0079, 'He': 4.0026, - # Second Period - 'Li': 6.941, 'Be': 9.0122, 'B': 10.811, 'C': 12.0107, 'N': 14.0067, 'O': 15.9994, 'F': 18.9984, 'Ne': 20.1797, - # Third Period - 'Na': 22.9897, 'Mg': 24.305, 'Al': 26.9815, 'Si': 28.0855, 'P': 30.9738, 'S': 32.065, 'Cl': 35.453, 'Ar': 39.948, - # Fourth Period - 'K': 39.0983, 'Ca': 40.078, 'Sc': 44.9559, 'Ti': 47.867, 'V': 50.9415, 'Cr': 51.9961, 'Mn': 54.938, 'Fe': 55.845, - 'Co': 58.9332, 'Ni': 58.6934, 'Cu': 63.546, 'Zn': 65.39, 'Ga': 69.723, 'Ge': 72.64, 'As': 74.9216, 'Se': 78.96, - 'Br': 79.904, 'Kr': 83.8, - # Fifth Period - 'Rb': 85.4678, 'Sr': 87.62, 'Y': 88.9059, 'Zr': 91.224, 'Nb': 92.9064, 'Mo': 95.94, 'Tc': 98, 'Ru': 101.07, - 'Rh': 102.9055, 'Pd': 106.42, 'Ag': 107.8682, 'Cd': 112.411, 'In': 114.818, 'Sn': 118.71, 'Sb': 121.76, 'Te': 127.6, - 'I': 126.9045, 'Xe': 131.293, - # Sixth Period - 'Cs': 132.9055, 'Ba': 137.327, 'La': 138.9055, 'Ce': 140.116, 'Pr': 140.9077, 'Nd': 144.24, 'Pm': 145, 'Sm': 150.36, - 'Eu': 151.964, 'Gd': 157.25, 'Tb': 158.9253, 'Dy': 162.5, 'Ho': 164.9303, 'Er': 167.259, 'Tm': 168.9342, - 'Yb': 173.04, 'Lu': 174.967, 'Hf': 178.49, 'Ta': 180.9479, 'W': 183.84, 'Re': 186.207, 'Os': 190.23, 'Ir': 192.217, - 'Pt': 195.078, 'Au': 196.9665, 'Hg': 200.59, 'Tl': 204.3833, 'Pb': 207.2, 'Bi': 208.9804, 'Po': 209, 'At': 210, - 'Rn': 222, - # Seventh Period - 'Fr': 223, 'Ra': 226, 'Ac': 227, 'Th': 232.0381, 'Pa': 231.0359, 'U': 238.0289, 'Np': 237, 'Pu': 244, 'Am': 243, - 'Cm': 247, 'Bk': 247, 'Cf': 251, 'Es': 252, 'Fm': 257, 'Md': 258, 'No': 259, 'Lr': 262, 'Rf': 261, 'Db': 262, - 'Sg': 266, 'Bh': 264, 'Hs': 277, 'Mt': 268} - -periodic_table_list = { - # First Period - 'H': 1, 'He': 2, - # Second Period - 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'Ne': 10, - # Third Period - 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15, 'S': 16, 'Cl': 17, 'Ar': 18, - # Fourth Period - 'K': 19, 'Ca': 20, 'Sc': 21, 'Ti': 22, 'V': 23, 'Cr': 24, 'Mn': 25, 'Fe': 26, 'Co': 27, - 'Ni': 28, 'Cu': 29, 'Zn': 30, 'Ga': 31, 'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36, - # Fifth Period - 'Rb': 37, 'Sr': 38, 'Y': 39, 'Zr': 40, 'Nb': 41, 'Mo': 42, 'Tc': 43, 'Ru': 44, 'Rh': 45, - 'Pd': 46, 'Ag': 47, 'Cd': 48, 'In': 49, 'Sn': 50, 'Sb': 51, 'Te': 52, 'I': 53, 'Xe': 54, - # Sixth Period - 'Cs': 55, 'Ba': 56, 'La': 57, 'Ce': 58, 'Pr': 59, 'Nd': 60, 'Pm': 61, 'Sm': 62, 'Eu': 63, - 'Gd': 64, 'Tb': 65, 'Dy': 66, 'Ho': 67, 'Er': 68, 'Tm': 69, 'Yb': 70, 'Lu': 71, 'Hf': 72, - 'Ta': 73, 'W': 74, 'Re': 75, 'Os': 76, 'Ir': 77, 'Pt': 78, 'Au': 79, 'Hg': 80, 'Tl': 81, - 'Pb': 82, 'Bi': 83, 'Po': 84, 'At': 85, 'Rn': 86, - # Seventh Period - 'Fr': 87, 'Ra': 88, 'Ac': 89, 'Th': 90, 'Pa': 91, 'U': 92, 'Np': 93, 'Pu': 94, 'Am': 95, - 'Cm': 96, 'Bk': 97, 'Cf': 98, 'Es': 99, 'Fm': 100, 'Md': 101, 'No': 102, 'Lr': 103, 'Rf': 104, - 'Db': 105, 'Sg': 106, 'Bh': 107, 'Hs': 108, 'Mt': 109} - - -def sort_by_atomic_number(all_element_type): - current_order = [] - for element in all_element_type: - assert element in periodic_table_list, '{} element is not exist'.format(element) - current_order.append(periodic_table_list[element]) - - sorted_elements = [ELEMENT for ORDER, ELEMENT in sorted(zip(current_order, all_element_type))] - - return sorted_elements - -periodic_table_std = {value: key for key, value in periodic_table_list.items()} - -# following these website to get help: -# http://img.chem.ucl.ac.uk/sgp/LARGE/sgp.htm -# https://en.wikipedia.org/wiki/Space_group -space_group = { - 1:[["x,y,z"],'P1',"Triclinic"], - 2:[["x,y,z","-x,-y,-z"],'P-1',"Triclinic"], - 4:[["x,y,z","-x,y+1/2,-z"],'P21',"Monoclinic"], - 5:[["x,y,z","-x,y,-z","x+1/2,y+1/2,z","-x+1/2,y+1/2,-z"],'C2',"Monoclinic"], - 7:[["x,y,z","x,-y,1/2+z"],"Pc","Monoclinic"], - 9:[["x,y,z","x,-y,1/2+z","1/2+x,1/2+y,z","1/2+x,1/2-y,1/2+z"],"CC","Monoclinic"], - 12:[["x,y,z","-x,y,-z","1/2+x,1/2+y,z","1/2-x,1/2+y,-z","-x,-y,-z","x,-y,z","1/2-x,1/2-y,-z","1/2+x,1/2-y,z"],"C2/M","Monoclinic"], - 11:[["x,y,z","-x,y+1/2,-z","-x,-y,-z","x,1/2-y,z"],"P21/m","Monoclinic"], - 12:[["x,y,z","x,-y,z","-x,y,-z","-x,-y,-z","1/2+x,1/2+y,z","1/2+x,1/2-y,z","1/2-x,1/2+y,-z","1/2-x,1/2-y,-z"],"C2/m","Monoclinic"], - 13:[["x,y,z","-x,y,1/2-z","-x,-y,-z","x,-y,1/2+z"],"P2/c","Monoclinic"], - 13:[["x,y,z", "-x,y,-z+1/2", "-x,-y,-z", "x,-y,z+1/2"],'P2/c',"Monoclinic"], - 14:[["x,y,z", "-x,y+1/2,-z+1/2", "-x,-y,-z", "x,-y+1/2,z+1/2"],'P21/C',"Monoclinic"], - 15:[["x,y,z","-x,y,1/2-z","1/2+x,1/2+y,z","1/2-x,1/2+y,1/2-z","-x,-y,-z","x,-y,1/2+z","1/2-x,1/2-y,-z","1/2+x,1/2-y,1/2+z"],"C2/C","Monoclinic"], - 18:[["x,y,z","1/2+x,1/2-y,-z","1/2-x,1/2+y,-z","-x,-y,z"],"P21212","Orthorhombic"], - 19:[["x,y,z","1/2+x,1/2-y,-z","-x,1/2+y,1/2-z","1/2-x,-y,1/2+z"],"P212121","Orthorhombic"], - 29:[["x,y,z","1/2-x,y,1/2+z","1/2+x,-y,z","-x,-y,1/2+z"],"PCA21","Orthorhombic"], - 33:[["x,y,z","1/2-x,1/2+y,1/2+z","1/2+x,1/2-y,z","-x,-y,1/2+z"],"PNA21","Orthorhombic"], - 43:[["x,y,z","1/4-x,1/4+y,1/4+z","1/4+x,1/4-y,1/4+z","-x,-y,z","x,y+1/2,z+1/2","1/4-x,3/4+y,3/4+z","1/4+x,3/4-y,3/4+z","-x,-y+1/2,z+1/2", - "x+1/2,y,z+1/2","3/4-x,1/4+y,3/4+z","3/4+x,1/4-y,3/4+z","-x+1/2,-y,z+1/2","x+1/2,y+1/2,z","3/4-x,3/4+y,1/4+z","3/4+x,3/4-y,1/4+z","-x+1/2,-y+1/2,z"], "Fdd2", "Orthorhombic"], - 56:[["x,y,z","1/2-x,y,1/2+z","x,1/2-y,1/2+z","1/2+x,1/2+y,-z","-x,-y,-z","1/2+x,-y,1/2-z","-x,1/2+y,1/2-z","1/2-x,1/2-y,z"], "Pccn", "Orthorhombic"], - 60:[["x,y,z","1/2-x,1/2+y,z","x,-y,1/2+z","1/2+x,1/2+y,1/2-z","-x,-y,-z","1/2+x,1/2-y,-z","-x,y,1/2-z","1/2-x,1/2-y,1/2+z"],"Pbcn","Orthorhombic"], - 61:[["x,y,z","1/2-x,1/2+y,z","x,1/2-y,1/2+z","1/2+x,y,1/2-z","-x,-y,-z","1/2+x,1/2-y,-z","-x,1/2+y,1/2-z","1/2-x,-y,1/2+z"],"PBCA","Orthorhombic"], - 62:[["x,y,z","x+1/2,-y+1/2,-z+1/2","-x,y+1/2,-z","-x+1/2,-y,z+1/2","-x,-y,-z","-x+1/2,y+1/2,z+1/2","x,-y+1/2,z","x+1/2,y,-z+1/2"],"Pnma","Orthorhombic"], - 77:[["x,y,z","-x,-y,z","-y,x,1/2+z","y,-x,1/2+z"],"P42","Tetragonal"], - 88:[["x,y,z","-x,-y,z","-y,1/2+x,1/4+z","y,1/2-x,1/4+z","-x,1/2-y,1/4-z","x,1/2+y,1/4-z","y,-x,-z","-y,x,-z", - "1/2+x,1/2+y,1/2+z","1/2-x,1/2-y,1/2+z","1/2-y,x,3/4+z","1/2+y,-x,3/4+z","1/2-x,-y,3/4-z","1/2+x,y,3/4-z","1/2+y,1/2-x,1/2-z","1/2-y,1/2+x,1/2-z"],"I41/a","Tetragonal"], - 96:[["x,y,z","-x,-y,1/2+z","1/2-y,1/2+x,3/4+z","1/2+y,1/2-x,1/4+z","1/2+x,1/2-y,1/4-z","1/2-x,1/2+y,3/4-z","-y,-x,1/2-z","y,x,-z"],"P43212","Tetragonal"], - 143:[["x,y,z","-y,x-y,z","-x+y,-x,z"],"P3","Hexagonal"], - 147:[["x,y,z","-y,x-y,z","-x+y,-x,z","-x,-y,-z","y,-x+y,-z","x-y,x,-z"],"P-3","Hexagonal"], - 148:[["x,y,z","z,x,y","y,z,x","-x,-y,-z","-z,-x,-y","-y,-z,-x"],"R-3","Trigonal"], - 169:[["x,y,z","-y,x-y,1/3+z","-x+y,-x,2/3+z","-x,-y,1/2+z","x-y,x,1/6+z","y,-x+y,5/6+z"],"P61","Hexagonal"] -} - -point_group = {"Triclinic":[["a","b","c"],["alpha","beta","gamma"]], - "Monoclinic":[["a","b","c"],[90,"beta",90]], - "Orthorhombic":[["a","b","c"],[90,90,90]], - "Tetragonal":[["a","a","c"],[90,90,90]], - "Trigonal":[["a","a","a"],["alpha","alpha","alpha"]], - "Hexagonal":[["a","a","c"],[90,90,120]], - "Cubic":[["a","a","a"],[90,90,90]]} - +element_vdw_radii = { + # First Period + 'H': 1.20, + 'He': 1.40, + # Second Period + 'Li': 1.82, 'Be': 1.53, 'B': 1.92, 'C': 1.70, 'N': 1.55, 'O': 1.52, 'F': 1.35, 'Ne': 1.54, + # Third Period + 'Na': 2.27, 'Mg': 1.73, 'Al': 1.84, 'Si': 2.10, 'P': 1.80, 'S': 1.80, 'Cl': 1.75, 'Ar': 1.88, + # Fourth Period + 'K': 2.75, 'Ca': 2.31, 'Sc': 2.11, 'Ti': 1.87, 'V': 1.79, 'Cr': 1.89, 'Mn': 1.97, 'Fe': 1.94, 'Co': 1.92, + 'Ni': 1.63, 'Cu': 1.40, 'Zn': 1.39, 'Ga': 1.87, 'Ge': 2.11, 'As': 1.85, 'Se': 1.90, 'Br': 1.83, 'Kr': 2.02, + # Fifth Period + 'Rb': None, 'Sr': None, 'Y': None, 'Zr': None, 'Nb': None, 'Mo': None, 'Tc': None, 'Ru': None, 'Rh': None, + 'Pd': None, 'Ag': None, 'Cd': None, 'In': None, 'Sn': None, 'Sb': None, 'Te': 2.06, 'I': 1.98, 'Xe': 2.16} +# comes from https://pubchem.ncbi.nlm.nih.gov/ptable/atomic-radius/ + +element_covalent_radii = { + # First Period + #'H': 0.31, + 'H': 0.31, + 'He': 0.28, + # Second Period + 'Li': 1.28, 'Be': 0.96, 'B': 0.84, 'C': 0.76, 'N': 0.71, 'O': 0.66, 'F': 0.57, 'Ne': 0.58, + # Third Period + 'Na': 1.66, 'Mg': 1.41, 'Al': 1.21, 'Si': 1.11, 'P': 1.07, 'S': 1.05, 'Cl': 1.02, 'Ar': 1.06, + # Fourth Period + 'K': 2.03, 'Ca': 1.76, 'Sc': 1.70, 'Ti': 1.60, 'V': 1.53, 'Cr': 1.39, 'Mn': 1.61, 'Fe': 1.52, 'Co': 1.50, + 'Ni': 1.24, 'Cu': 1.32, 'Zn': 1.22, 'Ga': 1.22, 'Ge': 1.20, 'As': 1.19, 'Se': 1.20, 'Br': 1.20, 'Kr': 1.16, + # Fifth Period + 'Rb': 2.20, 'Sr': 1.95, 'Y': 1.90, 'Zr': 1.75, 'Nb': 1.64, 'Mo': 1.54, 'Tc': 1.47, 'Ru': 1.46, 'Rh': 1.42, + 'Pd': 1.39, 'Ag': 1.45, 'Cd': 1.44, 'In': 1.42, 'Sn': 1.39, 'Sb': 1.39, 'Te': 1.38, 'I': 1.39, 'Xe': 1.40} +# comes from DOI: 10.1039/b801115j + +element_masses = { + # First Period + 'H': 1.0079, 'He': 4.0026, + # Second Period + 'Li': 6.941, 'Be': 9.0122, 'B': 10.811, 'C': 12.0107, 'N': 14.0067, 'O': 15.9994, 'F': 18.9984, 'Ne': 20.1797, + # Third Period + 'Na': 22.9897, 'Mg': 24.305, 'Al': 26.9815, 'Si': 28.0855, 'P': 30.9738, 'S': 32.065, 'Cl': 35.453, 'Ar': 39.948, + # Fourth Period + 'K': 39.0983, 'Ca': 40.078, 'Sc': 44.9559, 'Ti': 47.867, 'V': 50.9415, 'Cr': 51.9961, 'Mn': 54.938, 'Fe': 55.845, + 'Co': 58.9332, 'Ni': 58.6934, 'Cu': 63.546, 'Zn': 65.39, 'Ga': 69.723, 'Ge': 72.64, 'As': 74.9216, 'Se': 78.96, + 'Br': 79.904, 'Kr': 83.8, + # Fifth Period + 'Rb': 85.4678, 'Sr': 87.62, 'Y': 88.9059, 'Zr': 91.224, 'Nb': 92.9064, 'Mo': 95.94, 'Tc': 98, 'Ru': 101.07, + 'Rh': 102.9055, 'Pd': 106.42, 'Ag': 107.8682, 'Cd': 112.411, 'In': 114.818, 'Sn': 118.71, 'Sb': 121.76, 'Te': 127.6, + 'I': 126.9045, 'Xe': 131.293, + # Sixth Period + 'Cs': 132.9055, 'Ba': 137.327, 'La': 138.9055, 'Ce': 140.116, 'Pr': 140.9077, 'Nd': 144.24, 'Pm': 145, 'Sm': 150.36, + 'Eu': 151.964, 'Gd': 157.25, 'Tb': 158.9253, 'Dy': 162.5, 'Ho': 164.9303, 'Er': 167.259, 'Tm': 168.9342, + 'Yb': 173.04, 'Lu': 174.967, 'Hf': 178.49, 'Ta': 180.9479, 'W': 183.84, 'Re': 186.207, 'Os': 190.23, 'Ir': 192.217, + 'Pt': 195.078, 'Au': 196.9665, 'Hg': 200.59, 'Tl': 204.3833, 'Pb': 207.2, 'Bi': 208.9804, 'Po': 209, 'At': 210, + 'Rn': 222, + # Seventh Period + 'Fr': 223, 'Ra': 226, 'Ac': 227, 'Th': 232.0381, 'Pa': 231.0359, 'U': 238.0289, 'Np': 237, 'Pu': 244, 'Am': 243, + 'Cm': 247, 'Bk': 247, 'Cf': 251, 'Es': 252, 'Fm': 257, 'Md': 258, 'No': 259, 'Lr': 262, 'Rf': 261, 'Db': 262, + 'Sg': 266, 'Bh': 264, 'Hs': 277, 'Mt': 268} + +periodic_table_list = { + # First Period + 'H': 1, 'He': 2, + # Second Period + 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'Ne': 10, + # Third Period + 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15, 'S': 16, 'Cl': 17, 'Ar': 18, + # Fourth Period + 'K': 19, 'Ca': 20, 'Sc': 21, 'Ti': 22, 'V': 23, 'Cr': 24, 'Mn': 25, 'Fe': 26, 'Co': 27, + 'Ni': 28, 'Cu': 29, 'Zn': 30, 'Ga': 31, 'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36, + # Fifth Period + 'Rb': 37, 'Sr': 38, 'Y': 39, 'Zr': 40, 'Nb': 41, 'Mo': 42, 'Tc': 43, 'Ru': 44, 'Rh': 45, + 'Pd': 46, 'Ag': 47, 'Cd': 48, 'In': 49, 'Sn': 50, 'Sb': 51, 'Te': 52, 'I': 53, 'Xe': 54, + # Sixth Period + 'Cs': 55, 'Ba': 56, 'La': 57, 'Ce': 58, 'Pr': 59, 'Nd': 60, 'Pm': 61, 'Sm': 62, 'Eu': 63, + 'Gd': 64, 'Tb': 65, 'Dy': 66, 'Ho': 67, 'Er': 68, 'Tm': 69, 'Yb': 70, 'Lu': 71, 'Hf': 72, + 'Ta': 73, 'W': 74, 'Re': 75, 'Os': 76, 'Ir': 77, 'Pt': 78, 'Au': 79, 'Hg': 80, 'Tl': 81, + 'Pb': 82, 'Bi': 83, 'Po': 84, 'At': 85, 'Rn': 86, + # Seventh Period + 'Fr': 87, 'Ra': 88, 'Ac': 89, 'Th': 90, 'Pa': 91, 'U': 92, 'Np': 93, 'Pu': 94, 'Am': 95, + 'Cm': 96, 'Bk': 97, 'Cf': 98, 'Es': 99, 'Fm': 100, 'Md': 101, 'No': 102, 'Lr': 103, 'Rf': 104, + 'Db': 105, 'Sg': 106, 'Bh': 107, 'Hs': 108, 'Mt': 109} + + +def sort_by_atomic_number(all_element_type): + current_order = [] + for element in all_element_type: + assert element in periodic_table_list, '{} element is not exist'.format(element) + current_order.append(periodic_table_list[element]) + + sorted_elements = [ELEMENT for ORDER, ELEMENT in sorted(zip(current_order, all_element_type))] + + return sorted_elements + +periodic_table_std = {value: key for key, value in periodic_table_list.items()} + +# following these website to get help: +# http://img.chem.ucl.ac.uk/sgp/LARGE/sgp.htm +# https://en.wikipedia.org/wiki/Space_group +space_group = { + 1:[["x,y,z"],'P1',"Triclinic"], + 2:[["x,y,z","-x,-y,-z"],'P-1',"Triclinic"], + 4:[["x,y,z","-x,y+1/2,-z"],'P21',"Monoclinic"], + 5:[["x,y,z","-x,y,-z","x+1/2,y+1/2,z","-x+1/2,y+1/2,-z"],'C2',"Monoclinic"], + 7:[["x,y,z","x,-y,1/2+z"],"Pc","Monoclinic"], + 9:[["x,y,z","x,-y,1/2+z","1/2+x,1/2+y,z","1/2+x,1/2-y,1/2+z"],"CC","Monoclinic"], + 12:[["x,y,z","-x,y,-z","1/2+x,1/2+y,z","1/2-x,1/2+y,-z","-x,-y,-z","x,-y,z","1/2-x,1/2-y,-z","1/2+x,1/2-y,z"],"C2/M","Monoclinic"], + 11:[["x,y,z","-x,y+1/2,-z","-x,-y,-z","x,1/2-y,z"],"P21/m","Monoclinic"], + 12:[["x,y,z","x,-y,z","-x,y,-z","-x,-y,-z","1/2+x,1/2+y,z","1/2+x,1/2-y,z","1/2-x,1/2+y,-z","1/2-x,1/2-y,-z"],"C2/m","Monoclinic"], + 13:[["x,y,z","-x,y,1/2-z","-x,-y,-z","x,-y,1/2+z"],"P2/c","Monoclinic"], + 13:[["x,y,z", "-x,y,-z+1/2", "-x,-y,-z", "x,-y,z+1/2"],'P2/c',"Monoclinic"], + 14:[["x,y,z", "-x,y+1/2,-z+1/2", "-x,-y,-z", "x,-y+1/2,z+1/2"],'P21/C',"Monoclinic"], + 15:[["x,y,z","-x,y,1/2-z","1/2+x,1/2+y,z","1/2-x,1/2+y,1/2-z","-x,-y,-z","x,-y,1/2+z","1/2-x,1/2-y,-z","1/2+x,1/2-y,1/2+z"],"C2/C","Monoclinic"], + 18:[["x,y,z","1/2+x,1/2-y,-z","1/2-x,1/2+y,-z","-x,-y,z"],"P21212","Orthorhombic"], + 19:[["x,y,z","1/2+x,1/2-y,-z","-x,1/2+y,1/2-z","1/2-x,-y,1/2+z"],"P212121","Orthorhombic"], + 29:[["x,y,z","1/2-x,y,1/2+z","1/2+x,-y,z","-x,-y,1/2+z"],"PCA21","Orthorhombic"], + 33:[["x,y,z","1/2-x,1/2+y,1/2+z","1/2+x,1/2-y,z","-x,-y,1/2+z"],"PNA21","Orthorhombic"], + 43:[["x,y,z","1/4-x,1/4+y,1/4+z","1/4+x,1/4-y,1/4+z","-x,-y,z","x,y+1/2,z+1/2","1/4-x,3/4+y,3/4+z","1/4+x,3/4-y,3/4+z","-x,-y+1/2,z+1/2", + "x+1/2,y,z+1/2","3/4-x,1/4+y,3/4+z","3/4+x,1/4-y,3/4+z","-x+1/2,-y,z+1/2","x+1/2,y+1/2,z","3/4-x,3/4+y,1/4+z","3/4+x,3/4-y,1/4+z","-x+1/2,-y+1/2,z"], "Fdd2", "Orthorhombic"], + 56:[["x,y,z","1/2-x,y,1/2+z","x,1/2-y,1/2+z","1/2+x,1/2+y,-z","-x,-y,-z","1/2+x,-y,1/2-z","-x,1/2+y,1/2-z","1/2-x,1/2-y,z"], "Pccn", "Orthorhombic"], + 60:[["x,y,z","1/2-x,1/2+y,z","x,-y,1/2+z","1/2+x,1/2+y,1/2-z","-x,-y,-z","1/2+x,1/2-y,-z","-x,y,1/2-z","1/2-x,1/2-y,1/2+z"],"Pbcn","Orthorhombic"], + 61:[["x,y,z","1/2-x,1/2+y,z","x,1/2-y,1/2+z","1/2+x,y,1/2-z","-x,-y,-z","1/2+x,1/2-y,-z","-x,1/2+y,1/2-z","1/2-x,-y,1/2+z"],"PBCA","Orthorhombic"], + 62:[["x,y,z","x+1/2,-y+1/2,-z+1/2","-x,y+1/2,-z","-x+1/2,-y,z+1/2","-x,-y,-z","-x+1/2,y+1/2,z+1/2","x,-y+1/2,z","x+1/2,y,-z+1/2"],"Pnma","Orthorhombic"], + 77:[["x,y,z","-x,-y,z","-y,x,1/2+z","y,-x,1/2+z"],"P42","Tetragonal"], + 88:[["x,y,z","-x,-y,z","-y,1/2+x,1/4+z","y,1/2-x,1/4+z","-x,1/2-y,1/4-z","x,1/2+y,1/4-z","y,-x,-z","-y,x,-z", + "1/2+x,1/2+y,1/2+z","1/2-x,1/2-y,1/2+z","1/2-y,x,3/4+z","1/2+y,-x,3/4+z","1/2-x,-y,3/4-z","1/2+x,y,3/4-z","1/2+y,1/2-x,1/2-z","1/2-y,1/2+x,1/2-z"],"I41/a","Tetragonal"], + 96:[["x,y,z","-x,-y,1/2+z","1/2-y,1/2+x,3/4+z","1/2+y,1/2-x,1/4+z","1/2+x,1/2-y,1/4-z","1/2-x,1/2+y,3/4-z","-y,-x,1/2-z","y,x,-z"],"P43212","Tetragonal"], + 143:[["x,y,z","-y,x-y,z","-x+y,-x,z"],"P3","Hexagonal"], + 147:[["x,y,z","-y,x-y,z","-x+y,-x,z","-x,-y,-z","y,-x+y,-z","x-y,x,-z"],"P-3","Hexagonal"], + 148:[["x,y,z","z,x,y","y,z,x","-x,-y,-z","-z,-x,-y","-y,-z,-x"],"R-3","Trigonal"], + 169:[["x,y,z","-y,x-y,1/3+z","-x+y,-x,2/3+z","-x,-y,1/2+z","x-y,x,1/6+z","y,-x+y,5/6+z"],"P61","Hexagonal"] +} + +point_group = {"Triclinic":[["a","b","c"],["alpha","beta","gamma"]], + "Monoclinic":[["a","b","c"],[90,"beta",90]], + "Orthorhombic":[["a","b","c"],[90,90,90]], + "Tetragonal":[["a","a","c"],[90,90,90]], + "Trigonal":[["a","a","a"],["alpha","alpha","alpha"]], + "Hexagonal":[["a","a","c"],[90,90,120]], + "Cubic":[["a","a","a"],[90,90,90]]} + diff --git a/basic_function/conformer_search.py b/basic_function/conformer_search.py index 736a5969f8f9e1fe126fbc612c37e8525c95287b..145a44d85055b4bb2eb0e5ccb3bad7c6ff3453e5 100644 --- a/basic_function/conformer_search.py +++ b/basic_function/conformer_search.py @@ -1,59 +1,59 @@ -from rdkit import Chem -from rdkit.Chem import AllChem -import os - - -def generate_conformers(molecule, num_conformers=10, max_attempts=1000, rms_thresh=0.2): - """ - Generate molecular conformers. - - Parameters: - molecule (RDKit Mol object): The input molecule. - num_conformers (int): Number of conformers to generate. - max_attempts (int): Maximum number of attempts. - rms_thresh (float): RMSD threshold for considering conformers as duplicates. - - Returns: - list: A list of generated conformers. - """ - params = AllChem.ETKDG() - params.numThreads = 0 - params.maxAttempts = max_attempts - params.pruneRmsThresh = rms_thresh - conformer_ids = AllChem.EmbedMultipleConfs(molecule, numConfs=num_conformers, params=params) - results = AllChem.UFFOptimizeMolecule(molecule) - - return conformer_ids, results - -def conformer_search(smiles, out_path, num_conformers=1000, max_attempts=10000, rms_thresh=0.2): - - try: - os.makedirs("{}/conformers".format(out_path)) - except: - print("Warning, these is already an structures folder in this path, skip mkdir") - - mol = Chem.MolFromSmiles(smiles) - mol = Chem.AddHs(mol) # add H atoms - - # conformer generate - conformer_ids, results = generate_conformers(mol, num_conformers=num_conformers, max_attempts=max_attempts, rms_thresh=rms_thresh) - - # print info - for i, conf in enumerate(conformer_ids): - print(f'Conformer {i}:') - - xyz_file = [] - xyz_file.append("{}\n".format(mol.GetNumAtoms())) - xyz_file.append("conformer_{}\n".format(i)) - - for j in range(mol.GetNumAtoms()): - atom = mol.GetAtomWithIdx(j) - symbol = atom.GetSymbol() - pos = mol.GetConformer(conf).GetAtomPosition(j) - # print(f' Atom {j} ({symbol}): x={pos.x:.3f}, y={pos.y:.3f}, z={pos.z:.3f}') - - xyz_file.append("{:6} {:16.8f} {:16.8f} {:16.8f}\n".format(symbol, pos.x, pos.y, pos.z)) - target = open("{}/conformers/conformer_{}.xyz".format(out_path,i), 'w') - target.writelines(xyz_file) - target.close() - +from rdkit import Chem +from rdkit.Chem import AllChem +import os + + +def generate_conformers(molecule, num_conformers=10, max_attempts=1000, rms_thresh=0.2): + """ + Generate molecular conformers. + + Parameters: + molecule (RDKit Mol object): The input molecule. + num_conformers (int): Number of conformers to generate. + max_attempts (int): Maximum number of attempts. + rms_thresh (float): RMSD threshold for considering conformers as duplicates. + + Returns: + list: A list of generated conformers. + """ + params = AllChem.ETKDG() + params.numThreads = 0 + params.maxAttempts = max_attempts + params.pruneRmsThresh = rms_thresh + conformer_ids = AllChem.EmbedMultipleConfs(molecule, numConfs=num_conformers, params=params) + results = AllChem.UFFOptimizeMolecule(molecule) + + return conformer_ids, results + +def conformer_search(smiles, out_path, num_conformers=1000, max_attempts=10000, rms_thresh=0.2): + + try: + os.makedirs("{}/conformers".format(out_path)) + except: + print("Warning, these is already an structures folder in this path, skip mkdir") + + mol = Chem.MolFromSmiles(smiles) + mol = Chem.AddHs(mol) # add H atoms + + # conformer generate + conformer_ids, results = generate_conformers(mol, num_conformers=num_conformers, max_attempts=max_attempts, rms_thresh=rms_thresh) + + # print info + for i, conf in enumerate(conformer_ids): + print(f'Conformer {i}:') + + xyz_file = [] + xyz_file.append("{}\n".format(mol.GetNumAtoms())) + xyz_file.append("conformer_{}\n".format(i)) + + for j in range(mol.GetNumAtoms()): + atom = mol.GetAtomWithIdx(j) + symbol = atom.GetSymbol() + pos = mol.GetConformer(conf).GetAtomPosition(j) + # print(f' Atom {j} ({symbol}): x={pos.x:.3f}, y={pos.y:.3f}, z={pos.z:.3f}') + + xyz_file.append("{:6} {:16.8f} {:16.8f} {:16.8f}\n".format(symbol, pos.x, pos.y, pos.z)) + target = open("{}/conformers/conformer_{}.xyz".format(out_path,i), 'w') + target.writelines(xyz_file) + target.close() + diff --git a/basic_function/data_classes.py b/basic_function/data_classes.py index 6facbb84e647721497fcb7498906a13f03c1b0f9..dd9cdbe45c9796484e3e968006722581187eb0ed 100644 --- a/basic_function/data_classes.py +++ b/basic_function/data_classes.py @@ -1,614 +1,614 @@ -""" -This module defines the core data structures for representing atomic structures: -Atom, Crystal, and Molecule. - -These classes store information about atomic coordinates, lattice parameters, -and other physical properties, providing a foundational toolkit for geometric -and structural analysis in materials science simulations. -""" -import copy -from typing import List, Tuple, Union, Any, Optional, Dict - -import numpy as np -import fractions -import re -from scipy.spatial.distance import cdist -from tqdm import tqdm - -from basic_function import unit_cell_parser -from basic_function import chemical_knowledge -from basic_function import operation - - -class Atom: - """ - Represents a single atom in a chemical structure. - - Attributes: - element (str): The chemical symbol of the atom (e.g., 'H', 'C', 'O'). - cart_xyz (np.ndarray): Cartesian coordinates [x, y, z] in Angstroms. - frac_xyz (np.ndarray): Fractional coordinates [u, v, w] with respect to a lattice. - atom_id (int): A unique identifier for the atom within a larger structure. - force (np.ndarray): Force vector [fx, fy, fz] acting on the atom. - atom_charge (float): Partial charge of the atom. - atom_energy (float): Site potential energy of the atom. - molecule (int): Identifier for the molecule this atom belongs to. - bonded_atom (list): A list of IDs of atoms bonded to this one. - descriptor (any): A placeholder for feature vectors or other descriptors. - comment (dict): A dictionary for storing arbitrary metadata. - """ - - def __init__(self, **kwargs: Any): - """ - Initializes an Atom object. - - Args: - **kwargs: Keyword arguments to set atom attributes. - Required: 'element' and one of 'cart_xyz' or 'frac_xyz'. - Optional: 'atom_id', 'force_xyz', 'atom_charge', 'atom_energy', etc. - """ - self.element: str = kwargs.get("element", "unknown") - self.cart_xyz: Union[str, np.ndarray] = kwargs.get('cart_xyz', "unknown") - self.frac_xyz: Union[str, np.ndarray] = kwargs.get('frac_xyz', "unknown") - self.atom_id: Union[str, int] = kwargs.get("atom_id", "unknown") - self.force: Union[str, np.ndarray] = kwargs.get('force_xyz', 'unknown') - self.atom_charge: Union[str, float] = kwargs.get('atom_charge', 'unknown') - self.atom_energy: Union[str, float] = kwargs.get('atom_energy', 'unknown') - self.molecule: Union[str, int] = kwargs.get('molecule', 'unknown') - self.bonded_atom: list = kwargs.get('bonded_atom', []) - self.descriptor: Any = kwargs.get("descriptor", "unknown") - self.comment: dict = kwargs.get("comment", {}) - - def info(self) -> None: - """Prints all attributes of the atom to the console.""" - for key, value in self.__dict__.items(): - print(f"{key}: {value}") - - def check(self) -> None: - """ - Performs basic sanity checks on the atom's attributes. - - Raises: - AssertionError: If element is not defined, or if neither cartesian - nor fractional coordinates are provided. - """ - assert self.element != "unknown", "Atom must have an element type." - has_cart = isinstance(self.cart_xyz, (np.ndarray, list)) - has_frac = isinstance(self.frac_xyz, (np.ndarray, list)) - assert has_cart or has_frac, "Atom needs either cart_xyz or frac_xyz." - - -class Crystal: - """ - Represents a periodic crystal structure. - - Contains lattice information (cell parameters or vectors) and a list of atoms - that constitute the structure within the unit cell. - """ - - def __init__(self, **kwargs: Any): - """ - Initializes a Crystal object. - - The constructor requires lattice and atom information. It will automatically - calculate derived properties like volume and density. - - Args: - **kwargs: Keyword arguments to set crystal attributes. - Required: 'atoms' and one of 'cell_vect' or 'cell_para'. - - cell_vect (list): 3x3 list or array of lattice vectors. - - cell_para (list): [[a, b, c], [alpha, beta, gamma]]. - - atoms (List[Atom]): A list of Atom objects. - Optional: 'energy', 'comment', 'system_name', 'space_group', 'SYMM', etc. - """ - self.cell_vect: Union[str, np.ndarray] = kwargs.get("cell_vect", "unknown") - self.cell_para: Union[str, list] = kwargs.get("cell_para", "unknown") - self.atoms: Union[str, List[Atom]] = kwargs.get("atoms", "unknown") - self.energy: Union[str, float] = kwargs.get("energy", "unknown") - self.comment: Any = kwargs.get("comment", "unknown") - self.descriptor: Any = kwargs.get("descriptor", "unknown") - self.molecule_number: Union[str, int] = kwargs.get("molecule_number", "unknown") - self.system_name: str = kwargs.get("system_name", "unknown") - self.virial: Any = kwargs.get("virial", "unknown") - self.SYMM: list = kwargs.get("SYMM", ["x,y,z"]) - self.space_group: int = kwargs.get("space_group", 1) - self.other_properties: dict = {} - - # This method completes the initialization. - self.lattice_and_atom_complete() - - def lattice_and_atom_complete(self) -> None: - """ - Completes the initialization by ensuring consistency between cell representations, - atom coordinates, and calculating derived properties. - """ - # --- 1. Finalize Lattice Representation --- - has_vect = isinstance(self.cell_vect, (np.ndarray, list)) - has_para = isinstance(self.cell_para, (np.ndarray, list)) - - if has_vect and has_para: - # If both are provided, check for consistency - derived_para = np.array(unit_cell_parser.cell_vect_to_para(self.cell_vect)).flatten() - provided_para = np.array(self.cell_para).flatten() - assert np.allclose(derived_para, provided_para, atol=1e-3), \ - "Provided cell_para and cell_vect are inconsistent." - elif has_vect and not has_para: - self.cell_para = unit_cell_parser.cell_vect_to_para(self.cell_vect) - elif not has_vect and has_para: - self.cell_vect = unit_cell_parser.cell_para_to_vect(self.cell_para) - else: - raise ValueError("Crystal lattice is not defined. Provide 'cell_vect' or 'cell_para'.") - - # --- 2. Finalize Atom Coordinates --- - if self.atoms == "unknown" or not self.atoms: - print("Warning: Crystal initialized with no atoms.") - self.atoms = [] - else: - for atom in self.atoms: - has_cart = isinstance(atom.cart_xyz, (np.ndarray, list)) - has_frac = isinstance(atom.frac_xyz, (np.ndarray, list)) - if not (has_cart and has_frac): - if has_frac: - atom.cart_xyz = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) - elif has_cart: - atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_vect(atom.cart_xyz, self.cell_vect) - else: - raise ValueError(f"Atom {atom.element} {atom.atom_id} has no coordinate information.") - - # --- 3. Calculate Derived Properties --- - self.volume = unit_cell_parser.calculate_volume(self.cell_para) - if self.atoms: - total_mass = sum(chemical_knowledge.element_masses[atom.element] for atom in self.atoms) - # Density in g/cm^3 - self.density = total_mass / (self.volume * 1e-24) / (6.022140857e23) - else: - self.density = 0.0 - - def update_cart_by_frac(self) -> None: - """Updates all atom cartesian coordinates from their fractional coordinates.""" - for atom in self.atoms: - atom.cart_xyz = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) - - def update_frac_by_cart(self) -> None: - """Updates all atom fractional coordinates from their cartesian coordinates.""" - for atom in self.atoms: - atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_vect(atom.cart_xyz, self.cell_vect) - - def check(self) -> None: - """Performs consistency checks on the crystal structure.""" - print("Performing consistency checks...") - # Check lattice consistency - self.lattice_and_atom_complete() - - # Check atom coordinate consistency - for atom in self.atoms: - derived_cart = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) - assert np.allclose(atom.cart_xyz, derived_cart, atol=1e-3), \ - f"Atom {atom.atom_id} cartesian and fractional coordinates do not match." - - # Check atom IDs - if all(atom.atom_id != "unknown" for atom in self.atoms): - print("All atoms have IDs.") - else: - print("Warning: Not all atoms have IDs. Use .give_atom_id_forced() to assign them.") - print("Checks passed.") - - def give_atom_id_forced(self) -> None: - """Assigns or resets atom IDs from 0 to N-1 and clears bonding info.""" - print("Warning: Resetting all atom IDs and bonding information!") - for i, atom in enumerate(self.atoms): - atom.atom_id = i - atom.bonded_atom = [] - - def move_atom_into_cell(self) -> None: - """ - Moves all atoms into the primary unit cell [0, 1) in fractional coordinates. - """ - for atom in self.atoms: - # Use modulo for a more direct and efficient way to wrap coordinates - atom.frac_xyz = np.mod(atom.frac_xyz, 1.0) - self.update_cart_by_frac() - - def find_molecule(self, tolerance: float = 1.15) -> None: - """ - Identifies molecules within the crystal based on bonding distances. - - This method performs a graph search (BFS) on the atoms, connecting them - based on scaled covalent radii. It populates the `atom.molecule` and - `self.molecule_number` attributes. - - Args: - tolerance: A scaling factor for covalent radii to determine bonding. - A bond is formed if dist(A, B) < (radius(A) + radius(B)) * tolerance. - """ - self.move_atom_into_cell() - atoms_to_visit = list(range(len(self.atoms))) - molecule_id = 0 - - while atoms_to_visit: - molecule_id += 1 - # Start a Breadth-First Search (BFS) from the first unvisited atom - q = [atoms_to_visit[0]] - visited_in_molecule = {atoms_to_visit[0]} - - head = 0 - while head < len(q): - current_atom_idx = q[head] - head += 1 - self.atoms[current_atom_idx].molecule = molecule_id - - # Check for bonds with all other atoms - for other_atom_idx in range(len(self.atoms)): - if current_atom_idx == other_atom_idx: - continue - - # is_bonding_crystal handles periodic boundaries - is_bonded, _ = operation.is_bonding_crystal( - self.atoms[current_atom_idx], - self.atoms[other_atom_idx], - self.cell_vect, - tolerance=tolerance, - update_atom2=False # Do not modify coordinates during search - ) - - if is_bonded and other_atom_idx not in visited_in_molecule: - visited_in_molecule.add(other_atom_idx) - q.append(other_atom_idx) - - # Remove all atoms found in the new molecule from the list to visit - atoms_to_visit = [idx for idx in atoms_to_visit if idx not in visited_in_molecule] - - self.molecule_number = molecule_id - - def get_element(self) -> List[str]: - """Returns a sorted list of unique element symbols in the crystal.""" - return chemical_knowledge.sort_by_atomic_number(set(atom.element for atom in self.atoms)) - - def get_element_amount(self) -> List[int]: - """Returns the count of each element, sorted by atomic number.""" - all_elements = [atom.element for atom in self.atoms] - return [all_elements.count(element) for element in self.get_element()] - - - def make_p1(self) -> None: - """ - Expands the asymmetric unit to the full P1 cell using symmetry operations. - - The crystal's space group is set to 1 (P1) and SYMM is reset. This - implementation is robustly designed to ensure the final coordinate array - is always 2-dimensional, preventing downstream errors. - """ - all_ele, all_frac = self.get_ele_and_frac() - all_reflect_position = [] - all_matrix_M = [] - all_matrix_C = [] - for sym_opt in self.SYMM: - sym_opt_ele = sym_opt.lower().replace(" ", "").split(",") - # assert len(sym_opt_ele) == 3, "sym {} could not be treat".format(sym_opt_ele) - matrix_M = np.zeros((3, 3)) - matrix_C = np.zeros((1, 3)) - for idx, word in enumerate(sym_opt_ele): - sym_opt_ele_split = re.findall(r".*?([+-]*[xyz0-9\/\.]+)", word) - for sym_opt_frag in sym_opt_ele_split: - if sym_opt_frag == 'x' or sym_opt_frag == '+x': - matrix_M[0][idx] = 1 - elif str(sym_opt_frag) == '-x': - matrix_M[0][idx] = -1 - elif sym_opt_frag == 'y' or sym_opt_frag == '+y': - matrix_M[1][idx] = 1 - elif sym_opt_frag == '-y': - matrix_M[1][idx] = -1 - elif sym_opt_frag == 'z' or sym_opt_frag == '+z': - matrix_M[2][idx] = 1 - elif sym_opt_frag == '-z': - matrix_M[2][idx] = -1 - elif operation.is_number(sym_opt_frag) is True: - matrix_C[0][idx] = float(fractions.Fraction(sym_opt_frag)) - else: - raise Exception("wrong sym opt of" + sym_opt_frag) - - all_matrix_M.append(matrix_M) - all_matrix_C.append(matrix_C) - - for j in range(0, len(all_matrix_M)): - new_positions = np.dot(np.array([all_frac]), all_matrix_M[j]) + all_matrix_C[j] - all_reflect_position.append(new_positions.squeeze()) - all_ele = all_ele*len(self.SYMM) - - new_atoms = [] - idx=0 - for element, frac_xyz in zip(all_ele, np.array(all_reflect_position).reshape(-1,3)): - new_atoms.append(Atom(element=element, - frac_xyz=frac_xyz, - atom_id=idx)) - idx+=1 - - self.SYMM = "[x,y,z]" - self.space_group = 1 - self.atoms = new_atoms - self.update_cart_by_frac() - - def sort_by_element(self) -> None: - """Sorts the atoms list based on atomic number.""" - self.atoms.sort(key=lambda atom: chemical_knowledge.periodic_table_list[atom.element]) - - def get_ele_and_cart(self) -> Tuple[List[str], np.ndarray]: - """Returns all element symbols and their cartesian coordinates.""" - if not self.atoms: - return [], np.empty((0, 3)) - all_ele = [atom.element for atom in self.atoms] - all_carts = np.array([atom.cart_xyz for atom in self.atoms]) - return all_ele, all_carts - - def get_ele_and_frac(self) -> Tuple[List[str], np.ndarray]: - """Returns all element symbols and their fractional coordinates.""" - if not self.atoms: - return [], np.empty((0, 3)) - all_ele = [atom.element for atom in self.atoms] - all_fracs = np.array([atom.frac_xyz for atom in self.atoms]) - return all_ele, all_fracs - - def info(self, all_info: bool = False) -> None: - """ - Prints a formatted summary of the crystal structure. - - Args: - all_info: If True, prints an extended table including fractional - coordinates, forces, and other properties. - """ - print("--- Crystal System ---") - print(f"Name: {self.system_name}") - print("Lattice Vectors (Angstrom):") - for vec in self.cell_vect: - print(f"{vec[0]:16.8f} {vec[1]:16.8f} {vec[2]:16.8f}") - print("Lattice Parameters:") - print(f"a, b, c (A): {self.cell_para[0][0]:.4f}, {self.cell_para[0][1]:.4f}, {self.cell_para[0][2]:.4f}") - print(f"alpha, beta, gamma (deg): {self.cell_para[1][0]:.4f}, {self.cell_para[1][1]:.4f}, {self.cell_para[1][2]:.4f}") - print(f"Volume (A^3): {self.volume:.4f} | Density (g/cm^3): {self.density:.4f}") - print(f"\n--- Atomic Coordinates (Total: {len(self.atoms)}) ---") - - if not all_info: - print(f"{'Element':<10} {'Cartesian X':>16} {'Cartesian Y':>16} {'Cartesian Z':>16}") - print("-" * 58) - for atom in self.atoms: - print(f"{atom.element:<10} {atom.cart_xyz[0]:16.8f} {atom.cart_xyz[1]:16.8f} {atom.cart_xyz[2]:16.8f}") - else: - header = ( - f"{'ID':<5} {'Elem':<6} " - f"{'Frac X':>10} {'Frac Y':>10} {'Frac Z':>10} | " - f"{'Cart X':>12} {'Cart Y':>12} {'Cart Z':>12}" - ) - print(header) - print("-" * len(header)) - for atom in self.atoms: - aid = str(atom.atom_id) if atom.atom_id != 'unknown' else '-' - print( - f"{aid:<5} {atom.element:<6} " - f"{atom.frac_xyz[0]:10.6f} {atom.frac_xyz[1]:10.6f} {atom.frac_xyz[2]:10.6f} | " - f"{atom.cart_xyz[0]:12.6f} {atom.cart_xyz[1]:12.6f} {atom.cart_xyz[2]:12.6f}" - ) - - print("\n--- Other Properties ---") - print(f"Energy: {self.energy}") - print(f"Comment: {self.comment}") - print(f"Virial: {self.virial}") - - -class Molecule: - """Represents a non-periodic molecule (a collection of atoms).""" - - def __init__(self, **kwargs: Any): - """ - Initializes a Molecule object. - - Args: - **kwargs: Keyword arguments to set molecule attributes. - Required: 'atoms' (List[Atom]). - Optional: 'energy', 'comment', 'name', 'system_name'. - """ - self.atoms: Union[str, List[Atom]] = kwargs.get("atoms", "unknown") - self.energy: Union[str, float] = kwargs.get("energy", "unknown") - self.comment: Any = kwargs.get("comment", "unknown") - self.descriptor: Any = kwargs.get("descriptor", "unknown") - self.name: str = kwargs.get("name", "unknown") - self.system_name: str = kwargs.get("system_name", "unknown") - - if self.atoms == "unknown": - print("Warning: Molecule initialized with no atoms.") - self.atoms = [] - - def give_atom_id_forced(self) -> None: - """Assigns or resets atom IDs from 0 to N-1 and clears bonding info.""" - print("Warning: Resetting all atom IDs and bonding information!") - for i, atom in enumerate(self.atoms): - atom.atom_id = i - atom.bonded_atom = [] - - def get_element(self) -> List[str]: - """Returns a sorted list of unique element symbols in the molecule.""" - if not self.atoms: return [] - return chemical_knowledge.sort_by_atomic_number(set(atom.element for atom in self.atoms)) - - def get_element_amount(self) -> List[int]: - """Returns the count of each element, sorted by atomic number.""" - if not self.atoms: return [] - all_elements = [atom.element for atom in self.atoms] - return [all_elements.count(element) for element in self.get_element()] - - def get_ele_and_cart(self) -> Tuple[List[str], np.ndarray]: - """Returns all element symbols and their cartesian coordinates.""" - if not self.atoms: - return [], np.empty((0, 3)) - all_ele = [atom.element for atom in self.atoms] - all_carts = np.array([atom.cart_xyz for atom in self.atoms]) - return all_ele, all_carts - - def put_ele_cart_back(self, all_ele: List[str], all_carts: np.ndarray) -> None: - """Updates the molecule's atoms from lists of elements and coordinates.""" - for i, atom in enumerate(self.atoms): - atom.element = all_ele[i] - atom.cart_xyz = all_carts[i] - - def build_molecules_by_ele_cart(self, all_ele: List[str], all_carts: np.ndarray) -> None: - """Rebuilds the molecule's atoms list from elements and coordinates.""" - assert len(all_ele) == len(all_carts), "Element and coordinate lists must have the same length." - self.atoms = [ - Atom(element=ele, cart_xyz=cart, atom_id=i) - for i, (ele, cart) in enumerate(zip(all_ele, all_carts)) - ] - - def get_mass(self) -> float: - """Calculates the total mass of the molecule.""" - if not self.atoms: return 0.0 - return sum(chemical_knowledge.element_masses[atom.element] for atom in self.atoms) - - def get_center_of_mass(self) -> np.ndarray: - """Calculates the center of mass of the molecule.""" - if not self.atoms: return np.zeros(3) - - all_ele, all_carts = self.get_ele_and_cart() - masses = np.array([chemical_knowledge.element_masses[x] for x in all_ele]) - total_mass = np.sum(masses) - - if total_mass == 0: return np.zeros(3) - return np.sum(all_carts * masses[:, np.newaxis], axis=0) / total_mass - - def sort_by_element(self) -> None: - """Sorts the atoms list based on atomic number.""" - self.atoms.sort(key=lambda atom: chemical_knowledge.periodic_table_list[atom.element]) - - def sort_by_id(self) -> None: - """Sorts the atoms list based on their atom_id.""" - self.atoms.sort(key=lambda atom: atom.atom_id) - - def info(self) -> None: - """Prints a formatted summary of the molecule.""" - print(f"--- Molecule ---") - print(f"Name: {self.name} | System: {self.system_name}") - print(f"Number of atoms: {len(self.atoms)}") - print(f"Total Mass (amu): {self.get_mass():.4f}") - print(f"Energy: {self.energy}") - print(f"Comment: {self.comment}") - print(f"\n{'Element':<10} {'Cartesian X':>16} {'Cartesian Y':>16} {'Cartesian Z':>16}") - print("-" * 58) - if self.atoms: - for atom in self.atoms: - print(f"{atom.element:<10} {atom.cart_xyz[0]:16.8f} {atom.cart_xyz[1]:16.8f} {atom.cart_xyz[2]:16.8f}") - - def find_fragment(self, tolerance: float = 1.15) -> Dict[int, List[int]]: - """ - Identifies covalently bonded fragments within the molecule. - - This is useful for molecules that are actually composed of several - disconnected components (e.g., salts, solvent shells). - - Args: - tolerance: Scaling factor for covalent radii to determine bonding. - - Returns: - A dictionary mapping a fragment ID (starting from 1) to a list of - atom indices belonging to that fragment. - """ - if not self.atoms: return {} - - num_atoms = len(self.atoms) - cart_matrix = np.array([atom.cart_xyz for atom in self.atoms]) - radii = np.array([chemical_knowledge.element_covalent_radii[atom.element] for atom in self.atoms]) - - # Create a matrix of bond thresholds (r_i + r_j) - bond_threshold_matrix = (radii[:, np.newaxis] + radii) * tolerance - - # True where distance is less than the bond threshold - dist_matrix = cdist(cart_matrix, cart_matrix) - adj_matrix = dist_matrix < bond_threshold_matrix - np.fill_diagonal(adj_matrix, False) - - # Graph traversal (DFS) to find connected components - visited = [False] * num_atoms - groups = {} - group_id = 0 - for i in range(num_atoms): - if not visited[i]: - group_id += 1 - groups[group_id] = [] - stack = [i] - while stack: - atom_idx = stack.pop() - if not visited[atom_idx]: - visited[atom_idx] = True - groups[group_id].append(atom_idx) - # Find neighbors and add to stack - neighbors = np.where(adj_matrix[atom_idx])[0] - stack.extend(neighbors) - return groups - - def give_molecule_id(self, tolerance: float = 1.15) -> None: - """Assigns a molecule ID to each atom based on fragment analysis.""" - fragments = self.find_fragment(tolerance=tolerance) - for group_id, atom_indices in fragments.items(): - for atom_idx in atom_indices: - self.atoms[atom_idx].molecule = group_id - - def take_out_fragment(self, tolerance: float = 1.15) -> List['Molecule']: - """ - Splits the current molecule into a list of new Molecule objects, - one for each disconnected fragment. - """ - if not self.atoms: return [] - - self.give_atom_id_forced() # Ensure IDs are set for lookup - fragments = self.find_fragment(tolerance=tolerance) - new_molecules = [] - - for i, atom_indices in fragments.items(): - fragment_atoms = [self.atoms[j] for j in atom_indices] - new_mol = Molecule( - atoms=copy.deepcopy(fragment_atoms), - name=f"{self.name}_frag{i}", - system_name=f"{self.system_name}_frag{i}" - ) - new_molecules.append(new_mol) - return new_molecules - - def calculate_frac_xyz_by_cell_para(self, cell_para: list) -> None: - """Calculates fractional coordinates for all atoms given cell parameters.""" - for atom in self.atoms: - atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_para(atom.cart_xyz, cell_para) - - def molecule_volume(self, num_samples: int = 100000) -> float: - """ - Calculates the van der Waals volume using a Monte Carlo integration method. - - This method samples points in a bounding box around the molecule and - determines the ratio of points that fall within any atom's vdW sphere. - - Args: - num_samples: The number of random points to sample. More points - yield a more accurate volume at the cost of performance. - - Returns: - The estimated van der Waals volume in cubic Angstroms. - """ - if not self.atoms: return 0.0 - - elements, coords = self.get_ele_and_cart() - radii = np.array([chemical_knowledge.element_vdw_radii[el] for el in elements]) - - # Determine bounding box for sampling - min_bounds = np.min(coords, axis=0) - np.max(radii) - max_bounds = np.max(coords, axis=0) + np.max(radii) - bounding_box_volume = np.prod(max_bounds - min_bounds) - - # Generate random sample points within the bounding box - random_points = np.random.uniform(min_bounds, max_bounds, (num_samples, 3)) - - # Check for each point if it's inside ANY sphere - count_inside = 0 - for rp in tqdm(random_points, desc="Monte Carlo Volume", leave=False): - # Calculate squared distances from the point to all atom centers - dist_sq = np.sum((coords - rp)**2, axis=1) - # If any distance is within the radius, the point is inside - if np.any(dist_sq <= radii**2): - count_inside += 1 - +""" +This module defines the core data structures for representing atomic structures: +Atom, Crystal, and Molecule. + +These classes store information about atomic coordinates, lattice parameters, +and other physical properties, providing a foundational toolkit for geometric +and structural analysis in materials science simulations. +""" +import copy +from typing import List, Tuple, Union, Any, Optional, Dict + +import numpy as np +import fractions +import re +from scipy.spatial.distance import cdist +from tqdm import tqdm + +from basic_function import unit_cell_parser +from basic_function import chemical_knowledge +from basic_function import operation + + +class Atom: + """ + Represents a single atom in a chemical structure. + + Attributes: + element (str): The chemical symbol of the atom (e.g., 'H', 'C', 'O'). + cart_xyz (np.ndarray): Cartesian coordinates [x, y, z] in Angstroms. + frac_xyz (np.ndarray): Fractional coordinates [u, v, w] with respect to a lattice. + atom_id (int): A unique identifier for the atom within a larger structure. + force (np.ndarray): Force vector [fx, fy, fz] acting on the atom. + atom_charge (float): Partial charge of the atom. + atom_energy (float): Site potential energy of the atom. + molecule (int): Identifier for the molecule this atom belongs to. + bonded_atom (list): A list of IDs of atoms bonded to this one. + descriptor (any): A placeholder for feature vectors or other descriptors. + comment (dict): A dictionary for storing arbitrary metadata. + """ + + def __init__(self, **kwargs: Any): + """ + Initializes an Atom object. + + Args: + **kwargs: Keyword arguments to set atom attributes. + Required: 'element' and one of 'cart_xyz' or 'frac_xyz'. + Optional: 'atom_id', 'force_xyz', 'atom_charge', 'atom_energy', etc. + """ + self.element: str = kwargs.get("element", "unknown") + self.cart_xyz: Union[str, np.ndarray] = kwargs.get('cart_xyz', "unknown") + self.frac_xyz: Union[str, np.ndarray] = kwargs.get('frac_xyz', "unknown") + self.atom_id: Union[str, int] = kwargs.get("atom_id", "unknown") + self.force: Union[str, np.ndarray] = kwargs.get('force_xyz', 'unknown') + self.atom_charge: Union[str, float] = kwargs.get('atom_charge', 'unknown') + self.atom_energy: Union[str, float] = kwargs.get('atom_energy', 'unknown') + self.molecule: Union[str, int] = kwargs.get('molecule', 'unknown') + self.bonded_atom: list = kwargs.get('bonded_atom', []) + self.descriptor: Any = kwargs.get("descriptor", "unknown") + self.comment: dict = kwargs.get("comment", {}) + + def info(self) -> None: + """Prints all attributes of the atom to the console.""" + for key, value in self.__dict__.items(): + print(f"{key}: {value}") + + def check(self) -> None: + """ + Performs basic sanity checks on the atom's attributes. + + Raises: + AssertionError: If element is not defined, or if neither cartesian + nor fractional coordinates are provided. + """ + assert self.element != "unknown", "Atom must have an element type." + has_cart = isinstance(self.cart_xyz, (np.ndarray, list)) + has_frac = isinstance(self.frac_xyz, (np.ndarray, list)) + assert has_cart or has_frac, "Atom needs either cart_xyz or frac_xyz." + + +class Crystal: + """ + Represents a periodic crystal structure. + + Contains lattice information (cell parameters or vectors) and a list of atoms + that constitute the structure within the unit cell. + """ + + def __init__(self, **kwargs: Any): + """ + Initializes a Crystal object. + + The constructor requires lattice and atom information. It will automatically + calculate derived properties like volume and density. + + Args: + **kwargs: Keyword arguments to set crystal attributes. + Required: 'atoms' and one of 'cell_vect' or 'cell_para'. + - cell_vect (list): 3x3 list or array of lattice vectors. + - cell_para (list): [[a, b, c], [alpha, beta, gamma]]. + - atoms (List[Atom]): A list of Atom objects. + Optional: 'energy', 'comment', 'system_name', 'space_group', 'SYMM', etc. + """ + self.cell_vect: Union[str, np.ndarray] = kwargs.get("cell_vect", "unknown") + self.cell_para: Union[str, list] = kwargs.get("cell_para", "unknown") + self.atoms: Union[str, List[Atom]] = kwargs.get("atoms", "unknown") + self.energy: Union[str, float] = kwargs.get("energy", "unknown") + self.comment: Any = kwargs.get("comment", "unknown") + self.descriptor: Any = kwargs.get("descriptor", "unknown") + self.molecule_number: Union[str, int] = kwargs.get("molecule_number", "unknown") + self.system_name: str = kwargs.get("system_name", "unknown") + self.virial: Any = kwargs.get("virial", "unknown") + self.SYMM: list = kwargs.get("SYMM", ["x,y,z"]) + self.space_group: int = kwargs.get("space_group", 1) + self.other_properties: dict = {} + + # This method completes the initialization. + self.lattice_and_atom_complete() + + def lattice_and_atom_complete(self) -> None: + """ + Completes the initialization by ensuring consistency between cell representations, + atom coordinates, and calculating derived properties. + """ + # --- 1. Finalize Lattice Representation --- + has_vect = isinstance(self.cell_vect, (np.ndarray, list)) + has_para = isinstance(self.cell_para, (np.ndarray, list)) + + if has_vect and has_para: + # If both are provided, check for consistency + derived_para = np.array(unit_cell_parser.cell_vect_to_para(self.cell_vect)).flatten() + provided_para = np.array(self.cell_para).flatten() + assert np.allclose(derived_para, provided_para, atol=1e-3), \ + "Provided cell_para and cell_vect are inconsistent." + elif has_vect and not has_para: + self.cell_para = unit_cell_parser.cell_vect_to_para(self.cell_vect) + elif not has_vect and has_para: + self.cell_vect = unit_cell_parser.cell_para_to_vect(self.cell_para) + else: + raise ValueError("Crystal lattice is not defined. Provide 'cell_vect' or 'cell_para'.") + + # --- 2. Finalize Atom Coordinates --- + if self.atoms == "unknown" or not self.atoms: + print("Warning: Crystal initialized with no atoms.") + self.atoms = [] + else: + for atom in self.atoms: + has_cart = isinstance(atom.cart_xyz, (np.ndarray, list)) + has_frac = isinstance(atom.frac_xyz, (np.ndarray, list)) + if not (has_cart and has_frac): + if has_frac: + atom.cart_xyz = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) + elif has_cart: + atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_vect(atom.cart_xyz, self.cell_vect) + else: + raise ValueError(f"Atom {atom.element} {atom.atom_id} has no coordinate information.") + + # --- 3. Calculate Derived Properties --- + self.volume = unit_cell_parser.calculate_volume(self.cell_para) + if self.atoms: + total_mass = sum(chemical_knowledge.element_masses[atom.element] for atom in self.atoms) + # Density in g/cm^3 + self.density = total_mass / (self.volume * 1e-24) / (6.022140857e23) + else: + self.density = 0.0 + + def update_cart_by_frac(self) -> None: + """Updates all atom cartesian coordinates from their fractional coordinates.""" + for atom in self.atoms: + atom.cart_xyz = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) + + def update_frac_by_cart(self) -> None: + """Updates all atom fractional coordinates from their cartesian coordinates.""" + for atom in self.atoms: + atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_vect(atom.cart_xyz, self.cell_vect) + + def check(self) -> None: + """Performs consistency checks on the crystal structure.""" + print("Performing consistency checks...") + # Check lattice consistency + self.lattice_and_atom_complete() + + # Check atom coordinate consistency + for atom in self.atoms: + derived_cart = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) + assert np.allclose(atom.cart_xyz, derived_cart, atol=1e-3), \ + f"Atom {atom.atom_id} cartesian and fractional coordinates do not match." + + # Check atom IDs + if all(atom.atom_id != "unknown" for atom in self.atoms): + print("All atoms have IDs.") + else: + print("Warning: Not all atoms have IDs. Use .give_atom_id_forced() to assign them.") + print("Checks passed.") + + def give_atom_id_forced(self) -> None: + """Assigns or resets atom IDs from 0 to N-1 and clears bonding info.""" + print("Warning: Resetting all atom IDs and bonding information!") + for i, atom in enumerate(self.atoms): + atom.atom_id = i + atom.bonded_atom = [] + + def move_atom_into_cell(self) -> None: + """ + Moves all atoms into the primary unit cell [0, 1) in fractional coordinates. + """ + for atom in self.atoms: + # Use modulo for a more direct and efficient way to wrap coordinates + atom.frac_xyz = np.mod(atom.frac_xyz, 1.0) + self.update_cart_by_frac() + + def find_molecule(self, tolerance: float = 1.15) -> None: + """ + Identifies molecules within the crystal based on bonding distances. + + This method performs a graph search (BFS) on the atoms, connecting them + based on scaled covalent radii. It populates the `atom.molecule` and + `self.molecule_number` attributes. + + Args: + tolerance: A scaling factor for covalent radii to determine bonding. + A bond is formed if dist(A, B) < (radius(A) + radius(B)) * tolerance. + """ + self.move_atom_into_cell() + atoms_to_visit = list(range(len(self.atoms))) + molecule_id = 0 + + while atoms_to_visit: + molecule_id += 1 + # Start a Breadth-First Search (BFS) from the first unvisited atom + q = [atoms_to_visit[0]] + visited_in_molecule = {atoms_to_visit[0]} + + head = 0 + while head < len(q): + current_atom_idx = q[head] + head += 1 + self.atoms[current_atom_idx].molecule = molecule_id + + # Check for bonds with all other atoms + for other_atom_idx in range(len(self.atoms)): + if current_atom_idx == other_atom_idx: + continue + + # is_bonding_crystal handles periodic boundaries + is_bonded, _ = operation.is_bonding_crystal( + self.atoms[current_atom_idx], + self.atoms[other_atom_idx], + self.cell_vect, + tolerance=tolerance, + update_atom2=False # Do not modify coordinates during search + ) + + if is_bonded and other_atom_idx not in visited_in_molecule: + visited_in_molecule.add(other_atom_idx) + q.append(other_atom_idx) + + # Remove all atoms found in the new molecule from the list to visit + atoms_to_visit = [idx for idx in atoms_to_visit if idx not in visited_in_molecule] + + self.molecule_number = molecule_id + + def get_element(self) -> List[str]: + """Returns a sorted list of unique element symbols in the crystal.""" + return chemical_knowledge.sort_by_atomic_number(set(atom.element for atom in self.atoms)) + + def get_element_amount(self) -> List[int]: + """Returns the count of each element, sorted by atomic number.""" + all_elements = [atom.element for atom in self.atoms] + return [all_elements.count(element) for element in self.get_element()] + + + def make_p1(self) -> None: + """ + Expands the asymmetric unit to the full P1 cell using symmetry operations. + + The crystal's space group is set to 1 (P1) and SYMM is reset. This + implementation is robustly designed to ensure the final coordinate array + is always 2-dimensional, preventing downstream errors. + """ + all_ele, all_frac = self.get_ele_and_frac() + all_reflect_position = [] + all_matrix_M = [] + all_matrix_C = [] + for sym_opt in self.SYMM: + sym_opt_ele = sym_opt.lower().replace(" ", "").split(",") + # assert len(sym_opt_ele) == 3, "sym {} could not be treat".format(sym_opt_ele) + matrix_M = np.zeros((3, 3)) + matrix_C = np.zeros((1, 3)) + for idx, word in enumerate(sym_opt_ele): + sym_opt_ele_split = re.findall(r".*?([+-]*[xyz0-9\/\.]+)", word) + for sym_opt_frag in sym_opt_ele_split: + if sym_opt_frag == 'x' or sym_opt_frag == '+x': + matrix_M[0][idx] = 1 + elif str(sym_opt_frag) == '-x': + matrix_M[0][idx] = -1 + elif sym_opt_frag == 'y' or sym_opt_frag == '+y': + matrix_M[1][idx] = 1 + elif sym_opt_frag == '-y': + matrix_M[1][idx] = -1 + elif sym_opt_frag == 'z' or sym_opt_frag == '+z': + matrix_M[2][idx] = 1 + elif sym_opt_frag == '-z': + matrix_M[2][idx] = -1 + elif operation.is_number(sym_opt_frag) is True: + matrix_C[0][idx] = float(fractions.Fraction(sym_opt_frag)) + else: + raise Exception("wrong sym opt of" + sym_opt_frag) + + all_matrix_M.append(matrix_M) + all_matrix_C.append(matrix_C) + + for j in range(0, len(all_matrix_M)): + new_positions = np.dot(np.array([all_frac]), all_matrix_M[j]) + all_matrix_C[j] + all_reflect_position.append(new_positions.squeeze()) + all_ele = all_ele*len(self.SYMM) + + new_atoms = [] + idx=0 + for element, frac_xyz in zip(all_ele, np.array(all_reflect_position).reshape(-1,3)): + new_atoms.append(Atom(element=element, + frac_xyz=frac_xyz, + atom_id=idx)) + idx+=1 + + self.SYMM = "[x,y,z]" + self.space_group = 1 + self.atoms = new_atoms + self.update_cart_by_frac() + + def sort_by_element(self) -> None: + """Sorts the atoms list based on atomic number.""" + self.atoms.sort(key=lambda atom: chemical_knowledge.periodic_table_list[atom.element]) + + def get_ele_and_cart(self) -> Tuple[List[str], np.ndarray]: + """Returns all element symbols and their cartesian coordinates.""" + if not self.atoms: + return [], np.empty((0, 3)) + all_ele = [atom.element for atom in self.atoms] + all_carts = np.array([atom.cart_xyz for atom in self.atoms]) + return all_ele, all_carts + + def get_ele_and_frac(self) -> Tuple[List[str], np.ndarray]: + """Returns all element symbols and their fractional coordinates.""" + if not self.atoms: + return [], np.empty((0, 3)) + all_ele = [atom.element for atom in self.atoms] + all_fracs = np.array([atom.frac_xyz for atom in self.atoms]) + return all_ele, all_fracs + + def info(self, all_info: bool = False) -> None: + """ + Prints a formatted summary of the crystal structure. + + Args: + all_info: If True, prints an extended table including fractional + coordinates, forces, and other properties. + """ + print("--- Crystal System ---") + print(f"Name: {self.system_name}") + print("Lattice Vectors (Angstrom):") + for vec in self.cell_vect: + print(f"{vec[0]:16.8f} {vec[1]:16.8f} {vec[2]:16.8f}") + print("Lattice Parameters:") + print(f"a, b, c (A): {self.cell_para[0][0]:.4f}, {self.cell_para[0][1]:.4f}, {self.cell_para[0][2]:.4f}") + print(f"alpha, beta, gamma (deg): {self.cell_para[1][0]:.4f}, {self.cell_para[1][1]:.4f}, {self.cell_para[1][2]:.4f}") + print(f"Volume (A^3): {self.volume:.4f} | Density (g/cm^3): {self.density:.4f}") + print(f"\n--- Atomic Coordinates (Total: {len(self.atoms)}) ---") + + if not all_info: + print(f"{'Element':<10} {'Cartesian X':>16} {'Cartesian Y':>16} {'Cartesian Z':>16}") + print("-" * 58) + for atom in self.atoms: + print(f"{atom.element:<10} {atom.cart_xyz[0]:16.8f} {atom.cart_xyz[1]:16.8f} {atom.cart_xyz[2]:16.8f}") + else: + header = ( + f"{'ID':<5} {'Elem':<6} " + f"{'Frac X':>10} {'Frac Y':>10} {'Frac Z':>10} | " + f"{'Cart X':>12} {'Cart Y':>12} {'Cart Z':>12}" + ) + print(header) + print("-" * len(header)) + for atom in self.atoms: + aid = str(atom.atom_id) if atom.atom_id != 'unknown' else '-' + print( + f"{aid:<5} {atom.element:<6} " + f"{atom.frac_xyz[0]:10.6f} {atom.frac_xyz[1]:10.6f} {atom.frac_xyz[2]:10.6f} | " + f"{atom.cart_xyz[0]:12.6f} {atom.cart_xyz[1]:12.6f} {atom.cart_xyz[2]:12.6f}" + ) + + print("\n--- Other Properties ---") + print(f"Energy: {self.energy}") + print(f"Comment: {self.comment}") + print(f"Virial: {self.virial}") + + +class Molecule: + """Represents a non-periodic molecule (a collection of atoms).""" + + def __init__(self, **kwargs: Any): + """ + Initializes a Molecule object. + + Args: + **kwargs: Keyword arguments to set molecule attributes. + Required: 'atoms' (List[Atom]). + Optional: 'energy', 'comment', 'name', 'system_name'. + """ + self.atoms: Union[str, List[Atom]] = kwargs.get("atoms", "unknown") + self.energy: Union[str, float] = kwargs.get("energy", "unknown") + self.comment: Any = kwargs.get("comment", "unknown") + self.descriptor: Any = kwargs.get("descriptor", "unknown") + self.name: str = kwargs.get("name", "unknown") + self.system_name: str = kwargs.get("system_name", "unknown") + + if self.atoms == "unknown": + print("Warning: Molecule initialized with no atoms.") + self.atoms = [] + + def give_atom_id_forced(self) -> None: + """Assigns or resets atom IDs from 0 to N-1 and clears bonding info.""" + print("Warning: Resetting all atom IDs and bonding information!") + for i, atom in enumerate(self.atoms): + atom.atom_id = i + atom.bonded_atom = [] + + def get_element(self) -> List[str]: + """Returns a sorted list of unique element symbols in the molecule.""" + if not self.atoms: return [] + return chemical_knowledge.sort_by_atomic_number(set(atom.element for atom in self.atoms)) + + def get_element_amount(self) -> List[int]: + """Returns the count of each element, sorted by atomic number.""" + if not self.atoms: return [] + all_elements = [atom.element for atom in self.atoms] + return [all_elements.count(element) for element in self.get_element()] + + def get_ele_and_cart(self) -> Tuple[List[str], np.ndarray]: + """Returns all element symbols and their cartesian coordinates.""" + if not self.atoms: + return [], np.empty((0, 3)) + all_ele = [atom.element for atom in self.atoms] + all_carts = np.array([atom.cart_xyz for atom in self.atoms]) + return all_ele, all_carts + + def put_ele_cart_back(self, all_ele: List[str], all_carts: np.ndarray) -> None: + """Updates the molecule's atoms from lists of elements and coordinates.""" + for i, atom in enumerate(self.atoms): + atom.element = all_ele[i] + atom.cart_xyz = all_carts[i] + + def build_molecules_by_ele_cart(self, all_ele: List[str], all_carts: np.ndarray) -> None: + """Rebuilds the molecule's atoms list from elements and coordinates.""" + assert len(all_ele) == len(all_carts), "Element and coordinate lists must have the same length." + self.atoms = [ + Atom(element=ele, cart_xyz=cart, atom_id=i) + for i, (ele, cart) in enumerate(zip(all_ele, all_carts)) + ] + + def get_mass(self) -> float: + """Calculates the total mass of the molecule.""" + if not self.atoms: return 0.0 + return sum(chemical_knowledge.element_masses[atom.element] for atom in self.atoms) + + def get_center_of_mass(self) -> np.ndarray: + """Calculates the center of mass of the molecule.""" + if not self.atoms: return np.zeros(3) + + all_ele, all_carts = self.get_ele_and_cart() + masses = np.array([chemical_knowledge.element_masses[x] for x in all_ele]) + total_mass = np.sum(masses) + + if total_mass == 0: return np.zeros(3) + return np.sum(all_carts * masses[:, np.newaxis], axis=0) / total_mass + + def sort_by_element(self) -> None: + """Sorts the atoms list based on atomic number.""" + self.atoms.sort(key=lambda atom: chemical_knowledge.periodic_table_list[atom.element]) + + def sort_by_id(self) -> None: + """Sorts the atoms list based on their atom_id.""" + self.atoms.sort(key=lambda atom: atom.atom_id) + + def info(self) -> None: + """Prints a formatted summary of the molecule.""" + print(f"--- Molecule ---") + print(f"Name: {self.name} | System: {self.system_name}") + print(f"Number of atoms: {len(self.atoms)}") + print(f"Total Mass (amu): {self.get_mass():.4f}") + print(f"Energy: {self.energy}") + print(f"Comment: {self.comment}") + print(f"\n{'Element':<10} {'Cartesian X':>16} {'Cartesian Y':>16} {'Cartesian Z':>16}") + print("-" * 58) + if self.atoms: + for atom in self.atoms: + print(f"{atom.element:<10} {atom.cart_xyz[0]:16.8f} {atom.cart_xyz[1]:16.8f} {atom.cart_xyz[2]:16.8f}") + + def find_fragment(self, tolerance: float = 1.15) -> Dict[int, List[int]]: + """ + Identifies covalently bonded fragments within the molecule. + + This is useful for molecules that are actually composed of several + disconnected components (e.g., salts, solvent shells). + + Args: + tolerance: Scaling factor for covalent radii to determine bonding. + + Returns: + A dictionary mapping a fragment ID (starting from 1) to a list of + atom indices belonging to that fragment. + """ + if not self.atoms: return {} + + num_atoms = len(self.atoms) + cart_matrix = np.array([atom.cart_xyz for atom in self.atoms]) + radii = np.array([chemical_knowledge.element_covalent_radii[atom.element] for atom in self.atoms]) + + # Create a matrix of bond thresholds (r_i + r_j) + bond_threshold_matrix = (radii[:, np.newaxis] + radii) * tolerance + + # True where distance is less than the bond threshold + dist_matrix = cdist(cart_matrix, cart_matrix) + adj_matrix = dist_matrix < bond_threshold_matrix + np.fill_diagonal(adj_matrix, False) + + # Graph traversal (DFS) to find connected components + visited = [False] * num_atoms + groups = {} + group_id = 0 + for i in range(num_atoms): + if not visited[i]: + group_id += 1 + groups[group_id] = [] + stack = [i] + while stack: + atom_idx = stack.pop() + if not visited[atom_idx]: + visited[atom_idx] = True + groups[group_id].append(atom_idx) + # Find neighbors and add to stack + neighbors = np.where(adj_matrix[atom_idx])[0] + stack.extend(neighbors) + return groups + + def give_molecule_id(self, tolerance: float = 1.15) -> None: + """Assigns a molecule ID to each atom based on fragment analysis.""" + fragments = self.find_fragment(tolerance=tolerance) + for group_id, atom_indices in fragments.items(): + for atom_idx in atom_indices: + self.atoms[atom_idx].molecule = group_id + + def take_out_fragment(self, tolerance: float = 1.15) -> List['Molecule']: + """ + Splits the current molecule into a list of new Molecule objects, + one for each disconnected fragment. + """ + if not self.atoms: return [] + + self.give_atom_id_forced() # Ensure IDs are set for lookup + fragments = self.find_fragment(tolerance=tolerance) + new_molecules = [] + + for i, atom_indices in fragments.items(): + fragment_atoms = [self.atoms[j] for j in atom_indices] + new_mol = Molecule( + atoms=copy.deepcopy(fragment_atoms), + name=f"{self.name}_frag{i}", + system_name=f"{self.system_name}_frag{i}" + ) + new_molecules.append(new_mol) + return new_molecules + + def calculate_frac_xyz_by_cell_para(self, cell_para: list) -> None: + """Calculates fractional coordinates for all atoms given cell parameters.""" + for atom in self.atoms: + atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_para(atom.cart_xyz, cell_para) + + def molecule_volume(self, num_samples: int = 100000) -> float: + """ + Calculates the van der Waals volume using a Monte Carlo integration method. + + This method samples points in a bounding box around the molecule and + determines the ratio of points that fall within any atom's vdW sphere. + + Args: + num_samples: The number of random points to sample. More points + yield a more accurate volume at the cost of performance. + + Returns: + The estimated van der Waals volume in cubic Angstroms. + """ + if not self.atoms: return 0.0 + + elements, coords = self.get_ele_and_cart() + radii = np.array([chemical_knowledge.element_vdw_radii[el] for el in elements]) + + # Determine bounding box for sampling + min_bounds = np.min(coords, axis=0) - np.max(radii) + max_bounds = np.max(coords, axis=0) + np.max(radii) + bounding_box_volume = np.prod(max_bounds - min_bounds) + + # Generate random sample points within the bounding box + random_points = np.random.uniform(min_bounds, max_bounds, (num_samples, 3)) + + # Check for each point if it's inside ANY sphere + count_inside = 0 + for rp in tqdm(random_points, desc="Monte Carlo Volume", leave=False): + # Calculate squared distances from the point to all atom centers + dist_sq = np.sum((coords - rp)**2, axis=1) + # If any distance is within the radius, the point is inside + if np.any(dist_sq <= radii**2): + count_inside += 1 + return (count_inside / num_samples) * bounding_box_volume \ No newline at end of file diff --git a/basic_function/format_parser.py b/basic_function/format_parser.py index c17226e0bb4f5107b47c87aa330f21cf88f1db83..6a343774dcba0c627dbec90399532109ba97e4c0 100644 --- a/basic_function/format_parser.py +++ b/basic_function/format_parser.py @@ -1,333 +1,333 @@ -import re -from basic_function import operation -from basic_function import data_classes -from basic_function import chemical_knowledge -import copy - - -def read_xyz_file(file_path): - input_file = open(file_path, 'r') - lines = input_file.readlines() - number_of_atoms = int(lines[0]) - name = str(lines[1][:-1]) - atoms = [] - for index,line in enumerate(lines): - split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) - if len(split_line)==4 and operation.is_number(split_line[1]) and \ - operation.is_number(split_line[2]) and operation.is_number(split_line[3]): - atoms.append(data_classes.Atom(element=split_line[0], - cart_xyz=[float(split_line[1]), float(split_line[2]), float(split_line[3])], - atom_id=index-2)) - - if number_of_atoms!=len(atoms): - print("Warning! The length of atoms don't match the number of atoms given") - - molecule = data_classes.Molecule(atoms=atoms, name=name, system_name=name) - - return molecule - -def write_cif_file(crystal, sym=False, name="zcx"): - """ - Accept crystal class, give the cif file out - :param crystal: crystal class - :param coordinates: frac or cart - :param sym: False:give all atoms out; True:with symmetry - :param name: file name - :return: cif_out - cif file in list format should be print using the following function: - target=open("D:\\zcx.cif",'w') - target.writelines(cif_out) - target.close() - """ - - if crystal.system_name!="unknown": - name = crystal.system_name - cif_file = [] - - cif_file.append("data_"+str(name)+"\n") - if sym==False: - if crystal.space_group==1: - crystal_temp = crystal - else: - crystal_temp = copy.deepcopy(crystal) - crystal_temp.make_p1() - cif_file.append("_symmetry_space_group_name_H-M \'P1\'"+"\n") - cif_file.append("_symmetry_Int_Tables_number 1"+"\n") - - cif_file.append("loop_"+"\n") - cif_file.append("_symmetry_equiv_pos_site_id"+"\n") - cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n") - cif_file.append("1 x,y,z"+"\n") - cif_file.append("_cell_length_a "+str(crystal_temp.cell_para[0][0])+"\n") - cif_file.append("_cell_length_b "+str(crystal_temp.cell_para[0][1])+"\n") - cif_file.append("_cell_length_c "+str(crystal_temp.cell_para[0][2])+"\n") - cif_file.append("_cell_angle_alpha "+str(crystal_temp.cell_para[1][0])+"\n") - cif_file.append("_cell_angle_beta "+str(crystal_temp.cell_para[1][1])+"\n") - cif_file.append("_cell_angle_gamma "+str(crystal_temp.cell_para[1][2])+"\n") - - cif_file.append("loop_"+"\n") - cif_file.append("_atom_site_label"+"\n") - cif_file.append("_atom_site_type_symbol"+"\n") - cif_file.append("_atom_site_fract_x"+"\n") - cif_file.append("_atom_site_fract_y"+"\n") - cif_file.append("_atom_site_fract_z"+"\n") - for i in range(0,len(crystal_temp.atoms)): - cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n" - .format(i+1,crystal_temp.atoms[i].element,crystal_temp.atoms[i].frac_xyz[0], - crystal_temp.atoms[i].frac_xyz[1],crystal_temp.atoms[i].frac_xyz[2])) - - return cif_file - - elif sym==True: - cif_file.append("_symmetry_space_group_name_H-M \'{}\'".format(chemical_knowledge.space_group[crystal.space_group][1])+"\n") - cif_file.append("_symmetry_Int_Tables_number {}".format(crystal.space_group)+"\n") - - cif_file.append("loop_"+"\n") - cif_file.append("_symmetry_equiv_pos_site_id"+"\n") - cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n") - for idx, SYMM in enumerate(crystal.SYMM): - cif_file.append("{} {}".format(idx+1,SYMM)+"\n") - cif_file.append("_cell_length_a "+str(crystal.cell_para[0][0])+"\n") - cif_file.append("_cell_length_b "+str(crystal.cell_para[0][1])+"\n") - cif_file.append("_cell_length_c "+str(crystal.cell_para[0][2])+"\n") - cif_file.append("_cell_angle_alpha "+str(crystal.cell_para[1][0])+"\n") - cif_file.append("_cell_angle_beta "+str(crystal.cell_para[1][1])+"\n") - cif_file.append("_cell_angle_gamma "+str(crystal.cell_para[1][2])+"\n") - - cif_file.append("loop_"+"\n") - cif_file.append("_atom_site_label"+"\n") - cif_file.append("_atom_site_type_symbol"+"\n") - cif_file.append("_atom_site_fract_x"+"\n") - cif_file.append("_atom_site_fract_y"+"\n") - cif_file.append("_atom_site_fract_z"+"\n") - for i in range(0,len(crystal.atoms)): - cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n" - .format(i+1,crystal.atoms[i].element,crystal.atoms[i].frac_xyz[0], - crystal.atoms[i].frac_xyz[1],crystal.atoms[i].frac_xyz[2])) - - return cif_file - - -def write_cifs_file(crystals, sym=False, name="zcx"): - cifs_file = [] - for crystal in crystals: - single_cif = write_cif_file(crystal,sym=sym, name=name) - cifs_file.extend(single_cif) - return cifs_file - - -def read_cif_file(file_path,on_sym_check=False,shut_up=False,system_name="unknown",comment_name="unknown"): - input_file = open(file_path, 'r') - lines = input_file.readlines() - step_pickle = [] - crystal_all = [] - if system_name=="unknown": - no_name = True - else: - no_name = False - # first time scan - for index,line in enumerate(lines): - # find out all the step pickle - if line.startswith("data_"): - step_pickle.append(index) - step_pickle.append(len(lines)) - - # treat every step and return a crystal - for m in range(0,len(step_pickle)-1): - atoms = [] - atoms_P1 = [] - SYMM = [] - cell_para = [["unknown","unknown","unknown"],["unknown","unknown","unknown"]] - - for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]): - split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) - if line.startswith("#"): - pass - elif len(split_line)==0: - pass - # read the loop of symmetry - # elif split_line[0]=="loop_" and lines[step_pickle[m]+index+1]=="_symmetry_equiv_pos_as_xyz\n": - elif split_line[0] == "loop_" and lines[step_pickle[m] + index + 1] == "_symmetry_equiv_pos_as_xyz\n": - temp_number = 1 - while "_" not in lines[step_pickle[m]+index+1+temp_number]: - split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) - temp_number+=1 - if not operation.is_number(split_line_temp[0]): - SYMM.append(split_line_temp[0]) - else: - SYMM.append(split_line_temp[1]) - elif split_line[0] == "loop_" and lines[step_pickle[m]+index+2].strip(" ")=="_symmetry_equiv_pos_as_xyz\n": - temp_number = 1 - while "_" not in lines[step_pickle[m]+index+2+temp_number]: - split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+2+temp_number]))) - temp_number+=1 - SYMM.append("".join(split_line_temp[1:])) - elif split_line[0] == "loop_" and "_space_group_symop_operation_xyz\n" in lines[step_pickle[m] + index + 1]: - # ase format - temp_number = 1 - while "_" not in lines[step_pickle[m]+index+1+temp_number]: - if lines[step_pickle[m]+index+1+temp_number]=="\n": - temp_number += 1 - continue - split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) - temp_number+=1 - if not operation.is_number(split_line_temp[0]): - SYMM.append("".join(split_line_temp)) - - # read the loop of atoms: - elif (split_line[0] == "loop_" and lines[step_pickle[m] + index + 1].strip(" ") == "_atom_site_label\n") or \ - (split_line[0] == "loop_" and lines[step_pickle[m] + index + 2].strip(" ") == "_atom_site_label\n"): - temp_number = 0 - while "_" in lines[step_pickle[m]+index+1+temp_number]: - if lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_type_symbol\n": - ele_pos = temp_number - elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_x\n": - x_pos = temp_number - elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_y\n": - y_pos = temp_number - elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_z\n": - z_pos = temp_number - temp_number+=1 - how_long = temp_number - - while len(list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))) == how_long: - split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) - atoms.append(data_classes.Atom(element=split_line_temp[ele_pos], - frac_xyz=[float(split_line_temp[x_pos]),float(split_line_temp[y_pos]), - float(split_line_temp[z_pos])])) - temp_number += 1 - if step_pickle[m]+index+1+temp_number==len(lines): - break - - elif split_line[0] == "_cell_length_a": - cell_para[0][0] = float(split_line[1]) - elif split_line[0] == "_cell_length_b": - cell_para[0][1] = float(split_line[1]) - elif split_line[0] == "_cell_length_c": - cell_para[0][2] = float(split_line[1]) - elif split_line[0] == "_cell_angle_alpha": - cell_para[1][0] = float(split_line[1]) - elif split_line[0] == "_cell_angle_beta": - cell_para[1][1] = float(split_line[1]) - elif split_line[0] == "_cell_angle_gamma": - cell_para[1][2] = float(split_line[1]) - elif "data_" in line: - if no_name == True: - system_name = line[5:] - system_name = system_name.replace(" ","_") - system_name = system_name.replace("\\", "_") - system_name = system_name.replace("\n", "") - for atom in atoms: - all_reflect_position = operation.space_group_transfer_for_single_atom(atom.frac_xyz, SYMM) - for new_position in all_reflect_position: - atoms_P1.append(data_classes.Atom(element=atom.element, - frac_xyz=[new_position[0], new_position[1], new_position[2]])) - crystal_all.append(data_classes.Crystal(cell_para=cell_para, atoms=atoms_P1, comment=comment_name, system_name=system_name)) - if on_sym_check == True: - raise Exception("Not finished part, TODO in code") - if shut_up==False: - if m%100 == 0: - print("{} structures have been treated".format(m)) - - return crystal_all - - -def write_poscar_file(crystal, coordinates = 'frac', name = "parser_zcx_create"): - - vasp_file = [] - - vasp_file.append('{}\n'.format(name)) - vasp_file.append('1.0\n') - cell_vect = crystal.cell_vect - for vect in cell_vect: - vasp_file.append("{:16.8f} {:16.8f} {:16.8f}\n".format(vect[0],vect[1],vect[2])) - crystal.sort_by_element() - vasp_file.append("".join("{:>6s}".format(x) for x in crystal.get_element()) + "\n") - vasp_file.append("".join("{:>6.0f}".format(x) for x in crystal.get_element_amount()) + "\n") - if coordinates == 'frac': - vasp_file.append('Direct\n') - for ELEMENT in crystal.get_element(): - for ATOM in crystal.atoms: - if ATOM.element == ELEMENT: - vasp_file.append( - "{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.frac_xyz[0], ATOM.frac_xyz[1], ATOM.frac_xyz[2])) - elif coordinates == 'cart': - vasp_file.append('Cartesian\n') - for ELEMENT in crystal.get_element(): - for ATOM in crystal.atoms: - if ATOM.element == ELEMENT: - vasp_file.append( - "{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.cart_xyz[0], ATOM.cart_xyz[1], ATOM.cart_xyz[2])) - else: - raise Exception("Wrong coordinates type: {}".format(coordinates)) - - return vasp_file - - -def read_ase_pbc_file(file_path,shut_up=False): - input_file = open(file_path, 'r') - lines = input_file.readlines()[2:] - step_pickle = [] - crystal_all = [] - - # first time scan - for index,line in enumerate(lines): - # find out all the step pickle - if line.startswith("Step "): - step_pickle.append(index) - step_pickle.append(len(lines)) - - # treat every step and return a crystal - for m in range(0,len(step_pickle)-1): - atoms_P1 = [] - force_matrix = [] - position_matrix = [] - in_forces = False - in_positions = False - - - for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]): - # split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) - line = line.strip() - # check Forces part - if line.startswith("Forces:"): - in_forces = True - in_positions = False - continue - - # check Positions part - if line.startswith("Positions:"): - in_positions = True - in_forces = False - continue - - if in_forces and line.startswith("[") and line.endswith("]"): - line = line.replace("[", "").replace("]", "") - force_matrix.append([float(x) for x in line.split()]) - - # analyse Positions part - if in_positions and line.startswith("[") and line.endswith("]"): - line = line.replace("[", "").replace("]", "") - position_matrix.append([float(x) for x in line.split()]) - - if line.startswith("Elements:"): - elements_string = line.strip().split(":", 1)[-1].strip() - elements_string = elements_string[1:-1] - - elements = [elem.strip().strip("'") for elem in elements_string.split(",")] - - if line.startswith("cell:"): - matrix_string = line[len("cell: Cell("):-1] - rows = matrix_string.split("], [") - cell_vect = [ - [float(value) for value in row.replace('[', '').replace(']', '').replace(')', '').split(", ")] - for row in rows - ] - - for i in range(0,len(elements)): - atoms_P1.append(data_classes.Atom(element=elements[i], cart_xyz=[position_matrix[i][0], position_matrix[i][1], position_matrix[i][2]])) - - crystal_all.append(data_classes.Crystal(cell_vect=cell_vect, atoms=atoms_P1)) - - return crystal_all +import re +from basic_function import operation +from basic_function import data_classes +from basic_function import chemical_knowledge +import copy + + +def read_xyz_file(file_path): + input_file = open(file_path, 'r') + lines = input_file.readlines() + number_of_atoms = int(lines[0]) + name = str(lines[1][:-1]) + atoms = [] + for index,line in enumerate(lines): + split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) + if len(split_line)==4 and operation.is_number(split_line[1]) and \ + operation.is_number(split_line[2]) and operation.is_number(split_line[3]): + atoms.append(data_classes.Atom(element=split_line[0], + cart_xyz=[float(split_line[1]), float(split_line[2]), float(split_line[3])], + atom_id=index-2)) + + if number_of_atoms!=len(atoms): + print("Warning! The length of atoms don't match the number of atoms given") + + molecule = data_classes.Molecule(atoms=atoms, name=name, system_name=name) + + return molecule + +def write_cif_file(crystal, sym=False, name="zcx"): + """ + Accept crystal class, give the cif file out + :param crystal: crystal class + :param coordinates: frac or cart + :param sym: False:give all atoms out; True:with symmetry + :param name: file name + :return: cif_out + cif file in list format should be print using the following function: + target=open("D:\\zcx.cif",'w') + target.writelines(cif_out) + target.close() + """ + + if crystal.system_name!="unknown": + name = crystal.system_name + cif_file = [] + + cif_file.append("data_"+str(name)+"\n") + if sym==False: + if crystal.space_group==1: + crystal_temp = crystal + else: + crystal_temp = copy.deepcopy(crystal) + crystal_temp.make_p1() + cif_file.append("_symmetry_space_group_name_H-M \'P1\'"+"\n") + cif_file.append("_symmetry_Int_Tables_number 1"+"\n") + + cif_file.append("loop_"+"\n") + cif_file.append("_symmetry_equiv_pos_site_id"+"\n") + cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n") + cif_file.append("1 x,y,z"+"\n") + cif_file.append("_cell_length_a "+str(crystal_temp.cell_para[0][0])+"\n") + cif_file.append("_cell_length_b "+str(crystal_temp.cell_para[0][1])+"\n") + cif_file.append("_cell_length_c "+str(crystal_temp.cell_para[0][2])+"\n") + cif_file.append("_cell_angle_alpha "+str(crystal_temp.cell_para[1][0])+"\n") + cif_file.append("_cell_angle_beta "+str(crystal_temp.cell_para[1][1])+"\n") + cif_file.append("_cell_angle_gamma "+str(crystal_temp.cell_para[1][2])+"\n") + + cif_file.append("loop_"+"\n") + cif_file.append("_atom_site_label"+"\n") + cif_file.append("_atom_site_type_symbol"+"\n") + cif_file.append("_atom_site_fract_x"+"\n") + cif_file.append("_atom_site_fract_y"+"\n") + cif_file.append("_atom_site_fract_z"+"\n") + for i in range(0,len(crystal_temp.atoms)): + cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n" + .format(i+1,crystal_temp.atoms[i].element,crystal_temp.atoms[i].frac_xyz[0], + crystal_temp.atoms[i].frac_xyz[1],crystal_temp.atoms[i].frac_xyz[2])) + + return cif_file + + elif sym==True: + cif_file.append("_symmetry_space_group_name_H-M \'{}\'".format(chemical_knowledge.space_group[crystal.space_group][1])+"\n") + cif_file.append("_symmetry_Int_Tables_number {}".format(crystal.space_group)+"\n") + + cif_file.append("loop_"+"\n") + cif_file.append("_symmetry_equiv_pos_site_id"+"\n") + cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n") + for idx, SYMM in enumerate(crystal.SYMM): + cif_file.append("{} {}".format(idx+1,SYMM)+"\n") + cif_file.append("_cell_length_a "+str(crystal.cell_para[0][0])+"\n") + cif_file.append("_cell_length_b "+str(crystal.cell_para[0][1])+"\n") + cif_file.append("_cell_length_c "+str(crystal.cell_para[0][2])+"\n") + cif_file.append("_cell_angle_alpha "+str(crystal.cell_para[1][0])+"\n") + cif_file.append("_cell_angle_beta "+str(crystal.cell_para[1][1])+"\n") + cif_file.append("_cell_angle_gamma "+str(crystal.cell_para[1][2])+"\n") + + cif_file.append("loop_"+"\n") + cif_file.append("_atom_site_label"+"\n") + cif_file.append("_atom_site_type_symbol"+"\n") + cif_file.append("_atom_site_fract_x"+"\n") + cif_file.append("_atom_site_fract_y"+"\n") + cif_file.append("_atom_site_fract_z"+"\n") + for i in range(0,len(crystal.atoms)): + cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n" + .format(i+1,crystal.atoms[i].element,crystal.atoms[i].frac_xyz[0], + crystal.atoms[i].frac_xyz[1],crystal.atoms[i].frac_xyz[2])) + + return cif_file + + +def write_cifs_file(crystals, sym=False, name="zcx"): + cifs_file = [] + for crystal in crystals: + single_cif = write_cif_file(crystal,sym=sym, name=name) + cifs_file.extend(single_cif) + return cifs_file + + +def read_cif_file(file_path,on_sym_check=False,shut_up=False,system_name="unknown",comment_name="unknown"): + input_file = open(file_path, 'r') + lines = input_file.readlines() + step_pickle = [] + crystal_all = [] + if system_name=="unknown": + no_name = True + else: + no_name = False + # first time scan + for index,line in enumerate(lines): + # find out all the step pickle + if line.startswith("data_"): + step_pickle.append(index) + step_pickle.append(len(lines)) + + # treat every step and return a crystal + for m in range(0,len(step_pickle)-1): + atoms = [] + atoms_P1 = [] + SYMM = [] + cell_para = [["unknown","unknown","unknown"],["unknown","unknown","unknown"]] + + for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]): + split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) + if line.startswith("#"): + pass + elif len(split_line)==0: + pass + # read the loop of symmetry + # elif split_line[0]=="loop_" and lines[step_pickle[m]+index+1]=="_symmetry_equiv_pos_as_xyz\n": + elif split_line[0] == "loop_" and lines[step_pickle[m] + index + 1] == "_symmetry_equiv_pos_as_xyz\n": + temp_number = 1 + while "_" not in lines[step_pickle[m]+index+1+temp_number]: + split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) + temp_number+=1 + if not operation.is_number(split_line_temp[0]): + SYMM.append(split_line_temp[0]) + else: + SYMM.append(split_line_temp[1]) + elif split_line[0] == "loop_" and lines[step_pickle[m]+index+2].strip(" ")=="_symmetry_equiv_pos_as_xyz\n": + temp_number = 1 + while "_" not in lines[step_pickle[m]+index+2+temp_number]: + split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+2+temp_number]))) + temp_number+=1 + SYMM.append("".join(split_line_temp[1:])) + elif split_line[0] == "loop_" and "_space_group_symop_operation_xyz\n" in lines[step_pickle[m] + index + 1]: + # ase format + temp_number = 1 + while "_" not in lines[step_pickle[m]+index+1+temp_number]: + if lines[step_pickle[m]+index+1+temp_number]=="\n": + temp_number += 1 + continue + split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) + temp_number+=1 + if not operation.is_number(split_line_temp[0]): + SYMM.append("".join(split_line_temp)) + + # read the loop of atoms: + elif (split_line[0] == "loop_" and lines[step_pickle[m] + index + 1].strip(" ") == "_atom_site_label\n") or \ + (split_line[0] == "loop_" and lines[step_pickle[m] + index + 2].strip(" ") == "_atom_site_label\n"): + temp_number = 0 + while "_" in lines[step_pickle[m]+index+1+temp_number]: + if lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_type_symbol\n": + ele_pos = temp_number + elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_x\n": + x_pos = temp_number + elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_y\n": + y_pos = temp_number + elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_z\n": + z_pos = temp_number + temp_number+=1 + how_long = temp_number + + while len(list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))) == how_long: + split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) + atoms.append(data_classes.Atom(element=split_line_temp[ele_pos], + frac_xyz=[float(split_line_temp[x_pos]),float(split_line_temp[y_pos]), + float(split_line_temp[z_pos])])) + temp_number += 1 + if step_pickle[m]+index+1+temp_number==len(lines): + break + + elif split_line[0] == "_cell_length_a": + cell_para[0][0] = float(split_line[1]) + elif split_line[0] == "_cell_length_b": + cell_para[0][1] = float(split_line[1]) + elif split_line[0] == "_cell_length_c": + cell_para[0][2] = float(split_line[1]) + elif split_line[0] == "_cell_angle_alpha": + cell_para[1][0] = float(split_line[1]) + elif split_line[0] == "_cell_angle_beta": + cell_para[1][1] = float(split_line[1]) + elif split_line[0] == "_cell_angle_gamma": + cell_para[1][2] = float(split_line[1]) + elif "data_" in line: + if no_name == True: + system_name = line[5:] + system_name = system_name.replace(" ","_") + system_name = system_name.replace("\\", "_") + system_name = system_name.replace("\n", "") + for atom in atoms: + all_reflect_position = operation.space_group_transfer_for_single_atom(atom.frac_xyz, SYMM) + for new_position in all_reflect_position: + atoms_P1.append(data_classes.Atom(element=atom.element, + frac_xyz=[new_position[0], new_position[1], new_position[2]])) + crystal_all.append(data_classes.Crystal(cell_para=cell_para, atoms=atoms_P1, comment=comment_name, system_name=system_name)) + if on_sym_check == True: + raise Exception("Not finished part, TODO in code") + if shut_up==False: + if m%100 == 0: + print("{} structures have been treated".format(m)) + + return crystal_all + + +def write_poscar_file(crystal, coordinates = 'frac', name = "parser_zcx_create"): + + vasp_file = [] + + vasp_file.append('{}\n'.format(name)) + vasp_file.append('1.0\n') + cell_vect = crystal.cell_vect + for vect in cell_vect: + vasp_file.append("{:16.8f} {:16.8f} {:16.8f}\n".format(vect[0],vect[1],vect[2])) + crystal.sort_by_element() + vasp_file.append("".join("{:>6s}".format(x) for x in crystal.get_element()) + "\n") + vasp_file.append("".join("{:>6.0f}".format(x) for x in crystal.get_element_amount()) + "\n") + if coordinates == 'frac': + vasp_file.append('Direct\n') + for ELEMENT in crystal.get_element(): + for ATOM in crystal.atoms: + if ATOM.element == ELEMENT: + vasp_file.append( + "{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.frac_xyz[0], ATOM.frac_xyz[1], ATOM.frac_xyz[2])) + elif coordinates == 'cart': + vasp_file.append('Cartesian\n') + for ELEMENT in crystal.get_element(): + for ATOM in crystal.atoms: + if ATOM.element == ELEMENT: + vasp_file.append( + "{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.cart_xyz[0], ATOM.cart_xyz[1], ATOM.cart_xyz[2])) + else: + raise Exception("Wrong coordinates type: {}".format(coordinates)) + + return vasp_file + + +def read_ase_pbc_file(file_path,shut_up=False): + input_file = open(file_path, 'r') + lines = input_file.readlines()[2:] + step_pickle = [] + crystal_all = [] + + # first time scan + for index,line in enumerate(lines): + # find out all the step pickle + if line.startswith("Step "): + step_pickle.append(index) + step_pickle.append(len(lines)) + + # treat every step and return a crystal + for m in range(0,len(step_pickle)-1): + atoms_P1 = [] + force_matrix = [] + position_matrix = [] + in_forces = False + in_positions = False + + + for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]): + # split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) + line = line.strip() + # check Forces part + if line.startswith("Forces:"): + in_forces = True + in_positions = False + continue + + # check Positions part + if line.startswith("Positions:"): + in_positions = True + in_forces = False + continue + + if in_forces and line.startswith("[") and line.endswith("]"): + line = line.replace("[", "").replace("]", "") + force_matrix.append([float(x) for x in line.split()]) + + # analyse Positions part + if in_positions and line.startswith("[") and line.endswith("]"): + line = line.replace("[", "").replace("]", "") + position_matrix.append([float(x) for x in line.split()]) + + if line.startswith("Elements:"): + elements_string = line.strip().split(":", 1)[-1].strip() + elements_string = elements_string[1:-1] + + elements = [elem.strip().strip("'") for elem in elements_string.split(",")] + + if line.startswith("cell:"): + matrix_string = line[len("cell: Cell("):-1] + rows = matrix_string.split("], [") + cell_vect = [ + [float(value) for value in row.replace('[', '').replace(']', '').replace(')', '').split(", ")] + for row in rows + ] + + for i in range(0,len(elements)): + atoms_P1.append(data_classes.Atom(element=elements[i], cart_xyz=[position_matrix[i][0], position_matrix[i][1], position_matrix[i][2]])) + + crystal_all.append(data_classes.Crystal(cell_vect=cell_vect, atoms=atoms_P1)) + + return crystal_all \ No newline at end of file diff --git a/basic_function/operation.py b/basic_function/operation.py index 1e157fb16bc218de88ab3a8fd85e1dd0b6f82dc2..a2ccb8f0a2926424e84e8f0a09d5c76398c2b0c3 100644 --- a/basic_function/operation.py +++ b/basic_function/operation.py @@ -1,469 +1,469 @@ -# -*- coding: utf-8 -*- -""" -A collection of functions for performing crystallographic and molecular operations, -such as symmetry application, supercell generation, and geometric analysis. -""" - -# --- Standard Library Imports --- -import copy -import fractions -import re -from typing import Any, Dict, List, Optional, Tuple, Union - -# --- Third-Party Imports --- -import networkx as nx -import numpy as np -import numpy.typing as npt -from scipy.spatial import cKDTree as KDTree - -# --- Local Application Imports --- -from basic_function import chemical_knowledge, data_classes - -# Type aliases for clarity -NDArrayFloat = npt.NDArray[np.float64] -CellVectors = List[List[float]] -SymmetryOperations = List[str] - - -def is_number(s: str) -> bool: - """Checks if a string can be interpreted as a number (float or fraction). - - Args: - s: The input string. - - Returns: - True if the string represents a number, False otherwise. - """ - try: - float(s) - return True - except ValueError: - pass - - try: - # Check for fractional representations like "1/2" - float(fractions.Fraction(s)) - return True - except ValueError: - return False - - -def _parse_symmetry_operations( - sym_ops: SymmetryOperations, -) -> Tuple[List[NDArrayFloat], List[NDArrayFloat]]: - """Parses a list of symmetry operation strings into matrices. - - This is an internal helper function to avoid code duplication in public functions. - - Args: - sym_ops: A list of symmetry operation strings (e.g., ['x, y, z+1/2']). - - Returns: - A tuple containing two lists: - - A list of 3x3 rotation/reflection matrices (M). - - A list of 1x3 translation vectors (C). - - Raises: - ValueError: If a symmetry operation string is malformed. - """ - rotation_matrices = [] - translation_vectors = [] - - for sym_op_str in sym_ops: - sym_op_parts = sym_op_str.lower().replace(" ", "").split(",") - if len(sym_op_parts) != 3: - raise ValueError(f"Symmetry operation '{sym_op_str}' is invalid.") - - matrix_m = np.zeros((3, 3)) - matrix_c = np.zeros((1, 3)) - - for i, part in enumerate(sym_op_parts): - # Regex to find elements like '+x', '-y', 'z', '1/2', '-0.5' - tokens = re.findall(r"([+-]?[xyz0-9./]+)", part) - for token in tokens: - token = token.strip() - if not token: - continue - - if "x" in token: - matrix_m[0, i] = -1.0 if token.startswith("-") else 1.0 - elif "y" in token: - matrix_m[1, i] = -1.0 if token.startswith("-") else 1.0 - elif "z" in token: - matrix_m[2, i] = -1.0 if token.startswith("-") else 1.0 - elif is_number(token): - matrix_c[0, i] += float(fractions.Fraction(token)) - else: - raise ValueError(f"Invalid fragment '{token}' in symmetry operation.") - - rotation_matrices.append(matrix_m) - translation_vectors.append(matrix_c) - - return rotation_matrices, translation_vectors - - -def space_group_transfer_for_single_atom( - frac_xyz: List[float], space_group_ops: SymmetryOperations -) -> List[List[float]]: - """Applies space group symmetry operations to a single atomic coordinate. - - Args: - frac_xyz: The fractional coordinates [x, y, z] of a single atom. - space_group_ops: A list of space group symmetry operation strings. - - Returns: - A list of all symmetrically equivalent fractional coordinates. - """ - rot_matrices, trans_vectors = _parse_symmetry_operations(space_group_ops) - - equivalent_positions = [] - atom_pos = np.array(frac_xyz) - - for rot, trans in zip(rot_matrices, trans_vectors): - new_pos = np.dot(atom_pos, rot.T) + trans.squeeze() - equivalent_positions.append(new_pos.tolist()) - - return equivalent_positions - - -def super_cell( - crystal: "data_classes.Crystal", - cell_range: Optional[List[List[int]]] = None, -) -> "data_classes.Crystal": - """Constructs a supercell from a unit cell. - - Args: - crystal: The input Crystal object. - cell_range: A list of ranges for each lattice vector, e.g., - [[-1, 1], [-1, 1], [-1, 1]] creates a 3x3x3 supercell. - If None, defaults to [[-1, 1], [-1, 1], [-1, 1]]. - - Returns: - A new Crystal object representing the supercell. - """ - if cell_range is None: - cell_range = [[-1, 1], [-1, 1], [-1, 1]] - - dims = [r[1] - r[0] + 1 for r in cell_range] - - new_lattice = [ - [dim * val for val in crystal.cell_vect[i]] for i, dim in enumerate(dims) - ] - - translation_vectors = [] - for h in range(cell_range[0][0], cell_range[0][1] + 1): - for k in range(cell_range[1][0], cell_range[1][1] + 1): - for l in range(cell_range[2][0], cell_range[2][1] + 1): - translation_vectors.append([h, k, l]) - - new_atoms = [] - for atom in crystal.atoms: - for trans_vec in translation_vectors: - new_frac_xyz = [ - (atom.frac_xyz[i] + trans_vec[i]) / dims[i] for i in range(3) - ] - new_atoms.append( - data_classes.Atom(element=atom.element, frac_xyz=new_frac_xyz) - ) - - if crystal.energy != "unknown": - total_cells = dims[0] * dims[1] * dims[2] - new_energy = crystal.energy * total_cells - else: - new_energy = "unknown" - - return data_classes.Crystal( - cell_vect=new_lattice, energy=new_energy, atoms=new_atoms - ) - - -def orient_molecule(molecule: "data_classes.Molecule") -> "data_classes.Molecule": - """Orients a molecule along its principal axes of inertia. - - The method uses the Moment of Inertia tensor to define a canonical orientation. - The molecule's coordinates are modified in-place. For more details, see: - http://sobereva.com/426 - - Args: - molecule: The Molecule object to be oriented. - - Returns: - The same Molecule object with its atoms reoriented. - """ - all_ele, all_cart = molecule.get_ele_and_cart() - - if len(all_cart) <= 1: - return molecule # No orientation needed for single atoms or empty molecules. - - masses = np.array([chemical_knowledge.element_masses[el] for el in all_ele]) - relative_position = all_cart - molecule.get_center_of_mass() - - # Calculate the moment of inertia tensor - I_xx = np.sum(masses * (relative_position[:, 1] ** 2 + relative_position[:, 2] ** 2)) - I_yy = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 2] ** 2)) - I_zz = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 1] ** 2)) - I_xy = -np.sum(masses * relative_position[:, 0] * relative_position[:, 1]) - I_xz = -np.sum(masses * relative_position[:, 0] * relative_position[:, 2]) - I_yz = -np.sum(masses * relative_position[:, 1] * relative_position[:, 2]) - - I_matrix = np.array([[I_xx, I_xy, I_xz], [I_xy, I_yy, I_yz], [I_xz, I_yz, I_zz]]) - - # Eigenvectors of the inertia tensor are the principal axes. - # np.linalg.eigh is used for symmetric matrices. - eigenvalues, eigenvectors = np.linalg.eigh(I_matrix) - principal_axes = eigenvectors.T - - # Project the relative positions onto the new axes system. - new_positions = np.dot(relative_position, principal_axes.T) - - molecule.put_ele_cart_back(all_ele, new_positions) - return molecule - - -def get_rotate_matrix(v: NDArrayFloat) -> NDArrayFloat: - """Generates a 3x3 rotation matrix from a 3D vector `v`. - - This function uses a mapping from a 3D vector to a quaternion, which is then - used to construct the rotation matrix. This method avoids gimbal lock. A - left-handed coordinate system is assumed. - - Args: - v: A 3-element NumPy array used to generate the quaternion. - - Returns: - A 3x3 rotation matrix. - """ - # Ensure v elements are within valid ranges if necessary, though the - # formulas handle most inputs gracefully. - v0_sqrt = np.sqrt(max(v[0], 0)) - v0_1_sqrt = np.sqrt(max(1.0 - v[0], 0)) - - angle1 = 2.0 * np.pi * v[1] - angle2 = 2.0 * np.pi * v[2] - - # Quaternion components (x, y, z, w) - qx = v0_1_sqrt * np.sin(angle1) - qy = v0_1_sqrt * np.cos(angle1) - qz = v0_sqrt * np.sin(angle2) - qw = v0_sqrt * np.cos(angle2) - - return np.array([ - [1 - 2*qy**2 - 2*qz**2, 2*qx*qy + 2*qw*qz, 2*qx*qz - 2*qw*qy], - [2*qx*qy - 2*qw*qz, 1 - 2*qx**2 - 2*qz**2, 2*qy*qz + 2*qw*qx], - [2*qx*qz + 2*qw*qy, 2*qy*qz - 2*qw*qx, 1 - 2*qx**2 - 2*qy**2] - ]) - - -def f2c_matrix( - cell_params: Tuple[List[float], List[float]] -) -> Optional[NDArrayFloat]: - """Calculates the fractional-to-Cartesian transformation matrix. - - Args: - cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]], - where lengths are in Angstroms and angles are in degrees. - - Returns: - The 3x3 transformation matrix, or None if cell parameters are invalid. - """ - lengths, angles = cell_params - a, b, c = lengths - alpha, beta, gamma = np.deg2rad(angles) - - cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) - sin_g = np.sin(gamma) - - # Volume calculation term - volume_term_sq = ( - 1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g - ) - if volume_term_sq < 0: - return None - - volume = a * b * c * np.sqrt(volume_term_sq) - - matrix = np.zeros((3, 3)) - matrix[0, 0] = a - matrix[0, 1] = b * cos_g - matrix[0, 2] = c * cos_b - matrix[1, 1] = b * sin_g - matrix[1, 2] = c * (cos_a - cos_b * cos_g) / sin_g - matrix[2, 2] = volume / (a * b * sin_g) - - return matrix.T - - -def c2f_matrix( - cell_params: Tuple[List[float], List[float]] -) -> Optional[NDArrayFloat]: - """Calculates the Cartesian-to-fractional transformation matrix. - - This is the inverse of the matrix generated by `f2c_matrix`. - - Args: - cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]], - where lengths are in Angstroms and angles are in degrees. - - Returns: - The 3x3 transformation matrix, or None if cell parameters are invalid. - """ - f2c = f2c_matrix(cell_params) - if f2c is None: - return None - - try: - return np.linalg.inv(f2c) - except np.linalg.LinAlgError: - return None - - -def apply_SYMM( - frac_xyz: NDArrayFloat, symm_ops: SymmetryOperations -) -> NDArrayFloat: - """Applies symmetry operations to a single set of fractional coordinates. - - Args: - frac_xyz: A NumPy array of fractional coordinates [x, y, z]. - symm_ops: A list of symmetry operation strings. - - Returns: - A NumPy array of all symmetrically equivalent fractional coordinates. - """ - rot_matrices, trans_vectors = _parse_symmetry_operations(symm_ops) - - equivalent_positions = [ - np.dot(frac_xyz, rot.T) + trans.squeeze() - for rot, trans in zip(rot_matrices, trans_vectors) - ] - - return np.array(equivalent_positions) - - -def apply_SYMM_with_element( - elements: Union[str, List[str]], - frac_xyzs: NDArrayFloat, - symm_ops: SymmetryOperations, -) -> Tuple[NDArrayFloat, NDArrayFloat]: - """Applies symmetry operations, returning new elements and coordinates. - - Args: - elements: The element symbol(s) corresponding to the coordinates. - frac_xyzs: A NumPy array of fractional coordinates. - symm_ops: A list of symmetry operation strings. - - Returns: - A tuple containing: - - A NumPy array of element symbols for each new position. - - A NumPy array of all symmetrically equivalent fractional coordinates. - """ - equivalent_positions = apply_SYMM(frac_xyzs, symm_ops) - num_ops = len(equivalent_positions) - - replicated_elements = np.tile(np.array(elements).squeeze(), (num_ops, 1)) - - return replicated_elements, equivalent_positions - - -def calculate_longest_diagonal_length(cell_vect: CellVectors) -> float: - """Calculates the length of the longest space diagonal of a unit cell. - - The longest diagonal connects the origin (0,0,0) to the opposite - corner (1,1,1) of the unit cell. - - Args: - cell_vect: The three lattice vectors of the cell. - - Returns: - The length of the longest diagonal in Angstroms. - """ - cell_vect_np = np.array(cell_vect) - diagonal_vector = np.sum(cell_vect_np, axis=0) - return float(np.linalg.norm(diagonal_vector)) - - -def calculate_distance_of_parallel_plane_in_crystal(cell_vect: CellVectors) -> List[float]: - """Calculates inter-planar distances for primary crystallographic planes. - - This computes the distances for the (100), (010), and (001) families of planes. - - Args: - cell_vect: The three lattice vectors [a, b, c] of the cell. - - Returns: - A list of three distances [d_a, d_b, d_c], where d_a is the distance - between planes parallel to the b-c plane, and so on. - """ - distances = [] - vectors = [np.array(v) for v in cell_vect] - - # Permutations to calculate distance for each primary plane - # (a to b-c plane, b to a-c plane, c to a-b plane) - indices = [(0, 1, 2), (1, 0, 2), (2, 0, 1)] - - for i, j, k in indices: - point_p = vectors[i] - plane_v1 = vectors[j] - plane_v2 = vectors[k] - - # Normal vector to the plane defined by plane_v1 and plane_v2 - normal_vector = np.cross(plane_v1, plane_v2) - - # Distance from point P to the plane is |N · P| / ||N|| - distance = abs(np.dot(normal_vector, point_p)) / np.linalg.norm(normal_vector) - distances.append(distance) - - return distances - - -def detect_is_frame_vdw_new(crystal: "data_classes.Crystal", tolerance: float = 1.2) -> bool: - """Detects if a crystal structure forms a connected framework via VdW radii. - - The method involves: - 1. Expanding the crystal to a P1 symmetry supercell. - 2. Building a 3x3x3 supercell to ensure periodic connections are considered. - 3. Constructing a graph where atoms are nodes and an edge exists if their - distance is within a scaled sum of their van der Waals radii. - 4. Checking if the largest connected component in the graph is large enough - to be considered a single, percolating framework. - - Args: - crystal: The Crystal object to analyze. - tolerance: A tolerance factor to scale the VdW radii sum. - - Returns: - True if the structure is a connected framework, False otherwise. - """ - crystal_temp = copy.deepcopy(crystal) - crystal_temp.make_p1() - crystal_temp.move_atom_into_cell() - - # Create a 3x3x3 supercell to check for connectivity across boundaries - crystal_supercell = super_cell(crystal_temp, cell_range=[[-1, 1], [-1, 1], [-1, 1]]) - - all_ele, all_carts = crystal_supercell.get_ele_and_cart() - - vdw_radii_map = chemical_knowledge.element_vdw_radii - vdw_max = max(vdw_radii_map[el] for el in set(all_ele)) - distance_threshold = vdw_max * tolerance * 2 - - # KDTree for efficient nearest-neighbor search - tree = KDTree(all_carts) - pairs = tree.query_pairs(r=distance_threshold) - - # Build a graph to find connected components - graph = nx.Graph() - graph.add_nodes_from(range(len(all_carts))) - graph.add_edges_from(list(pairs)) - - if not graph.nodes: - return False - - # Find the largest connected component - largest_cc = max(nx.connected_components(graph), key=len) - - # A heuristic to check for a percolating framework. A connected framework - # should connect most atoms. The threshold '9' is empirical but robustly - # distinguishes between isolated molecules and a fully connected lattice. - # In a 3x3x3 supercell (27 unit cells), a connected framework should involve - # significantly more atoms than in a few unit cells. +# -*- coding: utf-8 -*- +""" +A collection of functions for performing crystallographic and molecular operations, +such as symmetry application, supercell generation, and geometric analysis. +""" + +# --- Standard Library Imports --- +import copy +import fractions +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +# --- Third-Party Imports --- +import networkx as nx +import numpy as np +import numpy.typing as npt +from scipy.spatial import cKDTree as KDTree + +# --- Local Application Imports --- +from basic_function import chemical_knowledge, data_classes + +# Type aliases for clarity +NDArrayFloat = npt.NDArray[np.float64] +CellVectors = List[List[float]] +SymmetryOperations = List[str] + + +def is_number(s: str) -> bool: + """Checks if a string can be interpreted as a number (float or fraction). + + Args: + s: The input string. + + Returns: + True if the string represents a number, False otherwise. + """ + try: + float(s) + return True + except ValueError: + pass + + try: + # Check for fractional representations like "1/2" + float(fractions.Fraction(s)) + return True + except ValueError: + return False + + +def _parse_symmetry_operations( + sym_ops: SymmetryOperations, +) -> Tuple[List[NDArrayFloat], List[NDArrayFloat]]: + """Parses a list of symmetry operation strings into matrices. + + This is an internal helper function to avoid code duplication in public functions. + + Args: + sym_ops: A list of symmetry operation strings (e.g., ['x, y, z+1/2']). + + Returns: + A tuple containing two lists: + - A list of 3x3 rotation/reflection matrices (M). + - A list of 1x3 translation vectors (C). + + Raises: + ValueError: If a symmetry operation string is malformed. + """ + rotation_matrices = [] + translation_vectors = [] + + for sym_op_str in sym_ops: + sym_op_parts = sym_op_str.lower().replace(" ", "").split(",") + if len(sym_op_parts) != 3: + raise ValueError(f"Symmetry operation '{sym_op_str}' is invalid.") + + matrix_m = np.zeros((3, 3)) + matrix_c = np.zeros((1, 3)) + + for i, part in enumerate(sym_op_parts): + # Regex to find elements like '+x', '-y', 'z', '1/2', '-0.5' + tokens = re.findall(r"([+-]?[xyz0-9./]+)", part) + for token in tokens: + token = token.strip() + if not token: + continue + + if "x" in token: + matrix_m[0, i] = -1.0 if token.startswith("-") else 1.0 + elif "y" in token: + matrix_m[1, i] = -1.0 if token.startswith("-") else 1.0 + elif "z" in token: + matrix_m[2, i] = -1.0 if token.startswith("-") else 1.0 + elif is_number(token): + matrix_c[0, i] += float(fractions.Fraction(token)) + else: + raise ValueError(f"Invalid fragment '{token}' in symmetry operation.") + + rotation_matrices.append(matrix_m) + translation_vectors.append(matrix_c) + + return rotation_matrices, translation_vectors + + +def space_group_transfer_for_single_atom( + frac_xyz: List[float], space_group_ops: SymmetryOperations +) -> List[List[float]]: + """Applies space group symmetry operations to a single atomic coordinate. + + Args: + frac_xyz: The fractional coordinates [x, y, z] of a single atom. + space_group_ops: A list of space group symmetry operation strings. + + Returns: + A list of all symmetrically equivalent fractional coordinates. + """ + rot_matrices, trans_vectors = _parse_symmetry_operations(space_group_ops) + + equivalent_positions = [] + atom_pos = np.array(frac_xyz) + + for rot, trans in zip(rot_matrices, trans_vectors): + new_pos = np.dot(atom_pos, rot.T) + trans.squeeze() + equivalent_positions.append(new_pos.tolist()) + + return equivalent_positions + + +def super_cell( + crystal: "data_classes.Crystal", + cell_range: Optional[List[List[int]]] = None, +) -> "data_classes.Crystal": + """Constructs a supercell from a unit cell. + + Args: + crystal: The input Crystal object. + cell_range: A list of ranges for each lattice vector, e.g., + [[-1, 1], [-1, 1], [-1, 1]] creates a 3x3x3 supercell. + If None, defaults to [[-1, 1], [-1, 1], [-1, 1]]. + + Returns: + A new Crystal object representing the supercell. + """ + if cell_range is None: + cell_range = [[-1, 1], [-1, 1], [-1, 1]] + + dims = [r[1] - r[0] + 1 for r in cell_range] + + new_lattice = [ + [dim * val for val in crystal.cell_vect[i]] for i, dim in enumerate(dims) + ] + + translation_vectors = [] + for h in range(cell_range[0][0], cell_range[0][1] + 1): + for k in range(cell_range[1][0], cell_range[1][1] + 1): + for l in range(cell_range[2][0], cell_range[2][1] + 1): + translation_vectors.append([h, k, l]) + + new_atoms = [] + for atom in crystal.atoms: + for trans_vec in translation_vectors: + new_frac_xyz = [ + (atom.frac_xyz[i] + trans_vec[i]) / dims[i] for i in range(3) + ] + new_atoms.append( + data_classes.Atom(element=atom.element, frac_xyz=new_frac_xyz) + ) + + if crystal.energy != "unknown": + total_cells = dims[0] * dims[1] * dims[2] + new_energy = crystal.energy * total_cells + else: + new_energy = "unknown" + + return data_classes.Crystal( + cell_vect=new_lattice, energy=new_energy, atoms=new_atoms + ) + + +def orient_molecule(molecule: "data_classes.Molecule") -> "data_classes.Molecule": + """Orients a molecule along its principal axes of inertia. + + The method uses the Moment of Inertia tensor to define a canonical orientation. + The molecule's coordinates are modified in-place. For more details, see: + http://sobereva.com/426 + + Args: + molecule: The Molecule object to be oriented. + + Returns: + The same Molecule object with its atoms reoriented. + """ + all_ele, all_cart = molecule.get_ele_and_cart() + + if len(all_cart) <= 1: + return molecule # No orientation needed for single atoms or empty molecules. + + masses = np.array([chemical_knowledge.element_masses[el] for el in all_ele]) + relative_position = all_cart - molecule.get_center_of_mass() + + # Calculate the moment of inertia tensor + I_xx = np.sum(masses * (relative_position[:, 1] ** 2 + relative_position[:, 2] ** 2)) + I_yy = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 2] ** 2)) + I_zz = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 1] ** 2)) + I_xy = -np.sum(masses * relative_position[:, 0] * relative_position[:, 1]) + I_xz = -np.sum(masses * relative_position[:, 0] * relative_position[:, 2]) + I_yz = -np.sum(masses * relative_position[:, 1] * relative_position[:, 2]) + + I_matrix = np.array([[I_xx, I_xy, I_xz], [I_xy, I_yy, I_yz], [I_xz, I_yz, I_zz]]) + + # Eigenvectors of the inertia tensor are the principal axes. + # np.linalg.eigh is used for symmetric matrices. + eigenvalues, eigenvectors = np.linalg.eigh(I_matrix) + principal_axes = eigenvectors.T + + # Project the relative positions onto the new axes system. + new_positions = np.dot(relative_position, principal_axes.T) + + molecule.put_ele_cart_back(all_ele, new_positions) + return molecule + + +def get_rotate_matrix(v: NDArrayFloat) -> NDArrayFloat: + """Generates a 3x3 rotation matrix from a 3D vector `v`. + + This function uses a mapping from a 3D vector to a quaternion, which is then + used to construct the rotation matrix. This method avoids gimbal lock. A + left-handed coordinate system is assumed. + + Args: + v: A 3-element NumPy array used to generate the quaternion. + + Returns: + A 3x3 rotation matrix. + """ + # Ensure v elements are within valid ranges if necessary, though the + # formulas handle most inputs gracefully. + v0_sqrt = np.sqrt(max(v[0], 0)) + v0_1_sqrt = np.sqrt(max(1.0 - v[0], 0)) + + angle1 = 2.0 * np.pi * v[1] + angle2 = 2.0 * np.pi * v[2] + + # Quaternion components (x, y, z, w) + qx = v0_1_sqrt * np.sin(angle1) + qy = v0_1_sqrt * np.cos(angle1) + qz = v0_sqrt * np.sin(angle2) + qw = v0_sqrt * np.cos(angle2) + + return np.array([ + [1 - 2*qy**2 - 2*qz**2, 2*qx*qy + 2*qw*qz, 2*qx*qz - 2*qw*qy], + [2*qx*qy - 2*qw*qz, 1 - 2*qx**2 - 2*qz**2, 2*qy*qz + 2*qw*qx], + [2*qx*qz + 2*qw*qy, 2*qy*qz - 2*qw*qx, 1 - 2*qx**2 - 2*qy**2] + ]) + + +def f2c_matrix( + cell_params: Tuple[List[float], List[float]] +) -> Optional[NDArrayFloat]: + """Calculates the fractional-to-Cartesian transformation matrix. + + Args: + cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]], + where lengths are in Angstroms and angles are in degrees. + + Returns: + The 3x3 transformation matrix, or None if cell parameters are invalid. + """ + lengths, angles = cell_params + a, b, c = lengths + alpha, beta, gamma = np.deg2rad(angles) + + cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) + sin_g = np.sin(gamma) + + # Volume calculation term + volume_term_sq = ( + 1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g + ) + if volume_term_sq < 0: + return None + + volume = a * b * c * np.sqrt(volume_term_sq) + + matrix = np.zeros((3, 3)) + matrix[0, 0] = a + matrix[0, 1] = b * cos_g + matrix[0, 2] = c * cos_b + matrix[1, 1] = b * sin_g + matrix[1, 2] = c * (cos_a - cos_b * cos_g) / sin_g + matrix[2, 2] = volume / (a * b * sin_g) + + return matrix.T + + +def c2f_matrix( + cell_params: Tuple[List[float], List[float]] +) -> Optional[NDArrayFloat]: + """Calculates the Cartesian-to-fractional transformation matrix. + + This is the inverse of the matrix generated by `f2c_matrix`. + + Args: + cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]], + where lengths are in Angstroms and angles are in degrees. + + Returns: + The 3x3 transformation matrix, or None if cell parameters are invalid. + """ + f2c = f2c_matrix(cell_params) + if f2c is None: + return None + + try: + return np.linalg.inv(f2c) + except np.linalg.LinAlgError: + return None + + +def apply_SYMM( + frac_xyz: NDArrayFloat, symm_ops: SymmetryOperations +) -> NDArrayFloat: + """Applies symmetry operations to a single set of fractional coordinates. + + Args: + frac_xyz: A NumPy array of fractional coordinates [x, y, z]. + symm_ops: A list of symmetry operation strings. + + Returns: + A NumPy array of all symmetrically equivalent fractional coordinates. + """ + rot_matrices, trans_vectors = _parse_symmetry_operations(symm_ops) + + equivalent_positions = [ + np.dot(frac_xyz, rot.T) + trans.squeeze() + for rot, trans in zip(rot_matrices, trans_vectors) + ] + + return np.array(equivalent_positions) + + +def apply_SYMM_with_element( + elements: Union[str, List[str]], + frac_xyzs: NDArrayFloat, + symm_ops: SymmetryOperations, +) -> Tuple[NDArrayFloat, NDArrayFloat]: + """Applies symmetry operations, returning new elements and coordinates. + + Args: + elements: The element symbol(s) corresponding to the coordinates. + frac_xyzs: A NumPy array of fractional coordinates. + symm_ops: A list of symmetry operation strings. + + Returns: + A tuple containing: + - A NumPy array of element symbols for each new position. + - A NumPy array of all symmetrically equivalent fractional coordinates. + """ + equivalent_positions = apply_SYMM(frac_xyzs, symm_ops) + num_ops = len(equivalent_positions) + + replicated_elements = np.tile(np.array(elements).squeeze(), (num_ops, 1)) + + return replicated_elements, equivalent_positions + + +def calculate_longest_diagonal_length(cell_vect: CellVectors) -> float: + """Calculates the length of the longest space diagonal of a unit cell. + + The longest diagonal connects the origin (0,0,0) to the opposite + corner (1,1,1) of the unit cell. + + Args: + cell_vect: The three lattice vectors of the cell. + + Returns: + The length of the longest diagonal in Angstroms. + """ + cell_vect_np = np.array(cell_vect) + diagonal_vector = np.sum(cell_vect_np, axis=0) + return float(np.linalg.norm(diagonal_vector)) + + +def calculate_distance_of_parallel_plane_in_crystal(cell_vect: CellVectors) -> List[float]: + """Calculates inter-planar distances for primary crystallographic planes. + + This computes the distances for the (100), (010), and (001) families of planes. + + Args: + cell_vect: The three lattice vectors [a, b, c] of the cell. + + Returns: + A list of three distances [d_a, d_b, d_c], where d_a is the distance + between planes parallel to the b-c plane, and so on. + """ + distances = [] + vectors = [np.array(v) for v in cell_vect] + + # Permutations to calculate distance for each primary plane + # (a to b-c plane, b to a-c plane, c to a-b plane) + indices = [(0, 1, 2), (1, 0, 2), (2, 0, 1)] + + for i, j, k in indices: + point_p = vectors[i] + plane_v1 = vectors[j] + plane_v2 = vectors[k] + + # Normal vector to the plane defined by plane_v1 and plane_v2 + normal_vector = np.cross(plane_v1, plane_v2) + + # Distance from point P to the plane is |N · P| / ||N|| + distance = abs(np.dot(normal_vector, point_p)) / np.linalg.norm(normal_vector) + distances.append(distance) + + return distances + + +def detect_is_frame_vdw_new(crystal: "data_classes.Crystal", tolerance: float = 1.2) -> bool: + """Detects if a crystal structure forms a connected framework via VdW radii. + + The method involves: + 1. Expanding the crystal to a P1 symmetry supercell. + 2. Building a 3x3x3 supercell to ensure periodic connections are considered. + 3. Constructing a graph where atoms are nodes and an edge exists if their + distance is within a scaled sum of their van der Waals radii. + 4. Checking if the largest connected component in the graph is large enough + to be considered a single, percolating framework. + + Args: + crystal: The Crystal object to analyze. + tolerance: A tolerance factor to scale the VdW radii sum. + + Returns: + True if the structure is a connected framework, False otherwise. + """ + crystal_temp = copy.deepcopy(crystal) + crystal_temp.make_p1() + crystal_temp.move_atom_into_cell() + + # Create a 3x3x3 supercell to check for connectivity across boundaries + crystal_supercell = super_cell(crystal_temp, cell_range=[[-1, 1], [-1, 1], [-1, 1]]) + + all_ele, all_carts = crystal_supercell.get_ele_and_cart() + + vdw_radii_map = chemical_knowledge.element_vdw_radii + vdw_max = max(vdw_radii_map[el] for el in set(all_ele)) + distance_threshold = vdw_max * tolerance * 2 + + # KDTree for efficient nearest-neighbor search + tree = KDTree(all_carts) + pairs = tree.query_pairs(r=distance_threshold) + + # Build a graph to find connected components + graph = nx.Graph() + graph.add_nodes_from(range(len(all_carts))) + graph.add_edges_from(list(pairs)) + + if not graph.nodes: + return False + + # Find the largest connected component + largest_cc = max(nx.connected_components(graph), key=len) + + # A heuristic to check for a percolating framework. A connected framework + # should connect most atoms. The threshold '9' is empirical but robustly + # distinguishes between isolated molecules and a fully connected lattice. + # In a 3x3x3 supercell (27 unit cells), a connected framework should involve + # significantly more atoms than in a few unit cells. return len(largest_cc) > 9 * len(crystal_temp.atoms) \ No newline at end of file diff --git a/basic_function/packaged_function.py b/basic_function/packaged_function.py index b131115cbd3a3ccb83bfff26526dd2d97430aa8c..11632b7f7b7669036c0151aacb2e085ebaf87beb 100644 --- a/basic_function/packaged_function.py +++ b/basic_function/packaged_function.py @@ -1,102 +1,102 @@ -from basic_function import format_parser -from basic_function import CSP_generator_normal -import os -import concurrent.futures -import sys - - - -def process_crystal(seed, sg, molecules,output_path,add_name): - aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg) - molecules_number = sum(len(molecule.atoms) for molecule in molecules) - new_crystal = aaa.generate(seed=seed) - sys.stdout.flush() - if new_crystal is not None: - cif_out = format_parser.write_cif_file(new_crystal) - with open(f"{output_path}/structures/{add_name}_{sg}_{seed}_z{len(molecules)}_{molecules_number}.cif", 'w') as target: - target.writelines(cif_out) - return True - return False - -def CSP_generater_parallel(molecules,output_path,need_structure = 100, space_group_list=[1],max_workers=8,add_name='',start_seed=1): - space_groups = space_group_list - accept_count = need_structure - - try: - os.makedirs("{}/structures".format(output_path)) - except: - print("Warning, these is already an structures folder in this path, skip mkdir") - for sg in space_groups: - accept = 0 - seed = start_seed - - with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: - futures = {} - while accept < accept_count: - # submit new task - while len(futures) < max_workers and accept + len(futures) < accept_count: - future = executor.submit(process_crystal, seed, sg, molecules, output_path,add_name) - futures[future] = seed - seed += 1 - - # check the finished task - done, _ = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED) - for future in done: - if future.result(): - accept += 1 - # remove it from list, no matter what result it is - del futures[future] - - # cancel all task if the number need is arrived. - if accept >= accept_count: - for future in futures: - future.cancel() - break - - -def CSP_generater_serial(molecules,output_path,need_structure = 100, densely_pack_method=False, space_group_list=[1]): - """ - :param molecules: a list [molecule1, molecule2, ...] - :param output_path: a str indicate the path of output folder - :param need_structure: int - :param space_group_list:a list indicate the space group need to search - """ - try: - os.makedirs("{}\\structures".format(output_path)) - except: - print("Warning, these is already an structures folder in this path, skip mkdir") - for sg in space_group_list: - aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg) - accept=0 - i=1 - while accept= accept_count: + for future in futures: + future.cancel() + break + + +def CSP_generater_serial(molecules,output_path,need_structure = 100, densely_pack_method=False, space_group_list=[1]): + """ + :param molecules: a list [molecule1, molecule2, ...] + :param output_path: a str indicate the path of output folder + :param need_structure: int + :param space_group_list:a list indicate the space group need to search + """ + try: + os.makedirs("{}\\structures".format(output_path)) + except: + print("Warning, these is already an structures folder in this path, skip mkdir") + for sg in space_group_list: + aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg) + accept=0 + i=1 + while accept CellVectors: - """Converts cell parameters to lattice vectors. - - The lattice vector `a` is aligned with the x-axis. The vector `b` lies in - the xy-plane. - - Args: - cell_para: A tuple containing [[a, b, c], [alpha, beta, gamma]], - where lengths are in Angstroms and angles are in degrees. - check: If True, asserts the input shape is correct. - - Returns: - A 3x3 list of lists representing the cell vectors [a, b, c]. - """ - if check: - shape_check = np.array(cell_para) - assert shape_check.shape == (2, 3), "Input `cell_para` must have shape (2, 3)." - - lengths = cell_para[0] - angles_deg = cell_para[1] - - a, b, c = lengths - alpha, beta, gamma = np.deg2rad(angles_deg) - - cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) - sin_g = np.sin(gamma) - - # This term is related to the square of the cell volume. - # It ensures the cell parameters are physically valid. - volume_term_sq = ( - 1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g - ) - - # Ensure the argument for sqrt is non-negative - volume_term = np.sqrt(max(0, volume_term_sq)) - - cell_vect = np.zeros((3, 3)) - cell_vect[0, 0] = a - cell_vect[1, 0] = b * cos_g - cell_vect[1, 1] = b * sin_g - cell_vect[2, 0] = c * cos_b - cell_vect[2, 1] = c * (cos_a - cos_b * cos_g) / sin_g - cell_vect[2, 2] = c * volume_term / sin_g - - return cell_vect.tolist() - - -def cell_vect_to_para(cell_vect: CellVectors, check: bool = False) -> CellParameters: - """Converts lattice vectors to cell parameters. - - Args: - cell_vect: A 3x3 array-like object representing the lattice vectors. - check: If True, asserts the input shape is correct. - - Returns: - A tuple containing [[a, b, c], [alpha, beta, gamma]]. - """ - cell_vect_np = np.array(cell_vect) - if check: - assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." - - vec_a, vec_b, vec_c = cell_vect_np - - len_a = np.linalg.norm(vec_a) - len_b = np.linalg.norm(vec_b) - len_c = np.linalg.norm(vec_c) - - lengths = [len_a, len_b, len_c] - - # Calculate angles using the dot product formula; handle potential floating point inaccuracies. - def _calculate_angle(v1, v2, norm1, norm2): - cosine_angle = np.dot(v1, v2) / (norm1 * norm2) - # Clip to handle values slightly outside [-1, 1] due to precision issues - return np.arccos(np.clip(cosine_angle, -1.0, 1.0)) - - alpha_rad = _calculate_angle(vec_b, vec_c, len_b, len_c) - beta_rad = _calculate_angle(vec_a, vec_c, len_a, len_c) - gamma_rad = _calculate_angle(vec_a, vec_b, len_a, len_b) - - angles_deg = np.rad2deg([alpha_rad, beta_rad, gamma_rad]).tolist() - - return (lengths, angles_deg) - - -def atom_frac_to_cart_by_cell_vect( - atom_frac: Coordinates, cell_vect: CellVectors, check: bool = False -) -> List[float]: - """Converts fractional coordinates to Cartesian coordinates using cell vectors. - - Args: - atom_frac: A 3-element list or array of fractional coordinates. - cell_vect: A 3x3 matrix of lattice vectors. - check: If True, asserts input shapes are correct. - - Returns: - A list of 3 Cartesian coordinates. - """ - atom_frac_np = np.array(atom_frac) - cell_vect_np = np.array(cell_vect) - - if check: - assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." - assert atom_frac_np.shape == (3,), "Input `atom_frac` must have 3 elements." - - # The transformation is a linear combination of the basis vectors. - # atom_cart = frac_x * vec_a + frac_y * vec_b + frac_z * vec_c - # This is equivalent to a dot product: [fx, fy, fz] @ [[ax,ay,az],[bx,by,bz],[cx,cy,cz]] - atom_cart = np.dot(atom_frac_np, cell_vect_np) - return atom_cart.tolist() - - -def atom_frac_to_cart_by_cell_para( - atom_frac: Coordinates, cell_para: CellParameters, check: bool = False -) -> List[float]: - """Converts fractional coordinates to Cartesian using cell parameters. - - Args: - atom_frac: A 3-element list or array of fractional coordinates. - cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]]. - check: If True, performs validation checks in underlying functions. - - Returns: - A list of 3 Cartesian coordinates. - """ - cell_vect = cell_para_to_vect(cell_para, check=check) - return atom_frac_to_cart_by_cell_vect(atom_frac, cell_vect, check=check) - - -def atom_cart_to_frac_by_cell_vect( - atom_cart: Coordinates, cell_vect: CellVectors, check: bool = False -) -> List[float]: - """Converts Cartesian coordinates to fractional coordinates using cell vectors. - - Args: - atom_cart: A 3-element list or array of Cartesian coordinates. - cell_vect: A 3x3 matrix of lattice vectors. - check: If True, asserts input shapes are correct. - - Returns: - A list of 3 fractional coordinates. - """ - atom_cart_np = np.array(atom_cart) - cell_vect_np = np.array(cell_vect) - - if check: - assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." - assert atom_cart_np.shape == (3,), "Input `atom_cart` must have 3 elements." - - # The transformation is atom_frac = atom_cart @ inverse(cell_vect) - inv_cell_vect = np.linalg.inv(cell_vect_np) - atom_frac = np.dot(atom_cart_np, inv_cell_vect) - return atom_frac.tolist() - - -def atom_cart_to_frac_by_cell_para( - atom_cart: Coordinates, cell_para: CellParameters, check: bool = False -) -> List[float]: - """Converts Cartesian coordinates to fractional using cell parameters. - - Args: - atom_cart: A 3-element list or array of Cartesian coordinates. - cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]]. - check: If True, performs validation checks in underlying functions. - - Returns: - A list of 3 fractional coordinates. - """ - cell_vect = cell_para_to_vect(cell_para, check=check) - return atom_cart_to_frac_by_cell_vect(atom_cart, cell_vect, check=check) - - -def calculate_volume(cell_info: Union[CellParameters, CellVectors]) -> float: - """Calculates the volume of the unit cell. - - Args: - cell_info: Can be either cell parameters [[a,b,c], [al,be,ga]] or - a 3x3 matrix of cell vectors. - - Returns: - The volume of the cell in cubic Angstroms. - - Raises: - ValueError: If the shape of `cell_info` is not (2, 3) or (3, 3). - """ - cell_info_np = np.array(cell_info) - - if cell_info_np.shape == (3, 3): - # Input is cell vectors, calculate volume using the scalar triple product. - return float(np.abs(np.dot(cell_info_np[0], np.cross(cell_info_np[1], cell_info_np[2])))) - - elif cell_info_np.shape == (2, 3): - # Input is cell parameters. - lengths, angles_deg = cell_info_np - a, b, c = lengths - alpha, beta, gamma = np.deg2rad(angles_deg) - - cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) - - # Standard formula for volume from cell parameters - volume_sq = ( - a**2 * b**2 * c**2 * (1 - cos_a**2 - cos_b**2 - cos_g**2 + 2 * cos_a * cos_b * cos_g) - ) - return float(np.sqrt(max(0, volume_sq))) - - else: +# -*- coding: utf-8 -*- +""" +Provides functions for converting between different representations of a +crystallographic unit cell (cell parameters and lattice vectors) and for +transforming atomic coordinates between fractional and Cartesian systems. +""" + +# --- Standard Library Imports --- +from typing import List, Tuple, Union + +# --- Third-Party Imports --- +import numpy as np +import numpy.typing as npt + +# --- Type Aliases for Clarity --- +NDArrayFloat = npt.NDArray[np.float64] +CellParameters = Tuple[List[float], List[float]] +CellVectors = Union[List[List[float]], NDArrayFloat] +Coordinates = Union[List[float], NDArrayFloat] + + +def cell_para_to_vect( + cell_para: CellParameters, check: bool = False +) -> CellVectors: + """Converts cell parameters to lattice vectors. + + The lattice vector `a` is aligned with the x-axis. The vector `b` lies in + the xy-plane. + + Args: + cell_para: A tuple containing [[a, b, c], [alpha, beta, gamma]], + where lengths are in Angstroms and angles are in degrees. + check: If True, asserts the input shape is correct. + + Returns: + A 3x3 list of lists representing the cell vectors [a, b, c]. + """ + if check: + shape_check = np.array(cell_para) + assert shape_check.shape == (2, 3), "Input `cell_para` must have shape (2, 3)." + + lengths = cell_para[0] + angles_deg = cell_para[1] + + a, b, c = lengths + alpha, beta, gamma = np.deg2rad(angles_deg) + + cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) + sin_g = np.sin(gamma) + + # This term is related to the square of the cell volume. + # It ensures the cell parameters are physically valid. + volume_term_sq = ( + 1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g + ) + + # Ensure the argument for sqrt is non-negative + volume_term = np.sqrt(max(0, volume_term_sq)) + + cell_vect = np.zeros((3, 3)) + cell_vect[0, 0] = a + cell_vect[1, 0] = b * cos_g + cell_vect[1, 1] = b * sin_g + cell_vect[2, 0] = c * cos_b + cell_vect[2, 1] = c * (cos_a - cos_b * cos_g) / sin_g + cell_vect[2, 2] = c * volume_term / sin_g + + return cell_vect.tolist() + + +def cell_vect_to_para(cell_vect: CellVectors, check: bool = False) -> CellParameters: + """Converts lattice vectors to cell parameters. + + Args: + cell_vect: A 3x3 array-like object representing the lattice vectors. + check: If True, asserts the input shape is correct. + + Returns: + A tuple containing [[a, b, c], [alpha, beta, gamma]]. + """ + cell_vect_np = np.array(cell_vect) + if check: + assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." + + vec_a, vec_b, vec_c = cell_vect_np + + len_a = np.linalg.norm(vec_a) + len_b = np.linalg.norm(vec_b) + len_c = np.linalg.norm(vec_c) + + lengths = [len_a, len_b, len_c] + + # Calculate angles using the dot product formula; handle potential floating point inaccuracies. + def _calculate_angle(v1, v2, norm1, norm2): + cosine_angle = np.dot(v1, v2) / (norm1 * norm2) + # Clip to handle values slightly outside [-1, 1] due to precision issues + return np.arccos(np.clip(cosine_angle, -1.0, 1.0)) + + alpha_rad = _calculate_angle(vec_b, vec_c, len_b, len_c) + beta_rad = _calculate_angle(vec_a, vec_c, len_a, len_c) + gamma_rad = _calculate_angle(vec_a, vec_b, len_a, len_b) + + angles_deg = np.rad2deg([alpha_rad, beta_rad, gamma_rad]).tolist() + + return (lengths, angles_deg) + + +def atom_frac_to_cart_by_cell_vect( + atom_frac: Coordinates, cell_vect: CellVectors, check: bool = False +) -> List[float]: + """Converts fractional coordinates to Cartesian coordinates using cell vectors. + + Args: + atom_frac: A 3-element list or array of fractional coordinates. + cell_vect: A 3x3 matrix of lattice vectors. + check: If True, asserts input shapes are correct. + + Returns: + A list of 3 Cartesian coordinates. + """ + atom_frac_np = np.array(atom_frac) + cell_vect_np = np.array(cell_vect) + + if check: + assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." + assert atom_frac_np.shape == (3,), "Input `atom_frac` must have 3 elements." + + # The transformation is a linear combination of the basis vectors. + # atom_cart = frac_x * vec_a + frac_y * vec_b + frac_z * vec_c + # This is equivalent to a dot product: [fx, fy, fz] @ [[ax,ay,az],[bx,by,bz],[cx,cy,cz]] + atom_cart = np.dot(atom_frac_np, cell_vect_np) + return atom_cart.tolist() + + +def atom_frac_to_cart_by_cell_para( + atom_frac: Coordinates, cell_para: CellParameters, check: bool = False +) -> List[float]: + """Converts fractional coordinates to Cartesian using cell parameters. + + Args: + atom_frac: A 3-element list or array of fractional coordinates. + cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]]. + check: If True, performs validation checks in underlying functions. + + Returns: + A list of 3 Cartesian coordinates. + """ + cell_vect = cell_para_to_vect(cell_para, check=check) + return atom_frac_to_cart_by_cell_vect(atom_frac, cell_vect, check=check) + + +def atom_cart_to_frac_by_cell_vect( + atom_cart: Coordinates, cell_vect: CellVectors, check: bool = False +) -> List[float]: + """Converts Cartesian coordinates to fractional coordinates using cell vectors. + + Args: + atom_cart: A 3-element list or array of Cartesian coordinates. + cell_vect: A 3x3 matrix of lattice vectors. + check: If True, asserts input shapes are correct. + + Returns: + A list of 3 fractional coordinates. + """ + atom_cart_np = np.array(atom_cart) + cell_vect_np = np.array(cell_vect) + + if check: + assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." + assert atom_cart_np.shape == (3,), "Input `atom_cart` must have 3 elements." + + # The transformation is atom_frac = atom_cart @ inverse(cell_vect) + inv_cell_vect = np.linalg.inv(cell_vect_np) + atom_frac = np.dot(atom_cart_np, inv_cell_vect) + return atom_frac.tolist() + + +def atom_cart_to_frac_by_cell_para( + atom_cart: Coordinates, cell_para: CellParameters, check: bool = False +) -> List[float]: + """Converts Cartesian coordinates to fractional using cell parameters. + + Args: + atom_cart: A 3-element list or array of Cartesian coordinates. + cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]]. + check: If True, performs validation checks in underlying functions. + + Returns: + A list of 3 fractional coordinates. + """ + cell_vect = cell_para_to_vect(cell_para, check=check) + return atom_cart_to_frac_by_cell_vect(atom_cart, cell_vect, check=check) + + +def calculate_volume(cell_info: Union[CellParameters, CellVectors]) -> float: + """Calculates the volume of the unit cell. + + Args: + cell_info: Can be either cell parameters [[a,b,c], [al,be,ga]] or + a 3x3 matrix of cell vectors. + + Returns: + The volume of the cell in cubic Angstroms. + + Raises: + ValueError: If the shape of `cell_info` is not (2, 3) or (3, 3). + """ + cell_info_np = np.array(cell_info) + + if cell_info_np.shape == (3, 3): + # Input is cell vectors, calculate volume using the scalar triple product. + return float(np.abs(np.dot(cell_info_np[0], np.cross(cell_info_np[1], cell_info_np[2])))) + + elif cell_info_np.shape == (2, 3): + # Input is cell parameters. + lengths, angles_deg = cell_info_np + a, b, c = lengths + alpha, beta, gamma = np.deg2rad(angles_deg) + + cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) + + # Standard formula for volume from cell parameters + volume_sq = ( + a**2 * b**2 * c**2 * (1 - cos_a**2 - cos_b**2 - cos_g**2 + 2 * cos_a * cos_b * cos_g) + ) + return float(np.sqrt(max(0, volume_sq))) + + else: raise ValueError(f"Cannot understand input shape {cell_info_np.shape} for `cell_info`.") \ No newline at end of file diff --git a/csp.sh b/csp.sh index 42b7820342ff167d4882fbd49016b3bc9b59f6fe..e1ca13190b54d76ca4b18785247d74c900657c2d 100644 --- a/csp.sh +++ b/csp.sh @@ -1,36 +1,37 @@ -#!/bin/bash -TOP_DIR=$(pwd) -TAR_DIR="${TOP_DIR}/test" - -mkdir -p "${TAR_DIR}" -cd ${TAR_DIR} - -# generate structures -python "${TOP_DIR}/main.py" --path ${TAR_DIR} --smiles "OC(=O)c1cc(O)c(O)c(O)c1.O" \ - --molecule_num_in_cell 1,1 --space_group_list 13,14 --add_name KONTIQ --max_workers 16\ - --num_generation 100 --generate_conformers 20 --use_conformers 4 > generate.log 2>&1 - -# opt structures using mace, --batch_size 0 means auto batch size only for mace -mkdir -p "${TAR_DIR}/mace_opt" -cd "${TAR_DIR}/mace_opt" -python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \ - --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 80 --batch_size 0 \ - --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ - --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 1 --cueq true \ - --use_ordered_files true --model mace > opt.log 2>&1 - -# opt structures using 7net -# mkdir -p "${TAR_DIR}/7net_opt" -# cd "${TAR_DIR}/7net_opt" -# python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \ -# --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 48 --batch_size 2 \ -# --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ -# --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 --cueq true \ -# --use_ordered_files true --model sevennet > opt.log 2>&1 - -# Postprocess the opt structures -python "${TOP_DIR}/post_process/clean_table.py" -## Make sure you have installed csd-python-api in current env before execuing following commands -# conda activate ccdc -# python "${TOP_DIR}/post_process/check_match.py" --workers 80 --timeout 20 --ref_path "${TAR_DIR}/refs" -# python "${TOP_DIR}/post_process/duplicate_remove.py" --workers 80 + +TOP_DIR=$(pwd) +TAR_DIR="${TOP_DIR}/test" + +mkdir -p "${TAR_DIR}" +cd ${TAR_DIR} + +# conformer search and structure generation +# change --mode to conformer_only or structure_only to seperate the process. +python "${TOP_DIR}/main.py" --path ${TAR_DIR} --smiles "OC(=O)c1cc(O)c(O)c(O)c1.O" \ + --molecule_num_in_cell 1,1 --space_group_list 13,14 --add_name KONTIQ --max_workers 16\ + --num_generation 100 --generate_conformers 20 --use_conformers 4 --mode all > generate.log 2>&1 + +# opt structures using mace, --batch_size 0 means auto batch size only for mace +mkdir -p "${TAR_DIR}/mace_opt" +cd "${TAR_DIR}/mace_opt" +python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \ + --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 80 --batch_size 0 \ + --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ + --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 1 --cueq true \ + --use_ordered_files true --model mace > opt.log 2>&1 + +# opt structures using 7net +# mkdir -p "${TAR_DIR}/7net_opt" +# cd "${TAR_DIR}/7net_opt" +# python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \ +# --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 48 --batch_size 2 \ +# --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ +# --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 --cueq true \ +# --use_ordered_files true --model sevennet > opt.log 2>&1 + +# Postprocess the opt structures +python "${TOP_DIR}/post_process/clean_table.py" +## Make sure you have installed csd-python-api in current env before execuing following commands +# conda activate ccdc +# python "${TOP_DIR}/post_process/check_match.py" --workers 80 --timeout 20 --ref_path "${TAR_DIR}/refs" +# python "${TOP_DIR}/post_process/duplicate_remove.py" --workers 80 diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__init__.py b/mace-bench/3rdparty/SevenNet/sevenn/__init__.py index 85a1516613da6de8d162f0607379512b6be8dc4c..6145e442c54f2249f48cfb7303611a7241a48b6e 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/__init__.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/__init__.py @@ -1,13 +1,13 @@ -from importlib.metadata import version - -from packaging.version import Version - -__version__ = version('sevenn') - -from e3nn import __version__ as e3nn_ver - -if Version(e3nn_ver) < Version('0.5.0'): - raise ValueError( - 'The e3nn version MUST be 0.5.0 or later due to changes in CG coefficient ' - 'convention.' - ) +from importlib.metadata import version + +from packaging.version import Version + +__version__ = version('sevenn') + +from e3nn import __version__ as e3nn_ver + +if Version(e3nn_ver) < Version('0.5.0'): + raise ValueError( + 'The e3nn version MUST be 0.5.0 or later due to changes in CG coefficient ' + 'convention.' + ) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 16f58640b1ac6fead56c3a802b5fe32431897b71..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/_const.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/_const.cpython-310.pyc deleted file mode 100644 index aec61878f46aec9a92dea1495af02a0a6455b5b0..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/_const.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/_keys.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/_keys.cpython-310.pyc deleted file mode 100644 index 41d9e6107852328bcffc771d55e11af4df329f43..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/_keys.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/atom_graph_data.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/atom_graph_data.cpython-310.pyc deleted file mode 100644 index 68284e4d537453a5df64c1a00e65d12b10aaa87c..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/atom_graph_data.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/calculator.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/calculator.cpython-310.pyc deleted file mode 100644 index 472b61a2aadcb889e60b29c60f5b4e31d8878788..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/calculator.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/checkpoint.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/checkpoint.cpython-310.pyc deleted file mode 100644 index 54ee9d8573bd5121535b081ada352ebc5ccef1d5..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/checkpoint.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/model_build.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/model_build.cpython-310.pyc deleted file mode 100644 index f40f0573b12f208c6d90acf01ab4d630bc456f6d..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/model_build.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/util.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/util.cpython-310.pyc deleted file mode 100644 index d944acb195fee14da3596908c9500e1487c23c51..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/util.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/_const.py b/mace-bench/3rdparty/SevenNet/sevenn/_const.py index 528a77b8b6923aded3e406621a879193fde1aec4..284b64e49c4e7c25daca517df44ae221065a9084 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/_const.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/_const.py @@ -1,310 +1,310 @@ -import os -from enum import Enum -from typing import Dict - -import torch - -import sevenn._keys as KEY -from sevenn.nn.activation import ShiftedSoftPlus - -NUM_UNIV_ELEMENT = 119 # Z = 0 ~ 118 - -IMPLEMENTED_RADIAL_BASIS = ['bessel'] -IMPLEMENTED_CUTOFF_FUNCTION = ['poly_cut', 'XPLOR'] -# TODO: support None. This became difficult because of parallel model -IMPLEMENTED_SELF_CONNECTION_TYPE = ['nequip', 'linear'] -IMPLEMENTED_INTERACTION_TYPE = ['nequip'] - -IMPLEMENTED_SHIFT = ['per_atom_energy_mean', 'elemwise_reference_energies'] -IMPLEMENTED_SCALE = ['force_rms', 'per_atom_energy_std', 'elemwise_force_rms'] - -SUPPORTING_METRICS = ['RMSE', 'ComponentRMSE', 'MAE', 'Loss'] -SUPPORTING_ERROR_TYPES = [ - 'TotalEnergy', - 'Energy', - 'Force', - 'Stress', - 'Stress_GPa', - 'TotalLoss', -] - -IMPLEMENTED_MODEL = ['E3_equivariant_model'] - -# string input to real torch function -ACTIVATION = { - 'relu': torch.nn.functional.relu, - 'silu': torch.nn.functional.silu, - 'tanh': torch.tanh, - 'abs': torch.abs, - 'ssp': ShiftedSoftPlus, - 'sigmoid': torch.sigmoid, - 'elu': torch.nn.functional.elu, -} -ACTIVATION_FOR_EVEN = { - 'ssp': ShiftedSoftPlus, - 'silu': torch.nn.functional.silu, -} -ACTIVATION_FOR_ODD = {'tanh': torch.tanh, 'abs': torch.abs} -ACTIVATION_DICT = {'e': ACTIVATION_FOR_EVEN, 'o': ACTIVATION_FOR_ODD} - -_prefix = os.path.abspath(f'{os.path.dirname(__file__)}/pretrained_potentials') -SEVENNET_0_11Jul2024 = f'{_prefix}/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth' -SEVENNET_0_22May2024 = f'{_prefix}/SevenNet_0__22May2024/checkpoint_sevennet_0.pth' -SEVENNET_l3i5 = f'{_prefix}/SevenNet_l3i5/checkpoint_l3i5.pth' -SEVENNET_MF_0 = f'{_prefix}/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth' -SEVENNET_MF_ompa = f'{_prefix}/SevenNet_MF_ompa/checkpoint_sevennet_mf_ompa.pth' -SEVENNET_omat = f'{_prefix}/SevenNet_omat/checkpoint_sevennet_omat.pth' - -_git_prefix = 'https://github.com/MDIL-SNU/SevenNet/releases/download' -CHECKPOINT_DOWNLOAD_LINKS = { - SEVENNET_MF_ompa: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_mf_ompa.pth', - SEVENNET_omat: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_omat.pth', -} -# to avoid torch script to compile torch_geometry.data -AtomGraphDataType = Dict[str, torch.Tensor] - - -class LossType(Enum): # only used for train_v1, do not use it afterwards - ENERGY = 'energy' # eV or eV/atom - FORCE = 'force' # eV/A - STRESS = 'stress' # kB - - -def error_record_condition(x): - if type(x) is not list: - return False - for v in x: - if type(v) is not list or len(v) != 2: - return False - if v[0] not in SUPPORTING_ERROR_TYPES: - return False - if v[0] == 'TotalLoss': - continue - if v[1] not in SUPPORTING_METRICS: - return False - return True - - -DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG = { - KEY.CUTOFF: 4.5, - KEY.NODE_FEATURE_MULTIPLICITY: 32, - KEY.IRREPS_MANUAL: False, - KEY.LMAX: 1, - KEY.LMAX_EDGE: -1, # -1 means lmax_edge = lmax - KEY.LMAX_NODE: -1, # -1 means lmax_node = lmax - KEY.IS_PARITY: True, - KEY.NUM_CONVOLUTION: 3, - KEY.RADIAL_BASIS: { - KEY.RADIAL_BASIS_NAME: 'bessel', - }, - KEY.CUTOFF_FUNCTION: { - KEY.CUTOFF_FUNCTION_NAME: 'poly_cut', - }, - KEY.ACTIVATION_RADIAL: 'silu', - KEY.ACTIVATION_SCARLAR: {'e': 'silu', 'o': 'tanh'}, - KEY.ACTIVATION_GATE: {'e': 'silu', 'o': 'tanh'}, - KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: [64, 64], - # KEY.AVG_NUM_NEIGH: True, # deprecated - # KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated - KEY.CONV_DENOMINATOR: 'avg_num_neigh', - KEY.TRAIN_DENOMINTAOR: False, - KEY.TRAIN_SHIFT_SCALE: False, - # KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True - KEY.USE_BIAS_IN_LINEAR: False, - KEY.USE_MODAL_NODE_EMBEDDING: False, - KEY.USE_MODAL_SELF_INTER_INTRO: False, - KEY.USE_MODAL_SELF_INTER_OUTRO: False, - KEY.USE_MODAL_OUTPUT_BLOCK: False, - KEY.READOUT_AS_FCN: False, - # Applied af readout as fcn is True - KEY.READOUT_FCN_HIDDEN_NEURONS: [30, 30], - KEY.READOUT_FCN_ACTIVATION: 'relu', - KEY.SELF_CONNECTION_TYPE: 'nequip', - KEY.INTERACTION_TYPE: 'nequip', - KEY._NORMALIZE_SPH: True, - KEY.CUEQUIVARIANCE_CONFIG: {}, -} - - -# Basically, "If provided, it should be type of ..." -MODEL_CONFIG_CONDITION = { - KEY.NODE_FEATURE_MULTIPLICITY: int, - KEY.LMAX: int, - KEY.LMAX_EDGE: int, - KEY.LMAX_NODE: int, - KEY.IS_PARITY: bool, - KEY.RADIAL_BASIS: { - KEY.RADIAL_BASIS_NAME: lambda x: x in IMPLEMENTED_RADIAL_BASIS, - }, - KEY.CUTOFF_FUNCTION: { - KEY.CUTOFF_FUNCTION_NAME: lambda x: x in IMPLEMENTED_CUTOFF_FUNCTION, - }, - KEY.CUTOFF: float, - KEY.NUM_CONVOLUTION: int, - KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float) - or x - in [ - 'avg_num_neigh', - 'sqrt_avg_num_neigh', - ], - KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: list, - KEY.TRAIN_SHIFT_SCALE: bool, - KEY.TRAIN_DENOMINTAOR: bool, - KEY.USE_BIAS_IN_LINEAR: bool, - KEY.USE_MODAL_NODE_EMBEDDING: bool, - KEY.USE_MODAL_SELF_INTER_INTRO: bool, - KEY.USE_MODAL_SELF_INTER_OUTRO: bool, - KEY.USE_MODAL_OUTPUT_BLOCK: bool, - KEY.READOUT_AS_FCN: bool, - KEY.READOUT_FCN_HIDDEN_NEURONS: list, - KEY.READOUT_FCN_ACTIVATION: str, - KEY.ACTIVATION_RADIAL: str, - KEY.SELF_CONNECTION_TYPE: lambda x: ( - x in IMPLEMENTED_SELF_CONNECTION_TYPE - or ( - isinstance(x, list) - and all(sc in IMPLEMENTED_SELF_CONNECTION_TYPE for sc in x) - ) - ), - KEY.INTERACTION_TYPE: lambda x: x in IMPLEMENTED_INTERACTION_TYPE, - KEY._NORMALIZE_SPH: bool, - KEY.CUEQUIVARIANCE_CONFIG: dict, -} - - -def model_defaults(config): - defaults = DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG - - if KEY.READOUT_AS_FCN not in config: - config[KEY.READOUT_AS_FCN] = defaults[KEY.READOUT_AS_FCN] - if config[KEY.READOUT_AS_FCN] is False: - defaults.pop(KEY.READOUT_FCN_ACTIVATION, None) - defaults.pop(KEY.READOUT_FCN_HIDDEN_NEURONS, None) - - return defaults - - -DEFAULT_DATA_CONFIG = { - KEY.DTYPE: 'single', - KEY.DATA_FORMAT: 'ase', - KEY.DATA_FORMAT_ARGS: {}, - KEY.SAVE_DATASET: False, - KEY.SAVE_BY_LABEL: False, - KEY.SAVE_BY_TRAIN_VALID: False, - KEY.RATIO: 0.0, - KEY.BATCH_SIZE: 6, - KEY.PREPROCESS_NUM_CORES: 1, - KEY.COMPUTE_STATISTICS: True, - KEY.DATASET_TYPE: 'graph', - # KEY.USE_SPECIES_WISE_SHIFT_SCALE: False, - KEY.USE_MODAL_WISE_SHIFT: False, - KEY.USE_MODAL_WISE_SCALE: False, - KEY.SHIFT: 'per_atom_energy_mean', - KEY.SCALE: 'force_rms', - # KEY.DATA_SHUFFLE: True, - # KEY.DATA_WEIGHT: False, - # KEY.DATA_MODALITY: False, -} - -DATA_CONFIG_CONDITION = { - KEY.DTYPE: str, - KEY.DATA_FORMAT: str, - KEY.DATA_FORMAT_ARGS: dict, - KEY.SAVE_DATASET: str, - KEY.SAVE_BY_LABEL: bool, - KEY.SAVE_BY_TRAIN_VALID: bool, - KEY.RATIO: float, - KEY.BATCH_SIZE: int, - KEY.PREPROCESS_NUM_CORES: int, - KEY.DATASET_TYPE: lambda x: x in ['graph', 'atoms'], - # KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool, - KEY.SHIFT: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SHIFT, - KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE, - KEY.USE_MODAL_WISE_SHIFT: bool, - KEY.USE_MODAL_WISE_SCALE: bool, - # KEY.DATA_SHUFFLE: bool, - KEY.COMPUTE_STATISTICS: bool, - # KEY.DATA_WEIGHT: bool, - # KEY.DATA_MODALITY: bool, -} - - -def data_defaults(config): - defaults = DEFAULT_DATA_CONFIG - if KEY.LOAD_VALIDSET in config: - defaults.pop(KEY.RATIO, None) - return defaults - - -DEFAULT_TRAINING_CONFIG = { - KEY.RANDOM_SEED: 1, - KEY.EPOCH: 300, - KEY.LOSS: 'mse', - KEY.LOSS_PARAM: {}, - KEY.OPTIMIZER: 'adam', - KEY.OPTIM_PARAM: {}, - KEY.SCHEDULER: 'exponentiallr', - KEY.SCHEDULER_PARAM: {}, - KEY.FORCE_WEIGHT: 0.1, - KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default - KEY.PER_EPOCH: 5, - # KEY.USE_TESTSET: False, - KEY.CONTINUE: { - KEY.CHECKPOINT: False, - KEY.RESET_OPTIMIZER: False, - KEY.RESET_SCHEDULER: False, - KEY.RESET_EPOCH: False, - KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: True, - KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: True, - }, - # KEY.DEFAULT_MODAL: 'common', - KEY.CSV_LOG: 'log.csv', - KEY.NUM_WORKERS: 0, - KEY.IS_TRAIN_STRESS: True, - KEY.TRAIN_SHUFFLE: True, - KEY.ERROR_RECORD: [ - ['Energy', 'RMSE'], - ['Force', 'RMSE'], - ['Stress', 'RMSE'], - ['TotalLoss', 'None'], - ], - KEY.BEST_METRIC: 'TotalLoss', - KEY.USE_WEIGHT: False, - KEY.USE_MODALITY: False, -} - - -TRAINING_CONFIG_CONDITION = { - KEY.RANDOM_SEED: int, - KEY.EPOCH: int, - KEY.FORCE_WEIGHT: float, - KEY.STRESS_WEIGHT: float, - KEY.USE_TESTSET: None, # Not used - KEY.NUM_WORKERS: int, - KEY.PER_EPOCH: int, - KEY.CONTINUE: { - KEY.CHECKPOINT: str, - KEY.RESET_OPTIMIZER: bool, - KEY.RESET_SCHEDULER: bool, - KEY.RESET_EPOCH: bool, - KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: bool, - KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: bool, - }, - KEY.DEFAULT_MODAL: str, - KEY.IS_TRAIN_STRESS: bool, - KEY.TRAIN_SHUFFLE: bool, - KEY.ERROR_RECORD: error_record_condition, - KEY.BEST_METRIC: str, - KEY.CSV_LOG: str, - KEY.USE_MODALITY: bool, - KEY.USE_WEIGHT: bool, -} - - -def train_defaults(config): - defaults = DEFAULT_TRAINING_CONFIG - if KEY.IS_TRAIN_STRESS not in config: - config[KEY.IS_TRAIN_STRESS] = defaults[KEY.IS_TRAIN_STRESS] - if not config[KEY.IS_TRAIN_STRESS]: - defaults.pop(KEY.STRESS_WEIGHT, None) - return defaults +import os +from enum import Enum +from typing import Dict + +import torch + +import sevenn._keys as KEY +from sevenn.nn.activation import ShiftedSoftPlus + +NUM_UNIV_ELEMENT = 119 # Z = 0 ~ 118 + +IMPLEMENTED_RADIAL_BASIS = ['bessel'] +IMPLEMENTED_CUTOFF_FUNCTION = ['poly_cut', 'XPLOR'] +# TODO: support None. This became difficult because of parallel model +IMPLEMENTED_SELF_CONNECTION_TYPE = ['nequip', 'linear'] +IMPLEMENTED_INTERACTION_TYPE = ['nequip'] + +IMPLEMENTED_SHIFT = ['per_atom_energy_mean', 'elemwise_reference_energies'] +IMPLEMENTED_SCALE = ['force_rms', 'per_atom_energy_std', 'elemwise_force_rms'] + +SUPPORTING_METRICS = ['RMSE', 'ComponentRMSE', 'MAE', 'Loss'] +SUPPORTING_ERROR_TYPES = [ + 'TotalEnergy', + 'Energy', + 'Force', + 'Stress', + 'Stress_GPa', + 'TotalLoss', +] + +IMPLEMENTED_MODEL = ['E3_equivariant_model'] + +# string input to real torch function +ACTIVATION = { + 'relu': torch.nn.functional.relu, + 'silu': torch.nn.functional.silu, + 'tanh': torch.tanh, + 'abs': torch.abs, + 'ssp': ShiftedSoftPlus, + 'sigmoid': torch.sigmoid, + 'elu': torch.nn.functional.elu, +} +ACTIVATION_FOR_EVEN = { + 'ssp': ShiftedSoftPlus, + 'silu': torch.nn.functional.silu, +} +ACTIVATION_FOR_ODD = {'tanh': torch.tanh, 'abs': torch.abs} +ACTIVATION_DICT = {'e': ACTIVATION_FOR_EVEN, 'o': ACTIVATION_FOR_ODD} + +_prefix = os.path.abspath(f'{os.path.dirname(__file__)}/pretrained_potentials') +SEVENNET_0_11Jul2024 = f'{_prefix}/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth' +SEVENNET_0_22May2024 = f'{_prefix}/SevenNet_0__22May2024/checkpoint_sevennet_0.pth' +SEVENNET_l3i5 = f'{_prefix}/SevenNet_l3i5/checkpoint_l3i5.pth' +SEVENNET_MF_0 = f'{_prefix}/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth' +SEVENNET_MF_ompa = f'{_prefix}/SevenNet_MF_ompa/checkpoint_sevennet_mf_ompa.pth' +SEVENNET_omat = f'{_prefix}/SevenNet_omat/checkpoint_sevennet_omat.pth' + +_git_prefix = 'https://github.com/MDIL-SNU/SevenNet/releases/download' +CHECKPOINT_DOWNLOAD_LINKS = { + SEVENNET_MF_ompa: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_mf_ompa.pth', + SEVENNET_omat: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_omat.pth', +} +# to avoid torch script to compile torch_geometry.data +AtomGraphDataType = Dict[str, torch.Tensor] + + +class LossType(Enum): # only used for train_v1, do not use it afterwards + ENERGY = 'energy' # eV or eV/atom + FORCE = 'force' # eV/A + STRESS = 'stress' # kB + + +def error_record_condition(x): + if type(x) is not list: + return False + for v in x: + if type(v) is not list or len(v) != 2: + return False + if v[0] not in SUPPORTING_ERROR_TYPES: + return False + if v[0] == 'TotalLoss': + continue + if v[1] not in SUPPORTING_METRICS: + return False + return True + + +DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG = { + KEY.CUTOFF: 4.5, + KEY.NODE_FEATURE_MULTIPLICITY: 32, + KEY.IRREPS_MANUAL: False, + KEY.LMAX: 1, + KEY.LMAX_EDGE: -1, # -1 means lmax_edge = lmax + KEY.LMAX_NODE: -1, # -1 means lmax_node = lmax + KEY.IS_PARITY: True, + KEY.NUM_CONVOLUTION: 3, + KEY.RADIAL_BASIS: { + KEY.RADIAL_BASIS_NAME: 'bessel', + }, + KEY.CUTOFF_FUNCTION: { + KEY.CUTOFF_FUNCTION_NAME: 'poly_cut', + }, + KEY.ACTIVATION_RADIAL: 'silu', + KEY.ACTIVATION_SCARLAR: {'e': 'silu', 'o': 'tanh'}, + KEY.ACTIVATION_GATE: {'e': 'silu', 'o': 'tanh'}, + KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: [64, 64], + # KEY.AVG_NUM_NEIGH: True, # deprecated + # KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated + KEY.CONV_DENOMINATOR: 'avg_num_neigh', + KEY.TRAIN_DENOMINTAOR: False, + KEY.TRAIN_SHIFT_SCALE: False, + # KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True + KEY.USE_BIAS_IN_LINEAR: False, + KEY.USE_MODAL_NODE_EMBEDDING: False, + KEY.USE_MODAL_SELF_INTER_INTRO: False, + KEY.USE_MODAL_SELF_INTER_OUTRO: False, + KEY.USE_MODAL_OUTPUT_BLOCK: False, + KEY.READOUT_AS_FCN: False, + # Applied af readout as fcn is True + KEY.READOUT_FCN_HIDDEN_NEURONS: [30, 30], + KEY.READOUT_FCN_ACTIVATION: 'relu', + KEY.SELF_CONNECTION_TYPE: 'nequip', + KEY.INTERACTION_TYPE: 'nequip', + KEY._NORMALIZE_SPH: True, + KEY.CUEQUIVARIANCE_CONFIG: {}, +} + + +# Basically, "If provided, it should be type of ..." +MODEL_CONFIG_CONDITION = { + KEY.NODE_FEATURE_MULTIPLICITY: int, + KEY.LMAX: int, + KEY.LMAX_EDGE: int, + KEY.LMAX_NODE: int, + KEY.IS_PARITY: bool, + KEY.RADIAL_BASIS: { + KEY.RADIAL_BASIS_NAME: lambda x: x in IMPLEMENTED_RADIAL_BASIS, + }, + KEY.CUTOFF_FUNCTION: { + KEY.CUTOFF_FUNCTION_NAME: lambda x: x in IMPLEMENTED_CUTOFF_FUNCTION, + }, + KEY.CUTOFF: float, + KEY.NUM_CONVOLUTION: int, + KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float) + or x + in [ + 'avg_num_neigh', + 'sqrt_avg_num_neigh', + ], + KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: list, + KEY.TRAIN_SHIFT_SCALE: bool, + KEY.TRAIN_DENOMINTAOR: bool, + KEY.USE_BIAS_IN_LINEAR: bool, + KEY.USE_MODAL_NODE_EMBEDDING: bool, + KEY.USE_MODAL_SELF_INTER_INTRO: bool, + KEY.USE_MODAL_SELF_INTER_OUTRO: bool, + KEY.USE_MODAL_OUTPUT_BLOCK: bool, + KEY.READOUT_AS_FCN: bool, + KEY.READOUT_FCN_HIDDEN_NEURONS: list, + KEY.READOUT_FCN_ACTIVATION: str, + KEY.ACTIVATION_RADIAL: str, + KEY.SELF_CONNECTION_TYPE: lambda x: ( + x in IMPLEMENTED_SELF_CONNECTION_TYPE + or ( + isinstance(x, list) + and all(sc in IMPLEMENTED_SELF_CONNECTION_TYPE for sc in x) + ) + ), + KEY.INTERACTION_TYPE: lambda x: x in IMPLEMENTED_INTERACTION_TYPE, + KEY._NORMALIZE_SPH: bool, + KEY.CUEQUIVARIANCE_CONFIG: dict, +} + + +def model_defaults(config): + defaults = DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG + + if KEY.READOUT_AS_FCN not in config: + config[KEY.READOUT_AS_FCN] = defaults[KEY.READOUT_AS_FCN] + if config[KEY.READOUT_AS_FCN] is False: + defaults.pop(KEY.READOUT_FCN_ACTIVATION, None) + defaults.pop(KEY.READOUT_FCN_HIDDEN_NEURONS, None) + + return defaults + + +DEFAULT_DATA_CONFIG = { + KEY.DTYPE: 'single', + KEY.DATA_FORMAT: 'ase', + KEY.DATA_FORMAT_ARGS: {}, + KEY.SAVE_DATASET: False, + KEY.SAVE_BY_LABEL: False, + KEY.SAVE_BY_TRAIN_VALID: False, + KEY.RATIO: 0.0, + KEY.BATCH_SIZE: 6, + KEY.PREPROCESS_NUM_CORES: 1, + KEY.COMPUTE_STATISTICS: True, + KEY.DATASET_TYPE: 'graph', + # KEY.USE_SPECIES_WISE_SHIFT_SCALE: False, + KEY.USE_MODAL_WISE_SHIFT: False, + KEY.USE_MODAL_WISE_SCALE: False, + KEY.SHIFT: 'per_atom_energy_mean', + KEY.SCALE: 'force_rms', + # KEY.DATA_SHUFFLE: True, + # KEY.DATA_WEIGHT: False, + # KEY.DATA_MODALITY: False, +} + +DATA_CONFIG_CONDITION = { + KEY.DTYPE: str, + KEY.DATA_FORMAT: str, + KEY.DATA_FORMAT_ARGS: dict, + KEY.SAVE_DATASET: str, + KEY.SAVE_BY_LABEL: bool, + KEY.SAVE_BY_TRAIN_VALID: bool, + KEY.RATIO: float, + KEY.BATCH_SIZE: int, + KEY.PREPROCESS_NUM_CORES: int, + KEY.DATASET_TYPE: lambda x: x in ['graph', 'atoms'], + # KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool, + KEY.SHIFT: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SHIFT, + KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE, + KEY.USE_MODAL_WISE_SHIFT: bool, + KEY.USE_MODAL_WISE_SCALE: bool, + # KEY.DATA_SHUFFLE: bool, + KEY.COMPUTE_STATISTICS: bool, + # KEY.DATA_WEIGHT: bool, + # KEY.DATA_MODALITY: bool, +} + + +def data_defaults(config): + defaults = DEFAULT_DATA_CONFIG + if KEY.LOAD_VALIDSET in config: + defaults.pop(KEY.RATIO, None) + return defaults + + +DEFAULT_TRAINING_CONFIG = { + KEY.RANDOM_SEED: 1, + KEY.EPOCH: 300, + KEY.LOSS: 'mse', + KEY.LOSS_PARAM: {}, + KEY.OPTIMIZER: 'adam', + KEY.OPTIM_PARAM: {}, + KEY.SCHEDULER: 'exponentiallr', + KEY.SCHEDULER_PARAM: {}, + KEY.FORCE_WEIGHT: 0.1, + KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default + KEY.PER_EPOCH: 5, + # KEY.USE_TESTSET: False, + KEY.CONTINUE: { + KEY.CHECKPOINT: False, + KEY.RESET_OPTIMIZER: False, + KEY.RESET_SCHEDULER: False, + KEY.RESET_EPOCH: False, + KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: True, + KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: True, + }, + # KEY.DEFAULT_MODAL: 'common', + KEY.CSV_LOG: 'log.csv', + KEY.NUM_WORKERS: 0, + KEY.IS_TRAIN_STRESS: True, + KEY.TRAIN_SHUFFLE: True, + KEY.ERROR_RECORD: [ + ['Energy', 'RMSE'], + ['Force', 'RMSE'], + ['Stress', 'RMSE'], + ['TotalLoss', 'None'], + ], + KEY.BEST_METRIC: 'TotalLoss', + KEY.USE_WEIGHT: False, + KEY.USE_MODALITY: False, +} + + +TRAINING_CONFIG_CONDITION = { + KEY.RANDOM_SEED: int, + KEY.EPOCH: int, + KEY.FORCE_WEIGHT: float, + KEY.STRESS_WEIGHT: float, + KEY.USE_TESTSET: None, # Not used + KEY.NUM_WORKERS: int, + KEY.PER_EPOCH: int, + KEY.CONTINUE: { + KEY.CHECKPOINT: str, + KEY.RESET_OPTIMIZER: bool, + KEY.RESET_SCHEDULER: bool, + KEY.RESET_EPOCH: bool, + KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: bool, + KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: bool, + }, + KEY.DEFAULT_MODAL: str, + KEY.IS_TRAIN_STRESS: bool, + KEY.TRAIN_SHUFFLE: bool, + KEY.ERROR_RECORD: error_record_condition, + KEY.BEST_METRIC: str, + KEY.CSV_LOG: str, + KEY.USE_MODALITY: bool, + KEY.USE_WEIGHT: bool, +} + + +def train_defaults(config): + defaults = DEFAULT_TRAINING_CONFIG + if KEY.IS_TRAIN_STRESS not in config: + config[KEY.IS_TRAIN_STRESS] = defaults[KEY.IS_TRAIN_STRESS] + if not config[KEY.IS_TRAIN_STRESS]: + defaults.pop(KEY.STRESS_WEIGHT, None) + return defaults diff --git a/mace-bench/3rdparty/SevenNet/sevenn/_keys.py b/mace-bench/3rdparty/SevenNet/sevenn/_keys.py index f91b6de15a27fdd50bf975acad3f00510d5c6fe3..1ff5a614484c81d9418f3c84cd0f7b9144384b55 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/_keys.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/_keys.py @@ -1,226 +1,226 @@ -""" -How to add new feature? - -1. Add new key to this file. -2. Add new key to _const.py -2.1. if the type of input is consistent, - write adequate condition and default to _const.py. -2.2. if the type of input is not consistent, - you must add your own input validation code to - parse_input.py -""" - -from typing import Final - -# see -# https://github.com/pytorch/pytorch/issues/52312 -# for FYI - -# ~~ keys ~~ # -# PyG : primitive key of torch_geometric.data.Data type - -# ==================================================# -# ~~~~~~~~~~~~~~~~~ KEY for data ~~~~~~~~~~~~~~~~~~ # -# ==================================================# -# some raw properties of graph -ATOMIC_NUMBERS: Final[str] = 'atomic_numbers' # (N) -POS: Final[str] = 'pos' # (N, 3) PyG -CELL: Final[str] = 'cell_lattice_vectors' # (3, 3) -CELL_SHIFT: Final[str] = 'pbc_shift' # (N, 3) -CELL_VOLUME: Final[str] = 'cell_volume' - -EDGE_VEC: Final[str] = 'edge_vec' # (N_edge, 3) -EDGE_LENGTH: Final[str] = 'edge_length' # (N_edge, 1) - -# some primary data of graph -EDGE_IDX: Final[str] = 'edge_index' # (2, N_edge) PyG -ATOM_TYPE: Final[str] = 'atom_type' # (N) one-hot index of nodes -NODE_FEATURE: Final[str] = 'x' # (N, ?) PyG -NODE_FEATURE_GHOST: Final[str] = 'x_ghost' -NODE_ATTR: Final[str] = 'node_attr' # (N, N_species) from one_hot -MODAL_ATTR: Final[str] = ( - 'modal_attr' # (1, N_modalities) for handling multi-modal -) -MODAL_TYPE: Final[str] = 'modal_type' # (1) one-hot index of modal -EDGE_ATTR: Final[str] = 'edge_attr' # (from spherical harmonics) -EDGE_EMBEDDING: Final[str] = 'edge_embedding' # (from edge embedding) - -# inputs of loss function -ENERGY: Final[str] = 'total_energy' # (1) -FORCE: Final[str] = 'force_of_atoms' # (N, 3) -STRESS: Final[str] = 'stress' # (6) - -# This is for training, per atom scale. -SCALED_ENERGY: Final[str] = 'scaled_total_energy' - -# general outputs of models -SCALED_ATOMIC_ENERGY: Final[str] = 'scaled_atomic_energy' -ATOMIC_ENERGY: Final[str] = 'atomic_energy' -PRED_TOTAL_ENERGY: Final[str] = 'inferred_total_energy' - -PRED_PER_ATOM_ENERGY: Final[str] = 'inferred_per_atom_energy' -PER_ATOM_ENERGY: Final[str] = 'per_atom_energy' - -PRED_FORCE: Final[str] = 'inferred_force' -SCALED_FORCE: Final[str] = 'scaled_force' - -PRED_STRESS: Final[str] = 'inferred_stress' -SCALED_STRESS: Final[str] = 'scaled_stress' - -# very general data property for AtomGraphData -NUM_ATOMS: Final[str] = 'num_atoms' # int -NUM_GHOSTS: Final[str] = 'num_ghosts' -NLOCAL: Final[str] = 'nlocal' # only for lammps parallel, must be on cpu -USER_LABEL: Final[str] = 'user_label' -DATA_WEIGHT: Final[str] = 'data_weight' # weight for given data -DATA_MODALITY: Final[str] = ( - 'data_modality' # modality of given data. e.g. PBE and SCAN -) -BATCH: Final[str] = 'batch' - -TAG = 'tag' # replace USER_LABEL - -# etc -SELF_CONNECTION_TEMP: Final[str] = 'self_cont_tmp' -BATCH_SIZE: Final[str] = 'batch_size' -INFO: Final[str] = 'data_info' - -# something special -LABEL_NONE: Final[str] = 'No_label' - -# ==================================================# -# ~~~~~~ KEY for train/data configuration ~~~~~~~~ # -# ==================================================# -PREPROCESS_NUM_CORES = 'preprocess_num_cores' -SAVE_DATASET = 'save_dataset_path' -SAVE_BY_LABEL = 'save_by_label' -SAVE_BY_TRAIN_VALID = 'save_by_train_valid' -DATA_FORMAT = 'data_format' -DATA_FORMAT_ARGS = 'data_format_args' -STRUCTURE_LIST = 'structure_list' -LOAD_DATASET = 'load_dataset_path' # not used in v2 -LOAD_TRAINSET = 'load_trainset_path' -LOAD_VALIDSET = 'load_validset_path' -LOAD_TESTSET = 'load_testset_path' -FORMAT_OUTPUTS = 'format_outputs_for_ase' -COMPUTE_STATISTICS = 'compute_statistics' -DATASET_TYPE = 'dataset_type' - -RANDOM_SEED = 'random_seed' -RATIO = 'data_divide_ratio' -USE_TESTSET = 'use_testset' -EPOCH = 'epoch' -LOSS = 'loss' -LOSS_PARAM = 'loss_param' -OPTIMIZER = 'optimizer' -OPTIM_PARAM = 'optim_param' -SCHEDULER = 'scheduler' -SCHEDULER_PARAM = 'scheduler_param' -FORCE_WEIGHT = 'force_loss_weight' -STRESS_WEIGHT = 'stress_loss_weight' -DEVICE = 'device' -DTYPE = 'dtype' - -TRAIN_SHUFFLE = 'train_shuffle' - -IS_TRAIN_STRESS = 'is_train_stress' - -CONTINUE = 'continue' -CHECKPOINT = 'checkpoint' -RESET_OPTIMIZER = 'reset_optimizer' -RESET_SCHEDULER = 'reset_scheduler' -RESET_EPOCH = 'reset_epoch' -USE_STATISTIC_VALUES_OF_CHECKPOINT = 'use_statistic_values_of_checkpoint' -USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY = ( - 'use_statistic_values_for_cp_modal_only' -) - -CSV_LOG = 'csv_log' - -ERROR_RECORD = 'error_record' -BEST_METRIC = 'best_metric' - -NUM_WORKERS = 'num_workers' # not work - -RANK = 'rank' -LOCAL_RANK = 'local_rank' -WORLD_SIZE = 'world_size' -IS_DDP = 'is_ddp' -DDP_BACKEND = 'ddp_backend' -PER_EPOCH = 'per_epoch' - -USE_WEIGHT = 'use_weight' -USE_MODALITY = 'use_modality' -DEFAULT_MODAL = 'default_modal' - - -# ==================================================# -# ~~~~~~~~ KEY for model configuration ~~~~~~~~~~~ # -# ==================================================# -# ~~ global model configuration ~~ # -# note that these names are directly used for input.yaml for user input -MODEL_TYPE = '_model_type' -CUTOFF = 'cutoff' -CHEMICAL_SPECIES = 'chemical_species' -MODAL_LIST = 'modal_list' -CHEMICAL_SPECIES_BY_ATOMIC_NUMBER = '_chemical_species_by_atomic_number' -NUM_SPECIES = '_number_of_species' -NUM_MODALITIES = '_number_of_modalities' -TYPE_MAP = '_type_map' -MODAL_MAP = '_modal_map' - -# ~~ E3 equivariant model build configuration keys ~~ # -# see model_build default_config for type -IRREPS_MANUAL = 'irreps_manual' -NODE_FEATURE_MULTIPLICITY = 'channel' - -RADIAL_BASIS = 'radial_basis' -BESSEL_BASIS_NUM = 'bessel_basis_num' - -CUTOFF_FUNCTION = 'cutoff_function' -POLY_CUT_P = 'poly_cut_p_value' - -LMAX = 'lmax' -LMAX_EDGE = 'lmax_edge' -LMAX_NODE = 'lmax_node' -IS_PARITY = 'is_parity' -CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS = 'weight_nn_hidden_neurons' -NUM_CONVOLUTION = 'num_convolution_layer' -ACTIVATION_SCARLAR = 'act_scalar' -ACTIVATION_GATE = 'act_gate' -ACTIVATION_RADIAL = 'act_radial' - -SELF_CONNECTION_TYPE = 'self_connection_type' - -RADIAL_BASIS_NAME = 'radial_basis_name' -CUTOFF_FUNCTION_NAME = 'cutoff_function_name' - -USE_BIAS_IN_LINEAR = 'use_bias_in_linear' - -USE_MODAL_NODE_EMBEDDING = 'use_modal_node_embedding' -USE_MODAL_SELF_INTER_INTRO = 'use_modal_self_inter_intro' -USE_MODAL_SELF_INTER_OUTRO = 'use_modal_self_inter_outro' -USE_MODAL_OUTPUT_BLOCK = 'use_modal_output_block' - -READOUT_AS_FCN = 'readout_as_fcn' -READOUT_FCN_HIDDEN_NEURONS = 'readout_fcn_hidden_neurons' -READOUT_FCN_ACTIVATION = 'readout_fcn_activation' - -AVG_NUM_NEIGH = 'avg_num_neigh' -CONV_DENOMINATOR = 'conv_denominator' -SHIFT = 'shift' -SCALE = 'scale' - -USE_SPECIES_WISE_SHIFT_SCALE = 'use_species_wise_shift_scale' -USE_MODAL_WISE_SHIFT = 'use_modal_wise_shift' -USE_MODAL_WISE_SCALE = 'use_modal_wise_scale' - -TRAIN_SHIFT_SCALE = 'train_shift_scale' -TRAIN_DENOMINTAOR = 'train_denominator' -INTERACTION_TYPE = 'interaction_type' -TRAIN_AVG_NUM_NEIGH = 'train_avg_num_neigh' # deprecated - -CUEQUIVARIANCE_CONFIG = 'cuequivariance_config' - -_NORMALIZE_SPH = '_normalize_sph' -OPTIMIZE_BY_REDUCE = 'optimize_by_reduce' +""" +How to add new feature? + +1. Add new key to this file. +2. Add new key to _const.py +2.1. if the type of input is consistent, + write adequate condition and default to _const.py. +2.2. if the type of input is not consistent, + you must add your own input validation code to + parse_input.py +""" + +from typing import Final + +# see +# https://github.com/pytorch/pytorch/issues/52312 +# for FYI + +# ~~ keys ~~ # +# PyG : primitive key of torch_geometric.data.Data type + +# ==================================================# +# ~~~~~~~~~~~~~~~~~ KEY for data ~~~~~~~~~~~~~~~~~~ # +# ==================================================# +# some raw properties of graph +ATOMIC_NUMBERS: Final[str] = 'atomic_numbers' # (N) +POS: Final[str] = 'pos' # (N, 3) PyG +CELL: Final[str] = 'cell_lattice_vectors' # (3, 3) +CELL_SHIFT: Final[str] = 'pbc_shift' # (N, 3) +CELL_VOLUME: Final[str] = 'cell_volume' + +EDGE_VEC: Final[str] = 'edge_vec' # (N_edge, 3) +EDGE_LENGTH: Final[str] = 'edge_length' # (N_edge, 1) + +# some primary data of graph +EDGE_IDX: Final[str] = 'edge_index' # (2, N_edge) PyG +ATOM_TYPE: Final[str] = 'atom_type' # (N) one-hot index of nodes +NODE_FEATURE: Final[str] = 'x' # (N, ?) PyG +NODE_FEATURE_GHOST: Final[str] = 'x_ghost' +NODE_ATTR: Final[str] = 'node_attr' # (N, N_species) from one_hot +MODAL_ATTR: Final[str] = ( + 'modal_attr' # (1, N_modalities) for handling multi-modal +) +MODAL_TYPE: Final[str] = 'modal_type' # (1) one-hot index of modal +EDGE_ATTR: Final[str] = 'edge_attr' # (from spherical harmonics) +EDGE_EMBEDDING: Final[str] = 'edge_embedding' # (from edge embedding) + +# inputs of loss function +ENERGY: Final[str] = 'total_energy' # (1) +FORCE: Final[str] = 'force_of_atoms' # (N, 3) +STRESS: Final[str] = 'stress' # (6) + +# This is for training, per atom scale. +SCALED_ENERGY: Final[str] = 'scaled_total_energy' + +# general outputs of models +SCALED_ATOMIC_ENERGY: Final[str] = 'scaled_atomic_energy' +ATOMIC_ENERGY: Final[str] = 'atomic_energy' +PRED_TOTAL_ENERGY: Final[str] = 'inferred_total_energy' + +PRED_PER_ATOM_ENERGY: Final[str] = 'inferred_per_atom_energy' +PER_ATOM_ENERGY: Final[str] = 'per_atom_energy' + +PRED_FORCE: Final[str] = 'inferred_force' +SCALED_FORCE: Final[str] = 'scaled_force' + +PRED_STRESS: Final[str] = 'inferred_stress' +SCALED_STRESS: Final[str] = 'scaled_stress' + +# very general data property for AtomGraphData +NUM_ATOMS: Final[str] = 'num_atoms' # int +NUM_GHOSTS: Final[str] = 'num_ghosts' +NLOCAL: Final[str] = 'nlocal' # only for lammps parallel, must be on cpu +USER_LABEL: Final[str] = 'user_label' +DATA_WEIGHT: Final[str] = 'data_weight' # weight for given data +DATA_MODALITY: Final[str] = ( + 'data_modality' # modality of given data. e.g. PBE and SCAN +) +BATCH: Final[str] = 'batch' + +TAG = 'tag' # replace USER_LABEL + +# etc +SELF_CONNECTION_TEMP: Final[str] = 'self_cont_tmp' +BATCH_SIZE: Final[str] = 'batch_size' +INFO: Final[str] = 'data_info' + +# something special +LABEL_NONE: Final[str] = 'No_label' + +# ==================================================# +# ~~~~~~ KEY for train/data configuration ~~~~~~~~ # +# ==================================================# +PREPROCESS_NUM_CORES = 'preprocess_num_cores' +SAVE_DATASET = 'save_dataset_path' +SAVE_BY_LABEL = 'save_by_label' +SAVE_BY_TRAIN_VALID = 'save_by_train_valid' +DATA_FORMAT = 'data_format' +DATA_FORMAT_ARGS = 'data_format_args' +STRUCTURE_LIST = 'structure_list' +LOAD_DATASET = 'load_dataset_path' # not used in v2 +LOAD_TRAINSET = 'load_trainset_path' +LOAD_VALIDSET = 'load_validset_path' +LOAD_TESTSET = 'load_testset_path' +FORMAT_OUTPUTS = 'format_outputs_for_ase' +COMPUTE_STATISTICS = 'compute_statistics' +DATASET_TYPE = 'dataset_type' + +RANDOM_SEED = 'random_seed' +RATIO = 'data_divide_ratio' +USE_TESTSET = 'use_testset' +EPOCH = 'epoch' +LOSS = 'loss' +LOSS_PARAM = 'loss_param' +OPTIMIZER = 'optimizer' +OPTIM_PARAM = 'optim_param' +SCHEDULER = 'scheduler' +SCHEDULER_PARAM = 'scheduler_param' +FORCE_WEIGHT = 'force_loss_weight' +STRESS_WEIGHT = 'stress_loss_weight' +DEVICE = 'device' +DTYPE = 'dtype' + +TRAIN_SHUFFLE = 'train_shuffle' + +IS_TRAIN_STRESS = 'is_train_stress' + +CONTINUE = 'continue' +CHECKPOINT = 'checkpoint' +RESET_OPTIMIZER = 'reset_optimizer' +RESET_SCHEDULER = 'reset_scheduler' +RESET_EPOCH = 'reset_epoch' +USE_STATISTIC_VALUES_OF_CHECKPOINT = 'use_statistic_values_of_checkpoint' +USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY = ( + 'use_statistic_values_for_cp_modal_only' +) + +CSV_LOG = 'csv_log' + +ERROR_RECORD = 'error_record' +BEST_METRIC = 'best_metric' + +NUM_WORKERS = 'num_workers' # not work + +RANK = 'rank' +LOCAL_RANK = 'local_rank' +WORLD_SIZE = 'world_size' +IS_DDP = 'is_ddp' +DDP_BACKEND = 'ddp_backend' +PER_EPOCH = 'per_epoch' + +USE_WEIGHT = 'use_weight' +USE_MODALITY = 'use_modality' +DEFAULT_MODAL = 'default_modal' + + +# ==================================================# +# ~~~~~~~~ KEY for model configuration ~~~~~~~~~~~ # +# ==================================================# +# ~~ global model configuration ~~ # +# note that these names are directly used for input.yaml for user input +MODEL_TYPE = '_model_type' +CUTOFF = 'cutoff' +CHEMICAL_SPECIES = 'chemical_species' +MODAL_LIST = 'modal_list' +CHEMICAL_SPECIES_BY_ATOMIC_NUMBER = '_chemical_species_by_atomic_number' +NUM_SPECIES = '_number_of_species' +NUM_MODALITIES = '_number_of_modalities' +TYPE_MAP = '_type_map' +MODAL_MAP = '_modal_map' + +# ~~ E3 equivariant model build configuration keys ~~ # +# see model_build default_config for type +IRREPS_MANUAL = 'irreps_manual' +NODE_FEATURE_MULTIPLICITY = 'channel' + +RADIAL_BASIS = 'radial_basis' +BESSEL_BASIS_NUM = 'bessel_basis_num' + +CUTOFF_FUNCTION = 'cutoff_function' +POLY_CUT_P = 'poly_cut_p_value' + +LMAX = 'lmax' +LMAX_EDGE = 'lmax_edge' +LMAX_NODE = 'lmax_node' +IS_PARITY = 'is_parity' +CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS = 'weight_nn_hidden_neurons' +NUM_CONVOLUTION = 'num_convolution_layer' +ACTIVATION_SCARLAR = 'act_scalar' +ACTIVATION_GATE = 'act_gate' +ACTIVATION_RADIAL = 'act_radial' + +SELF_CONNECTION_TYPE = 'self_connection_type' + +RADIAL_BASIS_NAME = 'radial_basis_name' +CUTOFF_FUNCTION_NAME = 'cutoff_function_name' + +USE_BIAS_IN_LINEAR = 'use_bias_in_linear' + +USE_MODAL_NODE_EMBEDDING = 'use_modal_node_embedding' +USE_MODAL_SELF_INTER_INTRO = 'use_modal_self_inter_intro' +USE_MODAL_SELF_INTER_OUTRO = 'use_modal_self_inter_outro' +USE_MODAL_OUTPUT_BLOCK = 'use_modal_output_block' + +READOUT_AS_FCN = 'readout_as_fcn' +READOUT_FCN_HIDDEN_NEURONS = 'readout_fcn_hidden_neurons' +READOUT_FCN_ACTIVATION = 'readout_fcn_activation' + +AVG_NUM_NEIGH = 'avg_num_neigh' +CONV_DENOMINATOR = 'conv_denominator' +SHIFT = 'shift' +SCALE = 'scale' + +USE_SPECIES_WISE_SHIFT_SCALE = 'use_species_wise_shift_scale' +USE_MODAL_WISE_SHIFT = 'use_modal_wise_shift' +USE_MODAL_WISE_SCALE = 'use_modal_wise_scale' + +TRAIN_SHIFT_SCALE = 'train_shift_scale' +TRAIN_DENOMINTAOR = 'train_denominator' +INTERACTION_TYPE = 'interaction_type' +TRAIN_AVG_NUM_NEIGH = 'train_avg_num_neigh' # deprecated + +CUEQUIVARIANCE_CONFIG = 'cuequivariance_config' + +_NORMALIZE_SPH = '_normalize_sph' +OPTIMIZE_BY_REDUCE = 'optimize_by_reduce' diff --git a/mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py b/mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py index ee5b7bb02c22781e346e464eab04f245f995cf73..e0de629a5e6e83352ba266d10099f2324658e35f 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py @@ -1,75 +1,75 @@ -from typing import Optional - -import torch -import torch_geometric.data - -import sevenn._keys as KEY -import sevenn.util - - -class AtomGraphData(torch_geometric.data.Data): - """ - Args: - x (Tensor, optional): atomic numbers with shape :obj:`[num_nodes, - atomic_numbers]`. (default: :obj:`None`) - edge_index (LongTensor, optional): Graph connectivity in coordinate - format with shape :obj:`[2, num_edges]`. (default: :obj:`None`) - edge_attr (Tensor, optional): Edge feature matrix with shape - :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) - y_energy: scalar # unit of eV (VASP raw) - y_force: [num_nodes, 3] # unit of eV/A (VASP raw) - y_stress: [6] # [xx, yy, zz, xy, yz, zx] # unit of eV/A^3 (VASP raw) - pos (Tensor, optional): Node position matrix with shape - :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) - **kwargs (optional): Additional attributes. - - x, y_force, pos should be aligned with each other. - """ - - def __init__( - self, - x: Optional[torch.Tensor] = None, - edge_index: Optional[torch.Tensor] = None, - pos: Optional[torch.Tensor] = None, - edge_attr: Optional[torch.Tensor] = None, - **kwargs - ): - super(AtomGraphData, self).__init__(x, edge_index, edge_attr, pos=pos) - self[KEY.NODE_ATTR] = x # ? - for k, v in kwargs.items(): - self[k] = v - - def to_numpy_dict(self): - # This is not debugged yet! - dct = { - k: v.detach().cpu().numpy() if type(v) is torch.Tensor else v - for k, v in self.items() - } - return dct - - def fit_dimension(self): - per_atom_keys = [ - KEY.ATOMIC_NUMBERS, - KEY.ATOMIC_ENERGY, - KEY.POS, - KEY.FORCE, - KEY.PRED_FORCE, - ] - natoms = self.num_atoms.item() - for k, v in self.items(): - if not isinstance(v, torch.Tensor): - continue - if natoms == 1 and k in per_atom_keys: - self[k] = v.squeeze().unsqueeze(0) - else: - self[k] = v.squeeze() - return self - - @staticmethod - def from_numpy_dict(dct): - for k, v in dct.items(): - if k == KEY.CELL_SHIFT: - dct[k] = torch.Tensor(v) # this is special - else: - dct[k] = sevenn.util.dtype_correct(v) - return AtomGraphData(**dct) +from typing import Optional + +import torch +import torch_geometric.data + +import sevenn._keys as KEY +import sevenn.util + + +class AtomGraphData(torch_geometric.data.Data): + """ + Args: + x (Tensor, optional): atomic numbers with shape :obj:`[num_nodes, + atomic_numbers]`. (default: :obj:`None`) + edge_index (LongTensor, optional): Graph connectivity in coordinate + format with shape :obj:`[2, num_edges]`. (default: :obj:`None`) + edge_attr (Tensor, optional): Edge feature matrix with shape + :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) + y_energy: scalar # unit of eV (VASP raw) + y_force: [num_nodes, 3] # unit of eV/A (VASP raw) + y_stress: [6] # [xx, yy, zz, xy, yz, zx] # unit of eV/A^3 (VASP raw) + pos (Tensor, optional): Node position matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + **kwargs (optional): Additional attributes. + + x, y_force, pos should be aligned with each other. + """ + + def __init__( + self, + x: Optional[torch.Tensor] = None, + edge_index: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + edge_attr: Optional[torch.Tensor] = None, + **kwargs + ): + super(AtomGraphData, self).__init__(x, edge_index, edge_attr, pos=pos) + self[KEY.NODE_ATTR] = x # ? + for k, v in kwargs.items(): + self[k] = v + + def to_numpy_dict(self): + # This is not debugged yet! + dct = { + k: v.detach().cpu().numpy() if type(v) is torch.Tensor else v + for k, v in self.items() + } + return dct + + def fit_dimension(self): + per_atom_keys = [ + KEY.ATOMIC_NUMBERS, + KEY.ATOMIC_ENERGY, + KEY.POS, + KEY.FORCE, + KEY.PRED_FORCE, + ] + natoms = self.num_atoms.item() + for k, v in self.items(): + if not isinstance(v, torch.Tensor): + continue + if natoms == 1 and k in per_atom_keys: + self[k] = v.squeeze().unsqueeze(0) + else: + self[k] = v.squeeze() + return self + + @staticmethod + def from_numpy_dict(dct): + for k, v in dct.items(): + if k == KEY.CELL_SHIFT: + dct[k] = torch.Tensor(v) # this is special + else: + dct[k] = sevenn.util.dtype_correct(v) + return AtomGraphData(**dct) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/calculator.py b/mace-bench/3rdparty/SevenNet/sevenn/calculator.py index 7a7f386e98850efad4009fb192414dc3c5adf302..e237886bd9bac00bab902219c8bba16f6c5991b4 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/calculator.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/calculator.py @@ -1,846 +1,846 @@ -import ctypes -import os -import pathlib -import warnings -from typing import Any, Dict, Optional, Union - -import numpy as np -import torch -import torch.jit -import torch.jit._script -from ase.calculators.calculator import Calculator, all_changes -from ase.calculators.mixing import SumCalculator -from ase.data import chemical_symbols - -import sevenn._keys as KEY -import sevenn.util as util -from sevenn.atom_graph_data import AtomGraphData -from sevenn.nn.sequential import AtomGraphSequential -from sevenn.train.dataload import unlabeled_atoms_to_graph -import logging - -torch_script_type = torch.jit._script.RecursiveScriptModule - - -class SevenNetCalculator(Calculator): - """Supporting properties: - 'free_energy', 'energy', 'forces', 'stress', 'energies' - free_energy equals energy. 'energies' stores atomic energy. - - Multi-GPU acceleration is not supported with ASE calculator. - You should use LAMMPS for the acceleration. - """ - - def __init__( - self, - model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0', - file_type: str = 'checkpoint', - device: Union[torch.device, str] = 'auto', - modal: Optional[str] = None, - enable_cueq: bool = False, - sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info - **kwargs, - ): - """Initialize SevenNetCalculator. - - Parameters - ---------- - model: str | Path | AtomGraphSequential, default='7net-0' - Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or - path to the checkpoint, deployed model or the model itself - file_type: str, default='checkpoint' - one of 'checkpoint' | 'torchscript' | 'model_instance' - device: str | torch.device, default='auto' - if not given, use CUDA if available - modal: str | None, default=None - modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, - it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) - case insensitive - enable_cueq: bool, default=False - if True, use cuEquivariant to accelerate inference. - sevennet_config: dict | None, default=None - Not used, but can be used to carry meta information of this calculator - """ - print("&&& Initializing SevenNetCalculator") - super().__init__(**kwargs) - self.sevennet_config = None - - if isinstance(model, pathlib.PurePath): - model = str(model) - - allowed_file_types = ['checkpoint', 'torchscript', 'model_instance'] - file_type = file_type.lower() - if file_type not in allowed_file_types: - raise ValueError(f'file_type not in {allowed_file_types}') - - if enable_cueq and file_type in ['model_instance', 'torchscript']: - warnings.warn( - 'file_type should be checkpoint to enable cueq. cueq set to False' - ) - enable_cueq = False - - if isinstance(device, str): # TODO: do we really need this? - if device == 'auto': - self.device = torch.device( - 'cuda' if torch.cuda.is_available() else 'cpu' - ) - else: - self.device = torch.device(device) - else: - self.device = device - - if file_type == 'checkpoint' and isinstance(model, str): - cp = util.load_checkpoint(model) - - backend = 'e3nn' if not enable_cueq else 'cueq' - model_loaded = cp.build_model(backend) - model_loaded.set_is_batch_data(False) - - self.type_map = cp.config[KEY.TYPE_MAP] - self.cutoff = cp.config[KEY.CUTOFF] - self.sevennet_config = cp.config - - elif file_type == 'torchscript' and isinstance(model, str): - if modal: - raise NotImplementedError() - extra_dict = { - 'chemical_symbols_to_index': b'', - 'cutoff': b'', - 'num_species': b'', - 'model_type': b'', - 'version': b'', - 'dtype': b'', - 'time': b'', - } - model_loaded = torch.jit.load( - model, _extra_files=extra_dict, map_location=self.device - ) - chem_symbols = extra_dict['chemical_symbols_to_index'].decode('utf-8') - sym_to_num = {sym: n for n, sym in enumerate(chemical_symbols)} - self.type_map = { - sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split()) - } - self.cutoff = float(extra_dict['cutoff'].decode('utf-8')) - - elif isinstance(model, AtomGraphSequential): - if model.type_map is None: - raise ValueError( - 'Model must have the type_map to be used with calculator' - ) - if model.cutoff == 0.0: - raise ValueError('Model cutoff seems not initialized') - model.eval_type_map = torch.tensor(True) # ? - model.set_is_batch_data(False) - model_loaded = model - self.type_map = model.type_map - self.cutoff = model.cutoff - else: - raise ValueError('Unexpected input combinations') - - if self.sevennet_config is None and sevennet_config is not None: - self.sevennet_config = sevennet_config - - self.model = model_loaded - - self.modal = None - if isinstance(self.model, AtomGraphSequential): - modal_map = self.model.modal_map - if modal_map: - modal_ava = list(modal_map.keys()) - if not modal: - raise ValueError(f'modal argument missing (avail: {modal_ava})') - elif modal not in modal_ava: - raise ValueError(f'unknown modal {modal} (not in {modal_ava})') - self.modal = modal - elif not self.model.modal_map and modal: - warnings.warn(f'modal={modal} is ignored as model has no modal_map') - - self.model.to(self.device) - self.model.eval() - self.implemented_properties = [ - 'free_energy', - 'energy', - 'forces', - 'stress', - 'energies', - ] - - def set_atoms(self, atoms): - # called by ase, when atoms.calc = calc - zs = tuple(set(atoms.get_atomic_numbers())) - for z in zs: - if z not in self.type_map: - sp = list(self.type_map.keys()) - raise ValueError( - f'Model do not know atomic number: {z}, (knows: {sp})' - ) - - def output_to_results(self, output): - energy = output[KEY.PRED_TOTAL_ENERGY].detach().cpu().item() - num_atoms = output['num_atoms'].item() - atomic_energies = output[KEY.ATOMIC_ENERGY].detach().cpu().numpy().flatten() - forces = output[KEY.PRED_FORCE].detach().cpu().numpy()[:num_atoms, :] - stress = np.array( - (-output[KEY.PRED_STRESS]) - .detach() - .cpu() - .numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation - ) - # Store results - return { - 'free_energy': energy, - 'energy': energy, - 'energies': atomic_energies, - 'forces': forces, - 'stress': stress, - 'num_edges': output[KEY.EDGE_IDX].shape[1], - } - - def calculate(self, atoms=None, properties=None, system_changes=all_changes): - # call parent class to set necessary atom attributes - Calculator.calculate(self, atoms, properties, system_changes) - if atoms is None: - raise ValueError('No atoms to evaluate') - data = AtomGraphData.from_numpy_dict( - unlabeled_atoms_to_graph(atoms, self.cutoff) - ) - if self.modal: - data[KEY.DATA_MODALITY] = self.modal - - data.to(self.device) # type: ignore - - if isinstance(self.model, torch_script_type): - data[KEY.NODE_FEATURE] = torch.tensor( - [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], - dtype=torch.int64, - device=self.device, - ) - data[KEY.POS].requires_grad_(True) # backward compatibility - data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility - data = data.to_dict() - del data['data_info'] - - import logging - logging.debug(f"data: {data}") - # logging.debug(f"data[pos]: {data['pos']}") - # logging.debug(f"data[x]: {data['x']}") - logging.debug(f"data[cell_lattice_vectors]: {data['cell_lattice_vectors']}") - logging.debug(f"data[cell_volume]: {data['cell_volume']}") - output = self.model(data) - # logging.info(f"input: {data}") - # logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}") - # logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}") - # logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}") - self.results = self.output_to_results(output) - # logging.debug(f"results['energy'] = {self.results['energy']}") - # logging.debug(f"results['forces'] = {self.results['forces']}") - # logging.debug(f"results['stress'] = {self.results['stress']}") - - def predict_one(self, atoms): - if atoms is None: - raise ValueError('No atoms to evaluate') - data = AtomGraphData.from_numpy_dict( - unlabeled_atoms_to_graph(atoms, self.cutoff) - ) - if self.modal: - data[KEY.DATA_MODALITY] = self.modal - - data.to(self.device) # type: ignore - - if isinstance(self.model, torch_script_type): - data[KEY.NODE_FEATURE] = torch.tensor( - [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], - dtype=torch.int64, - device=self.device, - ) - data[KEY.POS].requires_grad_(True) # backward compatibility - data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility - data = data.to_dict() - del data['data_info'] - - return self.model(data) - - - - def predict(self, atoms_list, properties=None): - - # if len(atoms_list) == 1: - # output = self.predict_one(atoms_list[0]) - # predictions = {} - # predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).unsqueeze(0) - # predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).unsqueeze(0) - # voigt = (-output[KEY.PRED_STRESS])[[0, 1, 2, 4, 5, 3]].to(torch.float64).unsqueeze(0) - # stress_list = [] - # for i in range(voigt.shape[0]): - # stress_list.append(self._stress2tensor(voigt[i,:])) - # predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3) - # return predictions - - - if not atoms_list: - raise ValueError("Empty atoms_list provided") - - if not isinstance(atoms_list, list): - atoms_list = [atoms_list] - - # Convert atoms to graph data - graph_list = [] - for atoms in atoms_list: - data = AtomGraphData.from_numpy_dict( - unlabeled_atoms_to_graph(atoms, self.cutoff) - ) - if self.modal: - data[KEY.DATA_MODALITY] = self.modal - - if isinstance(self.model, torch_script_type): - data[KEY.NODE_FEATURE] = torch.tensor( - [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], - dtype=torch.int64, - device=self.device, - ) - data[KEY.POS].requires_grad_(True) # backward compatibility - data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility - - graph_list.append(data) - - # Process graphs based on model type - # was_batch_mode = True - if isinstance(self.model, AtomGraphSequential): - # was_batch_mode = self.model.is_batch_data - self.model.set_is_batch_data(True) - self.model.eval() - - # Batch the data if there are multiple atoms - from torch_geometric.loader.dataloader import Collater - batched_data = Collater(graph_list)(graph_list) - batched_data = batched_data.to(self.device) - - import logging - logging.debug(f"batched_data: {batched_data}") - # logging.debug(f"batched_data[pos]: {batched_data['pos']}") - # logging.debug(f"batched_data[x]: {batched_data['x']}") - logging.debug(f"batched_data[cell_lattice_vectors]: {batched_data['cell_lattice_vectors']}") - logging.debug(f"batched_data[cell_volume]: {batched_data['cell_volume']}") - # Run model on batched data - if isinstance(self.model, torch_script_type): - batched_dict = batched_data.to_dict() - if 'data_info' in batched_dict: - del batched_dict['data_info'] - output = self.model(batched_dict) - else: - output = self.model(batched_data) - - # Convert to list of individual outputs using util.to_atom_graph_list - # logging.info(f"input: {batched_data}") - # logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}") - # logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}") - # logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}") - - predictions = {} - predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).detach() - predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).detach() - voigt = (-output[KEY.PRED_STRESS])[:, [0, 1, 2, 4, 5, 3]].to(torch.float64).detach() - stress_list = [] - for i in range(voigt.shape[0]): - stress_list.append(self._stress2tensor(voigt[i,:])) - predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3).detach() - - # logging.debug(f"predictions['energy'] = {predictions['energy']}") - # logging.debug(f"predictions['forces'] = {predictions['forces']}") - # logging.debug(f"predictions['stress'] = {predictions['stress']}") - return predictions - - def _stress2tensor(self, stress): - tensor = torch.tensor( - [ - [stress[0], stress[5], stress[4]], - [stress[5], stress[1], stress[3]], - [stress[4], stress[3], stress[2]], - ], - device=self.device - ) - return tensor - - -class SevenNetD3Calculator(SumCalculator): - def __init__( - self, - model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0', - file_type: str = 'checkpoint', - device: Union[torch.device, str] = 'auto', - sevennet_config: Optional[Any] = None, # hold meta information - damping_type: str = 'damp_bj', - functional_name: str = 'pbe', - vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au - cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au - batch_size=10, - **kwargs, - ): - """Initialize SevenNetD3Calculator. CUDA required. - - Parameters - ---------- - model: str | Path | AtomGraphSequential - Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or - path to the checkpoint, deployed model or the model itself - file_type: str, default='checkpoint' - one of 'checkpoint' | 'torchscript' | 'model_instance' - device: str | torch.device, default='auto' - if not given, use CUDA if available - modal: str | None, default=None - modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, - it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) - enable_cueq: bool, default=False - if True, use cuEquivariant to accelerate inference. - damping_type: str, default='damp_bj' - Damping type of D3, one of 'damp_bj' | 'damp_zero' - functional_name: str, default='pbe' - Target functional name of D3 parameters. - vdw_cutoff: float, default=9000 - vdw cutoff of D3 calculator in au - cn_cutoff: float, default=1600 - cn cutoff of D3 calculator in au - """ - self.d3_calc = D3Calculator( - damping_type=damping_type, - functional_name=functional_name, - vdw_cutoff=vdw_cutoff, - cn_cutoff=cn_cutoff, - **kwargs, - ) - - self.sevennet_calc = SevenNetCalculator( - model=model, - file_type=file_type, - device=device, - sevennet_config=sevennet_config, - **kwargs, - ) - - super().__init__([self.sevennet_calc, self.d3_calc]) - - self.device = device - self.d3_calcs = [] - for _ in range(batch_size): - self.d3_calcs.append( - D3Calculator( - damping_type=damping_type, - functional_name=functional_name, - vdw_cutoff=vdw_cutoff, - cn_cutoff=cn_cutoff, - **kwargs, - ) - ) - - - def predict(self, atoms_list): - """Predict the energy and forces for a list of atoms. - """ - # Call the predict method of the first calculator (SevenNetCalculator) - predictions = self.sevennet_calc.predict(atoms_list) - - energy_list = [] - forces_list = [] - stress_list = [] - predictions3d = {} - for i, atoms in enumerate(atoms_list): - prediction = self.d3_calcs[i].predict_one(atoms) - energy_list.append(torch.tensor(prediction['energy'])) - forces_list.append(torch.from_numpy(prediction['forces']).to(self.device)) - stress_list.append(self._stress2tensor(torch.from_numpy(prediction['stress']))) - - # Convert lists to tensors - predictions3d['energy'] = torch.stack(energy_list, dim=0).to(self.device) - predictions3d['forces'] = torch.cat(forces_list, dim=0).view(-1, 3) - predictions3d['stress'] = torch.stack(stress_list, dim=0).view(-1, 3, 3) - - predictions['energy'] += predictions3d['energy'].detach() - predictions['forces'] += predictions3d['forces'].detach() - predictions['stress'] += predictions3d['stress'].detach() - - return predictions - - def _stress2tensor(self, stress): - tensor = torch.tensor( - [ - # [stress[0], stress[3], stress[4]], - # [stress[3], stress[1], stress[5]], - # [stress[4], stress[5], stress[2]], - [stress[0], stress[5], stress[4]], - [stress[5], stress[1], stress[3]], - [stress[4], stress[3], stress[2]], - ], - device=self.device - ) - return tensor - - - -def _load(name: str) -> ctypes.CDLL: - from torch.utils.cpp_extension import LIB_EXT, _get_build_directory, load - - # Load the library from the candidate locations - - package_dir = os.path.dirname(os.path.abspath(__file__)) - try: - return ctypes.CDLL(os.path.join(package_dir, f'{name}{LIB_EXT}')) - except OSError: - pass - - cache_dir = _get_build_directory(name, verbose=False) - try: - return ctypes.CDLL(os.path.join(cache_dir, f'{name}{LIB_EXT}')) - except OSError: - pass - - # Compile the library if it is not found - - if os.access(package_dir, os.W_OK): - compile_dir = package_dir - else: - print('Warning: package directory is not writable. Using cache directory.') - compile_dir = cache_dir - - if 'TORCH_CUDA_ARCH_LIST' not in os.environ: - print('Warning: TORCH_CUDA_ARCH_LIST is not set.') - print('Warning: Use default CUDA architectures: 61, 70, 75, 80, 86, 89, 90') - os.environ['TORCH_CUDA_ARCH_LIST'] = '6.1;7.0;7.5;8.0;8.6;8.9;9.0' - - load( - name=name, - sources=[os.path.join(package_dir, 'pair_e3gnn', 'pair_d3_for_ase.cu')], - extra_cuda_cflags=['-O3', '--expt-relaxed-constexpr', '-fmad=false'], - build_directory=compile_dir, - verbose=True, - is_python_module=False, - ) - - return ctypes.CDLL(os.path.join(compile_dir, f'{name}{LIB_EXT}')) - - -class PairD3(ctypes.Structure): - pass # Opaque structure; only used as a pointer - - -class D3Calculator(Calculator): - """ASE calculator for accelerated D3 van der Waals (vdW) correction. - - Example: - from ase.calculators.mixing import SumCalculator - calc_1 = SevenNetCalculator() - calc_2 = D3Calculator() - return SumCalculator([calc_1, calc_2]) - - This calculator interfaces with the `libpaird3.so` library, - which is compiled by nvcc during the package installation. - If you encounter any errors, please verify - the installation process and the compilation options in `setup.py`. - Note: Multi-GPU parallel MD is not supported in this mode. - Note: Cffi could be used, but it was avoided to reduce dependencies. - """ - - # Here, free_energy = energy - implemented_properties = ['free_energy', 'energy', 'forces', 'stress'] - - def __init__( - self, - damping_type: str = 'damp_bj', # damp_bj, damp_zero - functional_name: str = 'pbe', # check the source code - vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au - cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au - **kwargs, - ): - super().__init__(**kwargs) - - if not torch.cuda.is_available(): - raise NotImplementedError('CPU + D3 is not implemented yet') - - self.rthr = vdw_cutoff - self.cnthr = cn_cutoff - self.damp_name = damping_type.lower() - self.func_name = functional_name.lower() - - if self.damp_name not in ['damp_bj', 'damp_zero']: - raise ValueError('Error: Invalid damping type.') - - self._lib = _load('pair_d3') - - self._lib.pair_init.restype = ctypes.POINTER(PairD3) - self.pair = self._lib.pair_init() - - self._lib.pair_set_atom.argtypes = [ - ctypes.POINTER(PairD3), # PairD3* pair - ctypes.c_int, # int natoms - ctypes.c_int, # int ntypes - ctypes.POINTER(ctypes.c_int), # int* types - ctypes.POINTER(ctypes.c_double), # double* x - ] - self._lib.pair_set_atom.restype = None - - self._lib.pair_set_domain.argtypes = [ - ctypes.POINTER(PairD3), # PairD3* pair - ctypes.c_int, # int xperiodic - ctypes.c_int, # int yperiodic - ctypes.c_int, # int zperiodic - ctypes.POINTER(ctypes.c_double), # double* boxlo - ctypes.POINTER(ctypes.c_double), # double* boxhi - ctypes.c_double, # double xy - ctypes.c_double, # double xz - ctypes.c_double, # double yz - ] - self._lib.pair_set_domain.restype = None - - self._lib.pair_run_settings.argtypes = [ - ctypes.POINTER(PairD3), # PairD3* pair - ctypes.c_double, # double rthr - ctypes.c_double, # double cnthr - ctypes.c_char_p, # const char* damp_name - ctypes.c_char_p, # const char* func_name - ] - self._lib.pair_run_settings.restype = None - - self._lib.pair_run_coeff.argtypes = [ - ctypes.POINTER(PairD3), # PairD3* pair - ctypes.POINTER(ctypes.c_int), # int* atomic_numbers - ] - self._lib.pair_run_coeff.restype = None - - self._lib.pair_run_compute.argtypes = [ctypes.POINTER(PairD3)] - self._lib.pair_run_compute.restype = None - - self._lib.pair_get_energy.argtypes = [ctypes.POINTER(PairD3)] - self._lib.pair_get_energy.restype = ctypes.c_double - - self._lib.pair_get_force.argtypes = [ctypes.POINTER(PairD3)] - self._lib.pair_get_force.restype = ctypes.POINTER(ctypes.c_double) - - self._lib.pair_get_stress.argtypes = [ctypes.POINTER(PairD3)] - self._lib.pair_get_stress.restype = ctypes.POINTER(ctypes.c_double * 6) - - self._lib.pair_fin.argtypes = [ctypes.POINTER(PairD3)] - self._lib.pair_fin.restype = None - - def _idx_to_numbers(self, Z_of_atoms): - unique_numbers = list(dict.fromkeys(Z_of_atoms)) - return unique_numbers - - def _idx_to_types(self, Z_of_atoms): - unique_numbers = list(dict.fromkeys(Z_of_atoms)) - mapping = {num: idx + 1 for idx, num in enumerate(unique_numbers)} - atom_types = [mapping[num] for num in Z_of_atoms] - return atom_types - - def _convert_domain_ase2lammps(self, cell): - qtrans, ltrans = np.linalg.qr(cell.T, mode='complete') - lammps_cell = ltrans.T - signs = np.sign(np.diag(lammps_cell)) - lammps_cell = lammps_cell * signs - qtrans = qtrans * signs - lammps_cell = lammps_cell[(0, 1, 2, 1, 2, 2), (0, 1, 2, 0, 0, 1)] - rotator = qtrans.T - return lammps_cell, rotator - - def _stress2tensor(self, stress): - tensor = np.array( - [ - [stress[0], stress[3], stress[4]], - [stress[3], stress[1], stress[5]], - [stress[4], stress[5], stress[2]], - ] - ) - return tensor - - def _tensor2stress(self, tensor): - stress = -np.array( - [ - tensor[0, 0], - tensor[1, 1], - tensor[2, 2], - tensor[1, 2], - tensor[0, 2], - tensor[0, 1], - ] - ) - return stress - - def calculate(self, atoms=None, properties=None, system_changes=all_changes): - Calculator.calculate(self, atoms, properties, system_changes) - if atoms is None: - raise ValueError('No atoms to evaluate') - - if atoms.get_cell().sum() == 0: - print( - 'Warning: D3Calculator requires a cell.\n' - 'Warning: An orthogonal cell large enough is generated.' - ) - positions = atoms.get_positions() - min_pos = positions.min(axis=0) - max_pos = positions.max(axis=0) - max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726 - - cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin - cell = np.eye(3) * cell_lengths - - atoms.set_cell(cell) - atoms.set_pbc([True, True, True]) # for minus positions - - cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell()) - - Z_of_atoms = atoms.get_atomic_numbers() - natoms = len(atoms) - ntypes = len(set(Z_of_atoms)) - types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms)) - - positions = atoms.get_positions() @ rotator.T - x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten()) - - atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms)) - - boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0) - boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2]) - xy = cell[3] - xz = cell[4] - yz = cell[5] - xperiodic, yperiodic, zperiodic = atoms.get_pbc() - - lib = self._lib - assert lib is not None - lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat) - - xperiodic = xperiodic.astype(int) - yperiodic = yperiodic.astype(int) - zperiodic = zperiodic.astype(int) - lib.pair_set_domain( - self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz - ) - - lib.pair_run_settings( - self.pair, - self.rthr, - self.cnthr, - self.damp_name.encode('utf-8'), - self.func_name.encode('utf-8'), - ) - - lib.pair_run_coeff(self.pair, atomic_numbers) - lib.pair_run_compute(self.pair) - - result_E = lib.pair_get_energy(self.pair) - - result_F_ptr = lib.pair_get_force(self.pair) - result_F_size = natoms * 3 - result_F = np.ctypeslib.as_array( - result_F_ptr, shape=(result_F_size,) - ).reshape((natoms, 3)) - result_F = np.array(result_F) - result_F = result_F @ rotator - - result_S = lib.pair_get_stress(self.pair) - result_S = np.array(result_S.contents) - result_S = ( - self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator) - / atoms.get_volume() - ) - - self.results = { - 'free_energy': result_E, - 'energy': result_E, - 'forces': result_F, - 'stress': result_S, - } - - def predict_one(self, atoms): - atoms = atoms.copy() - if atoms is None: - raise ValueError('No atoms to evaluate') - - if atoms.get_cell().sum() == 0: - print( - 'Warning: D3Calculator requires a cell.\n' - 'Warning: An orthogonal cell large enough is generated.' - ) - positions = atoms.get_positions() - min_pos = positions.min(axis=0) - max_pos = positions.max(axis=0) - max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726 - - cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin - cell = np.eye(3) * cell_lengths - - atoms.set_cell(cell) - atoms.set_pbc([True, True, True]) # for minus positions - - cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell()) - - Z_of_atoms = atoms.get_atomic_numbers() - natoms = len(atoms) - ntypes = len(set(Z_of_atoms)) - types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms)) - - positions = atoms.get_positions() @ rotator.T - x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten()) - - atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms)) - - boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0) - boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2]) - xy = cell[3] - xz = cell[4] - yz = cell[5] - xperiodic, yperiodic, zperiodic = atoms.get_pbc() - - lib = self._lib - assert lib is not None - lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat) - - xperiodic = xperiodic.astype(int) - yperiodic = yperiodic.astype(int) - zperiodic = zperiodic.astype(int) - lib.pair_set_domain( - self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz - ) - - lib.pair_run_settings( - self.pair, - self.rthr, - self.cnthr, - self.damp_name.encode('utf-8'), - self.func_name.encode('utf-8'), - ) - - lib.pair_run_coeff(self.pair, atomic_numbers) - lib.pair_run_compute(self.pair) - - result_E = lib.pair_get_energy(self.pair) - - result_F_ptr = lib.pair_get_force(self.pair) - result_F_size = natoms * 3 - result_F = np.ctypeslib.as_array( - result_F_ptr, shape=(result_F_size,) - ).reshape((natoms, 3)) - result_F = np.array(result_F) - result_F = result_F @ rotator - - result_S = lib.pair_get_stress(self.pair) - result_S = np.array(result_S.contents) - result_S = ( - self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator) - / atoms.get_volume() - ) - - prediction = { - 'free_energy': float(result_E), - 'energy': float(result_E), - 'forces': result_F.copy(), - 'stress': result_S.copy(), - } - - return prediction - - - def __del__(self): - if self._lib is not None: - self._lib.pair_fin(self.pair) - self._lib = None - self.pair = None - +import ctypes +import os +import pathlib +import warnings +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.jit +import torch.jit._script +from ase.calculators.calculator import Calculator, all_changes +from ase.calculators.mixing import SumCalculator +from ase.data import chemical_symbols + +import sevenn._keys as KEY +import sevenn.util as util +from sevenn.atom_graph_data import AtomGraphData +from sevenn.nn.sequential import AtomGraphSequential +from sevenn.train.dataload import unlabeled_atoms_to_graph +import logging + +torch_script_type = torch.jit._script.RecursiveScriptModule + + +class SevenNetCalculator(Calculator): + """Supporting properties: + 'free_energy', 'energy', 'forces', 'stress', 'energies' + free_energy equals energy. 'energies' stores atomic energy. + + Multi-GPU acceleration is not supported with ASE calculator. + You should use LAMMPS for the acceleration. + """ + + def __init__( + self, + model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0', + file_type: str = 'checkpoint', + device: Union[torch.device, str] = 'auto', + modal: Optional[str] = None, + enable_cueq: bool = False, + sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info + **kwargs, + ): + """Initialize SevenNetCalculator. + + Parameters + ---------- + model: str | Path | AtomGraphSequential, default='7net-0' + Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or + path to the checkpoint, deployed model or the model itself + file_type: str, default='checkpoint' + one of 'checkpoint' | 'torchscript' | 'model_instance' + device: str | torch.device, default='auto' + if not given, use CUDA if available + modal: str | None, default=None + modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, + it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) + case insensitive + enable_cueq: bool, default=False + if True, use cuEquivariant to accelerate inference. + sevennet_config: dict | None, default=None + Not used, but can be used to carry meta information of this calculator + """ + print("&&& Initializing SevenNetCalculator") + super().__init__(**kwargs) + self.sevennet_config = None + + if isinstance(model, pathlib.PurePath): + model = str(model) + + allowed_file_types = ['checkpoint', 'torchscript', 'model_instance'] + file_type = file_type.lower() + if file_type not in allowed_file_types: + raise ValueError(f'file_type not in {allowed_file_types}') + + if enable_cueq and file_type in ['model_instance', 'torchscript']: + warnings.warn( + 'file_type should be checkpoint to enable cueq. cueq set to False' + ) + enable_cueq = False + + if isinstance(device, str): # TODO: do we really need this? + if device == 'auto': + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu' + ) + else: + self.device = torch.device(device) + else: + self.device = device + + if file_type == 'checkpoint' and isinstance(model, str): + cp = util.load_checkpoint(model) + + backend = 'e3nn' if not enable_cueq else 'cueq' + model_loaded = cp.build_model(backend) + model_loaded.set_is_batch_data(False) + + self.type_map = cp.config[KEY.TYPE_MAP] + self.cutoff = cp.config[KEY.CUTOFF] + self.sevennet_config = cp.config + + elif file_type == 'torchscript' and isinstance(model, str): + if modal: + raise NotImplementedError() + extra_dict = { + 'chemical_symbols_to_index': b'', + 'cutoff': b'', + 'num_species': b'', + 'model_type': b'', + 'version': b'', + 'dtype': b'', + 'time': b'', + } + model_loaded = torch.jit.load( + model, _extra_files=extra_dict, map_location=self.device + ) + chem_symbols = extra_dict['chemical_symbols_to_index'].decode('utf-8') + sym_to_num = {sym: n for n, sym in enumerate(chemical_symbols)} + self.type_map = { + sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split()) + } + self.cutoff = float(extra_dict['cutoff'].decode('utf-8')) + + elif isinstance(model, AtomGraphSequential): + if model.type_map is None: + raise ValueError( + 'Model must have the type_map to be used with calculator' + ) + if model.cutoff == 0.0: + raise ValueError('Model cutoff seems not initialized') + model.eval_type_map = torch.tensor(True) # ? + model.set_is_batch_data(False) + model_loaded = model + self.type_map = model.type_map + self.cutoff = model.cutoff + else: + raise ValueError('Unexpected input combinations') + + if self.sevennet_config is None and sevennet_config is not None: + self.sevennet_config = sevennet_config + + self.model = model_loaded + + self.modal = None + if isinstance(self.model, AtomGraphSequential): + modal_map = self.model.modal_map + if modal_map: + modal_ava = list(modal_map.keys()) + if not modal: + raise ValueError(f'modal argument missing (avail: {modal_ava})') + elif modal not in modal_ava: + raise ValueError(f'unknown modal {modal} (not in {modal_ava})') + self.modal = modal + elif not self.model.modal_map and modal: + warnings.warn(f'modal={modal} is ignored as model has no modal_map') + + self.model.to(self.device) + self.model.eval() + self.implemented_properties = [ + 'free_energy', + 'energy', + 'forces', + 'stress', + 'energies', + ] + + def set_atoms(self, atoms): + # called by ase, when atoms.calc = calc + zs = tuple(set(atoms.get_atomic_numbers())) + for z in zs: + if z not in self.type_map: + sp = list(self.type_map.keys()) + raise ValueError( + f'Model do not know atomic number: {z}, (knows: {sp})' + ) + + def output_to_results(self, output): + energy = output[KEY.PRED_TOTAL_ENERGY].detach().cpu().item() + num_atoms = output['num_atoms'].item() + atomic_energies = output[KEY.ATOMIC_ENERGY].detach().cpu().numpy().flatten() + forces = output[KEY.PRED_FORCE].detach().cpu().numpy()[:num_atoms, :] + stress = np.array( + (-output[KEY.PRED_STRESS]) + .detach() + .cpu() + .numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation + ) + # Store results + return { + 'free_energy': energy, + 'energy': energy, + 'energies': atomic_energies, + 'forces': forces, + 'stress': stress, + 'num_edges': output[KEY.EDGE_IDX].shape[1], + } + + def calculate(self, atoms=None, properties=None, system_changes=all_changes): + # call parent class to set necessary atom attributes + Calculator.calculate(self, atoms, properties, system_changes) + if atoms is None: + raise ValueError('No atoms to evaluate') + data = AtomGraphData.from_numpy_dict( + unlabeled_atoms_to_graph(atoms, self.cutoff) + ) + if self.modal: + data[KEY.DATA_MODALITY] = self.modal + + data.to(self.device) # type: ignore + + if isinstance(self.model, torch_script_type): + data[KEY.NODE_FEATURE] = torch.tensor( + [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], + dtype=torch.int64, + device=self.device, + ) + data[KEY.POS].requires_grad_(True) # backward compatibility + data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility + data = data.to_dict() + del data['data_info'] + + import logging + logging.debug(f"data: {data}") + # logging.debug(f"data[pos]: {data['pos']}") + # logging.debug(f"data[x]: {data['x']}") + logging.debug(f"data[cell_lattice_vectors]: {data['cell_lattice_vectors']}") + logging.debug(f"data[cell_volume]: {data['cell_volume']}") + output = self.model(data) + # logging.info(f"input: {data}") + # logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}") + # logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}") + # logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}") + self.results = self.output_to_results(output) + # logging.debug(f"results['energy'] = {self.results['energy']}") + # logging.debug(f"results['forces'] = {self.results['forces']}") + # logging.debug(f"results['stress'] = {self.results['stress']}") + + def predict_one(self, atoms): + if atoms is None: + raise ValueError('No atoms to evaluate') + data = AtomGraphData.from_numpy_dict( + unlabeled_atoms_to_graph(atoms, self.cutoff) + ) + if self.modal: + data[KEY.DATA_MODALITY] = self.modal + + data.to(self.device) # type: ignore + + if isinstance(self.model, torch_script_type): + data[KEY.NODE_FEATURE] = torch.tensor( + [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], + dtype=torch.int64, + device=self.device, + ) + data[KEY.POS].requires_grad_(True) # backward compatibility + data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility + data = data.to_dict() + del data['data_info'] + + return self.model(data) + + + + def predict(self, atoms_list, properties=None): + + # if len(atoms_list) == 1: + # output = self.predict_one(atoms_list[0]) + # predictions = {} + # predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).unsqueeze(0) + # predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).unsqueeze(0) + # voigt = (-output[KEY.PRED_STRESS])[[0, 1, 2, 4, 5, 3]].to(torch.float64).unsqueeze(0) + # stress_list = [] + # for i in range(voigt.shape[0]): + # stress_list.append(self._stress2tensor(voigt[i,:])) + # predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3) + # return predictions + + + if not atoms_list: + raise ValueError("Empty atoms_list provided") + + if not isinstance(atoms_list, list): + atoms_list = [atoms_list] + + # Convert atoms to graph data + graph_list = [] + for atoms in atoms_list: + data = AtomGraphData.from_numpy_dict( + unlabeled_atoms_to_graph(atoms, self.cutoff) + ) + if self.modal: + data[KEY.DATA_MODALITY] = self.modal + + if isinstance(self.model, torch_script_type): + data[KEY.NODE_FEATURE] = torch.tensor( + [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], + dtype=torch.int64, + device=self.device, + ) + data[KEY.POS].requires_grad_(True) # backward compatibility + data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility + + graph_list.append(data) + + # Process graphs based on model type + # was_batch_mode = True + if isinstance(self.model, AtomGraphSequential): + # was_batch_mode = self.model.is_batch_data + self.model.set_is_batch_data(True) + self.model.eval() + + # Batch the data if there are multiple atoms + from torch_geometric.loader.dataloader import Collater + batched_data = Collater(graph_list)(graph_list) + batched_data = batched_data.to(self.device) + + import logging + logging.debug(f"batched_data: {batched_data}") + # logging.debug(f"batched_data[pos]: {batched_data['pos']}") + # logging.debug(f"batched_data[x]: {batched_data['x']}") + logging.debug(f"batched_data[cell_lattice_vectors]: {batched_data['cell_lattice_vectors']}") + logging.debug(f"batched_data[cell_volume]: {batched_data['cell_volume']}") + # Run model on batched data + if isinstance(self.model, torch_script_type): + batched_dict = batched_data.to_dict() + if 'data_info' in batched_dict: + del batched_dict['data_info'] + output = self.model(batched_dict) + else: + output = self.model(batched_data) + + # Convert to list of individual outputs using util.to_atom_graph_list + # logging.info(f"input: {batched_data}") + # logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}") + # logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}") + # logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}") + + predictions = {} + predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).detach() + predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).detach() + voigt = (-output[KEY.PRED_STRESS])[:, [0, 1, 2, 4, 5, 3]].to(torch.float64).detach() + stress_list = [] + for i in range(voigt.shape[0]): + stress_list.append(self._stress2tensor(voigt[i,:])) + predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3).detach() + + # logging.debug(f"predictions['energy'] = {predictions['energy']}") + # logging.debug(f"predictions['forces'] = {predictions['forces']}") + # logging.debug(f"predictions['stress'] = {predictions['stress']}") + return predictions + + def _stress2tensor(self, stress): + tensor = torch.tensor( + [ + [stress[0], stress[5], stress[4]], + [stress[5], stress[1], stress[3]], + [stress[4], stress[3], stress[2]], + ], + device=self.device + ) + return tensor + + +class SevenNetD3Calculator(SumCalculator): + def __init__( + self, + model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0', + file_type: str = 'checkpoint', + device: Union[torch.device, str] = 'auto', + sevennet_config: Optional[Any] = None, # hold meta information + damping_type: str = 'damp_bj', + functional_name: str = 'pbe', + vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au + cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au + batch_size=10, + **kwargs, + ): + """Initialize SevenNetD3Calculator. CUDA required. + + Parameters + ---------- + model: str | Path | AtomGraphSequential + Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or + path to the checkpoint, deployed model or the model itself + file_type: str, default='checkpoint' + one of 'checkpoint' | 'torchscript' | 'model_instance' + device: str | torch.device, default='auto' + if not given, use CUDA if available + modal: str | None, default=None + modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, + it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) + enable_cueq: bool, default=False + if True, use cuEquivariant to accelerate inference. + damping_type: str, default='damp_bj' + Damping type of D3, one of 'damp_bj' | 'damp_zero' + functional_name: str, default='pbe' + Target functional name of D3 parameters. + vdw_cutoff: float, default=9000 + vdw cutoff of D3 calculator in au + cn_cutoff: float, default=1600 + cn cutoff of D3 calculator in au + """ + self.d3_calc = D3Calculator( + damping_type=damping_type, + functional_name=functional_name, + vdw_cutoff=vdw_cutoff, + cn_cutoff=cn_cutoff, + **kwargs, + ) + + self.sevennet_calc = SevenNetCalculator( + model=model, + file_type=file_type, + device=device, + sevennet_config=sevennet_config, + **kwargs, + ) + + super().__init__([self.sevennet_calc, self.d3_calc]) + + self.device = device + self.d3_calcs = [] + for _ in range(batch_size): + self.d3_calcs.append( + D3Calculator( + damping_type=damping_type, + functional_name=functional_name, + vdw_cutoff=vdw_cutoff, + cn_cutoff=cn_cutoff, + **kwargs, + ) + ) + + + def predict(self, atoms_list): + """Predict the energy and forces for a list of atoms. + """ + # Call the predict method of the first calculator (SevenNetCalculator) + predictions = self.sevennet_calc.predict(atoms_list) + + energy_list = [] + forces_list = [] + stress_list = [] + predictions3d = {} + for i, atoms in enumerate(atoms_list): + prediction = self.d3_calcs[i].predict_one(atoms) + energy_list.append(torch.tensor(prediction['energy'])) + forces_list.append(torch.from_numpy(prediction['forces']).to(self.device)) + stress_list.append(self._stress2tensor(torch.from_numpy(prediction['stress']))) + + # Convert lists to tensors + predictions3d['energy'] = torch.stack(energy_list, dim=0).to(self.device) + predictions3d['forces'] = torch.cat(forces_list, dim=0).view(-1, 3) + predictions3d['stress'] = torch.stack(stress_list, dim=0).view(-1, 3, 3) + + predictions['energy'] += predictions3d['energy'].detach() + predictions['forces'] += predictions3d['forces'].detach() + predictions['stress'] += predictions3d['stress'].detach() + + return predictions + + def _stress2tensor(self, stress): + tensor = torch.tensor( + [ + # [stress[0], stress[3], stress[4]], + # [stress[3], stress[1], stress[5]], + # [stress[4], stress[5], stress[2]], + [stress[0], stress[5], stress[4]], + [stress[5], stress[1], stress[3]], + [stress[4], stress[3], stress[2]], + ], + device=self.device + ) + return tensor + + + +def _load(name: str) -> ctypes.CDLL: + from torch.utils.cpp_extension import LIB_EXT, _get_build_directory, load + + # Load the library from the candidate locations + + package_dir = os.path.dirname(os.path.abspath(__file__)) + try: + return ctypes.CDLL(os.path.join(package_dir, f'{name}{LIB_EXT}')) + except OSError: + pass + + cache_dir = _get_build_directory(name, verbose=False) + try: + return ctypes.CDLL(os.path.join(cache_dir, f'{name}{LIB_EXT}')) + except OSError: + pass + + # Compile the library if it is not found + + if os.access(package_dir, os.W_OK): + compile_dir = package_dir + else: + print('Warning: package directory is not writable. Using cache directory.') + compile_dir = cache_dir + + if 'TORCH_CUDA_ARCH_LIST' not in os.environ: + print('Warning: TORCH_CUDA_ARCH_LIST is not set.') + print('Warning: Use default CUDA architectures: 61, 70, 75, 80, 86, 89, 90') + os.environ['TORCH_CUDA_ARCH_LIST'] = '6.1;7.0;7.5;8.0;8.6;8.9;9.0' + + load( + name=name, + sources=[os.path.join(package_dir, 'pair_e3gnn', 'pair_d3_for_ase.cu')], + extra_cuda_cflags=['-O3', '--expt-relaxed-constexpr', '-fmad=false'], + build_directory=compile_dir, + verbose=True, + is_python_module=False, + ) + + return ctypes.CDLL(os.path.join(compile_dir, f'{name}{LIB_EXT}')) + + +class PairD3(ctypes.Structure): + pass # Opaque structure; only used as a pointer + + +class D3Calculator(Calculator): + """ASE calculator for accelerated D3 van der Waals (vdW) correction. + + Example: + from ase.calculators.mixing import SumCalculator + calc_1 = SevenNetCalculator() + calc_2 = D3Calculator() + return SumCalculator([calc_1, calc_2]) + + This calculator interfaces with the `libpaird3.so` library, + which is compiled by nvcc during the package installation. + If you encounter any errors, please verify + the installation process and the compilation options in `setup.py`. + Note: Multi-GPU parallel MD is not supported in this mode. + Note: Cffi could be used, but it was avoided to reduce dependencies. + """ + + # Here, free_energy = energy + implemented_properties = ['free_energy', 'energy', 'forces', 'stress'] + + def __init__( + self, + damping_type: str = 'damp_bj', # damp_bj, damp_zero + functional_name: str = 'pbe', # check the source code + vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au + cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au + **kwargs, + ): + super().__init__(**kwargs) + + if not torch.cuda.is_available(): + raise NotImplementedError('CPU + D3 is not implemented yet') + + self.rthr = vdw_cutoff + self.cnthr = cn_cutoff + self.damp_name = damping_type.lower() + self.func_name = functional_name.lower() + + if self.damp_name not in ['damp_bj', 'damp_zero']: + raise ValueError('Error: Invalid damping type.') + + self._lib = _load('pair_d3') + + self._lib.pair_init.restype = ctypes.POINTER(PairD3) + self.pair = self._lib.pair_init() + + self._lib.pair_set_atom.argtypes = [ + ctypes.POINTER(PairD3), # PairD3* pair + ctypes.c_int, # int natoms + ctypes.c_int, # int ntypes + ctypes.POINTER(ctypes.c_int), # int* types + ctypes.POINTER(ctypes.c_double), # double* x + ] + self._lib.pair_set_atom.restype = None + + self._lib.pair_set_domain.argtypes = [ + ctypes.POINTER(PairD3), # PairD3* pair + ctypes.c_int, # int xperiodic + ctypes.c_int, # int yperiodic + ctypes.c_int, # int zperiodic + ctypes.POINTER(ctypes.c_double), # double* boxlo + ctypes.POINTER(ctypes.c_double), # double* boxhi + ctypes.c_double, # double xy + ctypes.c_double, # double xz + ctypes.c_double, # double yz + ] + self._lib.pair_set_domain.restype = None + + self._lib.pair_run_settings.argtypes = [ + ctypes.POINTER(PairD3), # PairD3* pair + ctypes.c_double, # double rthr + ctypes.c_double, # double cnthr + ctypes.c_char_p, # const char* damp_name + ctypes.c_char_p, # const char* func_name + ] + self._lib.pair_run_settings.restype = None + + self._lib.pair_run_coeff.argtypes = [ + ctypes.POINTER(PairD3), # PairD3* pair + ctypes.POINTER(ctypes.c_int), # int* atomic_numbers + ] + self._lib.pair_run_coeff.restype = None + + self._lib.pair_run_compute.argtypes = [ctypes.POINTER(PairD3)] + self._lib.pair_run_compute.restype = None + + self._lib.pair_get_energy.argtypes = [ctypes.POINTER(PairD3)] + self._lib.pair_get_energy.restype = ctypes.c_double + + self._lib.pair_get_force.argtypes = [ctypes.POINTER(PairD3)] + self._lib.pair_get_force.restype = ctypes.POINTER(ctypes.c_double) + + self._lib.pair_get_stress.argtypes = [ctypes.POINTER(PairD3)] + self._lib.pair_get_stress.restype = ctypes.POINTER(ctypes.c_double * 6) + + self._lib.pair_fin.argtypes = [ctypes.POINTER(PairD3)] + self._lib.pair_fin.restype = None + + def _idx_to_numbers(self, Z_of_atoms): + unique_numbers = list(dict.fromkeys(Z_of_atoms)) + return unique_numbers + + def _idx_to_types(self, Z_of_atoms): + unique_numbers = list(dict.fromkeys(Z_of_atoms)) + mapping = {num: idx + 1 for idx, num in enumerate(unique_numbers)} + atom_types = [mapping[num] for num in Z_of_atoms] + return atom_types + + def _convert_domain_ase2lammps(self, cell): + qtrans, ltrans = np.linalg.qr(cell.T, mode='complete') + lammps_cell = ltrans.T + signs = np.sign(np.diag(lammps_cell)) + lammps_cell = lammps_cell * signs + qtrans = qtrans * signs + lammps_cell = lammps_cell[(0, 1, 2, 1, 2, 2), (0, 1, 2, 0, 0, 1)] + rotator = qtrans.T + return lammps_cell, rotator + + def _stress2tensor(self, stress): + tensor = np.array( + [ + [stress[0], stress[3], stress[4]], + [stress[3], stress[1], stress[5]], + [stress[4], stress[5], stress[2]], + ] + ) + return tensor + + def _tensor2stress(self, tensor): + stress = -np.array( + [ + tensor[0, 0], + tensor[1, 1], + tensor[2, 2], + tensor[1, 2], + tensor[0, 2], + tensor[0, 1], + ] + ) + return stress + + def calculate(self, atoms=None, properties=None, system_changes=all_changes): + Calculator.calculate(self, atoms, properties, system_changes) + if atoms is None: + raise ValueError('No atoms to evaluate') + + if atoms.get_cell().sum() == 0: + print( + 'Warning: D3Calculator requires a cell.\n' + 'Warning: An orthogonal cell large enough is generated.' + ) + positions = atoms.get_positions() + min_pos = positions.min(axis=0) + max_pos = positions.max(axis=0) + max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726 + + cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin + cell = np.eye(3) * cell_lengths + + atoms.set_cell(cell) + atoms.set_pbc([True, True, True]) # for minus positions + + cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell()) + + Z_of_atoms = atoms.get_atomic_numbers() + natoms = len(atoms) + ntypes = len(set(Z_of_atoms)) + types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms)) + + positions = atoms.get_positions() @ rotator.T + x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten()) + + atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms)) + + boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0) + boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2]) + xy = cell[3] + xz = cell[4] + yz = cell[5] + xperiodic, yperiodic, zperiodic = atoms.get_pbc() + + lib = self._lib + assert lib is not None + lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat) + + xperiodic = xperiodic.astype(int) + yperiodic = yperiodic.astype(int) + zperiodic = zperiodic.astype(int) + lib.pair_set_domain( + self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz + ) + + lib.pair_run_settings( + self.pair, + self.rthr, + self.cnthr, + self.damp_name.encode('utf-8'), + self.func_name.encode('utf-8'), + ) + + lib.pair_run_coeff(self.pair, atomic_numbers) + lib.pair_run_compute(self.pair) + + result_E = lib.pair_get_energy(self.pair) + + result_F_ptr = lib.pair_get_force(self.pair) + result_F_size = natoms * 3 + result_F = np.ctypeslib.as_array( + result_F_ptr, shape=(result_F_size,) + ).reshape((natoms, 3)) + result_F = np.array(result_F) + result_F = result_F @ rotator + + result_S = lib.pair_get_stress(self.pair) + result_S = np.array(result_S.contents) + result_S = ( + self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator) + / atoms.get_volume() + ) + + self.results = { + 'free_energy': result_E, + 'energy': result_E, + 'forces': result_F, + 'stress': result_S, + } + + def predict_one(self, atoms): + atoms = atoms.copy() + if atoms is None: + raise ValueError('No atoms to evaluate') + + if atoms.get_cell().sum() == 0: + print( + 'Warning: D3Calculator requires a cell.\n' + 'Warning: An orthogonal cell large enough is generated.' + ) + positions = atoms.get_positions() + min_pos = positions.min(axis=0) + max_pos = positions.max(axis=0) + max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726 + + cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin + cell = np.eye(3) * cell_lengths + + atoms.set_cell(cell) + atoms.set_pbc([True, True, True]) # for minus positions + + cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell()) + + Z_of_atoms = atoms.get_atomic_numbers() + natoms = len(atoms) + ntypes = len(set(Z_of_atoms)) + types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms)) + + positions = atoms.get_positions() @ rotator.T + x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten()) + + atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms)) + + boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0) + boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2]) + xy = cell[3] + xz = cell[4] + yz = cell[5] + xperiodic, yperiodic, zperiodic = atoms.get_pbc() + + lib = self._lib + assert lib is not None + lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat) + + xperiodic = xperiodic.astype(int) + yperiodic = yperiodic.astype(int) + zperiodic = zperiodic.astype(int) + lib.pair_set_domain( + self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz + ) + + lib.pair_run_settings( + self.pair, + self.rthr, + self.cnthr, + self.damp_name.encode('utf-8'), + self.func_name.encode('utf-8'), + ) + + lib.pair_run_coeff(self.pair, atomic_numbers) + lib.pair_run_compute(self.pair) + + result_E = lib.pair_get_energy(self.pair) + + result_F_ptr = lib.pair_get_force(self.pair) + result_F_size = natoms * 3 + result_F = np.ctypeslib.as_array( + result_F_ptr, shape=(result_F_size,) + ).reshape((natoms, 3)) + result_F = np.array(result_F) + result_F = result_F @ rotator + + result_S = lib.pair_get_stress(self.pair) + result_S = np.array(result_S.contents) + result_S = ( + self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator) + / atoms.get_volume() + ) + + prediction = { + 'free_energy': float(result_E), + 'energy': float(result_E), + 'forces': result_F.copy(), + 'stress': result_S.copy(), + } + + return prediction + + + def __del__(self): + if self._lib is not None: + self._lib.pair_fin(self.pair) + self._lib = None + self.pair = None + diff --git a/mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py b/mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py index 859cde23e3c4b4153e7c6a57c5409b20a410ace8..9d8e2846e21ef48ed12e2e1c8a2b3f367ad18752 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py @@ -1,552 +1,552 @@ -import os -import pathlib -import uuid -import warnings -from copy import deepcopy -from datetime import datetime -from typing import Any, Dict, Optional, Union - -import pandas as pd -from packaging.version import Version -from torch import Tensor -from torch import load as torch_load - -import sevenn -import sevenn._const as consts -import sevenn._keys as KEY -import sevenn.scripts.backward_compatibility as compat -from sevenn import model_build -from sevenn.nn.scale import get_resolved_shift_scale -from sevenn.nn.sequential import AtomGraphSequential - - -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): - import numpy as np - - def acl(a, b, rtol=rtol, atol=atol): - return np.allclose(a, b, rtol=rtol, atol=atol) - - assert len(atoms1) == len(atoms2) - assert acl(atoms1.get_cell(), atoms2.get_cell()) - assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) - assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) - assert acl( - atoms1.get_stress(voigt=False), - atoms2.get_stress(voigt=False), - rtol * 10, - atol * 10, - ) - # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) - - -def copy_state_dict(state_dict) -> dict: - if isinstance(state_dict, dict): - return {key: copy_state_dict(value) for key, value in state_dict.items()} - elif isinstance(state_dict, list): - return [copy_state_dict(item) for item in state_dict] # type: ignore - elif isinstance(state_dict, Tensor): - return state_dict.clone() # type: ignore - else: - # For non-tensor values (e.g., scalars, None), return as-is - return state_dict - - -def _config_cp_routine(config): - cp_ver = Version(config.get('version', None)) - this_ver = Version(sevenn.__version__) - if cp_ver > this_ver: - warnings.warn(f'The checkpoint version ({cp_ver}) is newer than this source' - f'({this_ver}). This may cause unexpected behaviors') - - defaults = {**consts.model_defaults(config)} - config = compat.patch_old_config(config) # type: ignore - - scaler = model_build.init_shift_scale(config) - shift, scale = get_resolved_shift_scale( - scaler, config.get(KEY.TYPE_MAP), config.get(KEY.MODAL_MAP, None) - ) - config['shift'] = shift - config['scale'] = scale - - for k, v in defaults.items(): - if k in config: - continue - if os.getenv('SEVENN_DEBUG', False): - warnings.warn(f'{k} not in config, use default value {v}', UserWarning) - config[k] = v - - for k, v in config.items(): - if isinstance(v, Tensor): - config[k] = v.cpu() - return config - - -def _convert_e3nn_and_cueq(stct_src, stct_dst, src_config, from_cueq): - """ - manually check keys and assert if something unexpected happens - """ - n_layer = src_config['num_convolution_layer'] - - linear_module_names = [ - 'onehot_to_feature_x', - 'reduce_input_to_hidden', - 'reduce_hidden_to_energy', - ] - convolution_module_names = [] - fc_tensor_product_module_names = [] - for i in range(n_layer): - linear_module_names.append(f'{i}_self_interaction_1') - linear_module_names.append(f'{i}_self_interaction_2') - if src_config.get(KEY.SELF_CONNECTION_TYPE) == 'linear': - linear_module_names.append(f'{i}_self_connection_intro') - elif src_config.get(KEY.SELF_CONNECTION_TYPE) == 'nequip': - fc_tensor_product_module_names.append(f'{i}_self_connection_intro') - convolution_module_names.append(f'{i}_convolution') - - # Rule: those keys can be safely ignored before state dict load, - # except for linear.bias. This should be aborted in advance to - # this function. Others are not parameters but constants. - cue_only_linear_followers = ['linear.f.tp.f_fx.module.c'] - e3nn_only_linear_followers = ['linear.bias', 'linear.output_mask'] - ignores_in_linear = cue_only_linear_followers + e3nn_only_linear_followers - - cue_only_conv_followers = [ - 'convolution.f.tp.f_fx.module.c', - 'convolution.f.tp.module.module.f.module.module._f.data', - ] - e3nn_only_conv_followers = [ - 'convolution._compiled_main_left_right._w3j', - 'convolution.weight', - 'convolution.output_mask', - ] - ignores_in_conv = cue_only_conv_followers + e3nn_only_conv_followers - - cue_only_fc_followers = ['fc_tensor_product.f.tp.f_fx.module.c'] - e3nn_only_fc_followers = [ - 'fc_tensor_product.output_mask', - ] - ignores_in_fc = cue_only_fc_followers + e3nn_only_fc_followers - - updated_keys = [] - for k, v in stct_src.items(): - module_name = k.split('.')[0] - flag = False - if module_name in linear_module_names: - for ignore in ignores_in_linear: - if '.'.join([module_name, ignore]) in k: - flag = True - break - if not flag and k == '.'.join([module_name, 'linear.weight']): - updated_keys.append(k) - stct_dst[k] = v.clone().reshape(stct_dst[k].shape) - flag = True - assert flag, f'Unexpected key from linear: {k}' - elif module_name in convolution_module_names: - for ignore in ignores_in_conv: - if '.'.join([module_name, ignore]) in k: - flag = True - break - if not flag and ( - k.startswith(f'{module_name}.weight_nn') - or k == '.'.join([module_name, 'denominator']) - ): - updated_keys.append(k) - stct_dst[k] = v.clone().reshape(stct_dst[k].shape) - flag = True - assert flag, f'Unexpected key from linear: {k}' - elif module_name in fc_tensor_product_module_names: - for ignore in ignores_in_fc: - if '.'.join([module_name, ignore]) in k: - flag = True - break - if not flag and k == '.'.join([module_name, 'fc_tensor_product.weight']): - updated_keys.append(k) - stct_dst[k] = v.clone().reshape(stct_dst[k].shape) - flag = True - assert flag, f'Unexpected key from fc tensor product: {k}' - else: - # assert k in stct_dst - updated_keys.append(k) - stct_dst[k] = v.clone().reshape(stct_dst[k].shape) - - return stct_dst - - -class SevenNetCheckpoint: - """ - Tool box for checkpoint processed from SevenNet. - """ - - def __init__(self, checkpoint_path: Union[pathlib.Path, str]): - self._checkpoint_path = os.path.abspath(checkpoint_path) - self._config = None - self._epoch = None - self._model_state_dict = None - self._optimizer_state_dict = None - self._scheduler_state_dict = None - self._hash = None - self._time = None - - self._loaded = False - - def __repr__(self) -> str: - cfg = self.config # just alias - if len(cfg) == 0: - return '' - dct = { - 'Sevennet version': cfg.get('version', 'Not found'), - 'When': self.time, - 'Hash': self.hash, - 'Cutoff': cfg.get('cutoff'), - 'Channel': cfg.get('channel'), - 'Lmax': cfg.get('lmax'), - 'Group (parity)': 'O3' if cfg.get('is_parity') else 'SO3', - 'Interaction layers': cfg.get('num_convolution_layer'), - 'Self connection type': cfg.get('self_connection_type', 'nequip'), - 'Last epoch': self.epoch, - 'Elements': len(cfg.get('chemical_species', [])), - } - if cfg.get('use_modality', False): - dct['Modality'] = ', '.join(list(cfg.get('_modal_map', {}).keys())) - - df = pd.DataFrame.from_dict([dct]).T # type: ignore - df.columns = [''] - return df.to_string() - - @property - def checkpoint_path(self) -> str: - return str(self._checkpoint_path) - - @property - def config(self) -> Dict[str, Any]: - if not self._loaded: - self._load() - assert isinstance(self._config, dict) - return deepcopy(self._config) - - @property - def model_state_dict(self) -> Dict[str, Any]: - if not self._loaded: - self._load() - assert isinstance(self._model_state_dict, dict) - return copy_state_dict(self._model_state_dict) - - @property - def optimizer_state_dict(self) -> Dict[str, Any]: - if not self._loaded: - self._load() - assert isinstance(self._optimizer_state_dict, dict) - return copy_state_dict(self._optimizer_state_dict) - - @property - def scheduler_state_dict(self) -> Dict[str, Any]: - if not self._loaded: - self._load() - assert isinstance(self._scheduler_state_dict, dict) - return copy_state_dict(self._scheduler_state_dict) - - @property - def epoch(self) -> Optional[int]: - if not self._loaded: - self._load() - return self._epoch - - @property - def time(self) -> str: - if not self._loaded: - self._load() - assert isinstance(self._time, str) - return self._time - - @property - def hash(self) -> str: - if not self._loaded: - self._load() - assert isinstance(self._hash, str) - return self._hash - - def _load(self) -> None: - assert not self._loaded - cp_path = self.checkpoint_path # just alias - - cp = torch_load(cp_path, weights_only=False, map_location='cpu') - self._config_original = cp.get('config', {}) - self._model_state_dict = cp.get('model_state_dict', {}) - self._optimizer_state_dict = cp.get('optimizer_state_dict', {}) - self._scheduler_state_dict = cp.get('scheduler_state_dict', {}) - self._epoch = cp.get('epoch', None) - self._time = cp.get('time', 'Not found') - self._hash = cp.get('hash', 'Not found') - - if len(self._config_original) == 0: - warnings.warn(f'config is not found from {cp_path}') - self._config = {} - else: - self._config = _config_cp_routine(self._config_original) - - if len(self._model_state_dict) == 0: - warnings.warn(f'model_state_dict is not found from {cp_path}') - - self._loaded = True - - def build_model(self, backend: Optional[str] = None) -> AtomGraphSequential: - from .model_build import build_E3_equivariant_model - - use_cue = not backend or backend.lower() in ['cue', 'cueq'] - try: - cp_using_cue = self.config[KEY.CUEQUIVARIANCE_CONFIG]['use'] - except KeyError: - cp_using_cue = False - - if (not backend) or (use_cue == cp_using_cue): - # backend not given, or checkpoint backend is same as requested - model = build_E3_equivariant_model(self.config) - state_dict = compat.patch_state_dict_if_old( - self.model_state_dict, self.config, model - ) - else: - cfg_new = self.config - cfg_new[KEY.CUEQUIVARIANCE_CONFIG] = {'use': use_cue} - model = build_E3_equivariant_model(cfg_new) - stct_src = compat.patch_state_dict_if_old( - self.model_state_dict, self.config, model - ) - state_dict = _convert_e3nn_and_cueq( - stct_src, model.state_dict(), self.config, from_cueq=cp_using_cue - ) - - missing, not_used = model.load_state_dict(state_dict, strict=False) - if len(not_used) > 0: - warnings.warn(f'Some keys are not used: {not_used}', UserWarning) - - assert len(missing) == 0, f'Missing keys: {missing}' - return model - - def yaml_dict(self, mode: str) -> dict: - """ - Return dict for input.yaml from checkpoint config - Dataset paths and statistic values are removed intentionally - """ - if mode not in ['reproduce', 'continue', 'continue_modal']: - raise ValueError(f'Unknown mode: {mode}') - - ignore = [ - 'when', - KEY.DDP_BACKEND, - KEY.LOCAL_RANK, - KEY.IS_DDP, - KEY.DEVICE, - KEY.MODEL_TYPE, - KEY.SHIFT, - KEY.SCALE, - KEY.CONV_DENOMINATOR, - KEY.SAVE_DATASET, - KEY.SAVE_BY_LABEL, - KEY.SAVE_BY_TRAIN_VALID, - KEY.CONTINUE, - KEY.LOAD_DATASET, # old - ] - - cfg = self.config - len_atoms = len(cfg[KEY.TYPE_MAP]) - - world_size = cfg.pop(KEY.WORLD_SIZE, 1) - cfg[KEY.BATCH_SIZE] = cfg[KEY.BATCH_SIZE] * world_size - cfg[KEY.LOAD_TRAINSET] = '**path_to_training_set**' - - major, minor, _ = cfg.pop('version', '0.0.0').split('.')[:3] - if int(major) == 0 and int(minor) <= 9: - warnings.warn('checkpoint version too old, yaml may wrong') - - ret = {'model': {}, 'train': {}, 'data': {}} - for k, v in cfg.items(): - if k.startswith('_') or k in ignore or k.endswith('set_path'): - continue - if k in consts.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG: - ret['model'][k] = v - elif k in consts.DEFAULT_TRAINING_CONFIG: - ret['train'][k] = v - elif k in consts.DEFAULT_DATA_CONFIG: - ret['data'][k] = v - - ret['model'][KEY.CHEMICAL_SPECIES] = ( - 'univ' if len_atoms == consts.NUM_UNIV_ELEMENT else 'auto' - ) - ret['data'][KEY.LOAD_TRAINSET] = '**path_to_trainset**' - ret['data'][KEY.LOAD_VALIDSET] = '**path_to_validset**' - - # TODO - ret['data'][KEY.SHIFT] = '**failed to infer shift, should be set**' - ret['data'][KEY.SCALE] = '**failed to infer scale, should be set**' - - if mode.startswith('continue'): - ret['train'].update( - {KEY.CONTINUE: {KEY.CHECKPOINT: self.checkpoint_path}} - ) - modal_names = None - if mode == 'continue_modal' and not cfg.get(KEY.USE_MODALITY, False): - ret['train'][KEY.USE_MODALITY] = True - - # suggest defaults - ret['model'][KEY.USE_MODAL_NODE_EMBEDDING] = False - ret['model'][KEY.USE_MODAL_SELF_INTER_INTRO] = True - ret['model'][KEY.USE_MODAL_SELF_INTER_OUTRO] = True - ret['model'][KEY.USE_MODAL_OUTPUT_BLOCK] = True - - ret['data'][KEY.USE_MODAL_WISE_SHIFT] = True - ret['data'][KEY.USE_MODAL_WISE_SCALE] = False - - modal_names = ['my_modal1', 'my_modal2'] - elif cfg.get(KEY.USE_MODALITY, False): - modal_names = list(cfg[KEY.MODAL_MAP].keys()) - - if modal_names: - ret['data'][KEY.LOAD_TRAINSET] = [ - {'data_modality': mm, 'file_list': [{'file': f'**path_to_{mm}**'}]} - for mm in modal_names - ] - - return ret - - def append_modal( - self, - dst_config, - original_modal_name: str = 'origin', - working_dir: str = os.getcwd(), - ): - """ """ - import sevenn.train.modal_dataset as modal_dataset - from sevenn.model_build import init_shift_scale - from sevenn.scripts.convert_model_modality import _append_modal_weight - - src_config = self.config - src_has_no_modal = not src_config.get(KEY.USE_MODALITY, False) - - # inherit element things first - chem_keys = [ - KEY.TYPE_MAP, - KEY.NUM_SPECIES, - KEY.CHEMICAL_SPECIES, - KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, - ] - dst_config.update({k: src_config[k] for k in chem_keys}) - - if dst_config[KEY.USE_MODAL_WISE_SHIFT] and ( - KEY.SHIFT not in dst_config or not isinstance(dst_config[KEY.SHIFT], str) - ): - raise ValueError('To use modal wise shift, keyword shift is required') - if dst_config[KEY.USE_MODAL_WISE_SCALE] and ( - KEY.SCALE not in dst_config or not isinstance(dst_config[KEY.SCALE], str) - ): - raise ValueError('To use modal wise scale, keyword scale is required') - - if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SHIFT]: - dst_config[KEY.SHIFT] = src_config[KEY.SHIFT] - if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SCALE]: - dst_config[KEY.SCALE] = src_config[KEY.SCALE] - - # get statistics of given datasets of yaml - # dst_config updated - _ = modal_dataset.from_config(dst_config, working_dir=working_dir) - dst_modal_map = dst_config[KEY.MODAL_MAP] - - found_modal_names = list(dst_modal_map.keys()) - if len(found_modal_names) == 0: - raise ValueError('No modality is found from config') - - # Check difference btw given modals and new modal map - orig_modal_map = src_config.get(KEY.MODAL_MAP, {original_modal_name: 0}) - assert isinstance(orig_modal_map, dict) - new_modal_map = orig_modal_map.copy() - for modal_name in found_modal_names: - if modal_name in orig_modal_map: # duplicate, skipping - continue - new_modal_map[modal_name] = len(new_modal_map) # assign new - print(f'New modals: {list(new_modal_map.keys())}') - - if src_has_no_modal: - append_num = len(new_modal_map) - else: - append_num = len(new_modal_map) - len(orig_modal_map) - if append_num == 0: - raise ValueError('Nothing to append from checkpoint') - - dst_config[KEY.NUM_MODALITIES] = len(new_modal_map) - dst_config[KEY.MODAL_MAP] = new_modal_map - - # update dst_config's shift scales based on src_config - for ss_key, use_mw in ( - (KEY.SHIFT, dst_config[KEY.USE_MODAL_WISE_SHIFT]), - (KEY.SCALE, dst_config[KEY.USE_MODAL_WISE_SCALE]), - ): - if not use_mw: # not using mw ss, just assign - assert not isinstance(dst_config[ss_key], dict) - dst_config[ss_key] = src_config[ss_key] - elif src_has_no_modal: - assert isinstance(dst_config[ss_key], dict) - # mw ss, update by dict but use original_modal_name - dst_config[ss_key].update({original_modal_name: src_config[ss_key]}) - else: - assert isinstance(dst_config[ss_key], dict) - # mw ss, update by dict - dst_config[ss_key].update(src_config[ss_key]) - scaler = init_shift_scale(dst_config) - - # finally, prepare updated continuable state dict using above - orig_model = self.build_model() - orig_state_dict = orig_model.state_dict() - - new_state_dict = copy_state_dict(orig_state_dict) - for stct_key in orig_state_dict: - sp = stct_key.split('.') - k, follower = sp[0], '.'.join(sp[1:]) - if k == 'rescale_atomic_energy' and follower == 'shift': - new_state_dict[stct_key] = scaler.shift.clone() - elif k == 'rescale_atomic_energy' and follower == 'scale': - new_state_dict[stct_key] = scaler.scale.clone() - elif follower == 'linear.weight' and ( # append linear layer - ( - dst_config[KEY.USE_MODAL_NODE_EMBEDDING] - and k.endswith('onehot_to_feature_x') - ) - or ( - dst_config[KEY.USE_MODAL_SELF_INTER_INTRO] - and k.endswith('self_interaction_1') - ) - or ( - dst_config[KEY.USE_MODAL_SELF_INTER_OUTRO] - and k.endswith('self_interaction_2') - ) - or ( - dst_config[KEY.USE_MODAL_OUTPUT_BLOCK] - and k == 'reduce_input_to_hidden' - ) - ): - orig_linear = getattr(orig_model._modules[k], 'linear') - # assert normalization element - new_state_dict[stct_key] = _append_modal_weight( - orig_state_dict, - k, - orig_linear.irreps_in, - orig_linear.irreps_out, - append_num, - ) - - dst_config['version'] = sevenn.__version__ - - return new_state_dict - - def get_checkpoint_dict(self) -> dict: - """ - Return duplicate of this checkpoint with new hash and time. - Convenient for creating variant of the checkpoint - """ - return { - 'config': self.config, - 'epoch': self.epoch, - 'model_state_dict': self.model_state_dict, - 'optimizer_state_dict': self.optimizer_state_dict, - 'scheduler_state_dict': self.scheduler_state_dict, - 'time': datetime.now().strftime('%Y-%m-%d %H:%M'), - 'hash': uuid.uuid4().hex, - } +import os +import pathlib +import uuid +import warnings +from copy import deepcopy +from datetime import datetime +from typing import Any, Dict, Optional, Union + +import pandas as pd +from packaging.version import Version +from torch import Tensor +from torch import load as torch_load + +import sevenn +import sevenn._const as consts +import sevenn._keys as KEY +import sevenn.scripts.backward_compatibility as compat +from sevenn import model_build +from sevenn.nn.scale import get_resolved_shift_scale +from sevenn.nn.sequential import AtomGraphSequential + + +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): + import numpy as np + + def acl(a, b, rtol=rtol, atol=atol): + return np.allclose(a, b, rtol=rtol, atol=atol) + + assert len(atoms1) == len(atoms2) + assert acl(atoms1.get_cell(), atoms2.get_cell()) + assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) + assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) + assert acl( + atoms1.get_stress(voigt=False), + atoms2.get_stress(voigt=False), + rtol * 10, + atol * 10, + ) + # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) + + +def copy_state_dict(state_dict) -> dict: + if isinstance(state_dict, dict): + return {key: copy_state_dict(value) for key, value in state_dict.items()} + elif isinstance(state_dict, list): + return [copy_state_dict(item) for item in state_dict] # type: ignore + elif isinstance(state_dict, Tensor): + return state_dict.clone() # type: ignore + else: + # For non-tensor values (e.g., scalars, None), return as-is + return state_dict + + +def _config_cp_routine(config): + cp_ver = Version(config.get('version', None)) + this_ver = Version(sevenn.__version__) + if cp_ver > this_ver: + warnings.warn(f'The checkpoint version ({cp_ver}) is newer than this source' + f'({this_ver}). This may cause unexpected behaviors') + + defaults = {**consts.model_defaults(config)} + config = compat.patch_old_config(config) # type: ignore + + scaler = model_build.init_shift_scale(config) + shift, scale = get_resolved_shift_scale( + scaler, config.get(KEY.TYPE_MAP), config.get(KEY.MODAL_MAP, None) + ) + config['shift'] = shift + config['scale'] = scale + + for k, v in defaults.items(): + if k in config: + continue + if os.getenv('SEVENN_DEBUG', False): + warnings.warn(f'{k} not in config, use default value {v}', UserWarning) + config[k] = v + + for k, v in config.items(): + if isinstance(v, Tensor): + config[k] = v.cpu() + return config + + +def _convert_e3nn_and_cueq(stct_src, stct_dst, src_config, from_cueq): + """ + manually check keys and assert if something unexpected happens + """ + n_layer = src_config['num_convolution_layer'] + + linear_module_names = [ + 'onehot_to_feature_x', + 'reduce_input_to_hidden', + 'reduce_hidden_to_energy', + ] + convolution_module_names = [] + fc_tensor_product_module_names = [] + for i in range(n_layer): + linear_module_names.append(f'{i}_self_interaction_1') + linear_module_names.append(f'{i}_self_interaction_2') + if src_config.get(KEY.SELF_CONNECTION_TYPE) == 'linear': + linear_module_names.append(f'{i}_self_connection_intro') + elif src_config.get(KEY.SELF_CONNECTION_TYPE) == 'nequip': + fc_tensor_product_module_names.append(f'{i}_self_connection_intro') + convolution_module_names.append(f'{i}_convolution') + + # Rule: those keys can be safely ignored before state dict load, + # except for linear.bias. This should be aborted in advance to + # this function. Others are not parameters but constants. + cue_only_linear_followers = ['linear.f.tp.f_fx.module.c'] + e3nn_only_linear_followers = ['linear.bias', 'linear.output_mask'] + ignores_in_linear = cue_only_linear_followers + e3nn_only_linear_followers + + cue_only_conv_followers = [ + 'convolution.f.tp.f_fx.module.c', + 'convolution.f.tp.module.module.f.module.module._f.data', + ] + e3nn_only_conv_followers = [ + 'convolution._compiled_main_left_right._w3j', + 'convolution.weight', + 'convolution.output_mask', + ] + ignores_in_conv = cue_only_conv_followers + e3nn_only_conv_followers + + cue_only_fc_followers = ['fc_tensor_product.f.tp.f_fx.module.c'] + e3nn_only_fc_followers = [ + 'fc_tensor_product.output_mask', + ] + ignores_in_fc = cue_only_fc_followers + e3nn_only_fc_followers + + updated_keys = [] + for k, v in stct_src.items(): + module_name = k.split('.')[0] + flag = False + if module_name in linear_module_names: + for ignore in ignores_in_linear: + if '.'.join([module_name, ignore]) in k: + flag = True + break + if not flag and k == '.'.join([module_name, 'linear.weight']): + updated_keys.append(k) + stct_dst[k] = v.clone().reshape(stct_dst[k].shape) + flag = True + assert flag, f'Unexpected key from linear: {k}' + elif module_name in convolution_module_names: + for ignore in ignores_in_conv: + if '.'.join([module_name, ignore]) in k: + flag = True + break + if not flag and ( + k.startswith(f'{module_name}.weight_nn') + or k == '.'.join([module_name, 'denominator']) + ): + updated_keys.append(k) + stct_dst[k] = v.clone().reshape(stct_dst[k].shape) + flag = True + assert flag, f'Unexpected key from linear: {k}' + elif module_name in fc_tensor_product_module_names: + for ignore in ignores_in_fc: + if '.'.join([module_name, ignore]) in k: + flag = True + break + if not flag and k == '.'.join([module_name, 'fc_tensor_product.weight']): + updated_keys.append(k) + stct_dst[k] = v.clone().reshape(stct_dst[k].shape) + flag = True + assert flag, f'Unexpected key from fc tensor product: {k}' + else: + # assert k in stct_dst + updated_keys.append(k) + stct_dst[k] = v.clone().reshape(stct_dst[k].shape) + + return stct_dst + + +class SevenNetCheckpoint: + """ + Tool box for checkpoint processed from SevenNet. + """ + + def __init__(self, checkpoint_path: Union[pathlib.Path, str]): + self._checkpoint_path = os.path.abspath(checkpoint_path) + self._config = None + self._epoch = None + self._model_state_dict = None + self._optimizer_state_dict = None + self._scheduler_state_dict = None + self._hash = None + self._time = None + + self._loaded = False + + def __repr__(self) -> str: + cfg = self.config # just alias + if len(cfg) == 0: + return '' + dct = { + 'Sevennet version': cfg.get('version', 'Not found'), + 'When': self.time, + 'Hash': self.hash, + 'Cutoff': cfg.get('cutoff'), + 'Channel': cfg.get('channel'), + 'Lmax': cfg.get('lmax'), + 'Group (parity)': 'O3' if cfg.get('is_parity') else 'SO3', + 'Interaction layers': cfg.get('num_convolution_layer'), + 'Self connection type': cfg.get('self_connection_type', 'nequip'), + 'Last epoch': self.epoch, + 'Elements': len(cfg.get('chemical_species', [])), + } + if cfg.get('use_modality', False): + dct['Modality'] = ', '.join(list(cfg.get('_modal_map', {}).keys())) + + df = pd.DataFrame.from_dict([dct]).T # type: ignore + df.columns = [''] + return df.to_string() + + @property + def checkpoint_path(self) -> str: + return str(self._checkpoint_path) + + @property + def config(self) -> Dict[str, Any]: + if not self._loaded: + self._load() + assert isinstance(self._config, dict) + return deepcopy(self._config) + + @property + def model_state_dict(self) -> Dict[str, Any]: + if not self._loaded: + self._load() + assert isinstance(self._model_state_dict, dict) + return copy_state_dict(self._model_state_dict) + + @property + def optimizer_state_dict(self) -> Dict[str, Any]: + if not self._loaded: + self._load() + assert isinstance(self._optimizer_state_dict, dict) + return copy_state_dict(self._optimizer_state_dict) + + @property + def scheduler_state_dict(self) -> Dict[str, Any]: + if not self._loaded: + self._load() + assert isinstance(self._scheduler_state_dict, dict) + return copy_state_dict(self._scheduler_state_dict) + + @property + def epoch(self) -> Optional[int]: + if not self._loaded: + self._load() + return self._epoch + + @property + def time(self) -> str: + if not self._loaded: + self._load() + assert isinstance(self._time, str) + return self._time + + @property + def hash(self) -> str: + if not self._loaded: + self._load() + assert isinstance(self._hash, str) + return self._hash + + def _load(self) -> None: + assert not self._loaded + cp_path = self.checkpoint_path # just alias + + cp = torch_load(cp_path, weights_only=False, map_location='cpu') + self._config_original = cp.get('config', {}) + self._model_state_dict = cp.get('model_state_dict', {}) + self._optimizer_state_dict = cp.get('optimizer_state_dict', {}) + self._scheduler_state_dict = cp.get('scheduler_state_dict', {}) + self._epoch = cp.get('epoch', None) + self._time = cp.get('time', 'Not found') + self._hash = cp.get('hash', 'Not found') + + if len(self._config_original) == 0: + warnings.warn(f'config is not found from {cp_path}') + self._config = {} + else: + self._config = _config_cp_routine(self._config_original) + + if len(self._model_state_dict) == 0: + warnings.warn(f'model_state_dict is not found from {cp_path}') + + self._loaded = True + + def build_model(self, backend: Optional[str] = None) -> AtomGraphSequential: + from .model_build import build_E3_equivariant_model + + use_cue = not backend or backend.lower() in ['cue', 'cueq'] + try: + cp_using_cue = self.config[KEY.CUEQUIVARIANCE_CONFIG]['use'] + except KeyError: + cp_using_cue = False + + if (not backend) or (use_cue == cp_using_cue): + # backend not given, or checkpoint backend is same as requested + model = build_E3_equivariant_model(self.config) + state_dict = compat.patch_state_dict_if_old( + self.model_state_dict, self.config, model + ) + else: + cfg_new = self.config + cfg_new[KEY.CUEQUIVARIANCE_CONFIG] = {'use': use_cue} + model = build_E3_equivariant_model(cfg_new) + stct_src = compat.patch_state_dict_if_old( + self.model_state_dict, self.config, model + ) + state_dict = _convert_e3nn_and_cueq( + stct_src, model.state_dict(), self.config, from_cueq=cp_using_cue + ) + + missing, not_used = model.load_state_dict(state_dict, strict=False) + if len(not_used) > 0: + warnings.warn(f'Some keys are not used: {not_used}', UserWarning) + + assert len(missing) == 0, f'Missing keys: {missing}' + return model + + def yaml_dict(self, mode: str) -> dict: + """ + Return dict for input.yaml from checkpoint config + Dataset paths and statistic values are removed intentionally + """ + if mode not in ['reproduce', 'continue', 'continue_modal']: + raise ValueError(f'Unknown mode: {mode}') + + ignore = [ + 'when', + KEY.DDP_BACKEND, + KEY.LOCAL_RANK, + KEY.IS_DDP, + KEY.DEVICE, + KEY.MODEL_TYPE, + KEY.SHIFT, + KEY.SCALE, + KEY.CONV_DENOMINATOR, + KEY.SAVE_DATASET, + KEY.SAVE_BY_LABEL, + KEY.SAVE_BY_TRAIN_VALID, + KEY.CONTINUE, + KEY.LOAD_DATASET, # old + ] + + cfg = self.config + len_atoms = len(cfg[KEY.TYPE_MAP]) + + world_size = cfg.pop(KEY.WORLD_SIZE, 1) + cfg[KEY.BATCH_SIZE] = cfg[KEY.BATCH_SIZE] * world_size + cfg[KEY.LOAD_TRAINSET] = '**path_to_training_set**' + + major, minor, _ = cfg.pop('version', '0.0.0').split('.')[:3] + if int(major) == 0 and int(minor) <= 9: + warnings.warn('checkpoint version too old, yaml may wrong') + + ret = {'model': {}, 'train': {}, 'data': {}} + for k, v in cfg.items(): + if k.startswith('_') or k in ignore or k.endswith('set_path'): + continue + if k in consts.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG: + ret['model'][k] = v + elif k in consts.DEFAULT_TRAINING_CONFIG: + ret['train'][k] = v + elif k in consts.DEFAULT_DATA_CONFIG: + ret['data'][k] = v + + ret['model'][KEY.CHEMICAL_SPECIES] = ( + 'univ' if len_atoms == consts.NUM_UNIV_ELEMENT else 'auto' + ) + ret['data'][KEY.LOAD_TRAINSET] = '**path_to_trainset**' + ret['data'][KEY.LOAD_VALIDSET] = '**path_to_validset**' + + # TODO + ret['data'][KEY.SHIFT] = '**failed to infer shift, should be set**' + ret['data'][KEY.SCALE] = '**failed to infer scale, should be set**' + + if mode.startswith('continue'): + ret['train'].update( + {KEY.CONTINUE: {KEY.CHECKPOINT: self.checkpoint_path}} + ) + modal_names = None + if mode == 'continue_modal' and not cfg.get(KEY.USE_MODALITY, False): + ret['train'][KEY.USE_MODALITY] = True + + # suggest defaults + ret['model'][KEY.USE_MODAL_NODE_EMBEDDING] = False + ret['model'][KEY.USE_MODAL_SELF_INTER_INTRO] = True + ret['model'][KEY.USE_MODAL_SELF_INTER_OUTRO] = True + ret['model'][KEY.USE_MODAL_OUTPUT_BLOCK] = True + + ret['data'][KEY.USE_MODAL_WISE_SHIFT] = True + ret['data'][KEY.USE_MODAL_WISE_SCALE] = False + + modal_names = ['my_modal1', 'my_modal2'] + elif cfg.get(KEY.USE_MODALITY, False): + modal_names = list(cfg[KEY.MODAL_MAP].keys()) + + if modal_names: + ret['data'][KEY.LOAD_TRAINSET] = [ + {'data_modality': mm, 'file_list': [{'file': f'**path_to_{mm}**'}]} + for mm in modal_names + ] + + return ret + + def append_modal( + self, + dst_config, + original_modal_name: str = 'origin', + working_dir: str = os.getcwd(), + ): + """ """ + import sevenn.train.modal_dataset as modal_dataset + from sevenn.model_build import init_shift_scale + from sevenn.scripts.convert_model_modality import _append_modal_weight + + src_config = self.config + src_has_no_modal = not src_config.get(KEY.USE_MODALITY, False) + + # inherit element things first + chem_keys = [ + KEY.TYPE_MAP, + KEY.NUM_SPECIES, + KEY.CHEMICAL_SPECIES, + KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, + ] + dst_config.update({k: src_config[k] for k in chem_keys}) + + if dst_config[KEY.USE_MODAL_WISE_SHIFT] and ( + KEY.SHIFT not in dst_config or not isinstance(dst_config[KEY.SHIFT], str) + ): + raise ValueError('To use modal wise shift, keyword shift is required') + if dst_config[KEY.USE_MODAL_WISE_SCALE] and ( + KEY.SCALE not in dst_config or not isinstance(dst_config[KEY.SCALE], str) + ): + raise ValueError('To use modal wise scale, keyword scale is required') + + if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SHIFT]: + dst_config[KEY.SHIFT] = src_config[KEY.SHIFT] + if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SCALE]: + dst_config[KEY.SCALE] = src_config[KEY.SCALE] + + # get statistics of given datasets of yaml + # dst_config updated + _ = modal_dataset.from_config(dst_config, working_dir=working_dir) + dst_modal_map = dst_config[KEY.MODAL_MAP] + + found_modal_names = list(dst_modal_map.keys()) + if len(found_modal_names) == 0: + raise ValueError('No modality is found from config') + + # Check difference btw given modals and new modal map + orig_modal_map = src_config.get(KEY.MODAL_MAP, {original_modal_name: 0}) + assert isinstance(orig_modal_map, dict) + new_modal_map = orig_modal_map.copy() + for modal_name in found_modal_names: + if modal_name in orig_modal_map: # duplicate, skipping + continue + new_modal_map[modal_name] = len(new_modal_map) # assign new + print(f'New modals: {list(new_modal_map.keys())}') + + if src_has_no_modal: + append_num = len(new_modal_map) + else: + append_num = len(new_modal_map) - len(orig_modal_map) + if append_num == 0: + raise ValueError('Nothing to append from checkpoint') + + dst_config[KEY.NUM_MODALITIES] = len(new_modal_map) + dst_config[KEY.MODAL_MAP] = new_modal_map + + # update dst_config's shift scales based on src_config + for ss_key, use_mw in ( + (KEY.SHIFT, dst_config[KEY.USE_MODAL_WISE_SHIFT]), + (KEY.SCALE, dst_config[KEY.USE_MODAL_WISE_SCALE]), + ): + if not use_mw: # not using mw ss, just assign + assert not isinstance(dst_config[ss_key], dict) + dst_config[ss_key] = src_config[ss_key] + elif src_has_no_modal: + assert isinstance(dst_config[ss_key], dict) + # mw ss, update by dict but use original_modal_name + dst_config[ss_key].update({original_modal_name: src_config[ss_key]}) + else: + assert isinstance(dst_config[ss_key], dict) + # mw ss, update by dict + dst_config[ss_key].update(src_config[ss_key]) + scaler = init_shift_scale(dst_config) + + # finally, prepare updated continuable state dict using above + orig_model = self.build_model() + orig_state_dict = orig_model.state_dict() + + new_state_dict = copy_state_dict(orig_state_dict) + for stct_key in orig_state_dict: + sp = stct_key.split('.') + k, follower = sp[0], '.'.join(sp[1:]) + if k == 'rescale_atomic_energy' and follower == 'shift': + new_state_dict[stct_key] = scaler.shift.clone() + elif k == 'rescale_atomic_energy' and follower == 'scale': + new_state_dict[stct_key] = scaler.scale.clone() + elif follower == 'linear.weight' and ( # append linear layer + ( + dst_config[KEY.USE_MODAL_NODE_EMBEDDING] + and k.endswith('onehot_to_feature_x') + ) + or ( + dst_config[KEY.USE_MODAL_SELF_INTER_INTRO] + and k.endswith('self_interaction_1') + ) + or ( + dst_config[KEY.USE_MODAL_SELF_INTER_OUTRO] + and k.endswith('self_interaction_2') + ) + or ( + dst_config[KEY.USE_MODAL_OUTPUT_BLOCK] + and k == 'reduce_input_to_hidden' + ) + ): + orig_linear = getattr(orig_model._modules[k], 'linear') + # assert normalization element + new_state_dict[stct_key] = _append_modal_weight( + orig_state_dict, + k, + orig_linear.irreps_in, + orig_linear.irreps_out, + append_num, + ) + + dst_config['version'] = sevenn.__version__ + + return new_state_dict + + def get_checkpoint_dict(self) -> dict: + """ + Return duplicate of this checkpoint with new hash and time. + Convenient for creating variant of the checkpoint + """ + return { + 'config': self.config, + 'epoch': self.epoch, + 'model_state_dict': self.model_state_dict, + 'optimizer_state_dict': self.optimizer_state_dict, + 'scheduler_state_dict': self.scheduler_state_dict, + 'time': datetime.now().strftime('%Y-%m-%d %H:%M'), + 'hash': uuid.uuid4().hex, + } diff --git a/mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py b/mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py index 48f4a20c801e1b84a23d9ac165ec9ff1211c2f34..5999aea48eb6e6d4e45505f294df1a19aabefe4c 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py @@ -1,430 +1,430 @@ -from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Tuple - -import torch -import torch.distributed as dist - -import sevenn._keys as KEY -from sevenn.train.loss import LossDefinition - -from .atom_graph_data import AtomGraphData -from .train.optim import loss_dict - -_ERROR_TYPES = { - 'TotalEnergy': { - 'name': 'Energy', - 'ref_key': KEY.ENERGY, - 'pred_key': KEY.PRED_TOTAL_ENERGY, - 'unit': 'eV', - 'vdim': 1, - }, - 'Energy': { # by default per-atom for energy - 'name': 'Energy', - 'ref_key': KEY.ENERGY, - 'pred_key': KEY.PRED_TOTAL_ENERGY, - 'unit': 'eV/atom', - 'per_atom': True, - 'vdim': 1, - }, - 'Force': { - 'name': 'Force', - 'ref_key': KEY.FORCE, - 'pred_key': KEY.PRED_FORCE, - 'unit': 'eV/Å', - 'vdim': 3, - }, - 'Stress': { - 'name': 'Stress', - 'ref_key': KEY.STRESS, - 'pred_key': KEY.PRED_STRESS, - 'unit': 'kbar', - 'coeff': 1602.1766208, - 'vdim': 6, - }, - 'Stress_GPa': { - 'name': 'Stress', - 'ref_key': KEY.STRESS, - 'pred_key': KEY.PRED_STRESS, - 'unit': 'GPa', - 'coeff': 160.21766208, - 'vdim': 6, - }, - 'TotalLoss': { - 'name': 'TotalLoss', - 'unit': None, - }, -} - - -def get_err_type(name: str) -> Dict[str, Any]: - return deepcopy(_ERROR_TYPES[name]) - - -def _get_loss_function_from_name(loss_functions, name): - for loss_def, w in loss_functions: - if loss_def.name.lower() == name.lower(): - return loss_def, w - return None, None - - -class AverageNumber: - def __init__(self): - self._sum = 0.0 - self._count = 0 - - def update(self, values: torch.Tensor): - self._sum += values.sum().item() - self._count += values.numel() - - def _ddp_reduce(self, device): - _sum = torch.tensor(self._sum, device=device) - _count = torch.tensor(self._count, device=device) - dist.all_reduce(_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(_count, op=dist.ReduceOp.SUM) - self._sum = _sum.item() - self._count = _count.item() - - def get(self): - if self._count == 0: - return torch.nan - return self._sum / self._count - - -class ErrorMetric: - """ - Base class for error metrics We always average error by # of structures, - and designed to collect errors in the middle of iteration (by AverageNumber) - """ - - def __init__( - self, - name: str, - ref_key: str, - pred_key: str, - coeff: float = 1.0, - unit: Optional[str] = None, - per_atom: bool = False, - ignore_unlabeled: bool = True, - **kwargs, - ): - self.name = name - self.unit = unit - self.coeff = coeff - self.ref_key = ref_key - self.pred_key = pred_key - self.per_atom = per_atom - self.ignore_unlabeled = ignore_unlabeled - self.value = AverageNumber() - - def update(self, output: AtomGraphData): - raise NotImplementedError - - def _retrieve(self, output: AtomGraphData): - y_ref = output[self.ref_key] * self.coeff - y_pred = output[self.pred_key] * self.coeff - if self.per_atom: - assert y_ref.dim() == 1 and y_pred.dim() == 1 - natoms = output[KEY.NUM_ATOMS] - y_ref = y_ref / natoms - y_pred = y_pred / natoms - if self.ignore_unlabeled: - unlabelled_idx = torch.isnan(y_ref) - y_ref = y_ref[~unlabelled_idx] - y_pred = y_pred[~unlabelled_idx] - return y_ref, y_pred - - def ddp_reduce(self, device): - self.value._ddp_reduce(device) - - def reset(self): - self.value = AverageNumber() - - def get(self): - return self.value.get() - - def key_str(self, with_unit=True): - if self.unit is None or not with_unit: - return self.name - else: - return f'{self.name} ({self.unit})' - - def __str__(self): - return f'{self.key_str()}: {self.value.get():.6f}' - - -class RMSError(ErrorMetric): - """ - Vector squared error - """ - - def __init__(self, vdim: int = 1, **kwargs): - super().__init__(**kwargs) - self.vdim = vdim - self._se = torch.nn.MSELoss(reduction='none') - - def _square_error(self, y_ref, y_pred, vdim: int): - return self._se(y_ref.view(-1, vdim), y_pred.view(-1, vdim)).sum(dim=1) - - def update(self, output: AtomGraphData): - y_ref, y_pred = self._retrieve(output) - se = self._square_error(y_ref, y_pred, self.vdim) - self.value.update(se) - - def get(self): - return self.value.get() ** 0.5 - - -class ComponentRMSError(ErrorMetric): - """ - Ignore vector dim and just average over components - Results smaller error - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._se = torch.nn.MSELoss(reduction='none') - - def _square_error(self, y_ref, y_pred): - return self._se(y_ref, y_pred) - - def update(self, output: AtomGraphData): - y_ref, y_pred = self._retrieve(output) - y_ref = y_ref.view(-1) - y_pred = y_pred.view(-1) - se = self._square_error(y_ref, y_pred) - self.value.update(se) - - def get(self): - return self.value.get() ** 0.5 - - -class MAError(ErrorMetric): - """ - Average over all component - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def _square_error(self, y_ref, y_pred): - return torch.abs(y_ref - y_pred) - - def update(self, output: AtomGraphData): - y_ref, y_pred = self._retrieve(output) - y_ref = y_ref.reshape((-1,)) - y_pred = y_pred.reshape((-1,)) - se = self._square_error(y_ref, y_pred) - self.value.update(se) - - -class CustomError(ErrorMetric): - """ - Custom error metric - Args: - func: a function that takes y_ref and y_pred - and returns a list of errors - """ - - def __init__(self, func: Callable, **kwargs): - super().__init__(**kwargs) - self.func = func - - def update(self, output: AtomGraphData): - y_ref, y_pred = self._retrieve(output) - se = self.func(y_ref, y_pred) if len(y_ref) > 0 else torch.tensor([]) - self.value.update(se) - - -class LossError(ErrorMetric): - """ - Error metric that record loss - """ - - def __init__( - self, - name: str, - loss_def: LossDefinition, - **kwargs, - ): - super().__init__( - name, - ignore_unlabeld=loss_def.ignore_unlabeled, - **kwargs, - ) - self.loss_def = loss_def - - def update(self, output: AtomGraphData): - loss = self.loss_def.get_loss(output) # type: ignore - self.value.update(loss) # type: ignore - - -class CombinedError(ErrorMetric): - """ - Combine multiple error metrics with weights - corresponds to a weighted sum of errors (normally used in loss) - """ - - def __init__(self, metrics: List[Tuple[ErrorMetric, float]], **kwargs): - super().__init__(**kwargs) - self.metrics = metrics - assert kwargs['unit'] is None - - def update(self, output: AtomGraphData): - for metric, _ in self.metrics: - metric.update(output) - - def reset(self): - for metric, _ in self.metrics: - metric.reset() - - def ddp_reduce(self, device): # override - for metric, _ in self.metrics: - metric.value._ddp_reduce(device) - - def get(self): - val = 0.0 - for metric, weight in self.metrics: - val += metric.get() * weight - return val - - -class ErrorRecorder: - """ - record errors of a model - """ - - METRIC_DICT = { - 'RMSE': RMSError, - 'ComponentRMSE': ComponentRMSError, - 'MAE': MAError, - 'Loss': LossError, - } - - def __init__(self, metrics: List[ErrorMetric]): - self.history = [] - self.metrics = metrics - - def _update(self, output: AtomGraphData): - for metric in self.metrics: - metric.update(output) - - def update(self, output: AtomGraphData, no_grad=True): - if no_grad: - with torch.no_grad(): - self._update(output) - else: - self._update(output) - - def get_metric_dict(self, with_unit=True): - return {metric.key_str(with_unit): metric.get() for metric in self.metrics} - - def get_current(self): - dct = {} - for metric in self.metrics: - dct[metric.name] = { - 'value': metric.get(), - 'unit': metric.unit, - 'ref_key': metric.ref_key, - 'pred_key': metric.pred_key, - } - return dct - - def get_dct(self, prefix=''): - dct = {} - if prefix.endswith('_') is False and prefix != '': - prefix = prefix + '_' - for metric in self.metrics: - dct[f'{prefix}{metric.name}'] = f'{metric.get():6f}' - return dct - - def get_key_str(self, name: str): - for metric in self.metrics: - if name == metric.name: - return metric.key_str() - return None - - def epoch_forward(self): - self.history.append(self.get_current()) - pretty = self.get_metric_dict(with_unit=True) - for metric in self.metrics: - metric.reset() - return pretty # for print - - @staticmethod - def init_total_loss_metric( - config, - criteria: Optional[Callable] = None, - loss_functions: Optional[List[Tuple[LossDefinition, float]]] = None, - ): - if criteria is None and loss_functions is None: - raise ValueError('both criteria and loss functions not given') - - is_stress = config[KEY.IS_TRAIN_STRESS] - metrics = [] - if criteria is not None: - energy_metric = CustomError(criteria, **get_err_type('Energy')) - metrics.append((energy_metric, 1)) - force_metric = CustomError(criteria, **get_err_type('Force')) - metrics.append((force_metric, config[KEY.FORCE_WEIGHT])) - if is_stress: - stress_metric = CustomError(criteria, **get_err_type('Stress')) - metrics.append((stress_metric, config[KEY.STRESS_WEIGHT])) - else: # TODO: this is hard-coded - for efs in ['Energy', 'Force', 'Stress']: - if efs == 'Stress' and not is_stress: - continue - lf, w = _get_loss_function_from_name(loss_functions, efs) - if lf is None: - raise ValueError(f'{efs} not found from loss_functions') - metric = LossError(loss_def=lf, **get_err_type(efs)) - metrics.append((metric, w)) - - total_loss_metric = CombinedError( - metrics, name='TotalLoss', unit=None, ref_key=None, pred_key=None - ) - return total_loss_metric - - @staticmethod - def from_config(config: dict, loss_functions=None): - loss_cls = loss_dict[config.get(KEY.LOSS, 'mse').lower()] - loss_param = config.get(KEY.LOSS_PARAM, {}) - criteria = loss_cls(**loss_param) if loss_functions is None else None - - err_config = config.get(KEY.ERROR_RECORD, False) - if not err_config: - raise ValueError( - 'No error_record config found. Consider util.get_error_recorder' - ) - err_config_n = [] - if not config.get(KEY.IS_TRAIN_STRESS, True): - for err_type, metric_name in err_config: - if 'Stress' in err_type: - continue - err_config_n.append((err_type, metric_name)) - err_config = err_config_n - - err_metrics = [] - for err_type, metric_name in err_config: - metric_kwargs = get_err_type(err_type) - if err_type == 'TotalLoss': # special case - err_metrics.append( - ErrorRecorder.init_total_loss_metric( - config, criteria, loss_functions - ) - ) - continue - metric_cls = ErrorRecorder.METRIC_DICT[metric_name] - assert isinstance(metric_kwargs['name'], str) - if metric_name == 'Loss': - if loss_functions is not None: - metric_cls = LossError - metric_kwargs['loss_def'], _ = _get_loss_function_from_name( - loss_functions, metric_kwargs['name'] - ) - else: - metric_cls = CustomError - metric_kwargs['func'] = criteria - metric_kwargs.pop('unit', None) - metric_kwargs['name'] += f'_{metric_name}' - err_metrics.append(metric_cls(**metric_kwargs)) - return ErrorRecorder(err_metrics) +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist + +import sevenn._keys as KEY +from sevenn.train.loss import LossDefinition + +from .atom_graph_data import AtomGraphData +from .train.optim import loss_dict + +_ERROR_TYPES = { + 'TotalEnergy': { + 'name': 'Energy', + 'ref_key': KEY.ENERGY, + 'pred_key': KEY.PRED_TOTAL_ENERGY, + 'unit': 'eV', + 'vdim': 1, + }, + 'Energy': { # by default per-atom for energy + 'name': 'Energy', + 'ref_key': KEY.ENERGY, + 'pred_key': KEY.PRED_TOTAL_ENERGY, + 'unit': 'eV/atom', + 'per_atom': True, + 'vdim': 1, + }, + 'Force': { + 'name': 'Force', + 'ref_key': KEY.FORCE, + 'pred_key': KEY.PRED_FORCE, + 'unit': 'eV/Å', + 'vdim': 3, + }, + 'Stress': { + 'name': 'Stress', + 'ref_key': KEY.STRESS, + 'pred_key': KEY.PRED_STRESS, + 'unit': 'kbar', + 'coeff': 1602.1766208, + 'vdim': 6, + }, + 'Stress_GPa': { + 'name': 'Stress', + 'ref_key': KEY.STRESS, + 'pred_key': KEY.PRED_STRESS, + 'unit': 'GPa', + 'coeff': 160.21766208, + 'vdim': 6, + }, + 'TotalLoss': { + 'name': 'TotalLoss', + 'unit': None, + }, +} + + +def get_err_type(name: str) -> Dict[str, Any]: + return deepcopy(_ERROR_TYPES[name]) + + +def _get_loss_function_from_name(loss_functions, name): + for loss_def, w in loss_functions: + if loss_def.name.lower() == name.lower(): + return loss_def, w + return None, None + + +class AverageNumber: + def __init__(self): + self._sum = 0.0 + self._count = 0 + + def update(self, values: torch.Tensor): + self._sum += values.sum().item() + self._count += values.numel() + + def _ddp_reduce(self, device): + _sum = torch.tensor(self._sum, device=device) + _count = torch.tensor(self._count, device=device) + dist.all_reduce(_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(_count, op=dist.ReduceOp.SUM) + self._sum = _sum.item() + self._count = _count.item() + + def get(self): + if self._count == 0: + return torch.nan + return self._sum / self._count + + +class ErrorMetric: + """ + Base class for error metrics We always average error by # of structures, + and designed to collect errors in the middle of iteration (by AverageNumber) + """ + + def __init__( + self, + name: str, + ref_key: str, + pred_key: str, + coeff: float = 1.0, + unit: Optional[str] = None, + per_atom: bool = False, + ignore_unlabeled: bool = True, + **kwargs, + ): + self.name = name + self.unit = unit + self.coeff = coeff + self.ref_key = ref_key + self.pred_key = pred_key + self.per_atom = per_atom + self.ignore_unlabeled = ignore_unlabeled + self.value = AverageNumber() + + def update(self, output: AtomGraphData): + raise NotImplementedError + + def _retrieve(self, output: AtomGraphData): + y_ref = output[self.ref_key] * self.coeff + y_pred = output[self.pred_key] * self.coeff + if self.per_atom: + assert y_ref.dim() == 1 and y_pred.dim() == 1 + natoms = output[KEY.NUM_ATOMS] + y_ref = y_ref / natoms + y_pred = y_pred / natoms + if self.ignore_unlabeled: + unlabelled_idx = torch.isnan(y_ref) + y_ref = y_ref[~unlabelled_idx] + y_pred = y_pred[~unlabelled_idx] + return y_ref, y_pred + + def ddp_reduce(self, device): + self.value._ddp_reduce(device) + + def reset(self): + self.value = AverageNumber() + + def get(self): + return self.value.get() + + def key_str(self, with_unit=True): + if self.unit is None or not with_unit: + return self.name + else: + return f'{self.name} ({self.unit})' + + def __str__(self): + return f'{self.key_str()}: {self.value.get():.6f}' + + +class RMSError(ErrorMetric): + """ + Vector squared error + """ + + def __init__(self, vdim: int = 1, **kwargs): + super().__init__(**kwargs) + self.vdim = vdim + self._se = torch.nn.MSELoss(reduction='none') + + def _square_error(self, y_ref, y_pred, vdim: int): + return self._se(y_ref.view(-1, vdim), y_pred.view(-1, vdim)).sum(dim=1) + + def update(self, output: AtomGraphData): + y_ref, y_pred = self._retrieve(output) + se = self._square_error(y_ref, y_pred, self.vdim) + self.value.update(se) + + def get(self): + return self.value.get() ** 0.5 + + +class ComponentRMSError(ErrorMetric): + """ + Ignore vector dim and just average over components + Results smaller error + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._se = torch.nn.MSELoss(reduction='none') + + def _square_error(self, y_ref, y_pred): + return self._se(y_ref, y_pred) + + def update(self, output: AtomGraphData): + y_ref, y_pred = self._retrieve(output) + y_ref = y_ref.view(-1) + y_pred = y_pred.view(-1) + se = self._square_error(y_ref, y_pred) + self.value.update(se) + + def get(self): + return self.value.get() ** 0.5 + + +class MAError(ErrorMetric): + """ + Average over all component + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _square_error(self, y_ref, y_pred): + return torch.abs(y_ref - y_pred) + + def update(self, output: AtomGraphData): + y_ref, y_pred = self._retrieve(output) + y_ref = y_ref.reshape((-1,)) + y_pred = y_pred.reshape((-1,)) + se = self._square_error(y_ref, y_pred) + self.value.update(se) + + +class CustomError(ErrorMetric): + """ + Custom error metric + Args: + func: a function that takes y_ref and y_pred + and returns a list of errors + """ + + def __init__(self, func: Callable, **kwargs): + super().__init__(**kwargs) + self.func = func + + def update(self, output: AtomGraphData): + y_ref, y_pred = self._retrieve(output) + se = self.func(y_ref, y_pred) if len(y_ref) > 0 else torch.tensor([]) + self.value.update(se) + + +class LossError(ErrorMetric): + """ + Error metric that record loss + """ + + def __init__( + self, + name: str, + loss_def: LossDefinition, + **kwargs, + ): + super().__init__( + name, + ignore_unlabeld=loss_def.ignore_unlabeled, + **kwargs, + ) + self.loss_def = loss_def + + def update(self, output: AtomGraphData): + loss = self.loss_def.get_loss(output) # type: ignore + self.value.update(loss) # type: ignore + + +class CombinedError(ErrorMetric): + """ + Combine multiple error metrics with weights + corresponds to a weighted sum of errors (normally used in loss) + """ + + def __init__(self, metrics: List[Tuple[ErrorMetric, float]], **kwargs): + super().__init__(**kwargs) + self.metrics = metrics + assert kwargs['unit'] is None + + def update(self, output: AtomGraphData): + for metric, _ in self.metrics: + metric.update(output) + + def reset(self): + for metric, _ in self.metrics: + metric.reset() + + def ddp_reduce(self, device): # override + for metric, _ in self.metrics: + metric.value._ddp_reduce(device) + + def get(self): + val = 0.0 + for metric, weight in self.metrics: + val += metric.get() * weight + return val + + +class ErrorRecorder: + """ + record errors of a model + """ + + METRIC_DICT = { + 'RMSE': RMSError, + 'ComponentRMSE': ComponentRMSError, + 'MAE': MAError, + 'Loss': LossError, + } + + def __init__(self, metrics: List[ErrorMetric]): + self.history = [] + self.metrics = metrics + + def _update(self, output: AtomGraphData): + for metric in self.metrics: + metric.update(output) + + def update(self, output: AtomGraphData, no_grad=True): + if no_grad: + with torch.no_grad(): + self._update(output) + else: + self._update(output) + + def get_metric_dict(self, with_unit=True): + return {metric.key_str(with_unit): metric.get() for metric in self.metrics} + + def get_current(self): + dct = {} + for metric in self.metrics: + dct[metric.name] = { + 'value': metric.get(), + 'unit': metric.unit, + 'ref_key': metric.ref_key, + 'pred_key': metric.pred_key, + } + return dct + + def get_dct(self, prefix=''): + dct = {} + if prefix.endswith('_') is False and prefix != '': + prefix = prefix + '_' + for metric in self.metrics: + dct[f'{prefix}{metric.name}'] = f'{metric.get():6f}' + return dct + + def get_key_str(self, name: str): + for metric in self.metrics: + if name == metric.name: + return metric.key_str() + return None + + def epoch_forward(self): + self.history.append(self.get_current()) + pretty = self.get_metric_dict(with_unit=True) + for metric in self.metrics: + metric.reset() + return pretty # for print + + @staticmethod + def init_total_loss_metric( + config, + criteria: Optional[Callable] = None, + loss_functions: Optional[List[Tuple[LossDefinition, float]]] = None, + ): + if criteria is None and loss_functions is None: + raise ValueError('both criteria and loss functions not given') + + is_stress = config[KEY.IS_TRAIN_STRESS] + metrics = [] + if criteria is not None: + energy_metric = CustomError(criteria, **get_err_type('Energy')) + metrics.append((energy_metric, 1)) + force_metric = CustomError(criteria, **get_err_type('Force')) + metrics.append((force_metric, config[KEY.FORCE_WEIGHT])) + if is_stress: + stress_metric = CustomError(criteria, **get_err_type('Stress')) + metrics.append((stress_metric, config[KEY.STRESS_WEIGHT])) + else: # TODO: this is hard-coded + for efs in ['Energy', 'Force', 'Stress']: + if efs == 'Stress' and not is_stress: + continue + lf, w = _get_loss_function_from_name(loss_functions, efs) + if lf is None: + raise ValueError(f'{efs} not found from loss_functions') + metric = LossError(loss_def=lf, **get_err_type(efs)) + metrics.append((metric, w)) + + total_loss_metric = CombinedError( + metrics, name='TotalLoss', unit=None, ref_key=None, pred_key=None + ) + return total_loss_metric + + @staticmethod + def from_config(config: dict, loss_functions=None): + loss_cls = loss_dict[config.get(KEY.LOSS, 'mse').lower()] + loss_param = config.get(KEY.LOSS_PARAM, {}) + criteria = loss_cls(**loss_param) if loss_functions is None else None + + err_config = config.get(KEY.ERROR_RECORD, False) + if not err_config: + raise ValueError( + 'No error_record config found. Consider util.get_error_recorder' + ) + err_config_n = [] + if not config.get(KEY.IS_TRAIN_STRESS, True): + for err_type, metric_name in err_config: + if 'Stress' in err_type: + continue + err_config_n.append((err_type, metric_name)) + err_config = err_config_n + + err_metrics = [] + for err_type, metric_name in err_config: + metric_kwargs = get_err_type(err_type) + if err_type == 'TotalLoss': # special case + err_metrics.append( + ErrorRecorder.init_total_loss_metric( + config, criteria, loss_functions + ) + ) + continue + metric_cls = ErrorRecorder.METRIC_DICT[metric_name] + assert isinstance(metric_kwargs['name'], str) + if metric_name == 'Loss': + if loss_functions is not None: + metric_cls = LossError + metric_kwargs['loss_def'], _ = _get_loss_function_from_name( + loss_functions, metric_kwargs['name'] + ) + else: + metric_cls = CustomError + metric_kwargs['func'] = criteria + metric_kwargs.pop('unit', None) + metric_kwargs['name'] += f'_{metric_name}' + err_metrics.append(metric_cls(**metric_kwargs)) + return ErrorRecorder(err_metrics) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/logger.py b/mace-bench/3rdparty/SevenNet/sevenn/logger.py index 1db11c0b98e7392b7c858822c80612e3792a61bc..1c66af8c8e5a47505d4ffbff4fdf1a63cdcd7b7c 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/logger.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/logger.py @@ -1,336 +1,336 @@ -import os -import time -import traceback -from datetime import datetime -from typing import Any, Dict, List, Optional - -from ase.data import atomic_numbers - -import sevenn._keys as KEY -from sevenn import __version__ - -CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()} - - -class Singleton(type): - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] - - -class Logger(metaclass=Singleton): - SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output - - def __init__( - self, filename: Optional[str] = None, screen: bool = False, rank: int = 0 - ): - self.rank = rank - self._filename = filename - if rank == 0: - # if filename is not None: - # self.logfile = open(filename, 'a', buffering=1) - self.logfile = None - self.files = {} - self.screen = screen - else: - self.logfile = None - self.screen = False - self.timer_dct = {} - self.active = True - - def __enter__(self): - if self.rank != 0: - return self - if self.logfile is None and self._filename is not None: - try: - self.logfile = open( - self._filename, 'a', buffering=1, encoding='utf-8' - ) - except IOError as e: - print(f'Failed to re-open log file {self._filename}: {e}') - self.logfile = None - self.files = {} - return self - - def __exit__(self, exc_type, exc_value, traceback): - if self.rank != 0: - return self - try: - if self.logfile is not None: - self.logfile.close() - self.logfile = None - for f in self.files.values(): - f.close() - except IOError as e: - print(f'Failed to close log files: {e}') - finally: - self.logfile = None - self.files = {} - - def switch_file(self, new_filename: str): - if self.rank != 0: - return self - if self.logfile is not None: - raise ValueError('Current logfile is not yet closed') - self._filename = new_filename - return self - - def write(self, content: str): - if self.rank != 0: - return - # no newline! - if self.logfile is not None and self.active: - self.logfile.write(content) - if self.screen and self.active: - print(content, end='') - - def writeline(self, content: str): - content = content + '\n' - self.write(content) - - def init_csv(self, filename: str, header: list): - """ - Deprecated - """ - if self.rank == 0: - self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8') - self.files[filename].write(','.join(header) + '\n') - else: - pass - - def append_csv(self, filename: str, content: list, decimal: int = 6): - """ - Deprecated - """ - if self.rank == 0: - if filename not in self.files: - self.files[filename] = open(filename, 'a', buffering=1) - str_content = [] - for c in content: - if isinstance(c, float): - str_content.append(f'{c:.{decimal}f}') - else: - str_content.append(str(c)) - self.files[filename].write(','.join(str_content) + '\n') - else: - pass - - def natoms_write(self, natoms: Dict[str, Dict]): - content = '' - total_natom = {} - for label, natom in natoms.items(): - content += self.format_k_v(label, natom) - for specie, num in natom.items(): - try: - total_natom[specie] += num - except KeyError: - total_natom[specie] = num - content += self.format_k_v('Total, label wise', total_natom) - content += self.format_k_v('Total', sum(total_natom.values())) - self.write(content) - - def statistic_write(self, statistic: Dict[str, Dict]): - content = '' - for label, dct in statistic.items(): - if label.startswith('_'): - continue - if not isinstance(dct, dict): - continue - dct_new = {} - for k, v in dct.items(): - if k.startswith('_'): - continue - if isinstance(v, int): - dct_new[k] = v - else: - dct_new[k] = f'{v:.3f}' - content += self.format_k_v(label, dct_new) - self.write(content) - - # TODO : refactoring!!!, this is not loss, rmse - def epoch_write_specie_wise_loss(self, train_loss, valid_loss): - lb_pad = 21 - fs = 6 - pad = 21 - fs - ln = '-' * fs - total_atom_type = train_loss.keys() - content = '' - - for at in total_atom_type: - t_F = train_loss[at] - v_F = valid_loss[at] - at_sym = CHEM_SYMBOLS[at] - content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format( - label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs - ) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format( - t_F=t_F, v_F=v_F, pad=pad, fs=fs - ) - content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format( - t_S=ln, v_S=ln, pad=pad, fs=fs - ) - content += '\n' - self.write(content) - - def write_full_table( - self, - dict_list: List[Dict], - row_labels: List[str], - decimal_places: int = 6, - pad: int = 2, - ): - """ - Assume data_list is list of dict with same keys - """ - assert len(dict_list) == len(row_labels) - label_len = max(map(len, row_labels)) - # Extract the column names and create a 2D array of values - col_names = list(dict_list[0].keys()) - - values = [list(d.values()) for d in dict_list] - - # Format the numbers with the given decimal places - formatted_values = [ - [f'{value:.{decimal_places}f}' for value in row] for row in values - ] - - # Calculate padding lengths for each column (with extra padding) - max_col_lengths = [ - max(len(str(value)) for value in col) + pad - for col in zip(col_names, *formatted_values) - ] - - # Create header row and separator - header = ' ' * (label_len + pad) + ' '.join( - col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths) - ) - separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * ( - label_len + pad - ) - - # Print header and separator - self.writeline(header) - self.writeline(separator) - - # Print the data rows with row labels - for row_label, row in zip(row_labels, formatted_values): - data_row = ' '.join( - value.rjust(pad) for value, pad in zip(row, max_col_lengths) - ) - self.writeline(f'{row_label.ljust(label_len)}{data_row}') - - def format_k_v(self, key: Any, val: Any, write: bool = False): - """ - key and val should be str convertible - """ - MAX_KEY_SIZE = 20 - SEPARATOR = ', ' - EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3) - NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5 - key = str(key) - val = str(val) - content = f'{key:<{MAX_KEY_SIZE}}: {val}' - if len(content) > NEW_LINE_LEN: - content = f'{key:<{MAX_KEY_SIZE}}: ' - # septate val by separator - val_list = val.split(SEPARATOR) - current_len = len(content) - for val_compo in val_list: - current_len += len(val_compo) - if current_len > NEW_LINE_LEN: - newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}' - content += f'\\\n{newline_content}' - current_len = len(newline_content) - else: - content += f'{val_compo}{SEPARATOR}' - - if content.endswith(f'{SEPARATOR}'): - content = content[: -len(SEPARATOR)] - content += '\n' - - if write is False: - return content - else: - self.write(content) - return '' - - def greeting(self): - LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii' - with open(LOGO_ASCII_FILE, 'r') as logo_f: - logo_ascii = logo_f.read() - content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n' - content += f'version {__version__}, {time.ctime()}\n' - self.write(content) - self.write(logo_ascii) - - def bar(self): - content = '-' * Logger.SCREEN_WIDTH + '\n' - self.write(content) - - def print_config( - self, - model_config: Dict[str, Any], - data_config: Dict[str, Any], - train_config: Dict[str, Any], - ): - """ - print some important information from config - """ - content = 'successfully read yaml config!\n\n' + 'from model configuration\n' - for k, v in model_config.items(): - content += self.format_k_v(k, str(v)) - content += '\nfrom train configuration\n' - for k, v in train_config.items(): - content += self.format_k_v(k, str(v)) - content += '\nfrom data configuration\n' - for k, v in data_config.items(): - content += self.format_k_v(k, str(v)) - self.write(content) - - # TODO: This is not good make own exception - def error(self, e: Exception): - content = '' - if type(e) is ValueError: - content += 'Error occurred!\n' - content += str(e) + '\n' - else: - content += 'Unknown error occurred!\n' - content += traceback.format_exc() - self.write(content) - - def timer_start(self, name: str): - self.timer_dct[name] = datetime.now() - - def timer_end(self, name: str, message: str, remove: bool = True): - """ - print f"{message}: {elapsed}" - """ - elapsed = str(datetime.now() - self.timer_dct[name]) - # elapsed = elapsed.strftime('%H-%M-%S') - if remove: - del self.timer_dct[name] - self.write(f'{message}: {elapsed[:-4]}\n') - - # TODO: print it without config - # TODO: refactoring, readout part name :( - def print_model_info(self, model, config): - from functools import partial - - kv_write = partial(self.format_k_v, write=True) - self.writeline('Irreps of features') - kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out')) - for i in range(config[KEY.NUM_CONVOLUTION]): - kv_write( - f'{i}th node', - model.get_irreps_in(f'{i}_self_interaction_1'), - ) - i = config[KEY.NUM_CONVOLUTION] - 1 - kv_write( - 'readout irreps', - model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'), - ) - - num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad) - self.writeline(f'# learnable parameters: {num_weights}\n') +import os +import time +import traceback +from datetime import datetime +from typing import Any, Dict, List, Optional + +from ase.data import atomic_numbers + +import sevenn._keys as KEY +from sevenn import __version__ + +CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()} + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class Logger(metaclass=Singleton): + SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output + + def __init__( + self, filename: Optional[str] = None, screen: bool = False, rank: int = 0 + ): + self.rank = rank + self._filename = filename + if rank == 0: + # if filename is not None: + # self.logfile = open(filename, 'a', buffering=1) + self.logfile = None + self.files = {} + self.screen = screen + else: + self.logfile = None + self.screen = False + self.timer_dct = {} + self.active = True + + def __enter__(self): + if self.rank != 0: + return self + if self.logfile is None and self._filename is not None: + try: + self.logfile = open( + self._filename, 'a', buffering=1, encoding='utf-8' + ) + except IOError as e: + print(f'Failed to re-open log file {self._filename}: {e}') + self.logfile = None + self.files = {} + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.rank != 0: + return self + try: + if self.logfile is not None: + self.logfile.close() + self.logfile = None + for f in self.files.values(): + f.close() + except IOError as e: + print(f'Failed to close log files: {e}') + finally: + self.logfile = None + self.files = {} + + def switch_file(self, new_filename: str): + if self.rank != 0: + return self + if self.logfile is not None: + raise ValueError('Current logfile is not yet closed') + self._filename = new_filename + return self + + def write(self, content: str): + if self.rank != 0: + return + # no newline! + if self.logfile is not None and self.active: + self.logfile.write(content) + if self.screen and self.active: + print(content, end='') + + def writeline(self, content: str): + content = content + '\n' + self.write(content) + + def init_csv(self, filename: str, header: list): + """ + Deprecated + """ + if self.rank == 0: + self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8') + self.files[filename].write(','.join(header) + '\n') + else: + pass + + def append_csv(self, filename: str, content: list, decimal: int = 6): + """ + Deprecated + """ + if self.rank == 0: + if filename not in self.files: + self.files[filename] = open(filename, 'a', buffering=1) + str_content = [] + for c in content: + if isinstance(c, float): + str_content.append(f'{c:.{decimal}f}') + else: + str_content.append(str(c)) + self.files[filename].write(','.join(str_content) + '\n') + else: + pass + + def natoms_write(self, natoms: Dict[str, Dict]): + content = '' + total_natom = {} + for label, natom in natoms.items(): + content += self.format_k_v(label, natom) + for specie, num in natom.items(): + try: + total_natom[specie] += num + except KeyError: + total_natom[specie] = num + content += self.format_k_v('Total, label wise', total_natom) + content += self.format_k_v('Total', sum(total_natom.values())) + self.write(content) + + def statistic_write(self, statistic: Dict[str, Dict]): + content = '' + for label, dct in statistic.items(): + if label.startswith('_'): + continue + if not isinstance(dct, dict): + continue + dct_new = {} + for k, v in dct.items(): + if k.startswith('_'): + continue + if isinstance(v, int): + dct_new[k] = v + else: + dct_new[k] = f'{v:.3f}' + content += self.format_k_v(label, dct_new) + self.write(content) + + # TODO : refactoring!!!, this is not loss, rmse + def epoch_write_specie_wise_loss(self, train_loss, valid_loss): + lb_pad = 21 + fs = 6 + pad = 21 - fs + ln = '-' * fs + total_atom_type = train_loss.keys() + content = '' + + for at in total_atom_type: + t_F = train_loss[at] + v_F = valid_loss[at] + at_sym = CHEM_SYMBOLS[at] + content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format( + label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs + ) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format( + t_F=t_F, v_F=v_F, pad=pad, fs=fs + ) + content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format( + t_S=ln, v_S=ln, pad=pad, fs=fs + ) + content += '\n' + self.write(content) + + def write_full_table( + self, + dict_list: List[Dict], + row_labels: List[str], + decimal_places: int = 6, + pad: int = 2, + ): + """ + Assume data_list is list of dict with same keys + """ + assert len(dict_list) == len(row_labels) + label_len = max(map(len, row_labels)) + # Extract the column names and create a 2D array of values + col_names = list(dict_list[0].keys()) + + values = [list(d.values()) for d in dict_list] + + # Format the numbers with the given decimal places + formatted_values = [ + [f'{value:.{decimal_places}f}' for value in row] for row in values + ] + + # Calculate padding lengths for each column (with extra padding) + max_col_lengths = [ + max(len(str(value)) for value in col) + pad + for col in zip(col_names, *formatted_values) + ] + + # Create header row and separator + header = ' ' * (label_len + pad) + ' '.join( + col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths) + ) + separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * ( + label_len + pad + ) + + # Print header and separator + self.writeline(header) + self.writeline(separator) + + # Print the data rows with row labels + for row_label, row in zip(row_labels, formatted_values): + data_row = ' '.join( + value.rjust(pad) for value, pad in zip(row, max_col_lengths) + ) + self.writeline(f'{row_label.ljust(label_len)}{data_row}') + + def format_k_v(self, key: Any, val: Any, write: bool = False): + """ + key and val should be str convertible + """ + MAX_KEY_SIZE = 20 + SEPARATOR = ', ' + EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3) + NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5 + key = str(key) + val = str(val) + content = f'{key:<{MAX_KEY_SIZE}}: {val}' + if len(content) > NEW_LINE_LEN: + content = f'{key:<{MAX_KEY_SIZE}}: ' + # septate val by separator + val_list = val.split(SEPARATOR) + current_len = len(content) + for val_compo in val_list: + current_len += len(val_compo) + if current_len > NEW_LINE_LEN: + newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}' + content += f'\\\n{newline_content}' + current_len = len(newline_content) + else: + content += f'{val_compo}{SEPARATOR}' + + if content.endswith(f'{SEPARATOR}'): + content = content[: -len(SEPARATOR)] + content += '\n' + + if write is False: + return content + else: + self.write(content) + return '' + + def greeting(self): + LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii' + with open(LOGO_ASCII_FILE, 'r') as logo_f: + logo_ascii = logo_f.read() + content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n' + content += f'version {__version__}, {time.ctime()}\n' + self.write(content) + self.write(logo_ascii) + + def bar(self): + content = '-' * Logger.SCREEN_WIDTH + '\n' + self.write(content) + + def print_config( + self, + model_config: Dict[str, Any], + data_config: Dict[str, Any], + train_config: Dict[str, Any], + ): + """ + print some important information from config + """ + content = 'successfully read yaml config!\n\n' + 'from model configuration\n' + for k, v in model_config.items(): + content += self.format_k_v(k, str(v)) + content += '\nfrom train configuration\n' + for k, v in train_config.items(): + content += self.format_k_v(k, str(v)) + content += '\nfrom data configuration\n' + for k, v in data_config.items(): + content += self.format_k_v(k, str(v)) + self.write(content) + + # TODO: This is not good make own exception + def error(self, e: Exception): + content = '' + if type(e) is ValueError: + content += 'Error occurred!\n' + content += str(e) + '\n' + else: + content += 'Unknown error occurred!\n' + content += traceback.format_exc() + self.write(content) + + def timer_start(self, name: str): + self.timer_dct[name] = datetime.now() + + def timer_end(self, name: str, message: str, remove: bool = True): + """ + print f"{message}: {elapsed}" + """ + elapsed = str(datetime.now() - self.timer_dct[name]) + # elapsed = elapsed.strftime('%H-%M-%S') + if remove: + del self.timer_dct[name] + self.write(f'{message}: {elapsed[:-4]}\n') + + # TODO: print it without config + # TODO: refactoring, readout part name :( + def print_model_info(self, model, config): + from functools import partial + + kv_write = partial(self.format_k_v, write=True) + self.writeline('Irreps of features') + kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out')) + for i in range(config[KEY.NUM_CONVOLUTION]): + kv_write( + f'{i}th node', + model.get_irreps_in(f'{i}_self_interaction_1'), + ) + i = config[KEY.NUM_CONVOLUTION] - 1 + kv_write( + 'readout irreps', + model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'), + ) + + num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad) + self.writeline(f'# learnable parameters: {num_weights}\n') diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn.py index a944add94f7a9d13d5656c4d28fbbb2138beb74f..d96a1bf8e316464def20b894bc8c689e158541b6 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn.py @@ -1,248 +1,248 @@ -import argparse -import os -import sys -import time - -from sevenn import __version__ - -description = 'train a model given the input.yaml' - -input_yaml_help = 'input.yaml for training' -mode_help = 'main training script to run. Default is train.' -working_dir_help = 'path to write output. Default is cwd.' -screen_help = 'print log to stdout' -distributed_help = 'set this flag if it is distributed training' -distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi' - -# Metainfo will be saved to checkpoint -global_config = { - 'version': __version__, - 'when': time.ctime(), - '_model_type': 'E3_equivariant_model', -} - - -def run(args): - """ - main function of sevenn - """ - import random - import sys - - import torch - import torch.distributed as dist - - import sevenn._keys as KEY - from sevenn.logger import Logger - from sevenn.parse_input import read_config_yaml - from sevenn.scripts.train import train, train_v2 - from sevenn.util import unique_filepath - - input_yaml = args.input_yaml - mode = args.mode - working_dir = args.working_dir - log = args.log - screen = args.screen - distributed = args.distributed - distributed_backend = args.distributed_backend - use_cue = args.enable_cueq - - if use_cue: - import sevenn.nn.cue_helper - - if not sevenn.nn.cue_helper.is_cue_available(): - raise ImportError('cuEquivariance not installed.') - - if working_dir is None: - working_dir = os.getcwd() - elif not os.path.isdir(working_dir): - os.makedirs(working_dir, exist_ok=True) - - world_size = 1 - if distributed: - if distributed_backend == 'nccl': - local_rank = int(os.environ['LOCAL_RANK']) - rank = int(os.environ['RANK']) - world_size = int(os.environ['WORLD_SIZE']) - elif distributed_backend == 'mpi': - local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) - rank = int(os.environ['OMPI_COMM_WORLD_RANK']) - world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) - else: - raise ValueError(f'Unknown distributed backend: {distributed_backend}') - - dist.init_process_group( - backend=distributed_backend, world_size=world_size, rank=rank - ) - else: - local_rank, rank, world_size = 0, 0, 1 - - log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}') - with Logger(filename=log_fname, screen=screen, rank=rank) as logger: - logger.greeting() - - if distributed: - logger.writeline( - f'Distributed training enabled, total world size is {world_size}' - ) - - try: - model_config, train_config, data_config = read_config_yaml( - input_yaml, return_separately=True - ) - except Exception as e: - logger.writeline('Failed to parsing input.yaml') - logger.error(e) - sys.exit(1) - - train_config[KEY.IS_DDP] = distributed - train_config[KEY.DDP_BACKEND] = distributed_backend - train_config[KEY.LOCAL_RANK] = local_rank - train_config[KEY.RANK] = rank - train_config[KEY.WORLD_SIZE] = world_size - - if distributed: - torch.cuda.set_device(torch.device('cuda', local_rank)) - - if use_cue: - if KEY.CUEQUIVARIANCE_CONFIG not in model_config: - model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True} - else: - model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True}) - - logger.print_config(model_config, data_config, train_config) - # don't have to distinguish configs inside program - global_config.update(model_config) - global_config.update(train_config) - global_config.update(data_config) - - # Not implemented - if global_config[KEY.DTYPE] == 'double': - raise Exception('double precision is not implemented yet') - # torch.set_default_dtype(torch.double) - - seed = global_config[KEY.RANDOM_SEED] - random.seed(seed) - torch.manual_seed(seed) - - # run train - if mode == 'train_v1': - train(global_config, working_dir) - elif mode == 'train_v2': - train_v2(global_config, working_dir) - - -def cmd_parser_train(parser): - ag = parser - ag.add_argument('input_yaml', help=input_yaml_help, type=str) - ag.add_argument( - '-m', - '--mode', - choices=['train_v1', 'train_v2'], - default='train_v2', - help=mode_help, - type=str, - ) - ag.add_argument( - '-cueq', - '--enable_cueq', - help='(Not stable!) use cuEquivariance for training', - action='store_true', - ) - ag.add_argument( - '-w', - '--working_dir', - nargs='?', - const=os.getcwd(), - help=working_dir_help, - type=str, - ) - ag.add_argument( - '-l', - '--log', - default='log.sevenn', - help='name of logfile, default is log.sevenn', - type=str, - ) - ag.add_argument('-s', '--screen', help=screen_help, action='store_true') - ag.add_argument( - '-d', '--distributed', help=distributed_help, action='store_true' - ) - ag.add_argument( - '--distributed_backend', - help=distributed_backend_help, - type=str, - default='nccl', - choices=['nccl', 'mpi'], - ) - - -def add_parser(subparsers): - ag = subparsers.add_parser('train', help=description) - cmd_parser_train(ag) - - -def set_default_subparser(self, name, args=None, positional_args=0): - """default subparser selection. Call after setup, just before parse_args() - name: is the name of the subparser to call by default - args: if set is the argument list handed to parse_args() - - Hack copied from stack overflow - """ - subparser_found = False - for arg in sys.argv[1:]: - if arg in ['-h', '--help']: # global help if no subparser - break - else: - for x in self._subparsers._actions: - if not isinstance(x, argparse._SubParsersAction): - continue - for sp_name in x._name_parser_map.keys(): - if sp_name in sys.argv[1:]: - subparser_found = True - if not subparser_found: - # insert default in last position before global positional - # arguments, this implies no global options are specified after - # first positional argument - if args is None: - sys.argv.insert(len(sys.argv) - positional_args, name) - else: - args.insert(len(args) - positional_args, name) - - -argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore - - -def main(): - import sevenn.main.sevenn_cp as checkpoint_cmd - import sevenn.main.sevenn_get_model as get_model_cmd - import sevenn.main.sevenn_graph_build as graph_build_cmd - import sevenn.main.sevenn_inference as inference_cmd - import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd - import sevenn.main.sevenn_preset as preset_cmd - - ag = argparse.ArgumentParser(f'SevenNet version={__version__}') - - subparsers = ag.add_subparsers(dest='command', help='Sub-commands') - add_parser(subparsers) # add 'train' - checkpoint_cmd.add_parser(subparsers) - inference_cmd.add_parser(subparsers) - graph_build_cmd.add_parser(subparsers) - preset_cmd.add_parser(subparsers) - get_model_cmd.add_parser(subparsers) - patch_lammps_cmd.add_parser(subparsers) - - ag.set_default_subparser('train') # type: ignore - args = ag.parse_args() - - if args.command is None: # backward compatibility - args.command = 'train' - - if args.command == 'train': - run(args) - elif args.command == 'preset': - preset_cmd.run(args) - - -if __name__ == '__main__': - main() +import argparse +import os +import sys +import time + +from sevenn import __version__ + +description = 'train a model given the input.yaml' + +input_yaml_help = 'input.yaml for training' +mode_help = 'main training script to run. Default is train.' +working_dir_help = 'path to write output. Default is cwd.' +screen_help = 'print log to stdout' +distributed_help = 'set this flag if it is distributed training' +distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi' + +# Metainfo will be saved to checkpoint +global_config = { + 'version': __version__, + 'when': time.ctime(), + '_model_type': 'E3_equivariant_model', +} + + +def run(args): + """ + main function of sevenn + """ + import random + import sys + + import torch + import torch.distributed as dist + + import sevenn._keys as KEY + from sevenn.logger import Logger + from sevenn.parse_input import read_config_yaml + from sevenn.scripts.train import train, train_v2 + from sevenn.util import unique_filepath + + input_yaml = args.input_yaml + mode = args.mode + working_dir = args.working_dir + log = args.log + screen = args.screen + distributed = args.distributed + distributed_backend = args.distributed_backend + use_cue = args.enable_cueq + + if use_cue: + import sevenn.nn.cue_helper + + if not sevenn.nn.cue_helper.is_cue_available(): + raise ImportError('cuEquivariance not installed.') + + if working_dir is None: + working_dir = os.getcwd() + elif not os.path.isdir(working_dir): + os.makedirs(working_dir, exist_ok=True) + + world_size = 1 + if distributed: + if distributed_backend == 'nccl': + local_rank = int(os.environ['LOCAL_RANK']) + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + elif distributed_backend == 'mpi': + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + else: + raise ValueError(f'Unknown distributed backend: {distributed_backend}') + + dist.init_process_group( + backend=distributed_backend, world_size=world_size, rank=rank + ) + else: + local_rank, rank, world_size = 0, 0, 1 + + log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}') + with Logger(filename=log_fname, screen=screen, rank=rank) as logger: + logger.greeting() + + if distributed: + logger.writeline( + f'Distributed training enabled, total world size is {world_size}' + ) + + try: + model_config, train_config, data_config = read_config_yaml( + input_yaml, return_separately=True + ) + except Exception as e: + logger.writeline('Failed to parsing input.yaml') + logger.error(e) + sys.exit(1) + + train_config[KEY.IS_DDP] = distributed + train_config[KEY.DDP_BACKEND] = distributed_backend + train_config[KEY.LOCAL_RANK] = local_rank + train_config[KEY.RANK] = rank + train_config[KEY.WORLD_SIZE] = world_size + + if distributed: + torch.cuda.set_device(torch.device('cuda', local_rank)) + + if use_cue: + if KEY.CUEQUIVARIANCE_CONFIG not in model_config: + model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True} + else: + model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True}) + + logger.print_config(model_config, data_config, train_config) + # don't have to distinguish configs inside program + global_config.update(model_config) + global_config.update(train_config) + global_config.update(data_config) + + # Not implemented + if global_config[KEY.DTYPE] == 'double': + raise Exception('double precision is not implemented yet') + # torch.set_default_dtype(torch.double) + + seed = global_config[KEY.RANDOM_SEED] + random.seed(seed) + torch.manual_seed(seed) + + # run train + if mode == 'train_v1': + train(global_config, working_dir) + elif mode == 'train_v2': + train_v2(global_config, working_dir) + + +def cmd_parser_train(parser): + ag = parser + ag.add_argument('input_yaml', help=input_yaml_help, type=str) + ag.add_argument( + '-m', + '--mode', + choices=['train_v1', 'train_v2'], + default='train_v2', + help=mode_help, + type=str, + ) + ag.add_argument( + '-cueq', + '--enable_cueq', + help='(Not stable!) use cuEquivariance for training', + action='store_true', + ) + ag.add_argument( + '-w', + '--working_dir', + nargs='?', + const=os.getcwd(), + help=working_dir_help, + type=str, + ) + ag.add_argument( + '-l', + '--log', + default='log.sevenn', + help='name of logfile, default is log.sevenn', + type=str, + ) + ag.add_argument('-s', '--screen', help=screen_help, action='store_true') + ag.add_argument( + '-d', '--distributed', help=distributed_help, action='store_true' + ) + ag.add_argument( + '--distributed_backend', + help=distributed_backend_help, + type=str, + default='nccl', + choices=['nccl', 'mpi'], + ) + + +def add_parser(subparsers): + ag = subparsers.add_parser('train', help=description) + cmd_parser_train(ag) + + +def set_default_subparser(self, name, args=None, positional_args=0): + """default subparser selection. Call after setup, just before parse_args() + name: is the name of the subparser to call by default + args: if set is the argument list handed to parse_args() + + Hack copied from stack overflow + """ + subparser_found = False + for arg in sys.argv[1:]: + if arg in ['-h', '--help']: # global help if no subparser + break + else: + for x in self._subparsers._actions: + if not isinstance(x, argparse._SubParsersAction): + continue + for sp_name in x._name_parser_map.keys(): + if sp_name in sys.argv[1:]: + subparser_found = True + if not subparser_found: + # insert default in last position before global positional + # arguments, this implies no global options are specified after + # first positional argument + if args is None: + sys.argv.insert(len(sys.argv) - positional_args, name) + else: + args.insert(len(args) - positional_args, name) + + +argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore + + +def main(): + import sevenn.main.sevenn_cp as checkpoint_cmd + import sevenn.main.sevenn_get_model as get_model_cmd + import sevenn.main.sevenn_graph_build as graph_build_cmd + import sevenn.main.sevenn_inference as inference_cmd + import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd + import sevenn.main.sevenn_preset as preset_cmd + + ag = argparse.ArgumentParser(f'SevenNet version={__version__}') + + subparsers = ag.add_subparsers(dest='command', help='Sub-commands') + add_parser(subparsers) # add 'train' + checkpoint_cmd.add_parser(subparsers) + inference_cmd.add_parser(subparsers) + graph_build_cmd.add_parser(subparsers) + preset_cmd.add_parser(subparsers) + get_model_cmd.add_parser(subparsers) + patch_lammps_cmd.add_parser(subparsers) + + ag.set_default_subparser('train') # type: ignore + args = ag.parse_args() + + if args.command is None: # backward compatibility + args.command = 'train' + + if args.command == 'train': + run(args) + elif args.command == 'preset': + preset_cmd.run(args) + + +if __name__ == '__main__': + main() diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_cp.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_cp.py index 29edf5be04f3638a148932d18932b4235d78a27a..319cb1a4b93507dd211ded4c888750a9c04f6570 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_cp.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_cp.py @@ -1,92 +1,92 @@ -import argparse -import os.path as osp - -from sevenn import __version__ - -description = ( - 'tool box for sevennet checkpoints' -) - - -def add_parser(subparsers): - ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp']) - add_args(ag) - - -def add_args(parser): - ag = parser - - ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str) - - group = ag.add_mutually_exclusive_group(required=False) - group.add_argument( - '--get_yaml', - choices=['reproduce', 'continue', 'continue_modal'], - help='create input.yaml based on the given checkpoint', - type=str, - ) - - group.add_argument( - '--append_modal_yaml', - help='append modality with given yaml.', - type=str, - ) - ag.add_argument( - '--original_modal_name', - help=( - 'when the append_modal is used and checkpoint is not multi-modal, ' - + 'used to name previously trained modality. defaults to "origin"' - ), - default='origin', - type=str, - ) - - -def run(args): - import torch - import yaml - - from sevenn.parse_input import read_config_yaml - from sevenn.util import load_checkpoint - - checkpoint = load_checkpoint(args.checkpoint) - if args.get_yaml: - mode = args.get_yaml - cfg = checkpoint.yaml_dict(mode) - print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False)) - elif args.append_modal_yaml: - dst_yaml = args.append_modal_yaml - if not osp.exists(dst_yaml): - raise FileNotFoundError(f'No yaml file {dst_yaml}') - - dst_config = read_config_yaml(dst_yaml, return_separately=False) - model_state_dict = checkpoint.append_modal( - dst_config, args.original_modal_name - ) - - to_save = checkpoint.get_checkpoint_dict() - to_save.update({'config': dst_config, 'model_state_dict': model_state_dict}) - - torch.save(to_save, 'checkpoint_modal_appended.pth') - print('checkpoint_modal_appended.pth is successfully saved.') - print(f'update continue of {dst_yaml} as blow (recommend) to continue') - cont_dct = { - 'continue': { - 'checkpoint': 'checkpoint_modal_appended.pth', - 'reset_epoch': True, - 'reset_optimizer': True, - 'reset_scheduler': True, - } - } - print( - yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False) - ) - - else: - print(checkpoint) - - -def main(args=None): - ag = argparse.ArgumentParser(description=description) - add_args(ag) - run(ag.parse_args()) +import argparse +import os.path as osp + +from sevenn import __version__ + +description = ( + 'tool box for sevennet checkpoints' +) + + +def add_parser(subparsers): + ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp']) + add_args(ag) + + +def add_args(parser): + ag = parser + + ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str) + + group = ag.add_mutually_exclusive_group(required=False) + group.add_argument( + '--get_yaml', + choices=['reproduce', 'continue', 'continue_modal'], + help='create input.yaml based on the given checkpoint', + type=str, + ) + + group.add_argument( + '--append_modal_yaml', + help='append modality with given yaml.', + type=str, + ) + ag.add_argument( + '--original_modal_name', + help=( + 'when the append_modal is used and checkpoint is not multi-modal, ' + + 'used to name previously trained modality. defaults to "origin"' + ), + default='origin', + type=str, + ) + + +def run(args): + import torch + import yaml + + from sevenn.parse_input import read_config_yaml + from sevenn.util import load_checkpoint + + checkpoint = load_checkpoint(args.checkpoint) + if args.get_yaml: + mode = args.get_yaml + cfg = checkpoint.yaml_dict(mode) + print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False)) + elif args.append_modal_yaml: + dst_yaml = args.append_modal_yaml + if not osp.exists(dst_yaml): + raise FileNotFoundError(f'No yaml file {dst_yaml}') + + dst_config = read_config_yaml(dst_yaml, return_separately=False) + model_state_dict = checkpoint.append_modal( + dst_config, args.original_modal_name + ) + + to_save = checkpoint.get_checkpoint_dict() + to_save.update({'config': dst_config, 'model_state_dict': model_state_dict}) + + torch.save(to_save, 'checkpoint_modal_appended.pth') + print('checkpoint_modal_appended.pth is successfully saved.') + print(f'update continue of {dst_yaml} as blow (recommend) to continue') + cont_dct = { + 'continue': { + 'checkpoint': 'checkpoint_modal_appended.pth', + 'reset_epoch': True, + 'reset_optimizer': True, + 'reset_scheduler': True, + } + } + print( + yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False) + ) + + else: + print(checkpoint) + + +def main(args=None): + ag = argparse.ArgumentParser(description=description) + add_args(ag) + run(ag.parse_args()) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_get_model.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_get_model.py index f2b7b0af01a55711f387ef2dfcef2f0f65d70103..f8b78e88f98d4dc9d92d686f77037c2f2ac57f18 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_get_model.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_get_model.py @@ -1,70 +1,70 @@ -import argparse -import os - -from sevenn import __version__ - -description_get_model = ( - 'deploy LAMMPS model from the checkpoint' -) -checkpoint_help = ( - 'path to the checkpoint | SevenNet-0 | 7net-0 |' - ' {SevenNet-0|7net-0}_{11July2024|22May2024}' -) -output_name_help = 'filename prefix' -get_parallel_help = 'deploy parallel model' - - -def add_parser(subparsers): - ag = subparsers.add_parser( - 'get_model', help=description_get_model, aliases=['deploy'] - ) - add_args(ag) - - -def add_args(parser): - ag = parser - ag.add_argument('checkpoint', help=checkpoint_help, type=str) - ag.add_argument( - '-o', '--output_prefix', nargs='?', help=output_name_help, type=str - ) - ag.add_argument( - '-p', '--get_parallel', help=get_parallel_help, action='store_true' - ) - ag.add_argument( - '-m', - '--modal', - help='Modality of multi-modal model', - type=str, - ) - - -def run(args): - import sevenn.util - from sevenn.scripts.deploy import deploy, deploy_parallel - - checkpoint = args.checkpoint - output_prefix = args.output_prefix - get_parallel = args.get_parallel - get_serial = not get_parallel - modal = args.modal - - if output_prefix is None: - output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial' - - checkpoint_path = None - if os.path.isfile(checkpoint): - checkpoint_path = checkpoint - else: - checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint) - - if get_serial: - deploy(checkpoint_path, output_prefix, modal) - else: - deploy_parallel(checkpoint_path, output_prefix, modal) - - -# legacy way -def main(): - ag = argparse.ArgumentParser(description=description_get_model) - add_args(ag) - run(ag.parse_args()) +import argparse +import os + +from sevenn import __version__ + +description_get_model = ( + 'deploy LAMMPS model from the checkpoint' +) +checkpoint_help = ( + 'path to the checkpoint | SevenNet-0 | 7net-0 |' + ' {SevenNet-0|7net-0}_{11July2024|22May2024}' +) +output_name_help = 'filename prefix' +get_parallel_help = 'deploy parallel model' + + +def add_parser(subparsers): + ag = subparsers.add_parser( + 'get_model', help=description_get_model, aliases=['deploy'] + ) + add_args(ag) + + +def add_args(parser): + ag = parser + ag.add_argument('checkpoint', help=checkpoint_help, type=str) + ag.add_argument( + '-o', '--output_prefix', nargs='?', help=output_name_help, type=str + ) + ag.add_argument( + '-p', '--get_parallel', help=get_parallel_help, action='store_true' + ) + ag.add_argument( + '-m', + '--modal', + help='Modality of multi-modal model', + type=str, + ) + + +def run(args): + import sevenn.util + from sevenn.scripts.deploy import deploy, deploy_parallel + + checkpoint = args.checkpoint + output_prefix = args.output_prefix + get_parallel = args.get_parallel + get_serial = not get_parallel + modal = args.modal + + if output_prefix is None: + output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial' + + checkpoint_path = None + if os.path.isfile(checkpoint): + checkpoint_path = checkpoint + else: + checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint) + + if get_serial: + deploy(checkpoint_path, output_prefix, modal) + else: + deploy_parallel(checkpoint_path, output_prefix, modal) + + +# legacy way +def main(): + ag = argparse.ArgumentParser(description=description_get_model) + add_args(ag) + run(ag.parse_args()) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_graph_build.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_graph_build.py index da533c951c8df1e62023796ff66f840ffb0c076d..fd9eaeef104305e5182d7eb01d3495cd538c4e0f 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_graph_build.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_graph_build.py @@ -1,130 +1,130 @@ -import argparse -import glob -import os -import sys -from datetime import datetime - -from sevenn import __version__ - -description = 'create `sevenn_data/dataset.pt` from ase readable' - -source_help = 'source data to build graph, knows *' -cutoff_help = 'cutoff radius of edges in Angstrom' -filename_help = ( - 'Name of the dataset, default is graph.pt. ' - + 'The dataset will be written under "sevenn_data", ' - + 'for example, {out}/sevenn_data/graph.pt.' -) -legacy_help = 'build legacy .sevenn_data' - - -def add_parser(subparsers): - ag = subparsers.add_parser('graph_build', help=description) - add_args(ag) - - -def add_args(parser): - ag = parser - ag.add_argument('source', help=source_help, type=str) - ag.add_argument('cutoff', help=cutoff_help, type=float) - ag.add_argument( - '-n', - '--num_cores', - help='number of cores to build graph in parallel', - default=1, - type=int, - ) - ag.add_argument( - '-o', - '--out', - help='Existing path to write outputs.', - type=str, - default='./', - ) - ag.add_argument( - '-f', - '--filename', - help=filename_help, - type=str, - default='graph.pt', - ) - ag.add_argument( - '--legacy', - help=legacy_help, - action='store_true', - ) - ag.add_argument( - '-s', - '--screen', - help='print log to the screen', - action='store_true', - ) - ag.add_argument( - '--kwargs', - nargs=argparse.REMAINDER, - help='will be passed to ase.io.read, or can be used to specify EFS key', - ) - - -def run(args): - import sevenn.scripts.graph_build as graph_build - from sevenn.logger import Logger - - source = glob.glob(args.source) - cutoff = args.cutoff - num_cores = args.num_cores - filename = args.filename - out = args.out - legacy = args.legacy - fmt_kwargs = {} - if args.kwargs: - for kwarg in args.kwargs: - k, v = kwarg.split('=') - fmt_kwargs[k] = v - - if len(source) == 0: - print('Source has zero len, nothing to read') - sys.exit(0) - - if not os.path.isdir(out): - raise NotADirectoryError(f'No such directory: {out}') - - to_be_written = os.path.join(out, 'sevenn_data', filename) - if os.path.isfile(to_be_written): - raise FileExistsError(f'File already exist: {to_be_written}') - - metadata = { - 'sevenn_version': __version__, - 'when': datetime.now().strftime('%Y-%m-%d'), - 'cutoff': cutoff, - } - - with Logger(filename=None, screen=args.screen) as logger: - logger.writeline(description) - - if not legacy: - graph_build.build_sevennet_graph_dataset( - source, - cutoff, - num_cores, - out, - filename, - metadata, - **fmt_kwargs, - ) - else: - out = os.path.join(out, filename.split('.')[0]) - graph_build.build_script( # build .sevenn_data - source, - cutoff, - num_cores, - out, - metadata, - **fmt_kwargs, - ) - - -def main(args=None): - ag = argparse.ArgumentParser(description=description) - add_args(ag) - run(ag.parse_args()) +import argparse +import glob +import os +import sys +from datetime import datetime + +from sevenn import __version__ + +description = 'create `sevenn_data/dataset.pt` from ase readable' + +source_help = 'source data to build graph, knows *' +cutoff_help = 'cutoff radius of edges in Angstrom' +filename_help = ( + 'Name of the dataset, default is graph.pt. ' + + 'The dataset will be written under "sevenn_data", ' + + 'for example, {out}/sevenn_data/graph.pt.' +) +legacy_help = 'build legacy .sevenn_data' + + +def add_parser(subparsers): + ag = subparsers.add_parser('graph_build', help=description) + add_args(ag) + + +def add_args(parser): + ag = parser + ag.add_argument('source', help=source_help, type=str) + ag.add_argument('cutoff', help=cutoff_help, type=float) + ag.add_argument( + '-n', + '--num_cores', + help='number of cores to build graph in parallel', + default=1, + type=int, + ) + ag.add_argument( + '-o', + '--out', + help='Existing path to write outputs.', + type=str, + default='./', + ) + ag.add_argument( + '-f', + '--filename', + help=filename_help, + type=str, + default='graph.pt', + ) + ag.add_argument( + '--legacy', + help=legacy_help, + action='store_true', + ) + ag.add_argument( + '-s', + '--screen', + help='print log to the screen', + action='store_true', + ) + ag.add_argument( + '--kwargs', + nargs=argparse.REMAINDER, + help='will be passed to ase.io.read, or can be used to specify EFS key', + ) + + +def run(args): + import sevenn.scripts.graph_build as graph_build + from sevenn.logger import Logger + + source = glob.glob(args.source) + cutoff = args.cutoff + num_cores = args.num_cores + filename = args.filename + out = args.out + legacy = args.legacy + fmt_kwargs = {} + if args.kwargs: + for kwarg in args.kwargs: + k, v = kwarg.split('=') + fmt_kwargs[k] = v + + if len(source) == 0: + print('Source has zero len, nothing to read') + sys.exit(0) + + if not os.path.isdir(out): + raise NotADirectoryError(f'No such directory: {out}') + + to_be_written = os.path.join(out, 'sevenn_data', filename) + if os.path.isfile(to_be_written): + raise FileExistsError(f'File already exist: {to_be_written}') + + metadata = { + 'sevenn_version': __version__, + 'when': datetime.now().strftime('%Y-%m-%d'), + 'cutoff': cutoff, + } + + with Logger(filename=None, screen=args.screen) as logger: + logger.writeline(description) + + if not legacy: + graph_build.build_sevennet_graph_dataset( + source, + cutoff, + num_cores, + out, + filename, + metadata, + **fmt_kwargs, + ) + else: + out = os.path.join(out, filename.split('.')[0]) + graph_build.build_script( # build .sevenn_data + source, + cutoff, + num_cores, + out, + metadata, + **fmt_kwargs, + ) + + +def main(args=None): + ag = argparse.ArgumentParser(description=description) + add_args(ag) + run(ag.parse_args()) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_inference.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_inference.py index 5a9cd903e7100bea350b0ee6d339361301ea7ce1..bfac2371435715d0290d14fb34326a63bf68dcdd 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_inference.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_inference.py @@ -1,129 +1,129 @@ -import argparse -import glob -import os -import sys - -description = ( - 'evaluate sevenn_data/ase readable with a model (checkpoint).' -) -checkpoint_help = 'Checkpoint or pre-trained model name' -target_help = 'Target files to evaluate' - - -def add_parser(subparsers): - ag = subparsers.add_parser('inference', help=description, aliases=['inf']) - add_args(ag) - - -def add_args(parser): - ag = parser - ag.add_argument('checkpoint', type=str, help=checkpoint_help) - ag.add_argument('targets', type=str, nargs='+', help=target_help) - ag.add_argument( - '-d', - '--device', - type=str, - default='auto', - help='cpu/cuda/cuda:x', - ) - ag.add_argument( - '-nw', - '--nworkers', - type=int, - default=1, - help='Number of cores to build graph, defaults to 1', - ) - ag.add_argument( - '-o', - '--output', - type=str, - default='./inference_results', - help='A directory name to write outputs', - ) - ag.add_argument( - '-b', - '--batch', - type=int, - default='4', - help='batch size, useful for GPU' - ) - ag.add_argument( - '-s', - '--save_graph', - action='store_true', - help='Additionally, save preprocessed graph as sevenn_data' - ) - ag.add_argument( - '-au', - '--allow_unlabeled', - action='store_true', - help='Allow energy or force unlabeled data' - ) - ag.add_argument( - '-m', - '--modal', - type=str, - default=None, - help='modality for multi-modal inference', - ) - ag.add_argument( - '--kwargs', - nargs=argparse.REMAINDER, - help='will be passed to reader, or can be used to specify EFS key', - ) - - -def run(args): - import torch - - from sevenn.scripts.inference import inference - from sevenn.util import pretrained_name_to_path - - out = args.output - - if os.path.exists(out): - raise FileExistsError(f'Directory {out} already exists') - - device = args.device - if device == 'auto': - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - targets = [] - for target in args.targets: - targets.extend(glob.glob(target)) - - if len(targets) == 0: - print('No targets (data to inference) are found') - sys.exit(0) - - cp = args.checkpoint - if not os.path.isfile(cp): - cp = pretrained_name_to_path(cp) # raises value error - - fmt_kwargs = {} - if args.kwargs: - for kwarg in args.kwargs: - k, v = kwarg.split('=') - fmt_kwargs[k] = v - - if args.save_graph and args.allow_unlabeled: - raise ValueError('save_graph and allow_unlabeled are mutually exclusive') - - inference( - cp, - targets, - out, - args.nworkers, - device, - args.batch, - args.save_graph, - args.allow_unlabeled, - args.modal, - **fmt_kwargs, - ) - - -def main(args=None): - ag = argparse.ArgumentParser(description=description) - add_args(ag) - run(ag.parse_args()) +import argparse +import glob +import os +import sys + +description = ( + 'evaluate sevenn_data/ase readable with a model (checkpoint).' +) +checkpoint_help = 'Checkpoint or pre-trained model name' +target_help = 'Target files to evaluate' + + +def add_parser(subparsers): + ag = subparsers.add_parser('inference', help=description, aliases=['inf']) + add_args(ag) + + +def add_args(parser): + ag = parser + ag.add_argument('checkpoint', type=str, help=checkpoint_help) + ag.add_argument('targets', type=str, nargs='+', help=target_help) + ag.add_argument( + '-d', + '--device', + type=str, + default='auto', + help='cpu/cuda/cuda:x', + ) + ag.add_argument( + '-nw', + '--nworkers', + type=int, + default=1, + help='Number of cores to build graph, defaults to 1', + ) + ag.add_argument( + '-o', + '--output', + type=str, + default='./inference_results', + help='A directory name to write outputs', + ) + ag.add_argument( + '-b', + '--batch', + type=int, + default='4', + help='batch size, useful for GPU' + ) + ag.add_argument( + '-s', + '--save_graph', + action='store_true', + help='Additionally, save preprocessed graph as sevenn_data' + ) + ag.add_argument( + '-au', + '--allow_unlabeled', + action='store_true', + help='Allow energy or force unlabeled data' + ) + ag.add_argument( + '-m', + '--modal', + type=str, + default=None, + help='modality for multi-modal inference', + ) + ag.add_argument( + '--kwargs', + nargs=argparse.REMAINDER, + help='will be passed to reader, or can be used to specify EFS key', + ) + + +def run(args): + import torch + + from sevenn.scripts.inference import inference + from sevenn.util import pretrained_name_to_path + + out = args.output + + if os.path.exists(out): + raise FileExistsError(f'Directory {out} already exists') + + device = args.device + if device == 'auto': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + targets = [] + for target in args.targets: + targets.extend(glob.glob(target)) + + if len(targets) == 0: + print('No targets (data to inference) are found') + sys.exit(0) + + cp = args.checkpoint + if not os.path.isfile(cp): + cp = pretrained_name_to_path(cp) # raises value error + + fmt_kwargs = {} + if args.kwargs: + for kwarg in args.kwargs: + k, v = kwarg.split('=') + fmt_kwargs[k] = v + + if args.save_graph and args.allow_unlabeled: + raise ValueError('save_graph and allow_unlabeled are mutually exclusive') + + inference( + cp, + targets, + out, + args.nworkers, + device, + args.batch, + args.save_graph, + args.allow_unlabeled, + args.modal, + **fmt_kwargs, + ) + + +def main(args=None): + ag = argparse.ArgumentParser(description=description) + add_args(ag) + run(ag.parse_args()) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_patch_lammps.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_patch_lammps.py index 9ede5a6ecf7b541cb3d0c23155c40d84c83ca1b4..38f7299ad816600b4bc16f3eafe269a3bc62c434 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_patch_lammps.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_patch_lammps.py @@ -1,55 +1,55 @@ -import argparse -import os -import subprocess - -from sevenn import __version__ - -# python wrapper of patch_lammps.sh script -# importlib.resources is correct way to do these things -# but it changes so frequently to use -pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn') - -description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile' - - -def add_parser(subparsers): - ag = subparsers.add_parser('patch_lammps', help=description) - add_args(ag) - - -def add_args(parser): - ag = parser - ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str) - ag.add_argument('--d3', help='Enable D3 support', action='store_true') - # cxx_standard is detected automatically - - -def run(args): - lammps_dir = os.path.abspath(args.lammps_dir) - - print('Patching LAMMPS with the following settings:') - print(' - LAMMPS source directory:', lammps_dir) - - cxx_standard = '17' # always 17 - - if args.d3: - d3_support = '1' - print(' - D3 support enabled') - else: - d3_support = '0' - print(' - D3 support disabled') - - script = f'{pair_e3gnn_dir}/patch_lammps.sh' - cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}' - res = subprocess.run(cmd.split()) - return res.returncode # is it meaningless? - - -def main(args=None): - ag = argparse.ArgumentParser(description=description) - add_args(ag) - run(ag.parse_args()) - - -if __name__ == '__main__': - main() +import argparse +import os +import subprocess + +from sevenn import __version__ + +# python wrapper of patch_lammps.sh script +# importlib.resources is correct way to do these things +# but it changes so frequently to use +pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn') + +description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile' + + +def add_parser(subparsers): + ag = subparsers.add_parser('patch_lammps', help=description) + add_args(ag) + + +def add_args(parser): + ag = parser + ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str) + ag.add_argument('--d3', help='Enable D3 support', action='store_true') + # cxx_standard is detected automatically + + +def run(args): + lammps_dir = os.path.abspath(args.lammps_dir) + + print('Patching LAMMPS with the following settings:') + print(' - LAMMPS source directory:', lammps_dir) + + cxx_standard = '17' # always 17 + + if args.d3: + d3_support = '1' + print(' - D3 support enabled') + else: + d3_support = '0' + print(' - D3 support disabled') + + script = f'{pair_e3gnn_dir}/patch_lammps.sh' + cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}' + res = subprocess.run(cmd.split()) + return res.returncode # is it meaningless? + + +def main(args=None): + ag = argparse.ArgumentParser(description=description) + add_args(ag) + run(ag.parse_args()) + + +if __name__ == '__main__': + main() diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_preset.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_preset.py index fdacea5e7301be4eec0d52b6651b07e83200fe84..f8587d1b8fccf85bcb5633ec4464a4ff7207d4f3 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_preset.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_preset.py @@ -1,45 +1,45 @@ -import argparse -import os - -from sevenn import __version__ - -description = ( - 'print the selected preset for training. ' - + 'ex) sevennet_preset fine_tune > my_input.yaml' -) - -preset_help = 'Name of preset' - - -def add_parser(subparsers): - ag = subparsers.add_parser('preset', help=description) - add_args(ag) - - -def add_args(parser): - ag = parser - ag.add_argument( - 'preset', choices=[ - 'fine_tune', - 'fine_tune_le', - 'sevennet-0', - 'sevennet-l3i5', - 'base', - 'multi_modal' - ], - help=preset_help - ) - - -def run(args): - preset = args.preset - prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets') - with open(f'{prefix}/{preset}.yaml', 'r') as f: - print(f.read()) - - -# When executed as sevenn_preset (legacy way) -def main(args=None): - ag = argparse.ArgumentParser(description=description) - add_args(ag) - run(ag.parse_args()) +import argparse +import os + +from sevenn import __version__ + +description = ( + 'print the selected preset for training. ' + + 'ex) sevennet_preset fine_tune > my_input.yaml' +) + +preset_help = 'Name of preset' + + +def add_parser(subparsers): + ag = subparsers.add_parser('preset', help=description) + add_args(ag) + + +def add_args(parser): + ag = parser + ag.add_argument( + 'preset', choices=[ + 'fine_tune', + 'fine_tune_le', + 'sevennet-0', + 'sevennet-l3i5', + 'base', + 'multi_modal' + ], + help=preset_help + ) + + +def run(args): + preset = args.preset + prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets') + with open(f'{prefix}/{preset}.yaml', 'r') as f: + print(f.read()) + + +# When executed as sevenn_preset (legacy way) +def main(args=None): + ag = argparse.ArgumentParser(description=description) + add_args(ag) + run(ag.parse_args()) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/model_build.py b/mace-bench/3rdparty/SevenNet/sevenn/model_build.py index 91136723aa47c240e955c0c0b28820b29509b684..2b8701d50eefffe82bac95119dd5417ced217934 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/model_build.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/model_build.py @@ -1,556 +1,556 @@ -import copy -import warnings -from collections import OrderedDict -from typing import List, Literal, Union, overload - -from e3nn.o3 import Irreps - -import sevenn._const as _const -import sevenn._keys as KEY -import sevenn.util as util - -from .nn.convolution import IrrepsConvolution -from .nn.edge_embedding import ( - BesselBasis, - EdgeEmbedding, - PolynomialCutoff, - SphericalEncoding, - XPLORCutoff, -) -from .nn.force_output import ForceStressOutputFromEdge -from .nn.interaction_blocks import NequIP_interaction_block -from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear -from .nn.node_embedding import OnehotEmbedding -from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale -from .nn.self_connection import ( - SelfConnectionIntro, - SelfConnectionLinearIntro, - SelfConnectionOutro, -) -from .nn.sequential import AtomGraphSequential - -# warning from PyTorch, about e3nn type annotations -warnings.filterwarnings( - 'ignore', - message=( - "The TorchScript type system doesn't " 'support instance-level annotations' - ), -) - - -def _insert_after(module_name_after, key_module_pair, layers): - idx = -1 - for i, (key, _) in enumerate(layers): - if key == module_name_after: - idx = i - break - if idx == -1: - return layers # do nothing if not found - layers.insert(idx + 1, key_module_pair) - return layers - - -def init_self_connection(config): - self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE] - num_conv = config[KEY.NUM_CONVOLUTION] - if isinstance(self_connection_type_list, str): - self_connection_type_list = [self_connection_type_list] * num_conv - - io_pair_list = [] - for sc_type in self_connection_type_list: - if sc_type == 'none': - io_pair = None - elif sc_type == 'nequip': - io_pair = SelfConnectionIntro, SelfConnectionOutro - elif sc_type == 'linear': - io_pair = SelfConnectionLinearIntro, SelfConnectionOutro - else: - raise ValueError(f'Unknown self_connection_type found: {sc_type}') - io_pair_list.append(io_pair) - return io_pair_list - - -def init_edge_embedding(config): - _cutoff_param = {'cutoff_length': config[KEY.CUTOFF]} - rbf, env, sph = None, None, None - - rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS]) - rbf_dct.update(_cutoff_param) - rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME) - if rbf_name == 'bessel': - rbf = BesselBasis(**rbf_dct) - - envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION]) - envelop_dct.update(_cutoff_param) - envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME) - if envelop_name == 'poly_cut': - env = PolynomialCutoff(**envelop_dct) - elif envelop_name == 'XPLOR': - env = XPLORCutoff(**envelop_dct) - - lmax_edge = config[KEY.LMAX] - if config[KEY.LMAX_EDGE] > 0: - lmax_edge = config[KEY.LMAX_EDGE] - parity = -1 if config[KEY.IS_PARITY] else 1 - _normalize_sph = config[KEY._NORMALIZE_SPH] - sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph) - - return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph) - - -def init_feature_reduce(config, irreps_x): - # features per node to scalar per node - layers = OrderedDict() - if config[KEY.READOUT_AS_FCN] is False: - hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))]) - layers.update( - { - 'reduce_input_to_hidden': IrrepsLinear( - irreps_x, - hidden_irreps, - data_key_in=KEY.NODE_FEATURE, - biases=config[KEY.USE_BIAS_IN_LINEAR], - ), - 'reduce_hidden_to_energy': IrrepsLinear( - hidden_irreps, - Irreps([(1, (0, 1))]), - data_key_in=KEY.NODE_FEATURE, - data_key_out=KEY.SCALED_ATOMIC_ENERGY, - biases=config[KEY.USE_BIAS_IN_LINEAR], - ), - } - ) - else: - act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]] - hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS] - layers.update( - { - 'readout_FCN': FCN_e3nn( - dim_out=1, - hidden_neurons=hidden_neurons, - activation=act, - data_key_in=KEY.NODE_FEATURE, - data_key_out=KEY.SCALED_ATOMIC_ENERGY, - irreps_in=irreps_x, - ) - } - ) - return layers - - -def init_shift_scale(config): - # for mm, ex, shift: modal_idx -> shifts - shift_scale = [] - train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE] - type_map = config[KEY.TYPE_MAP] - - # in case of modal, shift or scale has more dims [][] - # correct typing (I really want static python) - for s in (config[KEY.SHIFT], config[KEY.SCALE]): - if hasattr(s, 'tolist'): # numpy or torch - s = s.tolist() - if isinstance(s, dict): - s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()} - if isinstance(s, list) and len(s) == 1: - s = s[0] - shift_scale.append(s) - shift, scale = shift_scale - - rescale_module = None - if config.get(KEY.USE_MODALITY, False): - rescale_module = ModalWiseRescale.from_mappers( # type: ignore - shift, - scale, - config[KEY.USE_MODAL_WISE_SHIFT], - config[KEY.USE_MODAL_WISE_SCALE], - type_map=type_map, - modal_map=config[KEY.MODAL_MAP], - train_shift_scale=train_shift_scale, - ) - elif all([isinstance(s, float) for s in shift_scale]): - rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale) - elif any([isinstance(s, list) for s in shift_scale]): - rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore - shift, scale, type_map=type_map, train_shift_scale=train_shift_scale - ) - else: - raise ValueError('shift, scale should be list of float or float') - - return rescale_module - - -def patch_modality(layers: OrderedDict, config): - """ - Postprocess 7net-model to multimodal model. - 1. prepend modality one-hot embedding layer - 2. patch modalities of IrrepsLinear layers - Modality aware shift scale is handled by init_shift_scale, not here - """ - cfg = config - if not cfg.get(KEY.USE_MODALITY, False): - return layers - - _layers = list(layers.items()) - _layers = _insert_after( - 'onehot_idx_to_onehot', - ( - 'one_hot_modality', - OnehotEmbedding( - num_classes=config[KEY.NUM_MODALITIES], - data_key_x=KEY.MODAL_TYPE, - data_key_out=KEY.MODAL_ATTR, - data_key_save=None, - data_key_additional=None, - ), - ), - _layers, - ) - layers = OrderedDict(_layers) - - num_modal = config[KEY.NUM_MODALITIES] - for k, module in layers.items(): - if not isinstance(module, IrrepsLinear): - continue - if ( - (cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x')) - or ( - cfg[KEY.USE_MODAL_SELF_INTER_INTRO] - and k.endswith('self_interaction_1') - ) - or ( - cfg[KEY.USE_MODAL_SELF_INTER_OUTRO] - and k.endswith('self_interaction_2') - ) - or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden') - ): - module.set_num_modalities(num_modal) - return layers - - -def patch_cue(layers: OrderedDict, config): - import sevenn.nn.cue_helper as cue_helper - - cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {})) - - if not cue_cfg.pop('use', False): - return layers - - if not cue_helper.is_cue_available(): - warnings.warn( - ( - 'cuEquivariance is requested, but the package is not installed. ' - + 'Fallback to original code.' - ) - ) - return layers - - if not cue_helper.is_cue_cuda_available_model(config): - return layers - - group = 'O3' if config[KEY.IS_PARITY] else 'SO3' - cueq_module_params = dict(layout='mul_ir') - cueq_module_params.update(cue_cfg) - updates = {} - for k, module in layers.items(): - if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)): - if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape - continue - module_patched = cue_helper.patch_linear( - module, group, **cueq_module_params - ) - updates[k] = module_patched - elif isinstance(module, SelfConnectionIntro): - module_patched = cue_helper.patch_fully_connected( - module, group, **cueq_module_params - ) - updates[k] = module_patched - elif isinstance(module, IrrepsConvolution): - module_patched = cue_helper.patch_convolution( - module, group, **cueq_module_params - ) - updates[k] = module_patched - - layers.update(updates) - return layers - - -def patch_modules(layers: OrderedDict, config): - layers = patch_modality(layers, config) - layers = patch_cue(layers, config) - return layers - - -def _to_parallel_model(layers: OrderedDict, config): - num_classes = layers['onehot_idx_to_onehot'].num_classes - one_hot_irreps = Irreps(f'{num_classes}x0e') - irreps_node_zero = layers['onehot_to_feature_x'].irreps_out - - _layers = list(layers.items()) - layers_list = [] - - num_convolution_layer = config[KEY.NUM_CONVOLUTION] - - def slice_until_this(module_name, layers): - idx = -1 - for i, (key, _) in enumerate(layers): - if key == module_name: - idx = i - break - first_to = layers[: idx + 1] - remain = layers[idx + 1 :] - return first_to, remain - - _layers = _insert_after( - 'onehot_to_feature_x', - ( - 'one_hot_ghost', - OnehotEmbedding( - data_key_x=KEY.NODE_FEATURE_GHOST, - num_classes=num_classes, - data_key_save=None, - data_key_additional=None, - ), - ), - _layers, - ) - _layers = _insert_after( - 'one_hot_ghost', - ( - 'ghost_onehot_to_feature_x', - IrrepsLinear( - irreps_in=one_hot_irreps, - irreps_out=irreps_node_zero, - data_key_in=KEY.NODE_FEATURE_GHOST, - biases=config[KEY.USE_BIAS_IN_LINEAR], - ), - ), - _layers, - ) - _layers = _insert_after( - '0_self_interaction_1', - ( - 'ghost_0_self_interaction_1', - IrrepsLinear( - irreps_node_zero, - irreps_node_zero, - data_key_in=KEY.NODE_FEATURE_GHOST, - biases=config[KEY.USE_BIAS_IN_LINEAR], - ), - ), - _layers, - ) - # assign modules (before first communications) - # initialize edge related to retain position gradients - for i in range(1, num_convolution_layer): - sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers) - layers_list.append(OrderedDict(sliced)) - _layers.insert(0, ('edge_embedding', init_edge_embedding(config))) - - layers_list.append(OrderedDict(_layers)) - del layers_list[-1]['force_output'] # done in LAMMPS - return layers_list - - -@overload -def build_E3_equivariant_model( - config: dict, parallel: Literal[False] = False -) -> AtomGraphSequential: # noqa - ... - - -@overload -def build_E3_equivariant_model( - config: dict, parallel: Literal[True] -) -> List[AtomGraphSequential]: # noqa - ... - - -def build_E3_equivariant_model( - config: dict, parallel: bool = False -) -> Union[AtomGraphSequential, List[AtomGraphSequential]]: - """ - output shapes (w/o batch) - - PRED_TOTAL_ENERGY: (), - ATOMIC_ENERGY: (natoms, 1), # intended - PRED_FORCE: (natoms, 3), - PRED_STRESS: (6,), - - for data w/o cell volume, pred_stress has garbage values - """ - layers = OrderedDict() - - cutoff = config[KEY.CUTOFF] - num_species = config[KEY.NUM_SPECIES] - feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY] - num_convolution_layer = config[KEY.NUM_CONVOLUTION] - interaction_type = config[KEY.INTERACTION_TYPE] - use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR] - - lmax_node = config[KEY.LMAX] # ignore second (lmax_edge) - # if config[KEY.LMAX_EDGE] > 0: # not yet used - # _ = config[KEY.LMAX_EDGE] - if config[KEY.LMAX_NODE] > 0: - lmax_node = config[KEY.LMAX_NODE] - - act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]] - self_connection_pair_list = init_self_connection(config) - - irreps_manual = None - if config[KEY.IRREPS_MANUAL] is not False: - irreps_manual = config[KEY.IRREPS_MANUAL] - try: - irreps_manual = [Irreps(irr) for irr in irreps_manual] - assert len(irreps_manual) == num_convolution_layer + 1 - except Exception: - raise RuntimeError('invalid irreps_manual input given') - - conv_denominator = config[KEY.CONV_DENOMINATOR] - if not isinstance(conv_denominator, list): - conv_denominator = [conv_denominator] * num_convolution_layer - train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR] - - edge_embedding = init_edge_embedding(config) - irreps_filter = edge_embedding.spherical.irreps_out - radial_basis_num = edge_embedding.basis_function.num_basis - layers.update({'edge_embedding': edge_embedding}) - - one_hot_irreps = Irreps(f'{num_species}x0e') - irreps_x = ( - Irreps(f'{feature_multiplicity}x0e') - if irreps_manual is None - else irreps_manual[0] - ) - - layers.update( - { - 'onehot_idx_to_onehot': OnehotEmbedding( - num_classes=num_species, - data_key_x=KEY.NODE_FEATURE, - data_key_out=KEY.NODE_FEATURE, - data_key_save=KEY.ATOM_TYPE, # atomic numbers - data_key_additional=KEY.NODE_ATTR, # one-hot embeddings - ), - 'onehot_to_feature_x': IrrepsLinear( - irreps_in=one_hot_irreps, - irreps_out=irreps_x, - data_key_in=KEY.NODE_FEATURE, - biases=use_bias_in_linear, - ), - } - ) - - weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS] - weight_nn_layers = [radial_basis_num] + weight_nn_hidden - - param_interaction_block = { - 'irreps_filter': irreps_filter, - 'weight_nn_layers': weight_nn_layers, - 'train_conv_denominator': train_conv_denominator, - 'act_radial': act_radial, - 'bias_in_linear': use_bias_in_linear, - 'num_species': num_species, - 'parallel': parallel, - } - - interaction_builder = None - - if interaction_type in ['nequip']: - act_scalar = {} - act_gate = {} - for k, v in config[KEY.ACTIVATION_SCARLAR].items(): - act_scalar[k] = _const.ACTIVATION_DICT[k][v] - for k, v in config[KEY.ACTIVATION_GATE].items(): - act_gate[k] = _const.ACTIVATION_DICT[k][v] - param_interaction_block.update( - { - 'act_scalar': act_scalar, - 'act_gate': act_gate, - } - ) - - if interaction_type == 'nequip': - interaction_builder = NequIP_interaction_block - else: - raise ValueError(f'Unknown interaction type: {interaction_type}') - - for t in range(num_convolution_layer): - param_interaction_block.update( - { - 'irreps_x': irreps_x, - 't': t, - 'conv_denominator': conv_denominator[t], - 'self_connection_pair': self_connection_pair_list[t], - } - ) - if interaction_type == 'nequip': - parity_mode = 'full' - fix_multiplicity = False - if t == num_convolution_layer - 1: - lmax_node = 0 - parity_mode = 'even' - # TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out - irreps_out = ( - util.infer_irreps_out( - irreps_x, # type: ignore - irreps_filter, - lmax_node, # type: ignore - parity_mode, - fix_multiplicity=feature_multiplicity, - ) - if irreps_manual is None - else irreps_manual[t + 1] - ) - irreps_out_tp = util.infer_irreps_out( - irreps_x, # type: ignore - irreps_filter, - irreps_out.lmax, # type: ignore - parity_mode, - fix_multiplicity, - ) - else: - raise ValueError(f'Unknown interaction type: {interaction_type}') - param_interaction_block.update( - { - 'irreps_out_tp': irreps_out_tp, - 'irreps_out': irreps_out, - } - ) - layers.update(interaction_builder(**param_interaction_block)) - irreps_x = irreps_out - - layers.update(init_feature_reduce(config, irreps_x)) - - layers.update( - { - 'rescale_atomic_energy': init_shift_scale(config), - 'reduce_total_enegy': AtomReduce( - data_key_in=KEY.ATOMIC_ENERGY, - data_key_out=KEY.PRED_TOTAL_ENERGY, - ), - } - ) - - gradient_module = ForceStressOutputFromEdge() - grad_key = gradient_module.get_grad_key() - layers.update({'force_output': gradient_module}) - - common_args = { - 'cutoff': cutoff, - 'type_map': config[KEY.TYPE_MAP], - 'modal_map': config.get(KEY.MODAL_MAP, None), - 'eval_type_map': False if parallel else True, - 'eval_modal_map': False - if not config.get(KEY.USE_MODALITY, False) or parallel - else True, - 'data_key_grad': grad_key, - } - - if parallel: - layers_list = _to_parallel_model(layers, config) - return [ - AtomGraphSequential(patch_modules(layers, config), **common_args) - for layers in layers_list - ] - else: - return AtomGraphSequential(patch_modules(layers, config), **common_args) +import copy +import warnings +from collections import OrderedDict +from typing import List, Literal, Union, overload + +from e3nn.o3 import Irreps + +import sevenn._const as _const +import sevenn._keys as KEY +import sevenn.util as util + +from .nn.convolution import IrrepsConvolution +from .nn.edge_embedding import ( + BesselBasis, + EdgeEmbedding, + PolynomialCutoff, + SphericalEncoding, + XPLORCutoff, +) +from .nn.force_output import ForceStressOutputFromEdge +from .nn.interaction_blocks import NequIP_interaction_block +from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear +from .nn.node_embedding import OnehotEmbedding +from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale +from .nn.self_connection import ( + SelfConnectionIntro, + SelfConnectionLinearIntro, + SelfConnectionOutro, +) +from .nn.sequential import AtomGraphSequential + +# warning from PyTorch, about e3nn type annotations +warnings.filterwarnings( + 'ignore', + message=( + "The TorchScript type system doesn't " 'support instance-level annotations' + ), +) + + +def _insert_after(module_name_after, key_module_pair, layers): + idx = -1 + for i, (key, _) in enumerate(layers): + if key == module_name_after: + idx = i + break + if idx == -1: + return layers # do nothing if not found + layers.insert(idx + 1, key_module_pair) + return layers + + +def init_self_connection(config): + self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE] + num_conv = config[KEY.NUM_CONVOLUTION] + if isinstance(self_connection_type_list, str): + self_connection_type_list = [self_connection_type_list] * num_conv + + io_pair_list = [] + for sc_type in self_connection_type_list: + if sc_type == 'none': + io_pair = None + elif sc_type == 'nequip': + io_pair = SelfConnectionIntro, SelfConnectionOutro + elif sc_type == 'linear': + io_pair = SelfConnectionLinearIntro, SelfConnectionOutro + else: + raise ValueError(f'Unknown self_connection_type found: {sc_type}') + io_pair_list.append(io_pair) + return io_pair_list + + +def init_edge_embedding(config): + _cutoff_param = {'cutoff_length': config[KEY.CUTOFF]} + rbf, env, sph = None, None, None + + rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS]) + rbf_dct.update(_cutoff_param) + rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME) + if rbf_name == 'bessel': + rbf = BesselBasis(**rbf_dct) + + envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION]) + envelop_dct.update(_cutoff_param) + envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME) + if envelop_name == 'poly_cut': + env = PolynomialCutoff(**envelop_dct) + elif envelop_name == 'XPLOR': + env = XPLORCutoff(**envelop_dct) + + lmax_edge = config[KEY.LMAX] + if config[KEY.LMAX_EDGE] > 0: + lmax_edge = config[KEY.LMAX_EDGE] + parity = -1 if config[KEY.IS_PARITY] else 1 + _normalize_sph = config[KEY._NORMALIZE_SPH] + sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph) + + return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph) + + +def init_feature_reduce(config, irreps_x): + # features per node to scalar per node + layers = OrderedDict() + if config[KEY.READOUT_AS_FCN] is False: + hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))]) + layers.update( + { + 'reduce_input_to_hidden': IrrepsLinear( + irreps_x, + hidden_irreps, + data_key_in=KEY.NODE_FEATURE, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ), + 'reduce_hidden_to_energy': IrrepsLinear( + hidden_irreps, + Irreps([(1, (0, 1))]), + data_key_in=KEY.NODE_FEATURE, + data_key_out=KEY.SCALED_ATOMIC_ENERGY, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ), + } + ) + else: + act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]] + hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS] + layers.update( + { + 'readout_FCN': FCN_e3nn( + dim_out=1, + hidden_neurons=hidden_neurons, + activation=act, + data_key_in=KEY.NODE_FEATURE, + data_key_out=KEY.SCALED_ATOMIC_ENERGY, + irreps_in=irreps_x, + ) + } + ) + return layers + + +def init_shift_scale(config): + # for mm, ex, shift: modal_idx -> shifts + shift_scale = [] + train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE] + type_map = config[KEY.TYPE_MAP] + + # in case of modal, shift or scale has more dims [][] + # correct typing (I really want static python) + for s in (config[KEY.SHIFT], config[KEY.SCALE]): + if hasattr(s, 'tolist'): # numpy or torch + s = s.tolist() + if isinstance(s, dict): + s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()} + if isinstance(s, list) and len(s) == 1: + s = s[0] + shift_scale.append(s) + shift, scale = shift_scale + + rescale_module = None + if config.get(KEY.USE_MODALITY, False): + rescale_module = ModalWiseRescale.from_mappers( # type: ignore + shift, + scale, + config[KEY.USE_MODAL_WISE_SHIFT], + config[KEY.USE_MODAL_WISE_SCALE], + type_map=type_map, + modal_map=config[KEY.MODAL_MAP], + train_shift_scale=train_shift_scale, + ) + elif all([isinstance(s, float) for s in shift_scale]): + rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale) + elif any([isinstance(s, list) for s in shift_scale]): + rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore + shift, scale, type_map=type_map, train_shift_scale=train_shift_scale + ) + else: + raise ValueError('shift, scale should be list of float or float') + + return rescale_module + + +def patch_modality(layers: OrderedDict, config): + """ + Postprocess 7net-model to multimodal model. + 1. prepend modality one-hot embedding layer + 2. patch modalities of IrrepsLinear layers + Modality aware shift scale is handled by init_shift_scale, not here + """ + cfg = config + if not cfg.get(KEY.USE_MODALITY, False): + return layers + + _layers = list(layers.items()) + _layers = _insert_after( + 'onehot_idx_to_onehot', + ( + 'one_hot_modality', + OnehotEmbedding( + num_classes=config[KEY.NUM_MODALITIES], + data_key_x=KEY.MODAL_TYPE, + data_key_out=KEY.MODAL_ATTR, + data_key_save=None, + data_key_additional=None, + ), + ), + _layers, + ) + layers = OrderedDict(_layers) + + num_modal = config[KEY.NUM_MODALITIES] + for k, module in layers.items(): + if not isinstance(module, IrrepsLinear): + continue + if ( + (cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x')) + or ( + cfg[KEY.USE_MODAL_SELF_INTER_INTRO] + and k.endswith('self_interaction_1') + ) + or ( + cfg[KEY.USE_MODAL_SELF_INTER_OUTRO] + and k.endswith('self_interaction_2') + ) + or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden') + ): + module.set_num_modalities(num_modal) + return layers + + +def patch_cue(layers: OrderedDict, config): + import sevenn.nn.cue_helper as cue_helper + + cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {})) + + if not cue_cfg.pop('use', False): + return layers + + if not cue_helper.is_cue_available(): + warnings.warn( + ( + 'cuEquivariance is requested, but the package is not installed. ' + + 'Fallback to original code.' + ) + ) + return layers + + if not cue_helper.is_cue_cuda_available_model(config): + return layers + + group = 'O3' if config[KEY.IS_PARITY] else 'SO3' + cueq_module_params = dict(layout='mul_ir') + cueq_module_params.update(cue_cfg) + updates = {} + for k, module in layers.items(): + if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)): + if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape + continue + module_patched = cue_helper.patch_linear( + module, group, **cueq_module_params + ) + updates[k] = module_patched + elif isinstance(module, SelfConnectionIntro): + module_patched = cue_helper.patch_fully_connected( + module, group, **cueq_module_params + ) + updates[k] = module_patched + elif isinstance(module, IrrepsConvolution): + module_patched = cue_helper.patch_convolution( + module, group, **cueq_module_params + ) + updates[k] = module_patched + + layers.update(updates) + return layers + + +def patch_modules(layers: OrderedDict, config): + layers = patch_modality(layers, config) + layers = patch_cue(layers, config) + return layers + + +def _to_parallel_model(layers: OrderedDict, config): + num_classes = layers['onehot_idx_to_onehot'].num_classes + one_hot_irreps = Irreps(f'{num_classes}x0e') + irreps_node_zero = layers['onehot_to_feature_x'].irreps_out + + _layers = list(layers.items()) + layers_list = [] + + num_convolution_layer = config[KEY.NUM_CONVOLUTION] + + def slice_until_this(module_name, layers): + idx = -1 + for i, (key, _) in enumerate(layers): + if key == module_name: + idx = i + break + first_to = layers[: idx + 1] + remain = layers[idx + 1 :] + return first_to, remain + + _layers = _insert_after( + 'onehot_to_feature_x', + ( + 'one_hot_ghost', + OnehotEmbedding( + data_key_x=KEY.NODE_FEATURE_GHOST, + num_classes=num_classes, + data_key_save=None, + data_key_additional=None, + ), + ), + _layers, + ) + _layers = _insert_after( + 'one_hot_ghost', + ( + 'ghost_onehot_to_feature_x', + IrrepsLinear( + irreps_in=one_hot_irreps, + irreps_out=irreps_node_zero, + data_key_in=KEY.NODE_FEATURE_GHOST, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ), + ), + _layers, + ) + _layers = _insert_after( + '0_self_interaction_1', + ( + 'ghost_0_self_interaction_1', + IrrepsLinear( + irreps_node_zero, + irreps_node_zero, + data_key_in=KEY.NODE_FEATURE_GHOST, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ), + ), + _layers, + ) + # assign modules (before first communications) + # initialize edge related to retain position gradients + for i in range(1, num_convolution_layer): + sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers) + layers_list.append(OrderedDict(sliced)) + _layers.insert(0, ('edge_embedding', init_edge_embedding(config))) + + layers_list.append(OrderedDict(_layers)) + del layers_list[-1]['force_output'] # done in LAMMPS + return layers_list + + +@overload +def build_E3_equivariant_model( + config: dict, parallel: Literal[False] = False +) -> AtomGraphSequential: # noqa + ... + + +@overload +def build_E3_equivariant_model( + config: dict, parallel: Literal[True] +) -> List[AtomGraphSequential]: # noqa + ... + + +def build_E3_equivariant_model( + config: dict, parallel: bool = False +) -> Union[AtomGraphSequential, List[AtomGraphSequential]]: + """ + output shapes (w/o batch) + + PRED_TOTAL_ENERGY: (), + ATOMIC_ENERGY: (natoms, 1), # intended + PRED_FORCE: (natoms, 3), + PRED_STRESS: (6,), + + for data w/o cell volume, pred_stress has garbage values + """ + layers = OrderedDict() + + cutoff = config[KEY.CUTOFF] + num_species = config[KEY.NUM_SPECIES] + feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY] + num_convolution_layer = config[KEY.NUM_CONVOLUTION] + interaction_type = config[KEY.INTERACTION_TYPE] + use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR] + + lmax_node = config[KEY.LMAX] # ignore second (lmax_edge) + # if config[KEY.LMAX_EDGE] > 0: # not yet used + # _ = config[KEY.LMAX_EDGE] + if config[KEY.LMAX_NODE] > 0: + lmax_node = config[KEY.LMAX_NODE] + + act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]] + self_connection_pair_list = init_self_connection(config) + + irreps_manual = None + if config[KEY.IRREPS_MANUAL] is not False: + irreps_manual = config[KEY.IRREPS_MANUAL] + try: + irreps_manual = [Irreps(irr) for irr in irreps_manual] + assert len(irreps_manual) == num_convolution_layer + 1 + except Exception: + raise RuntimeError('invalid irreps_manual input given') + + conv_denominator = config[KEY.CONV_DENOMINATOR] + if not isinstance(conv_denominator, list): + conv_denominator = [conv_denominator] * num_convolution_layer + train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR] + + edge_embedding = init_edge_embedding(config) + irreps_filter = edge_embedding.spherical.irreps_out + radial_basis_num = edge_embedding.basis_function.num_basis + layers.update({'edge_embedding': edge_embedding}) + + one_hot_irreps = Irreps(f'{num_species}x0e') + irreps_x = ( + Irreps(f'{feature_multiplicity}x0e') + if irreps_manual is None + else irreps_manual[0] + ) + + layers.update( + { + 'onehot_idx_to_onehot': OnehotEmbedding( + num_classes=num_species, + data_key_x=KEY.NODE_FEATURE, + data_key_out=KEY.NODE_FEATURE, + data_key_save=KEY.ATOM_TYPE, # atomic numbers + data_key_additional=KEY.NODE_ATTR, # one-hot embeddings + ), + 'onehot_to_feature_x': IrrepsLinear( + irreps_in=one_hot_irreps, + irreps_out=irreps_x, + data_key_in=KEY.NODE_FEATURE, + biases=use_bias_in_linear, + ), + } + ) + + weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS] + weight_nn_layers = [radial_basis_num] + weight_nn_hidden + + param_interaction_block = { + 'irreps_filter': irreps_filter, + 'weight_nn_layers': weight_nn_layers, + 'train_conv_denominator': train_conv_denominator, + 'act_radial': act_radial, + 'bias_in_linear': use_bias_in_linear, + 'num_species': num_species, + 'parallel': parallel, + } + + interaction_builder = None + + if interaction_type in ['nequip']: + act_scalar = {} + act_gate = {} + for k, v in config[KEY.ACTIVATION_SCARLAR].items(): + act_scalar[k] = _const.ACTIVATION_DICT[k][v] + for k, v in config[KEY.ACTIVATION_GATE].items(): + act_gate[k] = _const.ACTIVATION_DICT[k][v] + param_interaction_block.update( + { + 'act_scalar': act_scalar, + 'act_gate': act_gate, + } + ) + + if interaction_type == 'nequip': + interaction_builder = NequIP_interaction_block + else: + raise ValueError(f'Unknown interaction type: {interaction_type}') + + for t in range(num_convolution_layer): + param_interaction_block.update( + { + 'irreps_x': irreps_x, + 't': t, + 'conv_denominator': conv_denominator[t], + 'self_connection_pair': self_connection_pair_list[t], + } + ) + if interaction_type == 'nequip': + parity_mode = 'full' + fix_multiplicity = False + if t == num_convolution_layer - 1: + lmax_node = 0 + parity_mode = 'even' + # TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out + irreps_out = ( + util.infer_irreps_out( + irreps_x, # type: ignore + irreps_filter, + lmax_node, # type: ignore + parity_mode, + fix_multiplicity=feature_multiplicity, + ) + if irreps_manual is None + else irreps_manual[t + 1] + ) + irreps_out_tp = util.infer_irreps_out( + irreps_x, # type: ignore + irreps_filter, + irreps_out.lmax, # type: ignore + parity_mode, + fix_multiplicity, + ) + else: + raise ValueError(f'Unknown interaction type: {interaction_type}') + param_interaction_block.update( + { + 'irreps_out_tp': irreps_out_tp, + 'irreps_out': irreps_out, + } + ) + layers.update(interaction_builder(**param_interaction_block)) + irreps_x = irreps_out + + layers.update(init_feature_reduce(config, irreps_x)) + + layers.update( + { + 'rescale_atomic_energy': init_shift_scale(config), + 'reduce_total_enegy': AtomReduce( + data_key_in=KEY.ATOMIC_ENERGY, + data_key_out=KEY.PRED_TOTAL_ENERGY, + ), + } + ) + + gradient_module = ForceStressOutputFromEdge() + grad_key = gradient_module.get_grad_key() + layers.update({'force_output': gradient_module}) + + common_args = { + 'cutoff': cutoff, + 'type_map': config[KEY.TYPE_MAP], + 'modal_map': config.get(KEY.MODAL_MAP, None), + 'eval_type_map': False if parallel else True, + 'eval_modal_map': False + if not config.get(KEY.USE_MODALITY, False) or parallel + else True, + 'data_key_grad': grad_key, + } + + if parallel: + layers_list = _to_parallel_model(layers, config) + return [ + AtomGraphSequential(patch_modules(layers, config), **common_args) + for layers in layers_list + ] + else: + return AtomGraphSequential(patch_modules(layers, config), **common_args) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 3dd949272dba255d4cb94e44d5c594209a25eaf5..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/activation.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/activation.cpython-310.pyc deleted file mode 100644 index 3816e640355365a2bc3e9b7a80744e72183e3e35..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/activation.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/convolution.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/convolution.cpython-310.pyc deleted file mode 100644 index 56d12b0ab1a2b560af5ce37e9dd76a50a1e3e737..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/convolution.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/cue_helper.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/cue_helper.cpython-310.pyc deleted file mode 100644 index ebbfa3761ce71d34b3b6e186c42b690ae59a8b1f..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/cue_helper.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/edge_embedding.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/edge_embedding.cpython-310.pyc deleted file mode 100644 index eaa819c1e45ffd4518bb1f4ec92d934b9275f635..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/edge_embedding.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/equivariant_gate.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/equivariant_gate.cpython-310.pyc deleted file mode 100644 index 0f0811028f417140947cdfebfcdbf04cb2bfce24..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/equivariant_gate.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/force_output.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/force_output.cpython-310.pyc deleted file mode 100644 index c4652dd5a76086849b92fefb3d68c50e0c24cd3e..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/force_output.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/interaction_blocks.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/interaction_blocks.cpython-310.pyc deleted file mode 100644 index 5bc82ade64600ebf995b460a7ca00dad1dad9e8d..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/interaction_blocks.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/linear.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/linear.cpython-310.pyc deleted file mode 100644 index 27a6d2b66650f1eb6e5efbd9ed98481804a38805..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/linear.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/node_embedding.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/node_embedding.cpython-310.pyc deleted file mode 100644 index 2bb5637981cade9a2f9539464a53a0ca09e4b23d..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/node_embedding.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/scale.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/scale.cpython-310.pyc deleted file mode 100644 index f8329937e1cb613c90e43560555f44c3fe4ee1a1..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/scale.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/self_connection.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/self_connection.cpython-310.pyc deleted file mode 100644 index ab5488554fbfe1592434028c4d13a3be1e395635..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/self_connection.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/sequential.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/sequential.cpython-310.pyc deleted file mode 100644 index bba755543c76b1bb1846a1f558fdc2850a48b64a..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/sequential.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/util.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/util.cpython-310.pyc deleted file mode 100644 index b1cf45e398ed9ed6fdd64dfe1593c33611c3b866..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/util.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/activation.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/activation.py index ca4d4d84819a40134c13496bde04823a652b832d..ac86df922c5b4b1dc388169aa4db152ffc59d2fc 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/activation.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/activation.py @@ -1,8 +1,8 @@ -import math - -import torch - - -@torch.jit.script -def ShiftedSoftPlus(x: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.softplus(x) - math.log(2.0) +import math + +import torch + + +@torch.jit.script +def ShiftedSoftPlus(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.softplus(x) - math.log(2.0) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/convolution.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/convolution.py index e62d6ad04c8bcfac4db6b9486113d84020e972fa..d5f1cd7bb79c98ea88d633eb3db52a5712b2dba7 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/convolution.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/convolution.py @@ -1,141 +1,141 @@ -from typing import List - -import torch -import torch.nn as nn -from e3nn.nn import FullyConnectedNet -from e3nn.o3 import Irreps, TensorProduct -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - -from .activation import ShiftedSoftPlus -from .util import broadcast - - -def message_gather( - node_features: torch.Tensor, - edge_dst: torch.Tensor, - message: torch.Tensor -): - index = broadcast(edge_dst, message, 0) - out_shape = [len(node_features)] + list(message.shape[1:]) - out = torch.zeros( - out_shape, - dtype=node_features.dtype, - device=node_features.device - ) - out.scatter_reduce_(0, index, message, reduce='sum') - return out - - -@compile_mode('script') -class IrrepsConvolution(nn.Module): - """ - convolution of (fig 2.b), comm. in LAMMPS - """ - - def __init__( - self, - irreps_x: Irreps, - irreps_filter: Irreps, - irreps_out: Irreps, - weight_layer_input_to_hidden: List[int], - weight_layer_act=ShiftedSoftPlus, - denominator: float = 1.0, - train_denominator: bool = False, - data_key_x: str = KEY.NODE_FEATURE, - data_key_filter: str = KEY.EDGE_ATTR, - data_key_weight_input: str = KEY.EDGE_EMBEDDING, - data_key_edge_idx: str = KEY.EDGE_IDX, - lazy_layer_instantiate: bool = True, - is_parallel: bool = False, - ): - super().__init__() - self.denominator = nn.Parameter( - torch.FloatTensor([denominator]), requires_grad=train_denominator - ) - self.key_x = data_key_x - self.key_filter = data_key_filter - self.key_weight_input = data_key_weight_input - self.key_edge_idx = data_key_edge_idx - self.is_parallel = is_parallel - - instructions = [] - irreps_mid = [] - weight_numel = 0 - for i, (mul_x, ir_x) in enumerate(irreps_x): - for j, (_, ir_filter) in enumerate(irreps_filter): - for ir_out in ir_x * ir_filter: - if ir_out in irreps_out: # here we drop l > lmax - k = len(irreps_mid) - weight_numel += mul_x * 1 # path shape - irreps_mid.append((mul_x, ir_out)) - instructions.append((i, j, k, 'uvu', True)) - - irreps_mid = Irreps(irreps_mid) - irreps_mid, p, _ = irreps_mid.sort() # type: ignore - instructions = [ - (i_in1, i_in2, p[i_out], mode, train) - for i_in1, i_in2, i_out, mode, train in instructions - ] - - # From v0.11.x, to compatible with cuEquivariance - self._instructions_before_sort = instructions - instructions = sorted(instructions, key=lambda x: x[2]) - - self.convolution_kwargs = dict( - irreps_in1=irreps_x, - irreps_in2=irreps_filter, - irreps_out=irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - ) - - self.weight_nn_kwargs = dict( - hs=weight_layer_input_to_hidden + [weight_numel], - act=weight_layer_act - ) - - self.convolution = None - self.weight_nn = None - self.layer_instantiated = False - self.convolution_cls = TensorProduct - self.weight_nn_cls = FullyConnectedNet - - if not lazy_layer_instantiate: - self.instantiate() - - self._comm_size = irreps_x.dim # used in parallel - - def instantiate(self): - if self.convolution is not None: - raise ValueError('Convolution layer already exists') - if self.weight_nn is not None: - raise ValueError('Weight_nn layer already exists') - - self.convolution = self.convolution_cls(**self.convolution_kwargs) - self.weight_nn = self.weight_nn_cls(**self.weight_nn_kwargs) - self.layer_instantiated = True - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - assert self.convolution is not None, 'Convolution is not instantiated' - assert self.weight_nn is not None, 'Weight_nn is not instantiated' - weight = self.weight_nn(data[self.key_weight_input]) - x = data[self.key_x] - if self.is_parallel: - x = torch.cat([x, data[KEY.NODE_FEATURE_GHOST]]) - - # note that 1 -> src 0 -> dst - edge_src = data[self.key_edge_idx][1] - edge_dst = data[self.key_edge_idx][0] - - message = self.convolution(x[edge_src], data[self.key_filter], weight) - - x = message_gather(x, edge_dst, message) - x = x.div(self.denominator) - if self.is_parallel: - x = torch.tensor_split(x, data[KEY.NLOCAL])[0] - data[self.key_x] = x - return data +from typing import List + +import torch +import torch.nn as nn +from e3nn.nn import FullyConnectedNet +from e3nn.o3 import Irreps, TensorProduct +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + +from .activation import ShiftedSoftPlus +from .util import broadcast + + +def message_gather( + node_features: torch.Tensor, + edge_dst: torch.Tensor, + message: torch.Tensor +): + index = broadcast(edge_dst, message, 0) + out_shape = [len(node_features)] + list(message.shape[1:]) + out = torch.zeros( + out_shape, + dtype=node_features.dtype, + device=node_features.device + ) + out.scatter_reduce_(0, index, message, reduce='sum') + return out + + +@compile_mode('script') +class IrrepsConvolution(nn.Module): + """ + convolution of (fig 2.b), comm. in LAMMPS + """ + + def __init__( + self, + irreps_x: Irreps, + irreps_filter: Irreps, + irreps_out: Irreps, + weight_layer_input_to_hidden: List[int], + weight_layer_act=ShiftedSoftPlus, + denominator: float = 1.0, + train_denominator: bool = False, + data_key_x: str = KEY.NODE_FEATURE, + data_key_filter: str = KEY.EDGE_ATTR, + data_key_weight_input: str = KEY.EDGE_EMBEDDING, + data_key_edge_idx: str = KEY.EDGE_IDX, + lazy_layer_instantiate: bool = True, + is_parallel: bool = False, + ): + super().__init__() + self.denominator = nn.Parameter( + torch.FloatTensor([denominator]), requires_grad=train_denominator + ) + self.key_x = data_key_x + self.key_filter = data_key_filter + self.key_weight_input = data_key_weight_input + self.key_edge_idx = data_key_edge_idx + self.is_parallel = is_parallel + + instructions = [] + irreps_mid = [] + weight_numel = 0 + for i, (mul_x, ir_x) in enumerate(irreps_x): + for j, (_, ir_filter) in enumerate(irreps_filter): + for ir_out in ir_x * ir_filter: + if ir_out in irreps_out: # here we drop l > lmax + k = len(irreps_mid) + weight_numel += mul_x * 1 # path shape + irreps_mid.append((mul_x, ir_out)) + instructions.append((i, j, k, 'uvu', True)) + + irreps_mid = Irreps(irreps_mid) + irreps_mid, p, _ = irreps_mid.sort() # type: ignore + instructions = [ + (i_in1, i_in2, p[i_out], mode, train) + for i_in1, i_in2, i_out, mode, train in instructions + ] + + # From v0.11.x, to compatible with cuEquivariance + self._instructions_before_sort = instructions + instructions = sorted(instructions, key=lambda x: x[2]) + + self.convolution_kwargs = dict( + irreps_in1=irreps_x, + irreps_in2=irreps_filter, + irreps_out=irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + self.weight_nn_kwargs = dict( + hs=weight_layer_input_to_hidden + [weight_numel], + act=weight_layer_act + ) + + self.convolution = None + self.weight_nn = None + self.layer_instantiated = False + self.convolution_cls = TensorProduct + self.weight_nn_cls = FullyConnectedNet + + if not lazy_layer_instantiate: + self.instantiate() + + self._comm_size = irreps_x.dim # used in parallel + + def instantiate(self): + if self.convolution is not None: + raise ValueError('Convolution layer already exists') + if self.weight_nn is not None: + raise ValueError('Weight_nn layer already exists') + + self.convolution = self.convolution_cls(**self.convolution_kwargs) + self.weight_nn = self.weight_nn_cls(**self.weight_nn_kwargs) + self.layer_instantiated = True + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + assert self.convolution is not None, 'Convolution is not instantiated' + assert self.weight_nn is not None, 'Weight_nn is not instantiated' + weight = self.weight_nn(data[self.key_weight_input]) + x = data[self.key_x] + if self.is_parallel: + x = torch.cat([x, data[KEY.NODE_FEATURE_GHOST]]) + + # note that 1 -> src 0 -> dst + edge_src = data[self.key_edge_idx][1] + edge_dst = data[self.key_edge_idx][0] + + message = self.convolution(x[edge_src], data[self.key_filter], weight) + + x = message_gather(x, edge_dst, message) + x = x.div(self.denominator) + if self.is_parallel: + x = torch.tensor_split(x, data[KEY.NLOCAL])[0] + data[self.key_x] = x + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/cue_helper.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/cue_helper.py index 1d0d0d6e875ea5fae1f5d04eb8000bbadde7be89..c798f40e943bb0033e12db1b4d1b6b44e7c84626 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/cue_helper.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/cue_helper.py @@ -1,189 +1,189 @@ -import itertools -import warnings -from typing import Iterator, Literal, Union - -import e3nn.o3 as o3 -import numpy as np - -from .convolution import IrrepsConvolution -from .linear import IrrepsLinear -from .self_connection import SelfConnectionIntro, SelfConnectionLinearIntro - -try: - import cuequivariance as cue - import cuequivariance_torch as cuet - - _CUE_AVAILABLE = True - - # Obatained from MACE - class O3_e3nn(cue.O3): - def __mul__( # type: ignore - rep1: 'O3_e3nn', rep2: 'O3_e3nn' - ) -> Iterator['O3_e3nn']: - return [ # type: ignore - O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2) - ] - - @classmethod - def clebsch_gordan( # type: ignore - cls, rep1: 'O3_e3nn', rep2: 'O3_e3nn', rep3: 'O3_e3nn' - ) -> np.ndarray: - rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) - - if rep1.p * rep2.p == rep3.p: - return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( - rep3.dim - ) - return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) - - def __lt__( # type: ignore - rep1: 'O3_e3nn', rep2: 'O3_e3nn' - ) -> bool: - rep2 = rep1._from(rep2) # type: ignore - return (rep1.l, rep1.p) < (rep2.l, rep2.p) - - @classmethod - def iterator(cls) -> Iterator['O3_e3nn']: - for l in itertools.count(0): - yield O3_e3nn(l=l, p=1 * (-1) ** l) - yield O3_e3nn(l=l, p=-1 * (-1) ** l) - -except ImportError: - _CUE_AVAILABLE = False - - -def is_cue_available(): - return _CUE_AVAILABLE - - -def cue_needed(func): - def wrapper(*args, **kwargs): - if is_cue_available(): - return func(*args, **kwargs) - else: - raise ImportError('cue is not available') - - return wrapper - - -def _check_may_not_compatible(orig_kwargs, defaults): - for k, v in defaults.items(): - v_given = orig_kwargs.pop(k, v) - if v_given != v: - warnings.warn(f'{k}: {v} is ignored to use cuEquivariance') - - -def is_cue_cuda_available_model(config): - if config.get('use_bias_in_linear', False): - warnings.warn('Bias in linear can not be used with cueq, fallback to e3nn') - return False - else: - return True - - -@cue_needed -def as_cue_irreps(irreps: o3.Irreps, group: Literal['SO3', 'O3']): - """Convert e3nn irreps to given group's cue irreps""" - if group == 'SO3': - assert all(irrep.ir.p == 1 for irrep in irreps) - return cue.Irreps('SO3', str(irreps).replace('e', '')) # type: ignore - elif group == 'O3': - return cue.Irreps(O3_e3nn, str(irreps)) # type: ignore - else: - raise ValueError(f'Unknown group: {group}') - - -@cue_needed -def patch_linear( - module: Union[IrrepsLinear, SelfConnectionLinearIntro], - group: Literal['SO3', 'O3'], - **cue_kwargs, -): - assert not module.layer_instantiated - - module.irreps_in = as_cue_irreps(module.irreps_in, group) # type: ignore - module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore - - orig_kwargs = module.linear_kwargs - - may_not_compatible_default = dict( - f_in=None, - f_out=None, - instructions=None, - biases=False, - path_normalization='element', - _optimize_einsums=None, - ) - # pop may_not_compatible_defaults - _check_may_not_compatible(orig_kwargs, may_not_compatible_default) - - module.linear_cls = cuet.Linear # type: ignore - orig_kwargs.update(**cue_kwargs) - return module - - -@cue_needed -def patch_convolution( - module: IrrepsConvolution, - group: Literal['SO3', 'O3'], - **cue_kwargs, -): - assert not module.layer_instantiated - - # conv_kwargs will be patched in place - conv_kwargs = module.convolution_kwargs - conv_kwargs.update( - dict( - irreps_in1=as_cue_irreps(conv_kwargs.get('irreps_in1'), group), - irreps_in2=as_cue_irreps(conv_kwargs.get('irreps_in2'), group), - filter_irreps_out=as_cue_irreps(conv_kwargs.pop('irreps_out'), group), - ) - ) - - inst_orig = conv_kwargs.pop('instructions') - inst_sorted = sorted(inst_orig, key=lambda x: x[2]) - assert all([a == b for a, b in zip(inst_orig, inst_sorted)]) - - may_not_compatible_default = dict( - in1_var=None, - in2_var=None, - out_var=None, - irrep_normalization=False, - path_normalization='element', - compile_left_right=True, - compile_right=False, - _specialized_code=None, - _optimize_einsums=None, - ) - # pop may_not_compatible_defaults - _check_may_not_compatible(conv_kwargs, may_not_compatible_default) - - module.convolution_cls = cuet.ChannelWiseTensorProduct # type: ignore - conv_kwargs.update(**cue_kwargs) - return module - - -@cue_needed -def patch_fully_connected( - module: SelfConnectionIntro, - group: Literal['SO3', 'O3'], - **cue_kwargs, -): - assert not module.layer_instantiated - - module.irreps_in1 = as_cue_irreps(module.irreps_in1, group) # type: ignore - module.irreps_in2 = as_cue_irreps(module.irreps_in2, group) # type: ignore - module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore - - may_not_compatible_default = dict( - irrep_normalization=None, - path_normalization=None, - ) - # pop may_not_compatible_defaults - _check_may_not_compatible( - module.fc_tensor_product_kwargs, may_not_compatible_default - ) - - module.fc_tensor_product_cls = cuet.FullyConnectedTensorProduct # type: ignore - module.fc_tensor_product_kwargs.update(**cue_kwargs) - return module +import itertools +import warnings +from typing import Iterator, Literal, Union + +import e3nn.o3 as o3 +import numpy as np + +from .convolution import IrrepsConvolution +from .linear import IrrepsLinear +from .self_connection import SelfConnectionIntro, SelfConnectionLinearIntro + +try: + import cuequivariance as cue + import cuequivariance_torch as cuet + + _CUE_AVAILABLE = True + + # Obatained from MACE + class O3_e3nn(cue.O3): + def __mul__( # type: ignore + rep1: 'O3_e3nn', rep2: 'O3_e3nn' + ) -> Iterator['O3_e3nn']: + return [ # type: ignore + O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2) + ] + + @classmethod + def clebsch_gordan( # type: ignore + cls, rep1: 'O3_e3nn', rep2: 'O3_e3nn', rep3: 'O3_e3nn' + ) -> np.ndarray: + rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) + + if rep1.p * rep2.p == rep3.p: + return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( + rep3.dim + ) + return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) + + def __lt__( # type: ignore + rep1: 'O3_e3nn', rep2: 'O3_e3nn' + ) -> bool: + rep2 = rep1._from(rep2) # type: ignore + return (rep1.l, rep1.p) < (rep2.l, rep2.p) + + @classmethod + def iterator(cls) -> Iterator['O3_e3nn']: + for l in itertools.count(0): + yield O3_e3nn(l=l, p=1 * (-1) ** l) + yield O3_e3nn(l=l, p=-1 * (-1) ** l) + +except ImportError: + _CUE_AVAILABLE = False + + +def is_cue_available(): + return _CUE_AVAILABLE + + +def cue_needed(func): + def wrapper(*args, **kwargs): + if is_cue_available(): + return func(*args, **kwargs) + else: + raise ImportError('cue is not available') + + return wrapper + + +def _check_may_not_compatible(orig_kwargs, defaults): + for k, v in defaults.items(): + v_given = orig_kwargs.pop(k, v) + if v_given != v: + warnings.warn(f'{k}: {v} is ignored to use cuEquivariance') + + +def is_cue_cuda_available_model(config): + if config.get('use_bias_in_linear', False): + warnings.warn('Bias in linear can not be used with cueq, fallback to e3nn') + return False + else: + return True + + +@cue_needed +def as_cue_irreps(irreps: o3.Irreps, group: Literal['SO3', 'O3']): + """Convert e3nn irreps to given group's cue irreps""" + if group == 'SO3': + assert all(irrep.ir.p == 1 for irrep in irreps) + return cue.Irreps('SO3', str(irreps).replace('e', '')) # type: ignore + elif group == 'O3': + return cue.Irreps(O3_e3nn, str(irreps)) # type: ignore + else: + raise ValueError(f'Unknown group: {group}') + + +@cue_needed +def patch_linear( + module: Union[IrrepsLinear, SelfConnectionLinearIntro], + group: Literal['SO3', 'O3'], + **cue_kwargs, +): + assert not module.layer_instantiated + + module.irreps_in = as_cue_irreps(module.irreps_in, group) # type: ignore + module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore + + orig_kwargs = module.linear_kwargs + + may_not_compatible_default = dict( + f_in=None, + f_out=None, + instructions=None, + biases=False, + path_normalization='element', + _optimize_einsums=None, + ) + # pop may_not_compatible_defaults + _check_may_not_compatible(orig_kwargs, may_not_compatible_default) + + module.linear_cls = cuet.Linear # type: ignore + orig_kwargs.update(**cue_kwargs) + return module + + +@cue_needed +def patch_convolution( + module: IrrepsConvolution, + group: Literal['SO3', 'O3'], + **cue_kwargs, +): + assert not module.layer_instantiated + + # conv_kwargs will be patched in place + conv_kwargs = module.convolution_kwargs + conv_kwargs.update( + dict( + irreps_in1=as_cue_irreps(conv_kwargs.get('irreps_in1'), group), + irreps_in2=as_cue_irreps(conv_kwargs.get('irreps_in2'), group), + filter_irreps_out=as_cue_irreps(conv_kwargs.pop('irreps_out'), group), + ) + ) + + inst_orig = conv_kwargs.pop('instructions') + inst_sorted = sorted(inst_orig, key=lambda x: x[2]) + assert all([a == b for a, b in zip(inst_orig, inst_sorted)]) + + may_not_compatible_default = dict( + in1_var=None, + in2_var=None, + out_var=None, + irrep_normalization=False, + path_normalization='element', + compile_left_right=True, + compile_right=False, + _specialized_code=None, + _optimize_einsums=None, + ) + # pop may_not_compatible_defaults + _check_may_not_compatible(conv_kwargs, may_not_compatible_default) + + module.convolution_cls = cuet.ChannelWiseTensorProduct # type: ignore + conv_kwargs.update(**cue_kwargs) + return module + + +@cue_needed +def patch_fully_connected( + module: SelfConnectionIntro, + group: Literal['SO3', 'O3'], + **cue_kwargs, +): + assert not module.layer_instantiated + + module.irreps_in1 = as_cue_irreps(module.irreps_in1, group) # type: ignore + module.irreps_in2 = as_cue_irreps(module.irreps_in2, group) # type: ignore + module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore + + may_not_compatible_default = dict( + irrep_normalization=None, + path_normalization=None, + ) + # pop may_not_compatible_defaults + _check_may_not_compatible( + module.fc_tensor_product_kwargs, may_not_compatible_default + ) + + module.fc_tensor_product_cls = cuet.FullyConnectedTensorProduct # type: ignore + module.fc_tensor_product_kwargs.update(**cue_kwargs) + return module diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/edge_embedding.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/edge_embedding.py index 8c5333920b0f3c589fe427aaf6de85dd96db149f..e738ef5375fc2546ccc426aad01d91b206b03a6d 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/edge_embedding.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/edge_embedding.py @@ -1,217 +1,217 @@ -import math - -import torch -import torch.nn as nn -from e3nn.o3 import Irreps, SphericalHarmonics -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -@compile_mode('script') -class EdgePreprocess(nn.Module): - """ - preprocessing pos to edge vectors and edge lengths - currently used in sevenn/scripts/deploy for lammps serial model - """ - - def __init__(self, is_stress: bool): - super().__init__() - # controlled by 'AtomGraphSequential' - self.is_stress = is_stress - self._is_batch_data = True - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - if self._is_batch_data: - cell = data[KEY.CELL].view(-1, 3, 3) - else: - cell = data[KEY.CELL].view(3, 3) - cell_shift = data[KEY.CELL_SHIFT] - pos = data[KEY.POS] - - batch = data[KEY.BATCH] # for deploy, must be defined first - if self.is_stress: - if self._is_batch_data: - num_batch = int(batch.max().cpu().item()) + 1 - strain = torch.zeros( - (num_batch, 3, 3), - dtype=pos.dtype, - device=pos.device, - ) - strain.requires_grad_(True) - data['_strain'] = strain - - sym_strain = 0.5 * (strain + strain.transpose(-1, -2)) - pos = pos + torch.bmm( - pos.unsqueeze(-2), sym_strain[batch] - ).squeeze(-2) - cell = cell + torch.bmm(cell, sym_strain) - else: - strain = torch.zeros( - (3, 3), - dtype=pos.dtype, - device=pos.device, - ) - strain.requires_grad_(True) - data['_strain'] = strain - - sym_strain = 0.5 * (strain + strain.transpose(-1, -2)) - pos = pos + torch.mm(pos, sym_strain) - cell = cell + torch.mm(cell, sym_strain) - - idx_src = data[KEY.EDGE_IDX][0] - idx_dst = data[KEY.EDGE_IDX][1] - - edge_vec = pos[idx_dst] - pos[idx_src] - - if self._is_batch_data: - edge_vec = edge_vec + torch.einsum( - 'ni,nij->nj', cell_shift, cell[batch[idx_src]] - ) - else: - edge_vec = edge_vec + torch.einsum( - 'ni,ij->nj', cell_shift, cell.squeeze(0) - ) - data[KEY.EDGE_VEC] = edge_vec - data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1) - return data - - -class BesselBasis(nn.Module): - """ - f : (*, 1) -> (*, bessel_basis_num) - """ - - def __init__( - self, - cutoff_length: float, - bessel_basis_num: int = 8, - trainable_coeff: bool = True, - ): - super().__init__() - self.num_basis = bessel_basis_num - self.prefactor = 2.0 / cutoff_length - self.coeffs = torch.FloatTensor([ - n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1) - ]) - if trainable_coeff: - self.coeffs = nn.Parameter(self.coeffs) - - def forward(self, r: torch.Tensor) -> torch.Tensor: - ur = r.unsqueeze(-1) # to fit dimension - return self.prefactor * torch.sin(self.coeffs * ur) / ur - - -class PolynomialCutoff(nn.Module): - """ - f : (*, 1) -> (*, 1) - https://arxiv.org/pdf/2003.03123.pdf - """ - - def __init__( - self, - cutoff_length: float, - poly_cut_p_value: int = 6, - ): - super().__init__() - p = poly_cut_p_value - self.cutoff_length = cutoff_length - self.p = p - self.coeff_p0 = (p + 1.0) * (p + 2.0) / 2.0 - self.coeff_p1 = p * (p + 2.0) - self.coeff_p2 = p * (p + 1.0) / 2.0 - - def forward(self, r: torch.Tensor) -> torch.Tensor: - r = r / self.cutoff_length - return ( - 1 - - self.coeff_p0 * torch.pow(r, self.p) - + self.coeff_p1 * torch.pow(r, self.p + 1.0) - - self.coeff_p2 * torch.pow(r, self.p + 2.0) - ) - - -class XPLORCutoff(nn.Module): - """ - https://hoomd-blue.readthedocs.io/en/latest/module-md-pair.html - """ - - def __init__( - self, - cutoff_length: float, - cutoff_on: float, - ): - super().__init__() - self.r_on = cutoff_on - self.r_cut = cutoff_length - assert self.r_on < self.r_cut - - def forward(self, r: torch.Tensor) -> torch.Tensor: - r_sq = r * r - r_on_sq = self.r_on * self.r_on - r_cut_sq = self.r_cut * self.r_cut - return torch.where( - r < self.r_on, - 1.0, - (r_cut_sq - r_sq) ** 2 - * (r_cut_sq + 2 * r_sq - 3 * r_on_sq) - / (r_cut_sq - r_on_sq) ** 3, - ) - - -@compile_mode('script') -class SphericalEncoding(nn.Module): - def __init__( - self, - lmax: int, - parity: int = -1, - normalization: str = 'component', - normalize: bool = True, - ): - super().__init__() - self.lmax = lmax - self.normalization = normalization - self.irreps_in = Irreps('1x1o') if parity == -1 else Irreps('1x1e') - self.irreps_out = Irreps.spherical_harmonics(lmax, parity) - self.sph = SphericalHarmonics( - self.irreps_out, - normalize=normalize, - normalization=normalization, - irreps_in=self.irreps_in, - ) - - def forward(self, r: torch.Tensor) -> torch.Tensor: - return self.sph(r) - - -@compile_mode('script') -class EdgeEmbedding(nn.Module): - """ - embedding layer of |r| by - RadialBasis(|r|)*CutOff(|r|) - f : (N_edge) -> (N_edge, basis_num) - """ - - def __init__( - self, - basis_module: nn.Module, - cutoff_module: nn.Module, - spherical_module: nn.Module, - ): - super().__init__() - self.basis_function = basis_module - self.cutoff_function = cutoff_module - self.spherical = spherical_module - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - rvec = data[KEY.EDGE_VEC] - r = torch.linalg.norm(data[KEY.EDGE_VEC], dim=-1) - data[KEY.EDGE_LENGTH] = r - - data[KEY.EDGE_EMBEDDING] = self.basis_function( - r - ) * self.cutoff_function(r).unsqueeze(-1) - data[KEY.EDGE_ATTR] = self.spherical(rvec) - - return data +import math + +import torch +import torch.nn as nn +from e3nn.o3 import Irreps, SphericalHarmonics +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +@compile_mode('script') +class EdgePreprocess(nn.Module): + """ + preprocessing pos to edge vectors and edge lengths + currently used in sevenn/scripts/deploy for lammps serial model + """ + + def __init__(self, is_stress: bool): + super().__init__() + # controlled by 'AtomGraphSequential' + self.is_stress = is_stress + self._is_batch_data = True + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + if self._is_batch_data: + cell = data[KEY.CELL].view(-1, 3, 3) + else: + cell = data[KEY.CELL].view(3, 3) + cell_shift = data[KEY.CELL_SHIFT] + pos = data[KEY.POS] + + batch = data[KEY.BATCH] # for deploy, must be defined first + if self.is_stress: + if self._is_batch_data: + num_batch = int(batch.max().cpu().item()) + 1 + strain = torch.zeros( + (num_batch, 3, 3), + dtype=pos.dtype, + device=pos.device, + ) + strain.requires_grad_(True) + data['_strain'] = strain + + sym_strain = 0.5 * (strain + strain.transpose(-1, -2)) + pos = pos + torch.bmm( + pos.unsqueeze(-2), sym_strain[batch] + ).squeeze(-2) + cell = cell + torch.bmm(cell, sym_strain) + else: + strain = torch.zeros( + (3, 3), + dtype=pos.dtype, + device=pos.device, + ) + strain.requires_grad_(True) + data['_strain'] = strain + + sym_strain = 0.5 * (strain + strain.transpose(-1, -2)) + pos = pos + torch.mm(pos, sym_strain) + cell = cell + torch.mm(cell, sym_strain) + + idx_src = data[KEY.EDGE_IDX][0] + idx_dst = data[KEY.EDGE_IDX][1] + + edge_vec = pos[idx_dst] - pos[idx_src] + + if self._is_batch_data: + edge_vec = edge_vec + torch.einsum( + 'ni,nij->nj', cell_shift, cell[batch[idx_src]] + ) + else: + edge_vec = edge_vec + torch.einsum( + 'ni,ij->nj', cell_shift, cell.squeeze(0) + ) + data[KEY.EDGE_VEC] = edge_vec + data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1) + return data + + +class BesselBasis(nn.Module): + """ + f : (*, 1) -> (*, bessel_basis_num) + """ + + def __init__( + self, + cutoff_length: float, + bessel_basis_num: int = 8, + trainable_coeff: bool = True, + ): + super().__init__() + self.num_basis = bessel_basis_num + self.prefactor = 2.0 / cutoff_length + self.coeffs = torch.FloatTensor([ + n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1) + ]) + if trainable_coeff: + self.coeffs = nn.Parameter(self.coeffs) + + def forward(self, r: torch.Tensor) -> torch.Tensor: + ur = r.unsqueeze(-1) # to fit dimension + return self.prefactor * torch.sin(self.coeffs * ur) / ur + + +class PolynomialCutoff(nn.Module): + """ + f : (*, 1) -> (*, 1) + https://arxiv.org/pdf/2003.03123.pdf + """ + + def __init__( + self, + cutoff_length: float, + poly_cut_p_value: int = 6, + ): + super().__init__() + p = poly_cut_p_value + self.cutoff_length = cutoff_length + self.p = p + self.coeff_p0 = (p + 1.0) * (p + 2.0) / 2.0 + self.coeff_p1 = p * (p + 2.0) + self.coeff_p2 = p * (p + 1.0) / 2.0 + + def forward(self, r: torch.Tensor) -> torch.Tensor: + r = r / self.cutoff_length + return ( + 1 + - self.coeff_p0 * torch.pow(r, self.p) + + self.coeff_p1 * torch.pow(r, self.p + 1.0) + - self.coeff_p2 * torch.pow(r, self.p + 2.0) + ) + + +class XPLORCutoff(nn.Module): + """ + https://hoomd-blue.readthedocs.io/en/latest/module-md-pair.html + """ + + def __init__( + self, + cutoff_length: float, + cutoff_on: float, + ): + super().__init__() + self.r_on = cutoff_on + self.r_cut = cutoff_length + assert self.r_on < self.r_cut + + def forward(self, r: torch.Tensor) -> torch.Tensor: + r_sq = r * r + r_on_sq = self.r_on * self.r_on + r_cut_sq = self.r_cut * self.r_cut + return torch.where( + r < self.r_on, + 1.0, + (r_cut_sq - r_sq) ** 2 + * (r_cut_sq + 2 * r_sq - 3 * r_on_sq) + / (r_cut_sq - r_on_sq) ** 3, + ) + + +@compile_mode('script') +class SphericalEncoding(nn.Module): + def __init__( + self, + lmax: int, + parity: int = -1, + normalization: str = 'component', + normalize: bool = True, + ): + super().__init__() + self.lmax = lmax + self.normalization = normalization + self.irreps_in = Irreps('1x1o') if parity == -1 else Irreps('1x1e') + self.irreps_out = Irreps.spherical_harmonics(lmax, parity) + self.sph = SphericalHarmonics( + self.irreps_out, + normalize=normalize, + normalization=normalization, + irreps_in=self.irreps_in, + ) + + def forward(self, r: torch.Tensor) -> torch.Tensor: + return self.sph(r) + + +@compile_mode('script') +class EdgeEmbedding(nn.Module): + """ + embedding layer of |r| by + RadialBasis(|r|)*CutOff(|r|) + f : (N_edge) -> (N_edge, basis_num) + """ + + def __init__( + self, + basis_module: nn.Module, + cutoff_module: nn.Module, + spherical_module: nn.Module, + ): + super().__init__() + self.basis_function = basis_module + self.cutoff_function = cutoff_module + self.spherical = spherical_module + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + rvec = data[KEY.EDGE_VEC] + r = torch.linalg.norm(data[KEY.EDGE_VEC], dim=-1) + data[KEY.EDGE_LENGTH] = r + + data[KEY.EDGE_EMBEDDING] = self.basis_function( + r + ) * self.cutoff_function(r).unsqueeze(-1) + data[KEY.EDGE_ATTR] = self.spherical(rvec) + + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/equivariant_gate.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/equivariant_gate.py index 5c36fe20504ce5015505cf858232e542530f945c..842240557b2694270881275011b6f1551f02e0f4 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/equivariant_gate.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/equivariant_gate.py @@ -1,61 +1,61 @@ -from typing import Callable, Dict - -import torch.nn as nn -from e3nn.nn import Gate -from e3nn.o3 import Irreps -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -@compile_mode('script') -class EquivariantGate(nn.Module): - def __init__( - self, - irreps_x: Irreps, - act_scalar_dict: Dict[int, Callable], - act_gate_dict: Dict[int, Callable], - data_key_x: str = KEY.NODE_FEATURE, - ): - super().__init__() - self.key_x = data_key_x - - parity_mapper = {'e': 1, 'o': -1} - act_scalar_dict = { - parity_mapper[k]: v for k, v in act_scalar_dict.items() - } - act_gate_dict = {parity_mapper[k]: v for k, v in act_gate_dict.items()} - - irreps_gated_elem = [] - irreps_scalars_elem = [] - # non scalar irreps > gated / scalar irreps > scalars - for mul, irreps in irreps_x: - if irreps.l > 0: - irreps_gated_elem.append((mul, irreps)) - else: - irreps_scalars_elem.append((mul, irreps)) - irreps_scalars = Irreps(irreps_scalars_elem) - irreps_gated = Irreps(irreps_gated_elem) - - irreps_gates_parity = 1 if '0e' in irreps_scalars else -1 - irreps_gates = Irreps( - [(mul, (0, irreps_gates_parity)) for mul, _ in irreps_gated] - ) - - act_scalars = [act_scalar_dict[p] for _, (_, p) in irreps_scalars] - act_gates = [act_gate_dict[p] for _, (_, p) in irreps_gates] - - self.gate = Gate( - irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated - ) - - def get_gate_irreps_in(self): - """ - user must call this function to get proper irreps in for forward - """ - return self.gate.irreps_in - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - data[self.key_x] = self.gate(data[self.key_x]) - return data +from typing import Callable, Dict + +import torch.nn as nn +from e3nn.nn import Gate +from e3nn.o3 import Irreps +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +@compile_mode('script') +class EquivariantGate(nn.Module): + def __init__( + self, + irreps_x: Irreps, + act_scalar_dict: Dict[int, Callable], + act_gate_dict: Dict[int, Callable], + data_key_x: str = KEY.NODE_FEATURE, + ): + super().__init__() + self.key_x = data_key_x + + parity_mapper = {'e': 1, 'o': -1} + act_scalar_dict = { + parity_mapper[k]: v for k, v in act_scalar_dict.items() + } + act_gate_dict = {parity_mapper[k]: v for k, v in act_gate_dict.items()} + + irreps_gated_elem = [] + irreps_scalars_elem = [] + # non scalar irreps > gated / scalar irreps > scalars + for mul, irreps in irreps_x: + if irreps.l > 0: + irreps_gated_elem.append((mul, irreps)) + else: + irreps_scalars_elem.append((mul, irreps)) + irreps_scalars = Irreps(irreps_scalars_elem) + irreps_gated = Irreps(irreps_gated_elem) + + irreps_gates_parity = 1 if '0e' in irreps_scalars else -1 + irreps_gates = Irreps( + [(mul, (0, irreps_gates_parity)) for mul, _ in irreps_gated] + ) + + act_scalars = [act_scalar_dict[p] for _, (_, p) in irreps_scalars] + act_gates = [act_gate_dict[p] for _, (_, p) in irreps_gates] + + self.gate = Gate( + irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated + ) + + def get_gate_irreps_in(self): + """ + user must call this function to get proper irreps in for forward + """ + return self.gate.irreps_in + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + data[self.key_x] = self.gate(data[self.key_x]) + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/force_output.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/force_output.py index d1360c76016aa04ae2e9aea216378d4c64cd32d8..f6b90e588d8382059bdcadd46fb2aefa1880def8 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/force_output.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/force_output.py @@ -1,224 +1,224 @@ -import torch -import torch.nn as nn -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - -from .util import broadcast - - -@compile_mode('script') -class ForceOutput(nn.Module): - """ - works when pos.requires_grad_ is True - """ - - def __init__( - self, - data_key_pos: str = KEY.POS, - data_key_energy: str = KEY.PRED_TOTAL_ENERGY, - data_key_force: str = KEY.PRED_FORCE, - ): - super().__init__() - self.key_pos = data_key_pos - self.key_energy = data_key_energy - self.key_force = data_key_force - - def get_grad_key(self): - return self.key_pos - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - pos_tensor = [data[self.key_pos]] - energy = [(data[self.key_energy]).sum()] - - # `materialize_grads` not supported in low version of pytorch - # Also can not be deployed when using it. - # But not using it makes problem in - # force/stress inference in sparse systems - # TODO: use it only in sevennet_calculator? - grad = torch.autograd.grad( - energy, - pos_tensor, - create_graph=self.training, - allow_unused=True, - # materialize_grads=True, - )[0] - - # For torchscript - if grad is not None: - data[self.key_force] = torch.neg(grad) - return data - - -@compile_mode('script') -class ForceStressOutput(nn.Module): - """ - Compute stress and force from positions. - Used in serial torchscipt models - """ - def __init__( - self, - data_key_pos: str = KEY.POS, - data_key_energy: str = KEY.PRED_TOTAL_ENERGY, - data_key_force: str = KEY.PRED_FORCE, - data_key_stress: str = KEY.PRED_STRESS, - data_key_cell_volume: str = KEY.CELL_VOLUME, - ): - - super().__init__() - self.key_pos = data_key_pos - self.key_energy = data_key_energy - self.key_force = data_key_force - self.key_stress = data_key_stress - self.key_cell_volume = data_key_cell_volume - self._is_batch_data = True - - def get_grad_key(self): - return self.key_pos - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - pos_tensor = data[self.key_pos] - energy = [(data[self.key_energy]).sum()] - - # `materialize_grads` not supported in low version of pytorch - # Also can not be deployed when using it. - # But not using it makes problem in - # force/stress inference in sparse systems - # TODO: use it only in sevennet_calculator? - grad = torch.autograd.grad( - energy, - [pos_tensor, data['_strain']], - create_graph=self.training, - allow_unused=True, - # materialize_grads=True, - ) - - # make grad is not Optional[Tensor] - fgrad = grad[0] - if fgrad is not None: - data[self.key_force] = torch.neg(fgrad) - - sgrad = grad[1] - volume = data[self.key_cell_volume] - vlim = 1e-3 # for cell volume = 0 for non PBC structures - if self._is_batch_data: - volume[volume < vlim] = vlim - elif volume < vlim: - volume = torch.tensor(vlim) - - if sgrad is not None: - if self._is_batch_data: - stress = sgrad / volume.view(-1, 1, 1) - stress = torch.neg(stress) - virial_stress = torch.vstack(( - stress[:, 0, 0], - stress[:, 1, 1], - stress[:, 2, 2], - stress[:, 0, 1], - stress[:, 1, 2], - stress[:, 0, 2], - )) - data[self.key_stress] = virial_stress.transpose(0, 1) - else: - stress = sgrad / volume - stress = torch.neg(stress) - virial_stress = torch.stack(( - stress[0, 0], - stress[1, 1], - stress[2, 2], - stress[0, 1], - stress[1, 2], - stress[0, 2], - )) - data[self.key_stress] = virial_stress - - return data - - -@compile_mode('script') -class ForceStressOutputFromEdge(nn.Module): - """ - Compute stress and force from edge. - Used in parallel torchscipt models, and training - """ - def __init__( - self, - data_key_edge: str = KEY.EDGE_VEC, - data_key_edge_idx: str = KEY.EDGE_IDX, - data_key_energy: str = KEY.PRED_TOTAL_ENERGY, - data_key_force: str = KEY.PRED_FORCE, - data_key_stress: str = KEY.PRED_STRESS, - data_key_cell_volume: str = KEY.CELL_VOLUME, - ): - - super().__init__() - self.key_edge = data_key_edge - self.key_edge_idx = data_key_edge_idx - self.key_energy = data_key_energy - self.key_force = data_key_force - self.key_stress = data_key_stress - self.key_cell_volume = data_key_cell_volume - self._is_batch_data = True - - def get_grad_key(self): - return self.key_edge - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - tot_num = torch.sum(data[KEY.NUM_ATOMS]) # ? item? - rij = data[self.key_edge] - energy = [(data[self.key_energy]).sum()] - edge_idx = data[self.key_edge_idx] - - grad = torch.autograd.grad( - energy, - [rij], - create_graph=self.training, - allow_unused=True - ) - - # make grad is not Optional[Tensor] - fij = grad[0] - - if fij is not None: - # compute force - pf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device) - nf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device) - _edge_src = broadcast(edge_idx[0], fij, 0) - _edge_dst = broadcast(edge_idx[1], fij, 0) - pf.scatter_reduce_(0, _edge_src, fij, reduce='sum') - nf.scatter_reduce_(0, _edge_dst, fij, reduce='sum') - data[self.key_force] = pf - nf - - # compute virial - diag = rij * fij - s12 = rij[..., 0] * fij[..., 1] - s23 = rij[..., 1] * fij[..., 2] - s31 = rij[..., 2] * fij[..., 0] - # cat last dimension - _virial = torch.cat([ - diag, - s12.unsqueeze(-1), - s23.unsqueeze(-1), - s31.unsqueeze(-1) - ], dim=-1) - - _s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device) - _edge_dst6 = broadcast(edge_idx[1], _virial, 0) - _s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum') - - if self._is_batch_data: - batch = data[KEY.BATCH] # for deploy, must be defined first - nbatch = int(batch.max().cpu().item()) + 1 - sout = torch.zeros( - (nbatch, 6), dtype=_virial.dtype, device=_virial.device - ) - _batch = broadcast(batch, _s, 0) - sout.scatter_reduce_(0, _batch, _s, reduce='sum') - else: - sout = torch.sum(_s, dim=0) - - data[self.key_stress] =\ - torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1) - - return data +import torch +import torch.nn as nn +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + +from .util import broadcast + + +@compile_mode('script') +class ForceOutput(nn.Module): + """ + works when pos.requires_grad_ is True + """ + + def __init__( + self, + data_key_pos: str = KEY.POS, + data_key_energy: str = KEY.PRED_TOTAL_ENERGY, + data_key_force: str = KEY.PRED_FORCE, + ): + super().__init__() + self.key_pos = data_key_pos + self.key_energy = data_key_energy + self.key_force = data_key_force + + def get_grad_key(self): + return self.key_pos + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + pos_tensor = [data[self.key_pos]] + energy = [(data[self.key_energy]).sum()] + + # `materialize_grads` not supported in low version of pytorch + # Also can not be deployed when using it. + # But not using it makes problem in + # force/stress inference in sparse systems + # TODO: use it only in sevennet_calculator? + grad = torch.autograd.grad( + energy, + pos_tensor, + create_graph=self.training, + allow_unused=True, + # materialize_grads=True, + )[0] + + # For torchscript + if grad is not None: + data[self.key_force] = torch.neg(grad) + return data + + +@compile_mode('script') +class ForceStressOutput(nn.Module): + """ + Compute stress and force from positions. + Used in serial torchscipt models + """ + def __init__( + self, + data_key_pos: str = KEY.POS, + data_key_energy: str = KEY.PRED_TOTAL_ENERGY, + data_key_force: str = KEY.PRED_FORCE, + data_key_stress: str = KEY.PRED_STRESS, + data_key_cell_volume: str = KEY.CELL_VOLUME, + ): + + super().__init__() + self.key_pos = data_key_pos + self.key_energy = data_key_energy + self.key_force = data_key_force + self.key_stress = data_key_stress + self.key_cell_volume = data_key_cell_volume + self._is_batch_data = True + + def get_grad_key(self): + return self.key_pos + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + pos_tensor = data[self.key_pos] + energy = [(data[self.key_energy]).sum()] + + # `materialize_grads` not supported in low version of pytorch + # Also can not be deployed when using it. + # But not using it makes problem in + # force/stress inference in sparse systems + # TODO: use it only in sevennet_calculator? + grad = torch.autograd.grad( + energy, + [pos_tensor, data['_strain']], + create_graph=self.training, + allow_unused=True, + # materialize_grads=True, + ) + + # make grad is not Optional[Tensor] + fgrad = grad[0] + if fgrad is not None: + data[self.key_force] = torch.neg(fgrad) + + sgrad = grad[1] + volume = data[self.key_cell_volume] + vlim = 1e-3 # for cell volume = 0 for non PBC structures + if self._is_batch_data: + volume[volume < vlim] = vlim + elif volume < vlim: + volume = torch.tensor(vlim) + + if sgrad is not None: + if self._is_batch_data: + stress = sgrad / volume.view(-1, 1, 1) + stress = torch.neg(stress) + virial_stress = torch.vstack(( + stress[:, 0, 0], + stress[:, 1, 1], + stress[:, 2, 2], + stress[:, 0, 1], + stress[:, 1, 2], + stress[:, 0, 2], + )) + data[self.key_stress] = virial_stress.transpose(0, 1) + else: + stress = sgrad / volume + stress = torch.neg(stress) + virial_stress = torch.stack(( + stress[0, 0], + stress[1, 1], + stress[2, 2], + stress[0, 1], + stress[1, 2], + stress[0, 2], + )) + data[self.key_stress] = virial_stress + + return data + + +@compile_mode('script') +class ForceStressOutputFromEdge(nn.Module): + """ + Compute stress and force from edge. + Used in parallel torchscipt models, and training + """ + def __init__( + self, + data_key_edge: str = KEY.EDGE_VEC, + data_key_edge_idx: str = KEY.EDGE_IDX, + data_key_energy: str = KEY.PRED_TOTAL_ENERGY, + data_key_force: str = KEY.PRED_FORCE, + data_key_stress: str = KEY.PRED_STRESS, + data_key_cell_volume: str = KEY.CELL_VOLUME, + ): + + super().__init__() + self.key_edge = data_key_edge + self.key_edge_idx = data_key_edge_idx + self.key_energy = data_key_energy + self.key_force = data_key_force + self.key_stress = data_key_stress + self.key_cell_volume = data_key_cell_volume + self._is_batch_data = True + + def get_grad_key(self): + return self.key_edge + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + tot_num = torch.sum(data[KEY.NUM_ATOMS]) # ? item? + rij = data[self.key_edge] + energy = [(data[self.key_energy]).sum()] + edge_idx = data[self.key_edge_idx] + + grad = torch.autograd.grad( + energy, + [rij], + create_graph=self.training, + allow_unused=True + ) + + # make grad is not Optional[Tensor] + fij = grad[0] + + if fij is not None: + # compute force + pf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device) + nf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device) + _edge_src = broadcast(edge_idx[0], fij, 0) + _edge_dst = broadcast(edge_idx[1], fij, 0) + pf.scatter_reduce_(0, _edge_src, fij, reduce='sum') + nf.scatter_reduce_(0, _edge_dst, fij, reduce='sum') + data[self.key_force] = pf - nf + + # compute virial + diag = rij * fij + s12 = rij[..., 0] * fij[..., 1] + s23 = rij[..., 1] * fij[..., 2] + s31 = rij[..., 2] * fij[..., 0] + # cat last dimension + _virial = torch.cat([ + diag, + s12.unsqueeze(-1), + s23.unsqueeze(-1), + s31.unsqueeze(-1) + ], dim=-1) + + _s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device) + _edge_dst6 = broadcast(edge_idx[1], _virial, 0) + _s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum') + + if self._is_batch_data: + batch = data[KEY.BATCH] # for deploy, must be defined first + nbatch = int(batch.max().cpu().item()) + 1 + sout = torch.zeros( + (nbatch, 6), dtype=_virial.dtype, device=_virial.device + ) + _batch = broadcast(batch, _s, 0) + sout.scatter_reduce_(0, _batch, _s, reduce='sum') + else: + sout = torch.sum(_s, dim=0) + + data[self.key_stress] =\ + torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1) + + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/interaction_blocks.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/interaction_blocks.py index 24caa8293013e9217098d09630b8634cd9a59d96..3f93768f913c447f062d6de3a31384730fd715df 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/interaction_blocks.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/interaction_blocks.py @@ -1,76 +1,76 @@ -from typing import Callable, List, Tuple - -from e3nn.o3 import Irreps - -import sevenn._keys as KEY - -from .convolution import IrrepsConvolution -from .equivariant_gate import EquivariantGate -from .linear import IrrepsLinear - - -def NequIP_interaction_block( - irreps_x: Irreps, - irreps_filter: Irreps, - irreps_out_tp: Irreps, - irreps_out: Irreps, - weight_nn_layers: List[int], - conv_denominator: float, - train_conv_denominator: bool, - self_connection_pair: Tuple[Callable, Callable], - act_scalar: Callable, - act_gate: Callable, - act_radial: Callable, - bias_in_linear: bool, - num_species: int, - t: int, # interaction layer index - data_key_x: str = KEY.NODE_FEATURE, - data_key_weight_input: str = KEY.EDGE_EMBEDDING, - parallel: bool = False, - **conv_kwargs, -): - block = {} - irreps_node_attr = Irreps(f'{num_species}x0e') - sc_intro, sc_outro = self_connection_pair - - gate_layer = EquivariantGate(irreps_out, act_scalar, act_gate) - irreps_for_gate_in = gate_layer.get_gate_irreps_in() - - block[f'{t}_self_connection_intro'] = sc_intro( - irreps_x, - irreps_operand=irreps_node_attr, - irreps_out=irreps_for_gate_in, - ) - - block[f'{t}_self_interaction_1'] = IrrepsLinear( - irreps_x, irreps_x, - data_key_in=data_key_x, - biases=bias_in_linear, - ) - - # convolution part, l>lmax is dropped as defined in irreps_out - block[f'{t}_convolution'] = IrrepsConvolution( - irreps_x=irreps_x, - irreps_filter=irreps_filter, - irreps_out=irreps_out_tp, - data_key_weight_input=data_key_weight_input, - weight_layer_input_to_hidden=weight_nn_layers, - weight_layer_act=act_radial, - denominator=conv_denominator, - train_denominator=train_conv_denominator, - is_parallel=parallel, - **conv_kwargs, - ) - - # irreps of x increase to gate_irreps_in - block[f'{t}_self_interaction_2'] = IrrepsLinear( - irreps_out_tp, - irreps_for_gate_in, - data_key_in=data_key_x, - biases=bias_in_linear, - ) - - block[f'{t}_self_connection_outro'] = sc_outro() - block[f'{t}_equivariant_gate'] = gate_layer - - return block +from typing import Callable, List, Tuple + +from e3nn.o3 import Irreps + +import sevenn._keys as KEY + +from .convolution import IrrepsConvolution +from .equivariant_gate import EquivariantGate +from .linear import IrrepsLinear + + +def NequIP_interaction_block( + irreps_x: Irreps, + irreps_filter: Irreps, + irreps_out_tp: Irreps, + irreps_out: Irreps, + weight_nn_layers: List[int], + conv_denominator: float, + train_conv_denominator: bool, + self_connection_pair: Tuple[Callable, Callable], + act_scalar: Callable, + act_gate: Callable, + act_radial: Callable, + bias_in_linear: bool, + num_species: int, + t: int, # interaction layer index + data_key_x: str = KEY.NODE_FEATURE, + data_key_weight_input: str = KEY.EDGE_EMBEDDING, + parallel: bool = False, + **conv_kwargs, +): + block = {} + irreps_node_attr = Irreps(f'{num_species}x0e') + sc_intro, sc_outro = self_connection_pair + + gate_layer = EquivariantGate(irreps_out, act_scalar, act_gate) + irreps_for_gate_in = gate_layer.get_gate_irreps_in() + + block[f'{t}_self_connection_intro'] = sc_intro( + irreps_x, + irreps_operand=irreps_node_attr, + irreps_out=irreps_for_gate_in, + ) + + block[f'{t}_self_interaction_1'] = IrrepsLinear( + irreps_x, irreps_x, + data_key_in=data_key_x, + biases=bias_in_linear, + ) + + # convolution part, l>lmax is dropped as defined in irreps_out + block[f'{t}_convolution'] = IrrepsConvolution( + irreps_x=irreps_x, + irreps_filter=irreps_filter, + irreps_out=irreps_out_tp, + data_key_weight_input=data_key_weight_input, + weight_layer_input_to_hidden=weight_nn_layers, + weight_layer_act=act_radial, + denominator=conv_denominator, + train_denominator=train_conv_denominator, + is_parallel=parallel, + **conv_kwargs, + ) + + # irreps of x increase to gate_irreps_in + block[f'{t}_self_interaction_2'] = IrrepsLinear( + irreps_out_tp, + irreps_for_gate_in, + data_key_in=data_key_x, + biases=bias_in_linear, + ) + + block[f'{t}_self_connection_outro'] = sc_outro() + block[f'{t}_equivariant_gate'] = gate_layer + + return block diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/linear.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/linear.py index 43faa297c0355f6d43a9dd97dcb52e00ebc4529e..b5b87d21e699051b5cdf92cfe1d1b044705d532b 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/linear.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/linear.py @@ -1,180 +1,180 @@ -from typing import Callable, List, Optional - -import torch -import torch.nn as nn -from e3nn.nn import FullyConnectedNet -from e3nn.o3 import Irreps, Linear -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -@compile_mode('script') -class IrrepsLinear(nn.Module): - """ - wrapper class of e3nn Linear to operate on AtomGraphData - """ - - def __init__( - self, - irreps_in: Irreps, - irreps_out: Irreps, - data_key_in: str, - data_key_out: Optional[str] = None, - data_key_modal_attr: str = KEY.MODAL_ATTR, - num_modalities: int = 0, - lazy_layer_instantiate: bool = True, - **linear_kwargs, - ): - super().__init__() - self.key_input = data_key_in - if data_key_out is None: - self.key_output = data_key_in - else: - self.key_output = data_key_out - self.key_modal_attr = data_key_modal_attr - - self._irreps_in_wo_modal = irreps_in - self.irreps_in = irreps_in - self.irreps_out = irreps_out - self.linear_kwargs = linear_kwargs - - self.linear = None - self.layer_instantiated = False - self.num_modalities = num_modalities - self._is_batch_data = True - - # use getter setter - self.linear_cls = Linear - - if num_modalities > 1: # in case of multi-modal - self.set_num_modalities(num_modalities) - - if not lazy_layer_instantiate: - self.instantiate() - - def instantiate(self): - if self.linear is not None: - raise ValueError('Linear layer already exists') - self.linear = self.linear_cls( - self.irreps_in, self.irreps_out, **self.linear_kwargs - ) - self.layer_instantiated = True - - def set_num_modalities(self, num_modalities): - if self.layer_instantiated: - raise ValueError('Layer already instantiated, can not change modalities') - irreps_in = self._irreps_in_wo_modal + Irreps(f'{num_modalities}x0e') - self.num_modalities = num_modalities - self.irreps_in = irreps_in - - def _patch_modal_to_data(self, data: AtomGraphDataType) -> AtomGraphDataType: - if self._is_batch_data: - batch = data[KEY.BATCH] - batch_modality_onehot = data[self.key_modal_attr].reshape( - -1, self.num_modalities - ) - batch_modality_onehot = batch_modality_onehot.type( - data[self.key_input].dtype - ) - data[self.key_input] = torch.cat( - [data[self.key_input], batch_modality_onehot[batch]], dim=1 - ) - else: - modality_onehot = data[self.key_modal_attr].expand( - len(data[self.key_input]), -1 - ) - modality_onehot = modality_onehot.type(data[self.key_input].dtype) - data[self.key_input] = torch.cat( - [data[self.key_input], modality_onehot], dim=1 - ) - return data - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - assert self.linear is not None, 'Layer is not instantiated' - if self.num_modalities > 1: - data = self._patch_modal_to_data(data) - - data[self.key_output] = self.linear(data[self.key_input]) - return data - - -@compile_mode('script') -class AtomReduce(nn.Module): - """ - atomic energy -> total energy - constant is multiplied to data - """ - - def __init__( - self, - data_key_in: str, - data_key_out: str, - reduce: str = 'sum', - constant: float = 1.0, - ): - super().__init__() - - self.key_input = data_key_in - self.key_output = data_key_out - self.constant = constant - self.reduce = reduce - - # controlled by the upper most wrapper 'AtomGraphSequential' - self._is_batch_data = True - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - if self._is_batch_data: - src = data[self.key_input].squeeze(1) - size = int(data[KEY.BATCH].max()) + 1 - output = torch.zeros( - (size), - dtype=src.dtype, - device=src.device, - ) - output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum') - data[self.key_output] = output * self.constant - else: - data[self.key_output] = torch.sum(data[self.key_input]) * self.constant - - return data - - -@compile_mode('script') -class FCN_e3nn(nn.Module): - """ - wrapper class of e3nn FullyConnectedNet - """ - - def __init__( - self, - irreps_in: Irreps, # confirm it is scalar & input size - dim_out: int, - hidden_neurons: List[int], - activation: Callable, - data_key_in: str, - data_key_out: Optional[str] = None, - **e3nn_kwargs, - ): - super().__init__() - self.key_input = data_key_in - self.irreps_in = irreps_in - if data_key_out is None: - self.key_output = data_key_in - else: - self.key_output = data_key_out - - for _, irrep in irreps_in: - assert irrep.is_scalar() - inp_dim = irreps_in.dim - - self.fcn = FullyConnectedNet( - [inp_dim] + hidden_neurons + [dim_out], - activation, - **e3nn_kwargs, - ) - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - data[self.key_output] = self.fcn(data[self.key_input]) - return data +from typing import Callable, List, Optional + +import torch +import torch.nn as nn +from e3nn.nn import FullyConnectedNet +from e3nn.o3 import Irreps, Linear +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +@compile_mode('script') +class IrrepsLinear(nn.Module): + """ + wrapper class of e3nn Linear to operate on AtomGraphData + """ + + def __init__( + self, + irreps_in: Irreps, + irreps_out: Irreps, + data_key_in: str, + data_key_out: Optional[str] = None, + data_key_modal_attr: str = KEY.MODAL_ATTR, + num_modalities: int = 0, + lazy_layer_instantiate: bool = True, + **linear_kwargs, + ): + super().__init__() + self.key_input = data_key_in + if data_key_out is None: + self.key_output = data_key_in + else: + self.key_output = data_key_out + self.key_modal_attr = data_key_modal_attr + + self._irreps_in_wo_modal = irreps_in + self.irreps_in = irreps_in + self.irreps_out = irreps_out + self.linear_kwargs = linear_kwargs + + self.linear = None + self.layer_instantiated = False + self.num_modalities = num_modalities + self._is_batch_data = True + + # use getter setter + self.linear_cls = Linear + + if num_modalities > 1: # in case of multi-modal + self.set_num_modalities(num_modalities) + + if not lazy_layer_instantiate: + self.instantiate() + + def instantiate(self): + if self.linear is not None: + raise ValueError('Linear layer already exists') + self.linear = self.linear_cls( + self.irreps_in, self.irreps_out, **self.linear_kwargs + ) + self.layer_instantiated = True + + def set_num_modalities(self, num_modalities): + if self.layer_instantiated: + raise ValueError('Layer already instantiated, can not change modalities') + irreps_in = self._irreps_in_wo_modal + Irreps(f'{num_modalities}x0e') + self.num_modalities = num_modalities + self.irreps_in = irreps_in + + def _patch_modal_to_data(self, data: AtomGraphDataType) -> AtomGraphDataType: + if self._is_batch_data: + batch = data[KEY.BATCH] + batch_modality_onehot = data[self.key_modal_attr].reshape( + -1, self.num_modalities + ) + batch_modality_onehot = batch_modality_onehot.type( + data[self.key_input].dtype + ) + data[self.key_input] = torch.cat( + [data[self.key_input], batch_modality_onehot[batch]], dim=1 + ) + else: + modality_onehot = data[self.key_modal_attr].expand( + len(data[self.key_input]), -1 + ) + modality_onehot = modality_onehot.type(data[self.key_input].dtype) + data[self.key_input] = torch.cat( + [data[self.key_input], modality_onehot], dim=1 + ) + return data + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + assert self.linear is not None, 'Layer is not instantiated' + if self.num_modalities > 1: + data = self._patch_modal_to_data(data) + + data[self.key_output] = self.linear(data[self.key_input]) + return data + + +@compile_mode('script') +class AtomReduce(nn.Module): + """ + atomic energy -> total energy + constant is multiplied to data + """ + + def __init__( + self, + data_key_in: str, + data_key_out: str, + reduce: str = 'sum', + constant: float = 1.0, + ): + super().__init__() + + self.key_input = data_key_in + self.key_output = data_key_out + self.constant = constant + self.reduce = reduce + + # controlled by the upper most wrapper 'AtomGraphSequential' + self._is_batch_data = True + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + if self._is_batch_data: + src = data[self.key_input].squeeze(1) + size = int(data[KEY.BATCH].max()) + 1 + output = torch.zeros( + (size), + dtype=src.dtype, + device=src.device, + ) + output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum') + data[self.key_output] = output * self.constant + else: + data[self.key_output] = torch.sum(data[self.key_input]) * self.constant + + return data + + +@compile_mode('script') +class FCN_e3nn(nn.Module): + """ + wrapper class of e3nn FullyConnectedNet + """ + + def __init__( + self, + irreps_in: Irreps, # confirm it is scalar & input size + dim_out: int, + hidden_neurons: List[int], + activation: Callable, + data_key_in: str, + data_key_out: Optional[str] = None, + **e3nn_kwargs, + ): + super().__init__() + self.key_input = data_key_in + self.irreps_in = irreps_in + if data_key_out is None: + self.key_output = data_key_in + else: + self.key_output = data_key_out + + for _, irrep in irreps_in: + assert irrep.is_scalar() + inp_dim = irreps_in.dim + + self.fcn = FullyConnectedNet( + [inp_dim] + hidden_neurons + [dim_out], + activation, + **e3nn_kwargs, + ) + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + data[self.key_output] = self.fcn(data[self.key_input]) + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/node_embedding.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/node_embedding.py index 747d0970d380e0ac013acf98f077d5ac0e7492e5..a5f272ae8a8f2edbd7814ad52d2b381ffbe7a9ce 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/node_embedding.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/node_embedding.py @@ -1,91 +1,91 @@ -from typing import Dict, List, Optional - -import torch -import torch.nn as nn -import torch.nn.functional -from ase.symbols import symbols2numbers -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -# TODO: put this to model_build and do not preprocess data by onehot -@compile_mode('script') -class OnehotEmbedding(nn.Module): - """ - x : tensor of shape (N, 1) - x_after : tensor of shape (N, num_classes) - It overwrite data_key_x - and saves input to data_key_save and output to data_key_additional - I know this is strange but it is for compatibility with previous version - and to specie wise shift scale work - ex) [0 1 1 0] -> [[1, 0] [0, 1] [0, 1] [1, 0]] (num_classes = 2) - """ - - def __init__( - self, - num_classes: int, - data_key_x: str = KEY.NODE_FEATURE, - data_key_out: Optional[str] = None, - data_key_save: Optional[str] = None, - data_key_additional: Optional[str] = None, # additional output - ): - super().__init__() - self.num_classes = num_classes - self.key_x = data_key_x - if data_key_out is None: - self.key_output = data_key_x - else: - self.key_output = data_key_out - self.key_save = data_key_save - self.key_additional_output = data_key_additional - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - inp = data[self.key_x] - embd = torch.nn.functional.one_hot(inp, self.num_classes) - embd = embd.float() - data[self.key_output] = embd - if self.key_additional_output is not None: - data[self.key_additional_output] = embd # for self-connection - if self.key_save is not None: - data[self.key_save] = inp # for elemwise shift scale - return data - - -def get_type_mapper_from_specie(specie_list: List[str]): - """ - from ['Hf', 'O'] - return {72: 0, 8: 1} - """ - specie_list = sorted(specie_list) - type_map = {} - unique_counter = 0 - for specie in specie_list: - atomic_num = symbols2numbers(specie)[0] - if atomic_num in type_map: - continue - type_map[atomic_num] = unique_counter - unique_counter += 1 - return type_map - - -# deprecated -def one_hot_atom_embedding( - atomic_numbers: List[int], type_map: Dict[int, int] -): - """ - atomic numbers from ase.get_atomic_numbers - type_map from get_type_mapper_from_specie() - """ - num_classes = len(type_map) - try: - type_numbers = torch.LongTensor( - [type_map[num] for num in atomic_numbers] - ) - except KeyError as e: - raise ValueError(f'Atomic number {e.args[0]} is not expected') - embd = torch.nn.functional.one_hot(type_numbers, num_classes) - embd = embd.to(torch.get_default_dtype()) - - return embd +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional +from ase.symbols import symbols2numbers +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +# TODO: put this to model_build and do not preprocess data by onehot +@compile_mode('script') +class OnehotEmbedding(nn.Module): + """ + x : tensor of shape (N, 1) + x_after : tensor of shape (N, num_classes) + It overwrite data_key_x + and saves input to data_key_save and output to data_key_additional + I know this is strange but it is for compatibility with previous version + and to specie wise shift scale work + ex) [0 1 1 0] -> [[1, 0] [0, 1] [0, 1] [1, 0]] (num_classes = 2) + """ + + def __init__( + self, + num_classes: int, + data_key_x: str = KEY.NODE_FEATURE, + data_key_out: Optional[str] = None, + data_key_save: Optional[str] = None, + data_key_additional: Optional[str] = None, # additional output + ): + super().__init__() + self.num_classes = num_classes + self.key_x = data_key_x + if data_key_out is None: + self.key_output = data_key_x + else: + self.key_output = data_key_out + self.key_save = data_key_save + self.key_additional_output = data_key_additional + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + inp = data[self.key_x] + embd = torch.nn.functional.one_hot(inp, self.num_classes) + embd = embd.float() + data[self.key_output] = embd + if self.key_additional_output is not None: + data[self.key_additional_output] = embd # for self-connection + if self.key_save is not None: + data[self.key_save] = inp # for elemwise shift scale + return data + + +def get_type_mapper_from_specie(specie_list: List[str]): + """ + from ['Hf', 'O'] + return {72: 0, 8: 1} + """ + specie_list = sorted(specie_list) + type_map = {} + unique_counter = 0 + for specie in specie_list: + atomic_num = symbols2numbers(specie)[0] + if atomic_num in type_map: + continue + type_map[atomic_num] = unique_counter + unique_counter += 1 + return type_map + + +# deprecated +def one_hot_atom_embedding( + atomic_numbers: List[int], type_map: Dict[int, int] +): + """ + atomic numbers from ase.get_atomic_numbers + type_map from get_type_mapper_from_specie() + """ + num_classes = len(type_map) + try: + type_numbers = torch.LongTensor( + [type_map[num] for num in atomic_numbers] + ) + except KeyError as e: + raise ValueError(f'Atomic number {e.args[0]} is not expected') + embd = torch.nn.functional.one_hot(type_numbers, num_classes) + embd = embd.to(torch.get_default_dtype()) + + return embd diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/scale.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/scale.py index f593770134c4baa35854caf013270b8f083278ac..73da83563002d86b8b12401dd7c0032a2d7fb264 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/scale.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/scale.py @@ -1,387 +1,387 @@ -from typing import Any, Dict, List, Optional, Union - -import torch -import torch.nn as nn -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType - - -def _as_univ( - ss: List[float], type_map: Dict[int, int], default: float -) -> List[float]: - assert len(ss) <= NUM_UNIV_ELEMENT, 'shift scale is too long' - return [ - ss[type_map[z]] if z in type_map else default - for z in range(NUM_UNIV_ELEMENT) - ] - - -@compile_mode('script') -class Rescale(nn.Module): - """ - Scaling and shifting energy (and automatically force and stress) - """ - - def __init__( - self, - shift: float, - scale: float, - data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, - data_key_out: str = KEY.ATOMIC_ENERGY, - train_shift_scale: bool = False, - **kwargs, - ): - assert isinstance(shift, float) and isinstance(scale, float) - super().__init__() - self.shift = nn.Parameter( - torch.FloatTensor([shift]), requires_grad=train_shift_scale - ) - self.scale = nn.Parameter( - torch.FloatTensor([scale]), requires_grad=train_shift_scale - ) - self.key_input = data_key_in - self.key_output = data_key_out - - def get_shift(self) -> float: - return self.shift.detach().cpu().tolist()[0] - - def get_scale(self) -> float: - return self.scale.detach().cpu().tolist()[0] - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - data[self.key_output] = data[self.key_input] * self.scale + self.shift - - return data - - -@compile_mode('script') -class SpeciesWiseRescale(nn.Module): - """ - Scaling and shifting energy (and automatically force and stress) - Use as it is if given list, expand to list if one of them is float - If two lists are given and length is not the same, raise error - """ - - def __init__( - self, - shift: Union[List[float], float], - scale: Union[List[float], float], - data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, - data_key_out: str = KEY.ATOMIC_ENERGY, - data_key_indices: str = KEY.ATOM_TYPE, - train_shift_scale: bool = False, - ): - super().__init__() - assert isinstance(shift, float) or isinstance(shift, list) - assert isinstance(scale, float) or isinstance(scale, list) - - if ( - isinstance(shift, list) - and isinstance(scale, list) - and len(shift) != len(scale) - ): - raise ValueError('List length should be same') - - if isinstance(shift, list): - num_species = len(shift) - elif isinstance(scale, list): - num_species = len(scale) - else: - raise ValueError('Both shift and scale is not a list') - - shift = [shift] * num_species if isinstance(shift, float) else shift - scale = [scale] * num_species if isinstance(scale, float) else scale - - self.shift = nn.Parameter( - torch.FloatTensor(shift), requires_grad=train_shift_scale - ) - self.scale = nn.Parameter( - torch.FloatTensor(scale), requires_grad=train_shift_scale - ) - self.key_input = data_key_in - self.key_output = data_key_out - self.key_indices = data_key_indices - - def get_shift(self, type_map: Optional[Dict[int, int]] = None) -> List[float]: - """ - Return shift in list of float. If type_map is given, return type_map reversed - shift, which index equals atomic_number. 0.0 is assigned for atomis not found - """ - shift = self.shift.detach().cpu().tolist() - if type_map: - shift = _as_univ(shift, type_map, 0.0) - return shift - - def get_scale(self, type_map: Optional[Dict[int, int]] = None) -> List[float]: - """ - Return scale in list of float. If type_map is given, return type_map reversed - scale, which index equals atomic_number. 1.0 is assigned for atomis not found - """ - scale = self.scale.detach().cpu().tolist() - if type_map: - scale = _as_univ(scale, type_map, 1.0) - return scale - - @staticmethod - def from_mappers( - shift: Union[float, List[float]], - scale: Union[float, List[float]], - type_map: Dict[int, int], - **kwargs, - ): - """ - Fit dimensions or mapping raw shift scale values to that is valid under - the given type_map: (atomic_numbers -> type_indices) - """ - shift_scale = [] - n_atom_types = len(type_map) - for s in (shift, scale): - if isinstance(s, list) and len(s) > n_atom_types: - if len(s) != NUM_UNIV_ELEMENT: - raise ValueError('given shift or scale is strange') - s = [s[z] for z in sorted(type_map, key=lambda x: type_map[x])] - # s = [s[z] for z in sorted(type_map, key=type_map.get)] - elif isinstance(s, float): - s = [s] * n_atom_types - elif isinstance(s, list) and len(s) == 1: - s = s * n_atom_types - shift_scale.append(s) - assert all([len(s) == n_atom_types for s in shift_scale]) - shift, scale = shift_scale - return SpeciesWiseRescale(shift, scale, **kwargs) - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - indices = data[self.key_indices] - data[self.key_output] = data[self.key_input] * self.scale[indices].view( - -1, 1 - ) + self.shift[indices].view(-1, 1) - - return data - - -@compile_mode('script') -class ModalWiseRescale(nn.Module): - """ - Scaling and shifting energy (and automatically force and stress) - Given shift or scale is either modal-wise and atom-wise or - not modal-wise but atom-wise. It is always interpreted as atom-wise. - """ - - def __init__( - self, - shift: List[List[float]], - scale: List[List[float]], - data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, - data_key_out: str = KEY.ATOMIC_ENERGY, - data_key_modal_indices: str = KEY.MODAL_TYPE, - data_key_atom_indices: str = KEY.ATOM_TYPE, - use_modal_wise_shift: bool = False, - use_modal_wise_scale: bool = False, - train_shift_scale: bool = False, - ): - super().__init__() - self.shift = nn.Parameter( - torch.FloatTensor(shift), requires_grad=train_shift_scale - ) - self.scale = nn.Parameter( - torch.FloatTensor(scale), requires_grad=train_shift_scale - ) - self.key_input = data_key_in - self.key_output = data_key_out - self.key_atom_indices = data_key_atom_indices - self.key_modal_indices = data_key_modal_indices - self.use_modal_wise_shift = use_modal_wise_shift - self.use_modal_wise_scale = use_modal_wise_scale - self._is_batch_data = True - - def get_shift( - self, - type_map: Optional[Dict[int, int]] = None, - modal_map: Optional[Dict[str, int]] = None, - ) -> Union[List[float], Dict[str, List[float]]]: - """ - Nothing is given: return as it is - type_map is given but not modal wise shift: return univ shift - both type_map and modal_map is given and modal wise shift: return fully - resolved modalwise univ shift - """ - shift = self.shift.detach().cpu().tolist() - if type_map and not self.use_modal_wise_shift: - shift = _as_univ(shift, type_map, 0.0) - elif self.use_modal_wise_shift and modal_map and type_map: - shift = [_as_univ(s, type_map, 0.0) for s in shift] - shift = {modal: shift[idx] for modal, idx in modal_map.items()} - - return shift - - def get_scale( - self, - type_map: Optional[Dict[int, int]] = None, - modal_map: Optional[Dict[str, int]] = None, - ) -> Union[List[float], Dict[str, List[float]]]: - """ - Nothing is given: return as it is - type_map is given but not modal wise scale: return univ scale - both type_map and modal_map is given and modal wise scale: return fully - resolved modalwise univ scale - """ - scale = self.scale.detach().cpu().tolist() - if type_map and not self.use_modal_wise_scale: - scale = _as_univ(scale, type_map, 0.0) - elif self.use_modal_wise_scale and modal_map and type_map: - scale = [_as_univ(s, type_map, 0.0) for s in scale] - scale = {modal: scale[idx] for modal, idx in modal_map.items()} - return scale - - @staticmethod - def from_mappers( - shift: Union[float, List[float], Dict[str, Any]], - scale: Union[float, List[float], Dict[str, Any]], - use_modal_wise_shift: bool, - use_modal_wise_scale: bool, - type_map: Dict[int, int], - modal_map: Dict[str, int], - **kwargs, - ): - """ - Fit dimensions or mapping raw shift scale values to that is valid under - the given type_map: (atomic_numbers -> type_indices) - If given List[float] and its length matches length of _const.NUM_UNIV_ELEMENT - , assume it is element-wise list - otherwise, it is modal-wise list - """ - - def solve_mapper(arr, map): - # value is attr index and never overlap, key is either 'z' or modal str - return [arr[z] for z in sorted(map, key=lambda x: map[x])] - - shift_scale = [] - n_atom_types = len(type_map) - n_modals = len(modal_map) - - for s, use_mw in ( - (shift, use_modal_wise_shift), - (scale, use_modal_wise_scale), - ): - # solve elemewise, or broadcast - if isinstance(s, float): - # given, modal-wise: no, elem-wise: no => broadcast - shape = (n_modals, n_atom_types) if use_mw else (n_atom_types,) - res = torch.full(shape, s).tolist() # TODO: w/o torch - elif isinstance(s, list) and len(s) == NUM_UNIV_ELEMENT: - # given, modal-wise: no, elem-wise: yes(univ) => solve elem map - s = solve_mapper(s, type_map) - res = [s] * n_modals if use_mw else s - elif ( # given, modal-wise: yes, elem-wise: no => broadcast to elemwise - isinstance(s, list) - and isinstance(s[0], float) - and len(s) == n_modals - and use_mw - ): - res = [[v] * n_atom_types for v in s] - elif ( # given, modal-wise: no, elem-wise: yes => as it is - isinstance(s, list) - and isinstance(s[0], float) - and len(s) == n_atom_types - and not use_mw - ): - res = s - elif ( # given, modal-wise: yes, elem-wise: yes => as it is - isinstance(s, list) - and isinstance(s[0], list) - and len(s) == n_modals - and len(s[0]) == n_atom_types - and use_mw - ): - res = s - elif isinstance(s, dict) and use_mw: - # solve modal dict, modal-wise: yes - s = solve_mapper(s, modal_map) - res = [] - for v in s: - if isinstance(v, list) and len(v) == NUM_UNIV_ELEMENT: - # elem-wise: yes(univ) => solve elem map - v = solve_mapper(v, type_map) - elif isinstance(v, float): - # elem-wise: no => broadcast to elemwise - v = [v] * n_atom_types - else: - raise ValueError(f'Invalid shift or scale {s}') - res.append(v) - else: - raise ValueError(f'Invalid shift or scale {s}') - - if use_mw: - assert ( - isinstance(res, list) - and isinstance(res[0], list) - and len(res) == n_modals - ) - assert all([len(r) == n_atom_types for r in res]) # type: ignore - else: - assert ( - isinstance(res, list) - and isinstance(res[0], float) - and len(res) == n_atom_types - ) - shift_scale.append(res) - shift, scale = shift_scale - - return ModalWiseRescale( - shift, - scale, - use_modal_wise_shift=use_modal_wise_shift, - use_modal_wise_scale=use_modal_wise_scale, - **kwargs, - ) - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - if self._is_batch_data: - batch = data[KEY.BATCH] - modal_indices = data[self.key_modal_indices][batch] - else: - modal_indices = data[self.key_modal_indices] - atom_indices = data[self.key_atom_indices] - shift = ( - self.shift[modal_indices, atom_indices] - if self.use_modal_wise_shift - else self.shift[atom_indices] - ) - scale = ( - self.scale[modal_indices, atom_indices] - if self.use_modal_wise_scale - else self.scale[atom_indices] - ) - data[self.key_output] = data[self.key_input] * scale.view( - -1, 1 - ) + shift.view(-1, 1) - - return data - - -def get_resolved_shift_scale( - module: Union[Rescale, SpeciesWiseRescale, ModalWiseRescale], - type_map: Optional[Dict[int, int]] = None, - modal_map: Optional[Dict[str, int]] = None, -): - """ - Return resolved shift and scale from scale modules. For element wise case, - convert to list of floats where idx is atomic number. For modal wise case, return - dictionary of shift scale where key is modal name given in modal_map - - Return: - Tuple of solved shift and scale - """ - - if isinstance(module, Rescale): - return (module.get_shift(), module.get_scale()) - elif isinstance(module, SpeciesWiseRescale): - return (module.get_shift(type_map), module.get_scale(type_map)) - elif isinstance(module, ModalWiseRescale): - return ( - module.get_shift(type_map, modal_map), - module.get_scale(type_map, modal_map), - ) - raise ValueError('Not scale module') +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType + + +def _as_univ( + ss: List[float], type_map: Dict[int, int], default: float +) -> List[float]: + assert len(ss) <= NUM_UNIV_ELEMENT, 'shift scale is too long' + return [ + ss[type_map[z]] if z in type_map else default + for z in range(NUM_UNIV_ELEMENT) + ] + + +@compile_mode('script') +class Rescale(nn.Module): + """ + Scaling and shifting energy (and automatically force and stress) + """ + + def __init__( + self, + shift: float, + scale: float, + data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, + data_key_out: str = KEY.ATOMIC_ENERGY, + train_shift_scale: bool = False, + **kwargs, + ): + assert isinstance(shift, float) and isinstance(scale, float) + super().__init__() + self.shift = nn.Parameter( + torch.FloatTensor([shift]), requires_grad=train_shift_scale + ) + self.scale = nn.Parameter( + torch.FloatTensor([scale]), requires_grad=train_shift_scale + ) + self.key_input = data_key_in + self.key_output = data_key_out + + def get_shift(self) -> float: + return self.shift.detach().cpu().tolist()[0] + + def get_scale(self) -> float: + return self.scale.detach().cpu().tolist()[0] + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + data[self.key_output] = data[self.key_input] * self.scale + self.shift + + return data + + +@compile_mode('script') +class SpeciesWiseRescale(nn.Module): + """ + Scaling and shifting energy (and automatically force and stress) + Use as it is if given list, expand to list if one of them is float + If two lists are given and length is not the same, raise error + """ + + def __init__( + self, + shift: Union[List[float], float], + scale: Union[List[float], float], + data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, + data_key_out: str = KEY.ATOMIC_ENERGY, + data_key_indices: str = KEY.ATOM_TYPE, + train_shift_scale: bool = False, + ): + super().__init__() + assert isinstance(shift, float) or isinstance(shift, list) + assert isinstance(scale, float) or isinstance(scale, list) + + if ( + isinstance(shift, list) + and isinstance(scale, list) + and len(shift) != len(scale) + ): + raise ValueError('List length should be same') + + if isinstance(shift, list): + num_species = len(shift) + elif isinstance(scale, list): + num_species = len(scale) + else: + raise ValueError('Both shift and scale is not a list') + + shift = [shift] * num_species if isinstance(shift, float) else shift + scale = [scale] * num_species if isinstance(scale, float) else scale + + self.shift = nn.Parameter( + torch.FloatTensor(shift), requires_grad=train_shift_scale + ) + self.scale = nn.Parameter( + torch.FloatTensor(scale), requires_grad=train_shift_scale + ) + self.key_input = data_key_in + self.key_output = data_key_out + self.key_indices = data_key_indices + + def get_shift(self, type_map: Optional[Dict[int, int]] = None) -> List[float]: + """ + Return shift in list of float. If type_map is given, return type_map reversed + shift, which index equals atomic_number. 0.0 is assigned for atomis not found + """ + shift = self.shift.detach().cpu().tolist() + if type_map: + shift = _as_univ(shift, type_map, 0.0) + return shift + + def get_scale(self, type_map: Optional[Dict[int, int]] = None) -> List[float]: + """ + Return scale in list of float. If type_map is given, return type_map reversed + scale, which index equals atomic_number. 1.0 is assigned for atomis not found + """ + scale = self.scale.detach().cpu().tolist() + if type_map: + scale = _as_univ(scale, type_map, 1.0) + return scale + + @staticmethod + def from_mappers( + shift: Union[float, List[float]], + scale: Union[float, List[float]], + type_map: Dict[int, int], + **kwargs, + ): + """ + Fit dimensions or mapping raw shift scale values to that is valid under + the given type_map: (atomic_numbers -> type_indices) + """ + shift_scale = [] + n_atom_types = len(type_map) + for s in (shift, scale): + if isinstance(s, list) and len(s) > n_atom_types: + if len(s) != NUM_UNIV_ELEMENT: + raise ValueError('given shift or scale is strange') + s = [s[z] for z in sorted(type_map, key=lambda x: type_map[x])] + # s = [s[z] for z in sorted(type_map, key=type_map.get)] + elif isinstance(s, float): + s = [s] * n_atom_types + elif isinstance(s, list) and len(s) == 1: + s = s * n_atom_types + shift_scale.append(s) + assert all([len(s) == n_atom_types for s in shift_scale]) + shift, scale = shift_scale + return SpeciesWiseRescale(shift, scale, **kwargs) + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + indices = data[self.key_indices] + data[self.key_output] = data[self.key_input] * self.scale[indices].view( + -1, 1 + ) + self.shift[indices].view(-1, 1) + + return data + + +@compile_mode('script') +class ModalWiseRescale(nn.Module): + """ + Scaling and shifting energy (and automatically force and stress) + Given shift or scale is either modal-wise and atom-wise or + not modal-wise but atom-wise. It is always interpreted as atom-wise. + """ + + def __init__( + self, + shift: List[List[float]], + scale: List[List[float]], + data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, + data_key_out: str = KEY.ATOMIC_ENERGY, + data_key_modal_indices: str = KEY.MODAL_TYPE, + data_key_atom_indices: str = KEY.ATOM_TYPE, + use_modal_wise_shift: bool = False, + use_modal_wise_scale: bool = False, + train_shift_scale: bool = False, + ): + super().__init__() + self.shift = nn.Parameter( + torch.FloatTensor(shift), requires_grad=train_shift_scale + ) + self.scale = nn.Parameter( + torch.FloatTensor(scale), requires_grad=train_shift_scale + ) + self.key_input = data_key_in + self.key_output = data_key_out + self.key_atom_indices = data_key_atom_indices + self.key_modal_indices = data_key_modal_indices + self.use_modal_wise_shift = use_modal_wise_shift + self.use_modal_wise_scale = use_modal_wise_scale + self._is_batch_data = True + + def get_shift( + self, + type_map: Optional[Dict[int, int]] = None, + modal_map: Optional[Dict[str, int]] = None, + ) -> Union[List[float], Dict[str, List[float]]]: + """ + Nothing is given: return as it is + type_map is given but not modal wise shift: return univ shift + both type_map and modal_map is given and modal wise shift: return fully + resolved modalwise univ shift + """ + shift = self.shift.detach().cpu().tolist() + if type_map and not self.use_modal_wise_shift: + shift = _as_univ(shift, type_map, 0.0) + elif self.use_modal_wise_shift and modal_map and type_map: + shift = [_as_univ(s, type_map, 0.0) for s in shift] + shift = {modal: shift[idx] for modal, idx in modal_map.items()} + + return shift + + def get_scale( + self, + type_map: Optional[Dict[int, int]] = None, + modal_map: Optional[Dict[str, int]] = None, + ) -> Union[List[float], Dict[str, List[float]]]: + """ + Nothing is given: return as it is + type_map is given but not modal wise scale: return univ scale + both type_map and modal_map is given and modal wise scale: return fully + resolved modalwise univ scale + """ + scale = self.scale.detach().cpu().tolist() + if type_map and not self.use_modal_wise_scale: + scale = _as_univ(scale, type_map, 0.0) + elif self.use_modal_wise_scale and modal_map and type_map: + scale = [_as_univ(s, type_map, 0.0) for s in scale] + scale = {modal: scale[idx] for modal, idx in modal_map.items()} + return scale + + @staticmethod + def from_mappers( + shift: Union[float, List[float], Dict[str, Any]], + scale: Union[float, List[float], Dict[str, Any]], + use_modal_wise_shift: bool, + use_modal_wise_scale: bool, + type_map: Dict[int, int], + modal_map: Dict[str, int], + **kwargs, + ): + """ + Fit dimensions or mapping raw shift scale values to that is valid under + the given type_map: (atomic_numbers -> type_indices) + If given List[float] and its length matches length of _const.NUM_UNIV_ELEMENT + , assume it is element-wise list + otherwise, it is modal-wise list + """ + + def solve_mapper(arr, map): + # value is attr index and never overlap, key is either 'z' or modal str + return [arr[z] for z in sorted(map, key=lambda x: map[x])] + + shift_scale = [] + n_atom_types = len(type_map) + n_modals = len(modal_map) + + for s, use_mw in ( + (shift, use_modal_wise_shift), + (scale, use_modal_wise_scale), + ): + # solve elemewise, or broadcast + if isinstance(s, float): + # given, modal-wise: no, elem-wise: no => broadcast + shape = (n_modals, n_atom_types) if use_mw else (n_atom_types,) + res = torch.full(shape, s).tolist() # TODO: w/o torch + elif isinstance(s, list) and len(s) == NUM_UNIV_ELEMENT: + # given, modal-wise: no, elem-wise: yes(univ) => solve elem map + s = solve_mapper(s, type_map) + res = [s] * n_modals if use_mw else s + elif ( # given, modal-wise: yes, elem-wise: no => broadcast to elemwise + isinstance(s, list) + and isinstance(s[0], float) + and len(s) == n_modals + and use_mw + ): + res = [[v] * n_atom_types for v in s] + elif ( # given, modal-wise: no, elem-wise: yes => as it is + isinstance(s, list) + and isinstance(s[0], float) + and len(s) == n_atom_types + and not use_mw + ): + res = s + elif ( # given, modal-wise: yes, elem-wise: yes => as it is + isinstance(s, list) + and isinstance(s[0], list) + and len(s) == n_modals + and len(s[0]) == n_atom_types + and use_mw + ): + res = s + elif isinstance(s, dict) and use_mw: + # solve modal dict, modal-wise: yes + s = solve_mapper(s, modal_map) + res = [] + for v in s: + if isinstance(v, list) and len(v) == NUM_UNIV_ELEMENT: + # elem-wise: yes(univ) => solve elem map + v = solve_mapper(v, type_map) + elif isinstance(v, float): + # elem-wise: no => broadcast to elemwise + v = [v] * n_atom_types + else: + raise ValueError(f'Invalid shift or scale {s}') + res.append(v) + else: + raise ValueError(f'Invalid shift or scale {s}') + + if use_mw: + assert ( + isinstance(res, list) + and isinstance(res[0], list) + and len(res) == n_modals + ) + assert all([len(r) == n_atom_types for r in res]) # type: ignore + else: + assert ( + isinstance(res, list) + and isinstance(res[0], float) + and len(res) == n_atom_types + ) + shift_scale.append(res) + shift, scale = shift_scale + + return ModalWiseRescale( + shift, + scale, + use_modal_wise_shift=use_modal_wise_shift, + use_modal_wise_scale=use_modal_wise_scale, + **kwargs, + ) + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + if self._is_batch_data: + batch = data[KEY.BATCH] + modal_indices = data[self.key_modal_indices][batch] + else: + modal_indices = data[self.key_modal_indices] + atom_indices = data[self.key_atom_indices] + shift = ( + self.shift[modal_indices, atom_indices] + if self.use_modal_wise_shift + else self.shift[atom_indices] + ) + scale = ( + self.scale[modal_indices, atom_indices] + if self.use_modal_wise_scale + else self.scale[atom_indices] + ) + data[self.key_output] = data[self.key_input] * scale.view( + -1, 1 + ) + shift.view(-1, 1) + + return data + + +def get_resolved_shift_scale( + module: Union[Rescale, SpeciesWiseRescale, ModalWiseRescale], + type_map: Optional[Dict[int, int]] = None, + modal_map: Optional[Dict[str, int]] = None, +): + """ + Return resolved shift and scale from scale modules. For element wise case, + convert to list of floats where idx is atomic number. For modal wise case, return + dictionary of shift scale where key is modal name given in modal_map + + Return: + Tuple of solved shift and scale + """ + + if isinstance(module, Rescale): + return (module.get_shift(), module.get_scale()) + elif isinstance(module, SpeciesWiseRescale): + return (module.get_shift(type_map), module.get_scale(type_map)) + elif isinstance(module, ModalWiseRescale): + return ( + module.get_shift(type_map, modal_map), + module.get_scale(type_map, modal_map), + ) + raise ValueError('Not scale module') diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/self_connection.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/self_connection.py index 40b55a2e6259bdce1587e9b83038de1d8af3c3c1..ce731b51494438d3d82206be129958f5ea912a96 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/self_connection.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/self_connection.py @@ -1,128 +1,128 @@ -import torch.nn as nn -from e3nn.o3 import FullyConnectedTensorProduct, Irreps, Linear -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -@compile_mode('script') -class SelfConnectionIntro(nn.Module): - """ - do TensorProduct of x and some data(here attribute of x) - and save it (to concatenate updated x at SelfConnectionOutro) - """ - - def __init__( - self, - irreps_in: Irreps, - irreps_operand: Irreps, - irreps_out: Irreps, - data_key_x: str = KEY.NODE_FEATURE, - data_key_operand: str = KEY.NODE_ATTR, - lazy_layer_instantiate: bool = True, - **kwargs, # for compatibility - ): - super().__init__() - - self.fc_tensor_product = FullyConnectedTensorProduct( - irreps_in, irreps_operand, irreps_out - ) - self.irreps_in1 = irreps_in - self.irreps_in2 = irreps_operand - self.irreps_out = irreps_out - - self.key_x = data_key_x - self.key_operand = data_key_operand - - self.fc_tensor_product = None - self.layer_instantiated = False - self.fc_tensor_product_cls = FullyConnectedTensorProduct - self.fc_tensor_product_kwargs = kwargs - - if not lazy_layer_instantiate: - self.instantiate() - - def instantiate(self): - if self.fc_tensor_product is not None: - raise ValueError('fc_tensor_product layer already exists') - self.fc_tensor_product = self.fc_tensor_product_cls( - self.irreps_in1, - self.irreps_in2, - self.irreps_out, - shared_weights=True, - internal_weights=None, # same as True - **self.fc_tensor_product_kwargs, - ) - self.layer_instantiated = True - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - assert self.fc_tensor_product is not None, 'Layer is not instantiated' - data[KEY.SELF_CONNECTION_TEMP] = self.fc_tensor_product( - data[self.key_x], data[self.key_operand] - ) - return data - - -@compile_mode('script') -class SelfConnectionLinearIntro(nn.Module): - """ - Linear style self connection update - """ - - def __init__( - self, - irreps_in: Irreps, - irreps_out: Irreps, - data_key_x: str = KEY.NODE_FEATURE, - lazy_layer_instantiate: bool = True, - **kwargs, - ): - super().__init__() - self.irreps_in = irreps_in - self.irreps_out = irreps_out - self.key_x = data_key_x - - self.linear = None - self.layer_instantiated = False - self.linear_cls = Linear - - # TODO: better to have SelfConnectionIntro super class - kwargs.pop('irreps_operand') - self.linear_kwargs = kwargs - - if not lazy_layer_instantiate: - self.instantiate() - - def instantiate(self): - if self.linear is not None: - raise ValueError('Linear layer already exists') - self.linear = self.linear_cls( - self.irreps_in, self.irreps_out, **self.linear_kwargs - ) - self.layer_instantiated = True - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - assert self.linear is not None, 'Layer is not instantiated' - data[KEY.SELF_CONNECTION_TEMP] = self.linear(data[self.key_x]) - return data - - -@compile_mode('script') -class SelfConnectionOutro(nn.Module): - """ - do TensorProduct of x and some data(here attribute of x) - and save it (to concatenate updated x at SelfConnectionOutro) - """ - - def __init__( - self, - data_key_x: str = KEY.NODE_FEATURE, - ): - super().__init__() - self.key_x = data_key_x - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - data[self.key_x] = data[self.key_x] + data[KEY.SELF_CONNECTION_TEMP] - del data[KEY.SELF_CONNECTION_TEMP] - return data +import torch.nn as nn +from e3nn.o3 import FullyConnectedTensorProduct, Irreps, Linear +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +@compile_mode('script') +class SelfConnectionIntro(nn.Module): + """ + do TensorProduct of x and some data(here attribute of x) + and save it (to concatenate updated x at SelfConnectionOutro) + """ + + def __init__( + self, + irreps_in: Irreps, + irreps_operand: Irreps, + irreps_out: Irreps, + data_key_x: str = KEY.NODE_FEATURE, + data_key_operand: str = KEY.NODE_ATTR, + lazy_layer_instantiate: bool = True, + **kwargs, # for compatibility + ): + super().__init__() + + self.fc_tensor_product = FullyConnectedTensorProduct( + irreps_in, irreps_operand, irreps_out + ) + self.irreps_in1 = irreps_in + self.irreps_in2 = irreps_operand + self.irreps_out = irreps_out + + self.key_x = data_key_x + self.key_operand = data_key_operand + + self.fc_tensor_product = None + self.layer_instantiated = False + self.fc_tensor_product_cls = FullyConnectedTensorProduct + self.fc_tensor_product_kwargs = kwargs + + if not lazy_layer_instantiate: + self.instantiate() + + def instantiate(self): + if self.fc_tensor_product is not None: + raise ValueError('fc_tensor_product layer already exists') + self.fc_tensor_product = self.fc_tensor_product_cls( + self.irreps_in1, + self.irreps_in2, + self.irreps_out, + shared_weights=True, + internal_weights=None, # same as True + **self.fc_tensor_product_kwargs, + ) + self.layer_instantiated = True + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + assert self.fc_tensor_product is not None, 'Layer is not instantiated' + data[KEY.SELF_CONNECTION_TEMP] = self.fc_tensor_product( + data[self.key_x], data[self.key_operand] + ) + return data + + +@compile_mode('script') +class SelfConnectionLinearIntro(nn.Module): + """ + Linear style self connection update + """ + + def __init__( + self, + irreps_in: Irreps, + irreps_out: Irreps, + data_key_x: str = KEY.NODE_FEATURE, + lazy_layer_instantiate: bool = True, + **kwargs, + ): + super().__init__() + self.irreps_in = irreps_in + self.irreps_out = irreps_out + self.key_x = data_key_x + + self.linear = None + self.layer_instantiated = False + self.linear_cls = Linear + + # TODO: better to have SelfConnectionIntro super class + kwargs.pop('irreps_operand') + self.linear_kwargs = kwargs + + if not lazy_layer_instantiate: + self.instantiate() + + def instantiate(self): + if self.linear is not None: + raise ValueError('Linear layer already exists') + self.linear = self.linear_cls( + self.irreps_in, self.irreps_out, **self.linear_kwargs + ) + self.layer_instantiated = True + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + assert self.linear is not None, 'Layer is not instantiated' + data[KEY.SELF_CONNECTION_TEMP] = self.linear(data[self.key_x]) + return data + + +@compile_mode('script') +class SelfConnectionOutro(nn.Module): + """ + do TensorProduct of x and some data(here attribute of x) + and save it (to concatenate updated x at SelfConnectionOutro) + """ + + def __init__( + self, + data_key_x: str = KEY.NODE_FEATURE, + ): + super().__init__() + self.key_x = data_key_x + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + data[self.key_x] = data[self.key_x] + data[KEY.SELF_CONNECTION_TEMP] + del data[KEY.SELF_CONNECTION_TEMP] + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/sequential.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/sequential.py index 1a91ae61117cbc5d9daa552e17f4d5fa67a45592..c300814b220112258136f5b1f722a3a7c89b630f 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/sequential.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/sequential.py @@ -1,183 +1,183 @@ -import warnings -from collections import OrderedDict -from typing import Dict, Optional - -import torch -import torch.nn as nn -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -def _instantiate_modules(modules): - # see IrrepsLinear of linear.py - for module in modules.values(): - if not getattr(module, 'layer_instantiated', True): - module.instantiate() - - -@compile_mode('script') -class _ModalInputPrepare(nn.Module): - - def __init__( - self, - modal_idx: int - ): - super().__init__() - self.modal_idx = modal_idx - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - data[KEY.MODAL_TYPE] = torch.tensor( - self.modal_idx, - dtype=torch.int64, - device=data['x'].device, - ) - return data - - -@compile_mode('script') -class AtomGraphSequential(nn.Sequential): - """ - Wrapper of SevenNet model - - Args: - modules: OrderedDict of nn.Modules - cutoff: not used internally, but makes sense to have - type_map: atomic_numbers => onehot index (see nn/node_embedding.py) - eval_type_map: perform index mapping using type_map defaults to True - data_key_atomic_numbers: used when eval_type_map is True - data_key_node_feature: used when eval_type_map is True - data_key_grad: if given, sets its requires grad True before pred - """ - - def __init__( - self, - modules: Dict[str, nn.Module], - cutoff: float = 0.0, - type_map: Optional[Dict[int, int]] = None, - modal_map: Optional[Dict[str, int]] = None, - eval_type_map: bool = True, - eval_modal_map: bool = False, - data_key_atomic_numbers: str = KEY.ATOMIC_NUMBERS, - data_key_node_feature: str = KEY.NODE_FEATURE, - data_key_grad: Optional[str] = None, - ): - if not isinstance(modules, OrderedDict): # backward compat - modules = OrderedDict(modules) - self.cutoff = cutoff - self.type_map = type_map - self.eval_type_map = eval_type_map - self.is_batch_data = True - - if cutoff == 0.0: - warnings.warn('cutoff is 0.0 or not given', UserWarning) - - if self.type_map is None: - warnings.warn('type_map is not given', UserWarning) - self.eval_type_map = False - else: - z_to_onehot_tensor = torch.neg(torch.ones(120, dtype=torch.long)) - for z, onehot in self.type_map.items(): - z_to_onehot_tensor[z] = onehot - self.z_to_onehot_tensor = z_to_onehot_tensor - - if eval_modal_map and modal_map is None: - raise ValueError('eval_modal_map is True but modal_map is None') - self.eval_modal_map = eval_modal_map - self.modal_map = modal_map - - self.key_atomic_numbers = data_key_atomic_numbers - self.key_node_feature = data_key_node_feature - self.key_grad = data_key_grad - - _instantiate_modules(modules) - super().__init__(modules) - if not isinstance(self._modules, OrderedDict): # backward compat - self._modules = OrderedDict(self._modules) - - def set_is_batch_data(self, flag: bool): - # whether given data is batched or not some module have to change - # its behavior. checking whether data is batched or not inside - # forward function make problem harder when make it into torchscript - for module in self: - try: # Easier to ask for forgiveness than permission. - module._is_batch_data = flag # type: ignore - except AttributeError: - pass - self.is_batch_data = flag - - def get_irreps_in(self, modlue_name: str, attr_key: str = 'irreps_in'): - tg_module = self._modules[modlue_name] - for m in tg_module.modules(): - try: - return repr(m.__getattribute__(attr_key)) - except AttributeError: - pass - return None - - def prepand_module(self, key: str, module: nn.Module): - self._modules.update({key: module}) - self._modules.move_to_end(key, last=False) # type: ignore - - def replace_module(self, key: str, module: nn.Module): - self._modules.update({key: module}) - - def delete_module_by_key(self, key: str): - if key in self._modules.keys(): - del self._modules[key] - - @torch.jit.unused - def _atomic_numbers_to_onehot(self, atomic_numbers: torch.Tensor): - assert atomic_numbers.dtype == torch.int64 - device = atomic_numbers.device - z_to_onehot_tensor = self.z_to_onehot_tensor.to(device) - return torch.index_select( - input=z_to_onehot_tensor, dim=0, index=atomic_numbers - ) - - @torch.jit.unused - def _eval_modal_map(self, data: AtomGraphDataType): - assert self.modal_map is not None - # modal_map: dict[str, int] - if not self.is_batch_data: - modal_idx = self.modal_map[data[KEY.DATA_MODALITY]] # type: ignore - else: - modal_idx = [ - self.modal_map[ii] # type: ignore - for ii in data[KEY.DATA_MODALITY] - ] - modal_idx = torch.tensor( - modal_idx, - dtype=torch.int64, - device=data.x.device, # type: ignore - ) - data[KEY.MODAL_TYPE] = modal_idx - - def _preprocess(self, data: AtomGraphDataType) -> AtomGraphDataType: - if self.eval_type_map: - atomic_numbers = data[self.key_atomic_numbers] - onehot = self._atomic_numbers_to_onehot(atomic_numbers) - data[self.key_node_feature] = onehot - - if self.eval_modal_map: - self._eval_modal_map(data) - - if self.key_grad is not None: - data[self.key_grad].requires_grad_(True) - - return data - - def prepare_modal_deploy(self, modal: str): - if self.modal_map is None: - return - self.eval_modal_map = False - self.set_is_batch_data(False) - modal_idx = self.modal_map[modal] # type: ignore - self.prepand_module('modal_input_prepare', _ModalInputPrepare(modal_idx)) - - def forward(self, input: AtomGraphDataType) -> AtomGraphDataType: - data = self._preprocess(input) - for module in self: - data = module(data) - return data +import warnings +from collections import OrderedDict +from typing import Dict, Optional + +import torch +import torch.nn as nn +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +def _instantiate_modules(modules): + # see IrrepsLinear of linear.py + for module in modules.values(): + if not getattr(module, 'layer_instantiated', True): + module.instantiate() + + +@compile_mode('script') +class _ModalInputPrepare(nn.Module): + + def __init__( + self, + modal_idx: int + ): + super().__init__() + self.modal_idx = modal_idx + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + data[KEY.MODAL_TYPE] = torch.tensor( + self.modal_idx, + dtype=torch.int64, + device=data['x'].device, + ) + return data + + +@compile_mode('script') +class AtomGraphSequential(nn.Sequential): + """ + Wrapper of SevenNet model + + Args: + modules: OrderedDict of nn.Modules + cutoff: not used internally, but makes sense to have + type_map: atomic_numbers => onehot index (see nn/node_embedding.py) + eval_type_map: perform index mapping using type_map defaults to True + data_key_atomic_numbers: used when eval_type_map is True + data_key_node_feature: used when eval_type_map is True + data_key_grad: if given, sets its requires grad True before pred + """ + + def __init__( + self, + modules: Dict[str, nn.Module], + cutoff: float = 0.0, + type_map: Optional[Dict[int, int]] = None, + modal_map: Optional[Dict[str, int]] = None, + eval_type_map: bool = True, + eval_modal_map: bool = False, + data_key_atomic_numbers: str = KEY.ATOMIC_NUMBERS, + data_key_node_feature: str = KEY.NODE_FEATURE, + data_key_grad: Optional[str] = None, + ): + if not isinstance(modules, OrderedDict): # backward compat + modules = OrderedDict(modules) + self.cutoff = cutoff + self.type_map = type_map + self.eval_type_map = eval_type_map + self.is_batch_data = True + + if cutoff == 0.0: + warnings.warn('cutoff is 0.0 or not given', UserWarning) + + if self.type_map is None: + warnings.warn('type_map is not given', UserWarning) + self.eval_type_map = False + else: + z_to_onehot_tensor = torch.neg(torch.ones(120, dtype=torch.long)) + for z, onehot in self.type_map.items(): + z_to_onehot_tensor[z] = onehot + self.z_to_onehot_tensor = z_to_onehot_tensor + + if eval_modal_map and modal_map is None: + raise ValueError('eval_modal_map is True but modal_map is None') + self.eval_modal_map = eval_modal_map + self.modal_map = modal_map + + self.key_atomic_numbers = data_key_atomic_numbers + self.key_node_feature = data_key_node_feature + self.key_grad = data_key_grad + + _instantiate_modules(modules) + super().__init__(modules) + if not isinstance(self._modules, OrderedDict): # backward compat + self._modules = OrderedDict(self._modules) + + def set_is_batch_data(self, flag: bool): + # whether given data is batched or not some module have to change + # its behavior. checking whether data is batched or not inside + # forward function make problem harder when make it into torchscript + for module in self: + try: # Easier to ask for forgiveness than permission. + module._is_batch_data = flag # type: ignore + except AttributeError: + pass + self.is_batch_data = flag + + def get_irreps_in(self, modlue_name: str, attr_key: str = 'irreps_in'): + tg_module = self._modules[modlue_name] + for m in tg_module.modules(): + try: + return repr(m.__getattribute__(attr_key)) + except AttributeError: + pass + return None + + def prepand_module(self, key: str, module: nn.Module): + self._modules.update({key: module}) + self._modules.move_to_end(key, last=False) # type: ignore + + def replace_module(self, key: str, module: nn.Module): + self._modules.update({key: module}) + + def delete_module_by_key(self, key: str): + if key in self._modules.keys(): + del self._modules[key] + + @torch.jit.unused + def _atomic_numbers_to_onehot(self, atomic_numbers: torch.Tensor): + assert atomic_numbers.dtype == torch.int64 + device = atomic_numbers.device + z_to_onehot_tensor = self.z_to_onehot_tensor.to(device) + return torch.index_select( + input=z_to_onehot_tensor, dim=0, index=atomic_numbers + ) + + @torch.jit.unused + def _eval_modal_map(self, data: AtomGraphDataType): + assert self.modal_map is not None + # modal_map: dict[str, int] + if not self.is_batch_data: + modal_idx = self.modal_map[data[KEY.DATA_MODALITY]] # type: ignore + else: + modal_idx = [ + self.modal_map[ii] # type: ignore + for ii in data[KEY.DATA_MODALITY] + ] + modal_idx = torch.tensor( + modal_idx, + dtype=torch.int64, + device=data.x.device, # type: ignore + ) + data[KEY.MODAL_TYPE] = modal_idx + + def _preprocess(self, data: AtomGraphDataType) -> AtomGraphDataType: + if self.eval_type_map: + atomic_numbers = data[self.key_atomic_numbers] + onehot = self._atomic_numbers_to_onehot(atomic_numbers) + data[self.key_node_feature] = onehot + + if self.eval_modal_map: + self._eval_modal_map(data) + + if self.key_grad is not None: + data[self.key_grad].requires_grad_(True) + + return data + + def prepare_modal_deploy(self, modal: str): + if self.modal_map is None: + return + self.eval_modal_map = False + self.set_is_batch_data(False) + modal_idx = self.modal_map[modal] # type: ignore + self.prepand_module('modal_input_prepare', _ModalInputPrepare(modal_idx)) + + def forward(self, input: AtomGraphDataType) -> AtomGraphDataType: + data = self._preprocess(input) + for module in self: + data = module(data) + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/util.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/util.py index 411b6dc029919ae2c17dd214d72c84836442737e..cf29c969ff4dead378fde06c824e3e18a6a2f764 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/util.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/util.py @@ -1,17 +1,17 @@ -import torch - - -def broadcast( - src: torch.Tensor, - other: torch.Tensor, - dim: int -): - if dim < 0: - dim = other.dim() + dim - if src.dim() == 1: - for _ in range(0, dim): - src = src.unsqueeze(0) - for _ in range(src.dim(), other.dim()): - src = src.unsqueeze(-1) - src = src.expand_as(other) - return src +import torch + + +def broadcast( + src: torch.Tensor, + other: torch.Tensor, + dim: int +): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src diff --git a/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/patch_lammps.sh b/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/patch_lammps.sh index b111d0ef40b90ae1f90bb65b29f0a2a03d54f0da..e6bc90d5459c000d2ecb81cfc7138e025ab32844 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/patch_lammps.sh +++ b/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/patch_lammps.sh @@ -1,154 +1,154 @@ -#!/bin/bash - -lammps_root=$1 -cxx_standard=$2 # 14, 17 -d3_support=$3 # 1, 0 -SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") - -########################################### -# Check if the given arguments are valid # -########################################### - -# Check the number of arguments -if [ "$#" -ne 3 ]; then - echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support}" - echo " {lammps_root}: Root directory of LAMMPS source" - echo " {cxx_standard}: C++ standard (14, 17)" - echo " {d3_support}: Support for pair_d3 (1, 0)" - exit 1 -fi - -# Check if the lammps_root directory exists -if [ ! -d "$lammps_root" ]; then - echo "Error: No such directory: $lammps_root" - exit 1 -fi - -# Check if the given directory is the root of LAMMPS source -if [ ! -d "$lammps_root/cmake" ] && [ ! -d "$lammps_root/potentials" ]; then - echo "Error: Given $lammps_root is not a root of LAMMPS source" - exit 1 -fi - -# Check if the script is being run from the root of SevenNet -if [ ! -f "${SCRIPT_DIR}/pair_e3gnn.cpp" ]; then - echo "Error: Script executed in a wrong directory" - exit 1 -fi - -# Check if the patch is already applied -if [ -f "$lammps_root/src/pair_e3gnn.cpp" ]; then - echo "----------------------------------------------------------" - echo "Seems like given LAMMPS is already patched." - echo "Try again after removing src/pair_e3gnn.cpp to force patch" - echo "----------------------------------------------------------" - echo "Example build commands, under LAMMPS root" - echo " mkdir build; cd build" - echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" - echo " make -j 4" - exit 0 -fi - -# Check if OpenMPI exists and if it is CUDA-aware -if command -v ompi_info &> /dev/null; then - cuda_support=$(ompi_info --parsable --all | grep mpi_built_with_cuda_support:value) - if [[ -z "$cuda_support" ]]; then - echo "OpenMPI not found, parallel performance is not optimal" - elif [[ "$cuda_support" == *"true" ]]; then - echo "OpenMPI is CUDA aware" - else - echo "This system's OpenMPI is not 'CUDA aware', parallel performance is not optimal" - fi -else - echo "OpenMPI not found, parallel performance is not optimal" -fi - -# Extract LAMMPS version and update -lammps_version=$(grep "#define LAMMPS_VERSION" $lammps_root/src/version.h | awk '{print $3, $4, $5}' | tr -d '"') - -# Combine version and update -detected_version="$lammps_version" -required_version="2 Aug 2023" # Example required version - -# Check if the detected version is compatible -if [[ "$detected_version" != "$required_version" ]]; then - echo "Warning: Detected LAMMPS version ($detected_version) may not be compatible. Required version: $required_version" -fi - -########################################### -# Backup original LAMMPS source code # -########################################### - -# Create a backup directory if it doesn't exist -backup_dir="$lammps_root/_backups" -mkdir -p $backup_dir - -# Copy comm_* from original LAMMPS source as backup -cp $lammps_root/src/comm_brick.cpp $backup_dir/ -cp $lammps_root/src/comm_brick.h $backup_dir/ - -# Copy cmake/CMakeLists.txt from original source as backup -cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt - -########################################### -# Patch LAMMPS source code: e3gnn # -########################################### - -# 1. Copy pair_e3gnn files to LAMMPS source -cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/ -cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/ - -# 2. Patch cmake/CMakeLists.txt -sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD $cxx_standard)/" $lammps_root/cmake/CMakeLists.txt -cat >> $lammps_root/cmake/CMakeLists.txt << "EOF" - -find_package(Torch REQUIRED) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") -target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}") -EOF - -########################################### -# Patch LAMMPS source code: d3 # -########################################### - -if [ "$d3_support" -ne 0 ]; then - -# 1. Copy pair_d3 files to LAMMPS source -cp $SCRIPT_DIR/pair_d3.cu $lammps_root/src/ -cp $SCRIPT_DIR/pair_d3.h $lammps_root/src/ -cp $SCRIPT_DIR/pair_d3_pars.h $lammps_root/src/ - -# 2. Patch cmake/CMakeLists.txt -sed -i "s/project(lammps CXX)/project(lammps CXX CUDA)/" $lammps_root/cmake/CMakeLists.txt -sed -i "s/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp \${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cu/" $lammps_root/cmake/CMakeLists.txt -cat >> $lammps_root/cmake/CMakeLists.txt << "EOF" - -find_package(CUDA) -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fmad=false -O3") -string(REPLACE "-gencode arch=compute_50,code=sm_50" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") -target_link_libraries(lammps PUBLIC ${CUDA_LIBRARIES} cuda) -EOF - -fi - -########################################### -# Print changes and backup file locations # -########################################### - -# Print changes and backup file locations -echo "Changes made:" -echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups" -echo " - Copied contents of pair_e3gnn to $lammps_root/src/" -echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard" -if [ "$d3_support" -ne 0 ]; then - echo " - Copied contents of pair_d3 to $lammps_root/src/" - echo " - Patched CMakeLists.txt: include CUDA" -fi - -# Provide example cmake command to the user -echo "Example build commands, under LAMMPS root" -echo " mkdir build; cd build" -echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" -echo " make -j 4" - -exit 0 +#!/bin/bash + +lammps_root=$1 +cxx_standard=$2 # 14, 17 +d3_support=$3 # 1, 0 +SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") + +########################################### +# Check if the given arguments are valid # +########################################### + +# Check the number of arguments +if [ "$#" -ne 3 ]; then + echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support}" + echo " {lammps_root}: Root directory of LAMMPS source" + echo " {cxx_standard}: C++ standard (14, 17)" + echo " {d3_support}: Support for pair_d3 (1, 0)" + exit 1 +fi + +# Check if the lammps_root directory exists +if [ ! -d "$lammps_root" ]; then + echo "Error: No such directory: $lammps_root" + exit 1 +fi + +# Check if the given directory is the root of LAMMPS source +if [ ! -d "$lammps_root/cmake" ] && [ ! -d "$lammps_root/potentials" ]; then + echo "Error: Given $lammps_root is not a root of LAMMPS source" + exit 1 +fi + +# Check if the script is being run from the root of SevenNet +if [ ! -f "${SCRIPT_DIR}/pair_e3gnn.cpp" ]; then + echo "Error: Script executed in a wrong directory" + exit 1 +fi + +# Check if the patch is already applied +if [ -f "$lammps_root/src/pair_e3gnn.cpp" ]; then + echo "----------------------------------------------------------" + echo "Seems like given LAMMPS is already patched." + echo "Try again after removing src/pair_e3gnn.cpp to force patch" + echo "----------------------------------------------------------" + echo "Example build commands, under LAMMPS root" + echo " mkdir build; cd build" + echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" + echo " make -j 4" + exit 0 +fi + +# Check if OpenMPI exists and if it is CUDA-aware +if command -v ompi_info &> /dev/null; then + cuda_support=$(ompi_info --parsable --all | grep mpi_built_with_cuda_support:value) + if [[ -z "$cuda_support" ]]; then + echo "OpenMPI not found, parallel performance is not optimal" + elif [[ "$cuda_support" == *"true" ]]; then + echo "OpenMPI is CUDA aware" + else + echo "This system's OpenMPI is not 'CUDA aware', parallel performance is not optimal" + fi +else + echo "OpenMPI not found, parallel performance is not optimal" +fi + +# Extract LAMMPS version and update +lammps_version=$(grep "#define LAMMPS_VERSION" $lammps_root/src/version.h | awk '{print $3, $4, $5}' | tr -d '"') + +# Combine version and update +detected_version="$lammps_version" +required_version="2 Aug 2023" # Example required version + +# Check if the detected version is compatible +if [[ "$detected_version" != "$required_version" ]]; then + echo "Warning: Detected LAMMPS version ($detected_version) may not be compatible. Required version: $required_version" +fi + +########################################### +# Backup original LAMMPS source code # +########################################### + +# Create a backup directory if it doesn't exist +backup_dir="$lammps_root/_backups" +mkdir -p $backup_dir + +# Copy comm_* from original LAMMPS source as backup +cp $lammps_root/src/comm_brick.cpp $backup_dir/ +cp $lammps_root/src/comm_brick.h $backup_dir/ + +# Copy cmake/CMakeLists.txt from original source as backup +cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt + +########################################### +# Patch LAMMPS source code: e3gnn # +########################################### + +# 1. Copy pair_e3gnn files to LAMMPS source +cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/ +cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/ + +# 2. Patch cmake/CMakeLists.txt +sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD $cxx_standard)/" $lammps_root/cmake/CMakeLists.txt +cat >> $lammps_root/cmake/CMakeLists.txt << "EOF" + +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") +target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}") +EOF + +########################################### +# Patch LAMMPS source code: d3 # +########################################### + +if [ "$d3_support" -ne 0 ]; then + +# 1. Copy pair_d3 files to LAMMPS source +cp $SCRIPT_DIR/pair_d3.cu $lammps_root/src/ +cp $SCRIPT_DIR/pair_d3.h $lammps_root/src/ +cp $SCRIPT_DIR/pair_d3_pars.h $lammps_root/src/ + +# 2. Patch cmake/CMakeLists.txt +sed -i "s/project(lammps CXX)/project(lammps CXX CUDA)/" $lammps_root/cmake/CMakeLists.txt +sed -i "s/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp \${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cu/" $lammps_root/cmake/CMakeLists.txt +cat >> $lammps_root/cmake/CMakeLists.txt << "EOF" + +find_package(CUDA) +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fmad=false -O3") +string(REPLACE "-gencode arch=compute_50,code=sm_50" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") +target_link_libraries(lammps PUBLIC ${CUDA_LIBRARIES} cuda) +EOF + +fi + +########################################### +# Print changes and backup file locations # +########################################### + +# Print changes and backup file locations +echo "Changes made:" +echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups" +echo " - Copied contents of pair_e3gnn to $lammps_root/src/" +echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard" +if [ "$d3_support" -ne 0 ]; then + echo " - Copied contents of pair_d3 to $lammps_root/src/" + echo " - Patched CMakeLists.txt: include CUDA" +fi + +# Provide example cmake command to the user +echo "Example build commands, under LAMMPS root" +echo " mkdir build; cd build" +echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" +echo " make -j 4" + +exit 0 diff --git a/mace-bench/3rdparty/SevenNet/sevenn/parse_input.py b/mace-bench/3rdparty/SevenNet/sevenn/parse_input.py index 62f3167002c6e01d8817f39ece375a88c7a9b080..f0406d8c6ada4303dbc9241beb9fb9b276e096cc 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/parse_input.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/parse_input.py @@ -1,246 +1,246 @@ -import glob -import os -import warnings -from typing import Any, Callable, Dict - -import torch -import yaml - -import sevenn._const as _const -import sevenn._keys as KEY -import sevenn.util as util - - -def config_initialize( - key: str, - config: Dict, - default: Any, - conditions: Dict, -): - # default value exist & no user input -> return default - if key not in config.keys(): - return default - - # No validation method exist => accept user input - user_input = config[key] - if key in conditions: - condition = conditions[key] - else: - return user_input - - if type(default) is dict and isinstance(condition, dict): - for i_key, val in default.items(): - user_input[i_key] = config_initialize( - i_key, user_input, val, condition - ) - return user_input - elif isinstance(condition, type): - if isinstance(user_input, condition): - return user_input - else: - try: - return condition(user_input) # try type casting - except ValueError: - raise ValueError( - f"Expect '{user_input}' for '{key}' is {condition}" - ) - elif isinstance(condition, Callable) and condition(user_input): - return user_input - else: - raise ValueError( - f"Given input '{user_input}' for '{key}' is not valid" - ) - - -def init_model_config(config: Dict): - # defaults = _const.model_defaults(config) - model_meta = {} - - # init complicated ones - if KEY.CHEMICAL_SPECIES not in config.keys(): - raise ValueError('required key chemical_species not exist') - input_chem = config[KEY.CHEMICAL_SPECIES] - if isinstance(input_chem, str) and input_chem.lower() == 'auto': - model_meta[KEY.CHEMICAL_SPECIES] = 'auto' - model_meta[KEY.NUM_SPECIES] = 'auto' - model_meta[KEY.TYPE_MAP] = 'auto' - elif isinstance(input_chem, str) and 'univ' in input_chem.lower(): - model_meta.update(util.chemical_species_preprocess([], universal=True)) - else: - if isinstance(input_chem, list) and all( - isinstance(x, str) for x in input_chem - ): - pass - elif isinstance(input_chem, str): - input_chem = ( - input_chem.replace('-', ',').replace(' ', ',').split(',') - ) - input_chem = [chem for chem in input_chem if len(chem) != 0] - else: - raise ValueError(f'given {KEY.CHEMICAL_SPECIES} input is strange') - model_meta.update(util.chemical_species_preprocess(input_chem)) - - # deprecation warnings - if KEY.AVG_NUM_NEIGH in config: - warnings.warn( - "key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'." - ' We use the default, the average number of neighbors in the' - ' dataset, if not provided.', - UserWarning, - ) - config.pop(KEY.AVG_NUM_NEIGH) - if KEY.TRAIN_AVG_NUM_NEIGH in config: - warnings.warn( - "key 'train_avg_num_neigh' is deprecated. Please use" - " 'train_denominator'. We overwrite train_denominator as given" - ' train_avg_num_neigh', - UserWarning, - ) - config[KEY.TRAIN_DENOMINTAOR] = config[KEY.TRAIN_AVG_NUM_NEIGH] - config.pop(KEY.TRAIN_AVG_NUM_NEIGH) - if KEY.OPTIMIZE_BY_REDUCE in config: - warnings.warn( - "key 'optimize_by_reduce' is deprecated. Always true", - UserWarning, - ) - config.pop(KEY.OPTIMIZE_BY_REDUCE) - - # init simpler ones - for key, default in _const.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG.items(): - model_meta[key] = config_initialize( - key, config, default, _const.MODEL_CONFIG_CONDITION - ) - - unknown_keys = [ - key for key in config.keys() if key not in model_meta.keys() - ] - if len(unknown_keys) != 0: - warnings.warn( - f'Unexpected model keys: {unknown_keys} will be ignored', - UserWarning, - ) - - return model_meta - - -def init_train_config(config: Dict): - train_meta = {} - # defaults = _const.train_defaults(config) - - try: - device_input = config[KEY.DEVICE] - train_meta[KEY.DEVICE] = torch.device(device_input) - except KeyError: - train_meta[KEY.DEVICE] = ( - torch.device('cuda') - if torch.cuda.is_available() - else torch.device('cpu') - ) - train_meta[KEY.DEVICE] = str(train_meta[KEY.DEVICE]) - - # init simpler ones - for key, default in _const.DEFAULT_TRAINING_CONFIG.items(): - train_meta[key] = config_initialize( - key, config, default, _const.TRAINING_CONFIG_CONDITION - ) - - if KEY.CONTINUE in config.keys(): - cnt_dct = config[KEY.CONTINUE] - if KEY.CHECKPOINT not in cnt_dct.keys(): - raise ValueError('no checkpoint is given in continue') - checkpoint = cnt_dct[KEY.CHECKPOINT] - if os.path.isfile(checkpoint): - checkpoint_file = checkpoint - else: - checkpoint_file = util.pretrained_name_to_path(checkpoint) - train_meta[KEY.CONTINUE].update({KEY.CHECKPOINT: checkpoint_file}) - - unknown_keys = [ - key for key in config.keys() if key not in train_meta.keys() - ] - if len(unknown_keys) != 0: - warnings.warn( - f'Unexpected train keys: {unknown_keys} will be ignored', - UserWarning, - ) - return train_meta - - -def init_data_config(config: Dict): - data_meta = {} - # defaults = _const.data_defaults(config) - - load_data_keys = [] - for k in config: - if k.startswith('load_') and k.endswith('_path'): - load_data_keys.append(k) - - for load_data_key in load_data_keys: - if load_data_key in config.keys(): - inp = config[load_data_key] - extended = [] - if type(inp) not in [str, list]: - raise ValueError(f'unexpected input {inp} for sturcture_list') - if type(inp) is str: - extended = glob.glob(inp) - elif type(inp) is list: - for i in inp: - if isinstance(i, str): - extended.extend(glob.glob(i)) - elif isinstance(i, dict): - extended.append(i) - if len(extended) == 0: - raise ValueError( - f'Cannot find {inp} for {load_data_key}' - + ' or path is not given' - ) - data_meta[load_data_key] = extended - else: - data_meta[load_data_key] = False - - for key, default in _const.DEFAULT_DATA_CONFIG.items(): - data_meta[key] = config_initialize( - key, config, default, _const.DATA_CONFIG_CONDITION - ) - - unknown_keys = [ - key for key in config.keys() if key not in data_meta.keys() - ] - if len(unknown_keys) != 0: - warnings.warn( - f'Unexpected data keys: {unknown_keys} will be ignored', - UserWarning, - ) - return data_meta - - -def read_config_yaml(filename: str, return_separately: bool = False): - with open(filename, 'r') as fstream: - inputs = yaml.safe_load(fstream) - - model_meta, train_meta, data_meta = {}, {}, {} - for key, config in inputs.items(): - if key == 'model': - model_meta = init_model_config(config) - elif key == 'train': - train_meta = init_train_config(config) - elif key == 'data': - data_meta = init_data_config(config) - else: - raise ValueError(f'Unexpected input {key} given') - - if return_separately: - return model_meta, train_meta, data_meta - else: - model_meta.update(train_meta) - model_meta.update(data_meta) - return model_meta - - -def main(): - filename = './input.yaml' - read_config_yaml(filename) - - -if __name__ == '__main__': - main() +import glob +import os +import warnings +from typing import Any, Callable, Dict + +import torch +import yaml + +import sevenn._const as _const +import sevenn._keys as KEY +import sevenn.util as util + + +def config_initialize( + key: str, + config: Dict, + default: Any, + conditions: Dict, +): + # default value exist & no user input -> return default + if key not in config.keys(): + return default + + # No validation method exist => accept user input + user_input = config[key] + if key in conditions: + condition = conditions[key] + else: + return user_input + + if type(default) is dict and isinstance(condition, dict): + for i_key, val in default.items(): + user_input[i_key] = config_initialize( + i_key, user_input, val, condition + ) + return user_input + elif isinstance(condition, type): + if isinstance(user_input, condition): + return user_input + else: + try: + return condition(user_input) # try type casting + except ValueError: + raise ValueError( + f"Expect '{user_input}' for '{key}' is {condition}" + ) + elif isinstance(condition, Callable) and condition(user_input): + return user_input + else: + raise ValueError( + f"Given input '{user_input}' for '{key}' is not valid" + ) + + +def init_model_config(config: Dict): + # defaults = _const.model_defaults(config) + model_meta = {} + + # init complicated ones + if KEY.CHEMICAL_SPECIES not in config.keys(): + raise ValueError('required key chemical_species not exist') + input_chem = config[KEY.CHEMICAL_SPECIES] + if isinstance(input_chem, str) and input_chem.lower() == 'auto': + model_meta[KEY.CHEMICAL_SPECIES] = 'auto' + model_meta[KEY.NUM_SPECIES] = 'auto' + model_meta[KEY.TYPE_MAP] = 'auto' + elif isinstance(input_chem, str) and 'univ' in input_chem.lower(): + model_meta.update(util.chemical_species_preprocess([], universal=True)) + else: + if isinstance(input_chem, list) and all( + isinstance(x, str) for x in input_chem + ): + pass + elif isinstance(input_chem, str): + input_chem = ( + input_chem.replace('-', ',').replace(' ', ',').split(',') + ) + input_chem = [chem for chem in input_chem if len(chem) != 0] + else: + raise ValueError(f'given {KEY.CHEMICAL_SPECIES} input is strange') + model_meta.update(util.chemical_species_preprocess(input_chem)) + + # deprecation warnings + if KEY.AVG_NUM_NEIGH in config: + warnings.warn( + "key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'." + ' We use the default, the average number of neighbors in the' + ' dataset, if not provided.', + UserWarning, + ) + config.pop(KEY.AVG_NUM_NEIGH) + if KEY.TRAIN_AVG_NUM_NEIGH in config: + warnings.warn( + "key 'train_avg_num_neigh' is deprecated. Please use" + " 'train_denominator'. We overwrite train_denominator as given" + ' train_avg_num_neigh', + UserWarning, + ) + config[KEY.TRAIN_DENOMINTAOR] = config[KEY.TRAIN_AVG_NUM_NEIGH] + config.pop(KEY.TRAIN_AVG_NUM_NEIGH) + if KEY.OPTIMIZE_BY_REDUCE in config: + warnings.warn( + "key 'optimize_by_reduce' is deprecated. Always true", + UserWarning, + ) + config.pop(KEY.OPTIMIZE_BY_REDUCE) + + # init simpler ones + for key, default in _const.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG.items(): + model_meta[key] = config_initialize( + key, config, default, _const.MODEL_CONFIG_CONDITION + ) + + unknown_keys = [ + key for key in config.keys() if key not in model_meta.keys() + ] + if len(unknown_keys) != 0: + warnings.warn( + f'Unexpected model keys: {unknown_keys} will be ignored', + UserWarning, + ) + + return model_meta + + +def init_train_config(config: Dict): + train_meta = {} + # defaults = _const.train_defaults(config) + + try: + device_input = config[KEY.DEVICE] + train_meta[KEY.DEVICE] = torch.device(device_input) + except KeyError: + train_meta[KEY.DEVICE] = ( + torch.device('cuda') + if torch.cuda.is_available() + else torch.device('cpu') + ) + train_meta[KEY.DEVICE] = str(train_meta[KEY.DEVICE]) + + # init simpler ones + for key, default in _const.DEFAULT_TRAINING_CONFIG.items(): + train_meta[key] = config_initialize( + key, config, default, _const.TRAINING_CONFIG_CONDITION + ) + + if KEY.CONTINUE in config.keys(): + cnt_dct = config[KEY.CONTINUE] + if KEY.CHECKPOINT not in cnt_dct.keys(): + raise ValueError('no checkpoint is given in continue') + checkpoint = cnt_dct[KEY.CHECKPOINT] + if os.path.isfile(checkpoint): + checkpoint_file = checkpoint + else: + checkpoint_file = util.pretrained_name_to_path(checkpoint) + train_meta[KEY.CONTINUE].update({KEY.CHECKPOINT: checkpoint_file}) + + unknown_keys = [ + key for key in config.keys() if key not in train_meta.keys() + ] + if len(unknown_keys) != 0: + warnings.warn( + f'Unexpected train keys: {unknown_keys} will be ignored', + UserWarning, + ) + return train_meta + + +def init_data_config(config: Dict): + data_meta = {} + # defaults = _const.data_defaults(config) + + load_data_keys = [] + for k in config: + if k.startswith('load_') and k.endswith('_path'): + load_data_keys.append(k) + + for load_data_key in load_data_keys: + if load_data_key in config.keys(): + inp = config[load_data_key] + extended = [] + if type(inp) not in [str, list]: + raise ValueError(f'unexpected input {inp} for sturcture_list') + if type(inp) is str: + extended = glob.glob(inp) + elif type(inp) is list: + for i in inp: + if isinstance(i, str): + extended.extend(glob.glob(i)) + elif isinstance(i, dict): + extended.append(i) + if len(extended) == 0: + raise ValueError( + f'Cannot find {inp} for {load_data_key}' + + ' or path is not given' + ) + data_meta[load_data_key] = extended + else: + data_meta[load_data_key] = False + + for key, default in _const.DEFAULT_DATA_CONFIG.items(): + data_meta[key] = config_initialize( + key, config, default, _const.DATA_CONFIG_CONDITION + ) + + unknown_keys = [ + key for key in config.keys() if key not in data_meta.keys() + ] + if len(unknown_keys) != 0: + warnings.warn( + f'Unexpected data keys: {unknown_keys} will be ignored', + UserWarning, + ) + return data_meta + + +def read_config_yaml(filename: str, return_separately: bool = False): + with open(filename, 'r') as fstream: + inputs = yaml.safe_load(fstream) + + model_meta, train_meta, data_meta = {}, {}, {} + for key, config in inputs.items(): + if key == 'model': + model_meta = init_model_config(config) + elif key == 'train': + train_meta = init_train_config(config) + elif key == 'data': + data_meta = init_data_config(config) + else: + raise ValueError(f'Unexpected input {key} given') + + if return_separately: + return model_meta, train_meta, data_meta + else: + model_meta.update(train_meta) + model_meta.update(data_meta) + return model_meta + + +def main(): + filename = './input.yaml' + read_config_yaml(filename) + + +if __name__ == '__main__': + main() diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 24894cc923674b16f0810c194bd4278fe4d4d34d..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/backward_compatibility.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/backward_compatibility.cpython-310.pyc deleted file mode 100644 index 141a08b4fe52ae09d43701ef6a9546b040c4a434..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/backward_compatibility.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/backward_compatibility.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/backward_compatibility.py index b01f9b12a6a23d0b39cd59b5ce2cfa9840915797..b8e81b1a07ae31f97ec06983312f496e05e80198 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/backward_compatibility.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/backward_compatibility.py @@ -1,184 +1,184 @@ -""" -Debt -keep old pre-trained checkpoints unchanged. -""" - -import copy - -import torch - -import sevenn._keys as KEY - - -def version_tuple(v1): - v1 = tuple(map(int, v1.split('.'))) - return v1 - - -def patch_old_config(config): - version = config.get('version', None) - if not version: - raise ValueError('No version found in config') - - major, minor, _ = version.split('.')[:3] - major, minor = int(major), int(minor) - - if major == 0 and minor <= 9: - if config[KEY.CUTOFF_FUNCTION][KEY.CUTOFF_FUNCTION_NAME] == 'XPLOR': - config[KEY.CUTOFF_FUNCTION].pop('poly_cut_p_value', None) - if KEY.TRAIN_DENOMINTAOR not in config: - config[KEY.TRAIN_DENOMINTAOR] = config.pop('train_avg_num_neigh', False) - _opt = config.pop('optimize_by_reduce', None) - if _opt is False: - raise ValueError( - 'This checkpoint(optimize_by_reduce: False) is no longer supported' - ) - if KEY.CONV_DENOMINATOR not in config: - config[KEY.CONV_DENOMINATOR] = 0.0 - if KEY._NORMALIZE_SPH not in config: - config[KEY._NORMALIZE_SPH] = False - - return config - - -def map_old_model(old_model_state_dict): - """ - For compatibility with old namings (before 'correct' branch merged 2404XX) - Map old model's module names to new model's module names - """ - _old_module_name_mapping = { - 'EdgeEmbedding': 'edge_embedding', - 'reducing nn input to hidden': 'reduce_input_to_hidden', - 'reducing nn hidden to energy': 'reduce_hidden_to_energy', - 'rescale atomic energy': 'rescale_atomic_energy', - } - for i in range(10): - _old_module_name_mapping[f'{i} self connection intro'] = ( - f'{i}_self_connection_intro' - ) - _old_module_name_mapping[f'{i} convolution'] = f'{i}_convolution' - _old_module_name_mapping[f'{i} self interaction 2'] = ( - f'{i}_self_interaction_2' - ) - _old_module_name_mapping[f'{i} equivariant gate'] = f'{i}_equivariant_gate' - - new_model_state_dict = {} - for k, v in old_model_state_dict.items(): - key_name = k.split('.')[0] - follower = '.'.join(k.split('.')[1:]) - if 'denumerator' in follower: - follower = follower.replace('denumerator', 'denominator') - if key_name in _old_module_name_mapping: - new_key_name = _old_module_name_mapping[key_name] + '.' + follower - new_model_state_dict[new_key_name] = v - else: - new_model_state_dict[k] = v - return new_model_state_dict - - -def sort_old_convolution(model_now, state_dict): - from e3nn.o3 import wigner_3j - - """ - Reason1: we have to sort instructions of convolution to be compatible with - cuEquivariance. (therefore, sort weight) - Reason2: some of old convolution module's w3j coeff has flipped sign. This also - has to be fixed to be compatible with cuEquivarinace. - """ - - def patch(stct): - inst_old = copy.copy(conv._instructions_before_sort) - inst_old = [(inst[0], inst[1], inst[2]) for inst in inst_old] - del conv._instructions_before_sort - - conv_args = conv.convolution_kwargs - irreps_in1 = conv_args['irreps_in1'] - irreps_in2 = conv_args['irreps_in2'] - irreps_out = conv_args.get('irreps_out', conv_args.get('filter_irreps_out')) - - inst_sorted = sorted(inst_old, key=lambda x: x[2]) - - inst_sorted = [ - # in1, in2, out, weights - (inst[0], inst[1], inst[2], irreps_in1[inst[0]].mul) - for inst in inst_sorted - ] - - n = len(weight_nn.hs) - 2 - ww_key = f'{conv_key}.weight_nn.layer{n}.weight' - ww = stct[ww_key] - ww_sorted = [None] * len(inst_old) - - _prev_idx = 0 - for ist_src in inst_old: - for j, ist_dst in enumerate(inst_sorted): - if not all(ist_src[ii] == ist_dst[ii] for ii in range(3)): - continue - - numel = ist_dst[3] # weight num - ww_src = ww[:, _prev_idx : _prev_idx + numel] - l1, l2, l3 = ( - irreps_in1[ist_src[0]].ir.l, - irreps_in2[ist_src[1]].ir.l, - irreps_out[ist_src[2]].ir.l, - ) - if l1 > 0 and l2 > 0 and l3 > 0: - w3j_key = f'_w3j_{l1}_{l2}_{l3}' - conv_w3j_key = ( - f'{conv_key}.convolution._compiled_main_left_right.{w3j_key}' - ) - w3j_old = stct[conv_w3j_key] - w3j_now = wigner_3j(l1, l2, l3) - if not torch.allclose(w3j_old.to(w3j_now.device), w3j_now): - assert torch.allclose( - w3j_old.to(w3j_now.device), -1 * w3j_now - ) - ww_src = -1 * ww_src - stct[conv_w3j_key] *= -1 # stct updated - _prev_idx += numel - ww_sorted[j] = ww_src - ww_sorted = torch.cat(ww_sorted, dim=1) # type: ignore - stct[ww_key] = ww_sorted.clone() # stct updated - - conv_dicts = {} - for k, v in state_dict.items(): - key_name = k.split('.')[0] - if key_name.split('_')[1] == 'convolution': - if key_name not in conv_dicts: - conv_dicts[key_name] = {} - conv_dicts[key_name].update({k: v}) - - new_state_dict = {} - new_state_dict.update(state_dict) - for conv_key, conv_state_dict in conv_dicts.items(): - conv = model_now._modules[conv_key] - weight_nn = conv.weight_nn - patch(conv_state_dict) - new_state_dict.update(conv_state_dict) - - return new_state_dict - - -def patch_state_dict_if_old(state_dict, config_cp, now_model): - version = config_cp.get('version', None) - if not version: - raise ValueError('No version found in config') - vs = version.split('.') - vsuffix = '' - if len(vs) == 4: - vsuffix = vs[-1] - vs = version_tuple('.'.join(vs[:3])) - else: - vs = version_tuple('.'.join(vs)) - - if vs < version_tuple('0.10.0'): - state_dict = map_old_model(state_dict) - - # TODO: change version criteria before release!!! - # it causes problem if model is sorted but this function is called - # ... more robust way? idk - if vs < version_tuple('0.11.0') or ( - vs == version_tuple('0.11.0') and vsuffix == 'dev0' - ): - state_dict = sort_old_convolution(now_model, state_dict) - return state_dict +""" +Debt +keep old pre-trained checkpoints unchanged. +""" + +import copy + +import torch + +import sevenn._keys as KEY + + +def version_tuple(v1): + v1 = tuple(map(int, v1.split('.'))) + return v1 + + +def patch_old_config(config): + version = config.get('version', None) + if not version: + raise ValueError('No version found in config') + + major, minor, _ = version.split('.')[:3] + major, minor = int(major), int(minor) + + if major == 0 and minor <= 9: + if config[KEY.CUTOFF_FUNCTION][KEY.CUTOFF_FUNCTION_NAME] == 'XPLOR': + config[KEY.CUTOFF_FUNCTION].pop('poly_cut_p_value', None) + if KEY.TRAIN_DENOMINTAOR not in config: + config[KEY.TRAIN_DENOMINTAOR] = config.pop('train_avg_num_neigh', False) + _opt = config.pop('optimize_by_reduce', None) + if _opt is False: + raise ValueError( + 'This checkpoint(optimize_by_reduce: False) is no longer supported' + ) + if KEY.CONV_DENOMINATOR not in config: + config[KEY.CONV_DENOMINATOR] = 0.0 + if KEY._NORMALIZE_SPH not in config: + config[KEY._NORMALIZE_SPH] = False + + return config + + +def map_old_model(old_model_state_dict): + """ + For compatibility with old namings (before 'correct' branch merged 2404XX) + Map old model's module names to new model's module names + """ + _old_module_name_mapping = { + 'EdgeEmbedding': 'edge_embedding', + 'reducing nn input to hidden': 'reduce_input_to_hidden', + 'reducing nn hidden to energy': 'reduce_hidden_to_energy', + 'rescale atomic energy': 'rescale_atomic_energy', + } + for i in range(10): + _old_module_name_mapping[f'{i} self connection intro'] = ( + f'{i}_self_connection_intro' + ) + _old_module_name_mapping[f'{i} convolution'] = f'{i}_convolution' + _old_module_name_mapping[f'{i} self interaction 2'] = ( + f'{i}_self_interaction_2' + ) + _old_module_name_mapping[f'{i} equivariant gate'] = f'{i}_equivariant_gate' + + new_model_state_dict = {} + for k, v in old_model_state_dict.items(): + key_name = k.split('.')[0] + follower = '.'.join(k.split('.')[1:]) + if 'denumerator' in follower: + follower = follower.replace('denumerator', 'denominator') + if key_name in _old_module_name_mapping: + new_key_name = _old_module_name_mapping[key_name] + '.' + follower + new_model_state_dict[new_key_name] = v + else: + new_model_state_dict[k] = v + return new_model_state_dict + + +def sort_old_convolution(model_now, state_dict): + from e3nn.o3 import wigner_3j + + """ + Reason1: we have to sort instructions of convolution to be compatible with + cuEquivariance. (therefore, sort weight) + Reason2: some of old convolution module's w3j coeff has flipped sign. This also + has to be fixed to be compatible with cuEquivarinace. + """ + + def patch(stct): + inst_old = copy.copy(conv._instructions_before_sort) + inst_old = [(inst[0], inst[1], inst[2]) for inst in inst_old] + del conv._instructions_before_sort + + conv_args = conv.convolution_kwargs + irreps_in1 = conv_args['irreps_in1'] + irreps_in2 = conv_args['irreps_in2'] + irreps_out = conv_args.get('irreps_out', conv_args.get('filter_irreps_out')) + + inst_sorted = sorted(inst_old, key=lambda x: x[2]) + + inst_sorted = [ + # in1, in2, out, weights + (inst[0], inst[1], inst[2], irreps_in1[inst[0]].mul) + for inst in inst_sorted + ] + + n = len(weight_nn.hs) - 2 + ww_key = f'{conv_key}.weight_nn.layer{n}.weight' + ww = stct[ww_key] + ww_sorted = [None] * len(inst_old) + + _prev_idx = 0 + for ist_src in inst_old: + for j, ist_dst in enumerate(inst_sorted): + if not all(ist_src[ii] == ist_dst[ii] for ii in range(3)): + continue + + numel = ist_dst[3] # weight num + ww_src = ww[:, _prev_idx : _prev_idx + numel] + l1, l2, l3 = ( + irreps_in1[ist_src[0]].ir.l, + irreps_in2[ist_src[1]].ir.l, + irreps_out[ist_src[2]].ir.l, + ) + if l1 > 0 and l2 > 0 and l3 > 0: + w3j_key = f'_w3j_{l1}_{l2}_{l3}' + conv_w3j_key = ( + f'{conv_key}.convolution._compiled_main_left_right.{w3j_key}' + ) + w3j_old = stct[conv_w3j_key] + w3j_now = wigner_3j(l1, l2, l3) + if not torch.allclose(w3j_old.to(w3j_now.device), w3j_now): + assert torch.allclose( + w3j_old.to(w3j_now.device), -1 * w3j_now + ) + ww_src = -1 * ww_src + stct[conv_w3j_key] *= -1 # stct updated + _prev_idx += numel + ww_sorted[j] = ww_src + ww_sorted = torch.cat(ww_sorted, dim=1) # type: ignore + stct[ww_key] = ww_sorted.clone() # stct updated + + conv_dicts = {} + for k, v in state_dict.items(): + key_name = k.split('.')[0] + if key_name.split('_')[1] == 'convolution': + if key_name not in conv_dicts: + conv_dicts[key_name] = {} + conv_dicts[key_name].update({k: v}) + + new_state_dict = {} + new_state_dict.update(state_dict) + for conv_key, conv_state_dict in conv_dicts.items(): + conv = model_now._modules[conv_key] + weight_nn = conv.weight_nn + patch(conv_state_dict) + new_state_dict.update(conv_state_dict) + + return new_state_dict + + +def patch_state_dict_if_old(state_dict, config_cp, now_model): + version = config_cp.get('version', None) + if not version: + raise ValueError('No version found in config') + vs = version.split('.') + vsuffix = '' + if len(vs) == 4: + vsuffix = vs[-1] + vs = version_tuple('.'.join(vs[:3])) + else: + vs = version_tuple('.'.join(vs)) + + if vs < version_tuple('0.10.0'): + state_dict = map_old_model(state_dict) + + # TODO: change version criteria before release!!! + # it causes problem if model is sorted but this function is called + # ... more robust way? idk + if vs < version_tuple('0.11.0') or ( + vs == version_tuple('0.11.0') and vsuffix == 'dev0' + ): + state_dict = sort_old_convolution(now_model, state_dict) + return state_dict diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/convert_model_modality.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/convert_model_modality.py index 99882a6f9fcc687476f574b8f4ddbb7f68d194a7..5581e19661ddea5a90f6e5fd2ee7f1b0fa239845 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/convert_model_modality.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/convert_model_modality.py @@ -1,301 +1,301 @@ -import math -from typing import List - -import torch -import torch.nn as nn -from e3nn.o3 import Irreps, Linear - -import sevenn._keys as KEY -from sevenn.model_build import build_E3_equivariant_model - -modal_module_dict = { - KEY.USE_MODAL_NODE_EMBEDDING: 'onehot_to_feature_x', - KEY.USE_MODAL_SELF_INTER_INTRO: 'self_interaction_1', - KEY.USE_MODAL_SELF_INTER_OUTRO: 'self_interaction_2', - KEY.USE_MODAL_OUTPUT_BLOCK: 'reduce_input_to_hidden', -} - - -def _get_scalar_index(irreps: Irreps): - scalar_indices = [] - for idx, (_, (l, p)) in enumerate(irreps): # noqa - if ( - l == 0 and p == 1 - ): # get index of parameter for scalar (0e), which is used for modality - scalar_indices.append(idx) - - return scalar_indices - - -def _reshape_weight_of_linear( - irreps_in: Irreps, irreps_out: Irreps, weight: torch.Tensor -) -> List[torch.Tensor]: - linear = Linear(irreps_in, irreps_out) - linear.weight = nn.Parameter(weight) - return list(linear.weight_views()) - - -def _erase_linear_modal_params( - model_state_dct: dict, - erase_modal_indices: List[int], - key: str, - irreps_in: Irreps, - irreps_out: Irreps, -): - orig_input_dim = irreps_in.count('0e') - new_input_dim = orig_input_dim - len(erase_modal_indices) - - orig_weight = model_state_dct[key + '.linear.weight'] - scalar_idx = _get_scalar_index(irreps_in) - linear_weight_list = _reshape_weight_of_linear( - irreps_in, irreps_out, orig_weight - ) - - new_weight_list = [] - - for idx, l_p_weight in enumerate(linear_weight_list[:-1]): - new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze() - if idx in scalar_idx: - new_weight = new_weight * math.sqrt(new_input_dim / orig_input_dim) - - new_weight_list.append(new_weight) - - """ - Following works for normalization = `path`, which is not used in SEVENNet - for l_p_weight in linear_weight_list[:-1]: - new_weight_list.append(torch.reshape(l_p_weight, (1, -1)).squeeze()) - """ - - flattened_weight = torch.cat(new_weight_list) - - return flattened_weight - - -def _get_modal_weight_as_bias( - model_state_dct: dict, - key: str, - ref_index: int, - irreps_in: Irreps, - irreps_out: Irreps, -): - assert ref_index != -1 - input_dim = irreps_in.count('0e') - output_dim = irreps_out.count('0e') - orig_weight = model_state_dct[key + '.linear.weight'] - orig_bias = model_state_dct[key + '.linear.bias'] - if len(orig_bias) == 0: - orig_bias = torch.zeros(output_dim, dtype=orig_weight.dtype) - - modal_weight = _reshape_weight_of_linear( - irreps_in, irreps_out, orig_weight - )[-1] - - new_bias = orig_bias + modal_weight[ref_index] / math.sqrt(input_dim) - - return new_bias - - -def _append_modal_weight( - model_state_dct: dict, # state dict to be targeted - key: str, # linear weight modune name - irreps_in: Irreps, # irreps_in before modality append - irreps_out: Irreps, - append_number: int, -): - # This works for normalization = `element`, default in SEVENNet. - # (normalization = `path` is curruently deprecated in SEVENNet.) - input_dim = irreps_in.count('0e') - output_dim = irreps_out.count('0e') - new_input_dim = input_dim + append_number - orig_weight = model_state_dct[key + '.linear.weight'] - scalar_idx = _get_scalar_index(irreps_in) - linear_weight_list = _reshape_weight_of_linear( - irreps_in, irreps_out, orig_weight - ) - - new_weight_list = [] - - # TODO: combine following as function with _erase_linear_modal_params - - for idx, l_p_weight in enumerate(linear_weight_list): - new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze() - if idx in scalar_idx: - new_weight = new_weight * math.sqrt(new_input_dim / input_dim) - - new_weight_list.append(new_weight) - - flattened_weight_list = [] - for l_p_weight in new_weight_list: - flattened_weight_list.append( - torch.reshape(l_p_weight, (1, -1)).squeeze() - ) - flattened_weight = torch.cat(flattened_weight_list) - - append_weight = torch.cat([ - flattened_weight, - torch.zeros(append_number * output_dim, dtype=flattened_weight.dtype), - ]) # zeros: starting from common model - - return append_weight - - -def get_single_modal_model_dct( - model_state_dct: dict, - config: dict, - ref_modal: str, - from_processing_cp: bool = False, - is_deploy: bool = False, -): - """ - Convert multimodal model state dictionary to single modal model. - Modal is selected by `ref_modal` - - `model_state_dct`: model state dictionary from multimodal checkpoint file - `config`: dictionary containing configuration of the checkpoint model - `ref_modal`: modal that are going to be converted - `from_processing_cp`: if True, use modal_map of the checkpoint file - `is_deploy`: if True, model is build with single-modal shift and scale - """ - if ( - not from_processing_cp and not config[KEY.USE_MODALITY] - ): # model is already single modal - return model_state_dct - - config[KEY.USE_BIAS_IN_LINEAR] = True - config['_deploy'] = is_deploy - - model = build_E3_equivariant_model(config) - del config['_deploy'] - key_add = '_cp' if from_processing_cp else '' - modal_type_dict = config[KEY.MODAL_MAP + key_add] - erase_modal_indices = range(len(modal_type_dict.keys())) # starts with 0 - - if ref_modal != 'common': - try: - ref_modal_index = modal_type_dict[ref_modal] - except: - raise KeyError( - f'{ref_modal} not in modal type. Use one of' - f' {modal_type_dict.keys()}.' - ) - - for module_key in model._modules.keys(): - for ( - use_modal_module_key, - modal_module_name, - ) in modal_module_dict.items(): - irreps_out = Irreps(model.get_irreps_in(module_key, 'irreps_out')) - # TODO: directly using "irreps_in" might not be compatible - # when changing `nn/linear.py` - output_dim = irreps_out.count('0e') - if ( - config[use_modal_module_key] - and modal_module_name in module_key - ): # this module is used for giving modality - - irreps_in = Irreps( - model.get_irreps_in(module_key, 'irreps_in') - ) - - new_bias = ( - torch.zeros(output_dim) - if ref_modal == 'common' - else _get_modal_weight_as_bias( - model_state_dct, - module_key, - ref_modal_index, - irreps_in, # type: ignore - irreps_out, # type: ignore - ) - ) - erased_modal_weight = _erase_linear_modal_params( - model_state_dct, - erase_modal_indices, - module_key, - irreps_in, # type: ignore - irreps_out, # type: ignore - ) - - model_state_dct[module_key + '.linear.weight'] = ( - erased_modal_weight - ) - model_state_dct[module_key + '.linear.bias'] = new_bias - elif modal_module_name in module_key: - model_state_dct[module_key + '.linear.bias'] = torch.zeros( - output_dim, - dtype=model_state_dct[module_key + '.linear.weight'].dtype, - ) - - final_block_key = 'reduce_hidden_to_energy' - model_state_dct[final_block_key + '.linear.bias'] = torch.tensor( - [0], dtype=model_state_dct[final_block_key + '.linear.weight'].dtype - ) - - if config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SHIFT]: - rescaler_names = [] - if config[KEY.USE_MODAL_WISE_SHIFT]: - rescaler_names.append('shift') - if config[KEY.USE_MODAL_WISE_SCALE]: - rescaler_names.append('scale') - config[KEY.USE_MODAL_WISE_SHIFT] = False - config[KEY.USE_MODAL_WISE_SCALE] = False - for rescaler_name in rescaler_names: - rescaler_key = 'rescale_atomic_energy.' + rescaler_name - rescaler = model_state_dct[rescaler_key][ref_modal_index] - model_state_dct.update({rescaler_key: rescaler}) - config.update({rescaler_name: rescaler}) - - config[KEY.USE_MODALITY] = False - - return model_state_dct - - -def append_modality_to_model_dct( - model_state_dct: dict, - config: dict, - orig_num_modal: int, - append_modal_length: int, -): - """ - Append modal-wise parameters to the original linear layers. - This enables expanding modal to single/multi modal model checkpoint. - - `model_state_dct`: model state dictionary from multimodal checkpoint file - `config`: dictionary containing configuration of the checkpoint model - + modality appended - `orig_num_modal`: Number of modality used in original checkpoint - `append_modal_length`: Number of modality to be appended in new checkpoint. - """ - config_num_modal = config[KEY.NUM_MODALITIES] - config.update({KEY.NUM_MODALITIES: orig_num_modal, KEY.USE_MODALITY: True}) - - model = build_E3_equivariant_model(config) - - for module_key in model._modules.keys(): - for ( - use_modal_module_key, - modal_module_name, - ) in modal_module_dict.items(): - if ( - config[use_modal_module_key] - and modal_module_name in module_key - ): # this module is used for giving modality - irreps_in = model.get_irreps_in( - module_key, 'irreps_in' - ) - # TODO: directly using "irreps_in" might not be compatible - # when changing `nn/linear.py` - irreps_out = model.get_irreps_in(module_key, 'irreps_out') - irreps_in, irreps_out = Irreps(irreps_in), Irreps(irreps_out) - - append_weight = _append_modal_weight( - model_state_dct, - module_key, - irreps_in, # type: ignore - irreps_out, # type: ignore - append_modal_length, - ) - model_state_dct[module_key + '.linear.weight'] = append_weight - config[KEY.NUM_MODALITIES] = config_num_modal - - return model_state_dct +import math +from typing import List + +import torch +import torch.nn as nn +from e3nn.o3 import Irreps, Linear + +import sevenn._keys as KEY +from sevenn.model_build import build_E3_equivariant_model + +modal_module_dict = { + KEY.USE_MODAL_NODE_EMBEDDING: 'onehot_to_feature_x', + KEY.USE_MODAL_SELF_INTER_INTRO: 'self_interaction_1', + KEY.USE_MODAL_SELF_INTER_OUTRO: 'self_interaction_2', + KEY.USE_MODAL_OUTPUT_BLOCK: 'reduce_input_to_hidden', +} + + +def _get_scalar_index(irreps: Irreps): + scalar_indices = [] + for idx, (_, (l, p)) in enumerate(irreps): # noqa + if ( + l == 0 and p == 1 + ): # get index of parameter for scalar (0e), which is used for modality + scalar_indices.append(idx) + + return scalar_indices + + +def _reshape_weight_of_linear( + irreps_in: Irreps, irreps_out: Irreps, weight: torch.Tensor +) -> List[torch.Tensor]: + linear = Linear(irreps_in, irreps_out) + linear.weight = nn.Parameter(weight) + return list(linear.weight_views()) + + +def _erase_linear_modal_params( + model_state_dct: dict, + erase_modal_indices: List[int], + key: str, + irreps_in: Irreps, + irreps_out: Irreps, +): + orig_input_dim = irreps_in.count('0e') + new_input_dim = orig_input_dim - len(erase_modal_indices) + + orig_weight = model_state_dct[key + '.linear.weight'] + scalar_idx = _get_scalar_index(irreps_in) + linear_weight_list = _reshape_weight_of_linear( + irreps_in, irreps_out, orig_weight + ) + + new_weight_list = [] + + for idx, l_p_weight in enumerate(linear_weight_list[:-1]): + new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze() + if idx in scalar_idx: + new_weight = new_weight * math.sqrt(new_input_dim / orig_input_dim) + + new_weight_list.append(new_weight) + + """ + Following works for normalization = `path`, which is not used in SEVENNet + for l_p_weight in linear_weight_list[:-1]: + new_weight_list.append(torch.reshape(l_p_weight, (1, -1)).squeeze()) + """ + + flattened_weight = torch.cat(new_weight_list) + + return flattened_weight + + +def _get_modal_weight_as_bias( + model_state_dct: dict, + key: str, + ref_index: int, + irreps_in: Irreps, + irreps_out: Irreps, +): + assert ref_index != -1 + input_dim = irreps_in.count('0e') + output_dim = irreps_out.count('0e') + orig_weight = model_state_dct[key + '.linear.weight'] + orig_bias = model_state_dct[key + '.linear.bias'] + if len(orig_bias) == 0: + orig_bias = torch.zeros(output_dim, dtype=orig_weight.dtype) + + modal_weight = _reshape_weight_of_linear( + irreps_in, irreps_out, orig_weight + )[-1] + + new_bias = orig_bias + modal_weight[ref_index] / math.sqrt(input_dim) + + return new_bias + + +def _append_modal_weight( + model_state_dct: dict, # state dict to be targeted + key: str, # linear weight modune name + irreps_in: Irreps, # irreps_in before modality append + irreps_out: Irreps, + append_number: int, +): + # This works for normalization = `element`, default in SEVENNet. + # (normalization = `path` is curruently deprecated in SEVENNet.) + input_dim = irreps_in.count('0e') + output_dim = irreps_out.count('0e') + new_input_dim = input_dim + append_number + orig_weight = model_state_dct[key + '.linear.weight'] + scalar_idx = _get_scalar_index(irreps_in) + linear_weight_list = _reshape_weight_of_linear( + irreps_in, irreps_out, orig_weight + ) + + new_weight_list = [] + + # TODO: combine following as function with _erase_linear_modal_params + + for idx, l_p_weight in enumerate(linear_weight_list): + new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze() + if idx in scalar_idx: + new_weight = new_weight * math.sqrt(new_input_dim / input_dim) + + new_weight_list.append(new_weight) + + flattened_weight_list = [] + for l_p_weight in new_weight_list: + flattened_weight_list.append( + torch.reshape(l_p_weight, (1, -1)).squeeze() + ) + flattened_weight = torch.cat(flattened_weight_list) + + append_weight = torch.cat([ + flattened_weight, + torch.zeros(append_number * output_dim, dtype=flattened_weight.dtype), + ]) # zeros: starting from common model + + return append_weight + + +def get_single_modal_model_dct( + model_state_dct: dict, + config: dict, + ref_modal: str, + from_processing_cp: bool = False, + is_deploy: bool = False, +): + """ + Convert multimodal model state dictionary to single modal model. + Modal is selected by `ref_modal` + + `model_state_dct`: model state dictionary from multimodal checkpoint file + `config`: dictionary containing configuration of the checkpoint model + `ref_modal`: modal that are going to be converted + `from_processing_cp`: if True, use modal_map of the checkpoint file + `is_deploy`: if True, model is build with single-modal shift and scale + """ + if ( + not from_processing_cp and not config[KEY.USE_MODALITY] + ): # model is already single modal + return model_state_dct + + config[KEY.USE_BIAS_IN_LINEAR] = True + config['_deploy'] = is_deploy + + model = build_E3_equivariant_model(config) + del config['_deploy'] + key_add = '_cp' if from_processing_cp else '' + modal_type_dict = config[KEY.MODAL_MAP + key_add] + erase_modal_indices = range(len(modal_type_dict.keys())) # starts with 0 + + if ref_modal != 'common': + try: + ref_modal_index = modal_type_dict[ref_modal] + except: + raise KeyError( + f'{ref_modal} not in modal type. Use one of' + f' {modal_type_dict.keys()}.' + ) + + for module_key in model._modules.keys(): + for ( + use_modal_module_key, + modal_module_name, + ) in modal_module_dict.items(): + irreps_out = Irreps(model.get_irreps_in(module_key, 'irreps_out')) + # TODO: directly using "irreps_in" might not be compatible + # when changing `nn/linear.py` + output_dim = irreps_out.count('0e') + if ( + config[use_modal_module_key] + and modal_module_name in module_key + ): # this module is used for giving modality + + irreps_in = Irreps( + model.get_irreps_in(module_key, 'irreps_in') + ) + + new_bias = ( + torch.zeros(output_dim) + if ref_modal == 'common' + else _get_modal_weight_as_bias( + model_state_dct, + module_key, + ref_modal_index, + irreps_in, # type: ignore + irreps_out, # type: ignore + ) + ) + erased_modal_weight = _erase_linear_modal_params( + model_state_dct, + erase_modal_indices, + module_key, + irreps_in, # type: ignore + irreps_out, # type: ignore + ) + + model_state_dct[module_key + '.linear.weight'] = ( + erased_modal_weight + ) + model_state_dct[module_key + '.linear.bias'] = new_bias + elif modal_module_name in module_key: + model_state_dct[module_key + '.linear.bias'] = torch.zeros( + output_dim, + dtype=model_state_dct[module_key + '.linear.weight'].dtype, + ) + + final_block_key = 'reduce_hidden_to_energy' + model_state_dct[final_block_key + '.linear.bias'] = torch.tensor( + [0], dtype=model_state_dct[final_block_key + '.linear.weight'].dtype + ) + + if config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SHIFT]: + rescaler_names = [] + if config[KEY.USE_MODAL_WISE_SHIFT]: + rescaler_names.append('shift') + if config[KEY.USE_MODAL_WISE_SCALE]: + rescaler_names.append('scale') + config[KEY.USE_MODAL_WISE_SHIFT] = False + config[KEY.USE_MODAL_WISE_SCALE] = False + for rescaler_name in rescaler_names: + rescaler_key = 'rescale_atomic_energy.' + rescaler_name + rescaler = model_state_dct[rescaler_key][ref_modal_index] + model_state_dct.update({rescaler_key: rescaler}) + config.update({rescaler_name: rescaler}) + + config[KEY.USE_MODALITY] = False + + return model_state_dct + + +def append_modality_to_model_dct( + model_state_dct: dict, + config: dict, + orig_num_modal: int, + append_modal_length: int, +): + """ + Append modal-wise parameters to the original linear layers. + This enables expanding modal to single/multi modal model checkpoint. + + `model_state_dct`: model state dictionary from multimodal checkpoint file + `config`: dictionary containing configuration of the checkpoint model + + modality appended + `orig_num_modal`: Number of modality used in original checkpoint + `append_modal_length`: Number of modality to be appended in new checkpoint. + """ + config_num_modal = config[KEY.NUM_MODALITIES] + config.update({KEY.NUM_MODALITIES: orig_num_modal, KEY.USE_MODALITY: True}) + + model = build_E3_equivariant_model(config) + + for module_key in model._modules.keys(): + for ( + use_modal_module_key, + modal_module_name, + ) in modal_module_dict.items(): + if ( + config[use_modal_module_key] + and modal_module_name in module_key + ): # this module is used for giving modality + irreps_in = model.get_irreps_in( + module_key, 'irreps_in' + ) + # TODO: directly using "irreps_in" might not be compatible + # when changing `nn/linear.py` + irreps_out = model.get_irreps_in(module_key, 'irreps_out') + irreps_in, irreps_out = Irreps(irreps_in), Irreps(irreps_out) + + append_weight = _append_modal_weight( + model_state_dct, + module_key, + irreps_in, # type: ignore + irreps_out, # type: ignore + append_modal_length, + ) + model_state_dct[module_key + '.linear.weight'] = append_weight + config[KEY.NUM_MODALITIES] = config_num_modal + + return model_state_dct diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/deploy.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/deploy.py index c0695794d9c8ed5476fa8067532fff305368b104..51ded15c579768bb467b1f84c8bb178e390c16ff 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/deploy.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/deploy.py @@ -1,148 +1,148 @@ -import os -from datetime import datetime -from typing import Optional - -import e3nn.util.jit -import torch -import torch.nn -from ase.data import chemical_symbols - -import sevenn._keys as KEY -from sevenn import __version__ -from sevenn.model_build import build_E3_equivariant_model -from sevenn.util import load_checkpoint - - -def deploy(checkpoint, fname='deployed_serial.pt', modal: Optional[str] = None): - """ - This method is messy to avoid changes in pair_e3gnn.cpp, while - refactoring python part. - If changes the behavior, and accordingly pair_e3gnn.cpp, - we have to recompile LAMMPS (which I always want to procrastinate) - """ - from sevenn.nn.edge_embedding import EdgePreprocess - from sevenn.nn.force_output import ForceStressOutput - - cp = load_checkpoint(checkpoint) - model, config = cp.build_model('e3nn'), cp.config - - model.prepand_module('edge_preprocess', EdgePreprocess(True)) - grad_module = ForceStressOutput() - model.replace_module('force_output', grad_module) - new_grad_key = grad_module.get_grad_key() - model.key_grad = new_grad_key - if hasattr(model, 'eval_type_map'): - setattr(model, 'eval_type_map', False) - - if modal: - model.prepare_modal_deploy(modal) - elif model.modal_map is not None and len(model.modal_map) >= 1: - raise ValueError( - f'Modal is not given. It has: {list(model.modal_map.keys())}' - ) - - model.set_is_batch_data(False) - model.eval() - - model = e3nn.util.jit.script(model) - model = torch.jit.freeze(model) - - # make some config need for md - md_configs = {} - type_map = config[KEY.TYPE_MAP] - chem_list = '' - for Z in type_map.keys(): - chem_list += chemical_symbols[Z] + ' ' - chem_list.strip() - md_configs.update({'chemical_symbols_to_index': chem_list}) - md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) - md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) - md_configs.update( - {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} - ) - md_configs.update({'version': __version__}) - md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) - md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) - - if fname.endswith('.pt') is False: - fname += '.pt' - torch.jit.save(model, fname, _extra_files=md_configs) - - -# TODO: build model only once -def deploy_parallel( - checkpoint, fname='deployed_parallel', modal: Optional[str] = None -): - # Additional layer for ghost atom (and copy parameters from original) - GHOST_LAYERS_KEYS = ['onehot_to_feature_x', '0_self_interaction_1'] - - cp = load_checkpoint(checkpoint) - model, config = cp.build_model('e3nn'), cp.config - config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': False} - model_state_dct = model.state_dict() - - model_list = build_E3_equivariant_model(config, parallel=True) - dct_temp = {} - copy_counter = {gk: 0 for gk in GHOST_LAYERS_KEYS} - for ghost_layer_key in GHOST_LAYERS_KEYS: - for key, val in model_state_dct.items(): - if not key.startswith(ghost_layer_key): - continue - dct_temp.update({f'ghost_{key}': val}) - copy_counter[ghost_layer_key] += 1 - # Ensure reference weights are copied from state dict - assert all(x > 0 for x in copy_counter.values()) - - model_state_dct.update(dct_temp) - - for model_part in model_list: - missing, _ = model_part.load_state_dict(model_state_dct, strict=False) - if hasattr(model_part, 'eval_type_map'): - setattr(model_part, 'eval_type_map', False) - # Ensure all values are inserted - assert len(missing) == 0, missing - - if modal: - model_list[0].prepare_modal_deploy(modal) - elif model_list[0].modal_map is not None: - raise ValueError( - f'Modal is not given. It has: {list(model_list[0].modal_map.keys())}' - ) - - # prepare some extra information for MD - md_configs = {} - type_map = config[KEY.TYPE_MAP] - - chem_list = '' - for Z in type_map.keys(): - chem_list += chemical_symbols[Z] + ' ' - chem_list.strip() - - comm_size = max( - [ - seg._modules[f'{t}_convolution']._comm_size # type: ignore - for t, seg in enumerate(model_list) - ] - ) - - md_configs.update({'chemical_symbols_to_index': chem_list}) - md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) - md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) - md_configs.update({'comm_size': str(comm_size)}) - md_configs.update( - {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} - ) - md_configs.update({'version': __version__}) - md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) - md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) - - os.makedirs(fname) - for idx, model in enumerate(model_list): - fname_full = f'{fname}/deployed_parallel_{idx}.pt' - model.set_is_batch_data(False) - model.eval() - - model = e3nn.util.jit.script(model) - model = torch.jit.freeze(model) - - torch.jit.save(model, fname_full, _extra_files=md_configs) +import os +from datetime import datetime +from typing import Optional + +import e3nn.util.jit +import torch +import torch.nn +from ase.data import chemical_symbols + +import sevenn._keys as KEY +from sevenn import __version__ +from sevenn.model_build import build_E3_equivariant_model +from sevenn.util import load_checkpoint + + +def deploy(checkpoint, fname='deployed_serial.pt', modal: Optional[str] = None): + """ + This method is messy to avoid changes in pair_e3gnn.cpp, while + refactoring python part. + If changes the behavior, and accordingly pair_e3gnn.cpp, + we have to recompile LAMMPS (which I always want to procrastinate) + """ + from sevenn.nn.edge_embedding import EdgePreprocess + from sevenn.nn.force_output import ForceStressOutput + + cp = load_checkpoint(checkpoint) + model, config = cp.build_model('e3nn'), cp.config + + model.prepand_module('edge_preprocess', EdgePreprocess(True)) + grad_module = ForceStressOutput() + model.replace_module('force_output', grad_module) + new_grad_key = grad_module.get_grad_key() + model.key_grad = new_grad_key + if hasattr(model, 'eval_type_map'): + setattr(model, 'eval_type_map', False) + + if modal: + model.prepare_modal_deploy(modal) + elif model.modal_map is not None and len(model.modal_map) >= 1: + raise ValueError( + f'Modal is not given. It has: {list(model.modal_map.keys())}' + ) + + model.set_is_batch_data(False) + model.eval() + + model = e3nn.util.jit.script(model) + model = torch.jit.freeze(model) + + # make some config need for md + md_configs = {} + type_map = config[KEY.TYPE_MAP] + chem_list = '' + for Z in type_map.keys(): + chem_list += chemical_symbols[Z] + ' ' + chem_list.strip() + md_configs.update({'chemical_symbols_to_index': chem_list}) + md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) + md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) + md_configs.update( + {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} + ) + md_configs.update({'version': __version__}) + md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) + md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) + + if fname.endswith('.pt') is False: + fname += '.pt' + torch.jit.save(model, fname, _extra_files=md_configs) + + +# TODO: build model only once +def deploy_parallel( + checkpoint, fname='deployed_parallel', modal: Optional[str] = None +): + # Additional layer for ghost atom (and copy parameters from original) + GHOST_LAYERS_KEYS = ['onehot_to_feature_x', '0_self_interaction_1'] + + cp = load_checkpoint(checkpoint) + model, config = cp.build_model('e3nn'), cp.config + config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': False} + model_state_dct = model.state_dict() + + model_list = build_E3_equivariant_model(config, parallel=True) + dct_temp = {} + copy_counter = {gk: 0 for gk in GHOST_LAYERS_KEYS} + for ghost_layer_key in GHOST_LAYERS_KEYS: + for key, val in model_state_dct.items(): + if not key.startswith(ghost_layer_key): + continue + dct_temp.update({f'ghost_{key}': val}) + copy_counter[ghost_layer_key] += 1 + # Ensure reference weights are copied from state dict + assert all(x > 0 for x in copy_counter.values()) + + model_state_dct.update(dct_temp) + + for model_part in model_list: + missing, _ = model_part.load_state_dict(model_state_dct, strict=False) + if hasattr(model_part, 'eval_type_map'): + setattr(model_part, 'eval_type_map', False) + # Ensure all values are inserted + assert len(missing) == 0, missing + + if modal: + model_list[0].prepare_modal_deploy(modal) + elif model_list[0].modal_map is not None: + raise ValueError( + f'Modal is not given. It has: {list(model_list[0].modal_map.keys())}' + ) + + # prepare some extra information for MD + md_configs = {} + type_map = config[KEY.TYPE_MAP] + + chem_list = '' + for Z in type_map.keys(): + chem_list += chemical_symbols[Z] + ' ' + chem_list.strip() + + comm_size = max( + [ + seg._modules[f'{t}_convolution']._comm_size # type: ignore + for t, seg in enumerate(model_list) + ] + ) + + md_configs.update({'chemical_symbols_to_index': chem_list}) + md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) + md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) + md_configs.update({'comm_size': str(comm_size)}) + md_configs.update( + {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} + ) + md_configs.update({'version': __version__}) + md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) + md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) + + os.makedirs(fname) + for idx, model in enumerate(model_list): + fname_full = f'{fname}/deployed_parallel_{idx}.pt' + model.set_is_batch_data(False) + model.eval() + + model = e3nn.util.jit.script(model) + model = torch.jit.freeze(model) + + torch.jit.save(model, fname_full, _extra_files=md_configs) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/graph_build.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/graph_build.py index f0364750b094d70cf24aacd1efc0c79fb835a34f..af1b162c99b18efd00c567d2256eac91710d584d 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/graph_build.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/graph_build.py @@ -1,119 +1,119 @@ -import os -from typing import List, Optional - -from sevenn.logger import Logger -from sevenn.train.dataset import AtomGraphDataset -from sevenn.util import unique_filepath - - -def build_sevennet_graph_dataset( - source: List[str], - cutoff: float, - num_cores: int, - out: str, - filename: str, - metadata: Optional[dict] = None, - **fmt_kwargs, -): - from sevenn.train.graph_dataset import SevenNetGraphDataset - - log = Logger() - if metadata is None: - metadata = {} - - log.timer_start('graph_build') - db = SevenNetGraphDataset( - cutoff=cutoff, - root=out, - files=source, - processed_name=filename, - process_num_cores=num_cores, - **fmt_kwargs, - ) - log.timer_end('graph_build', 'graph build time') - log.writeline(f'Graph saved: {db.processed_paths[0]}') - - log.bar() - for k, v in metadata.items(): - log.format_k_v(k, v, write=True) - log.bar() - - log.writeline('Distribution:') - log.statistic_write(db.statistics) - log.format_k_v('# atoms (node)', db.natoms, write=True) - log.format_k_v('# structures (graph)', len(db), write=True) - - -def dataset_finalize(dataset, metadata, out): - """ - Deprecated - """ - natoms = dataset.get_natoms() - species = dataset.get_species() - metadata = { - **metadata, - 'natoms': natoms, - 'species': species, - } - dataset.meta = metadata - - if os.path.isdir(out): - out = os.path.join(out, 'graph_built.sevenn_data') - elif out.endswith('.sevenn_data') is False: - out = out + '.sevenn_data' - out = unique_filepath(out) - - log = Logger() - log.writeline('The metadata of the dataset is...') - for k, v in metadata.items(): - log.format_k_v(k, v, write=True) - dataset.save(out) - log.writeline(f'dataset is saved to {out}') - - return dataset - - -def build_script( - source: List[str], - cutoff: float, - num_cores: int, - out: str, - metadata: Optional[dict] = None, - **fmt_kwargs, -): - """ - Deprecated - """ - from sevenn.train.dataload import file_to_dataset, match_reader - - if metadata is None: - metadata = {} - log = Logger() - - dataset = AtomGraphDataset({}, cutoff) - common_args = { - 'cutoff': cutoff, - 'cores': num_cores, - 'label': 'graph_build', - } - log.timer_start('graph_build') - for path in source: - if os.path.isdir(path): - continue - log.writeline(f'Read: {path}') - basename = os.path.basename(path) - if 'structure_list' in basename: - fmt = 'structure_list' - else: - fmt = 'ase' - reader, rmeta = match_reader(fmt, **fmt_kwargs) - metadata.update(**rmeta) - dataset.augment( - file_to_dataset( - file=path, - reader=reader, - **common_args, - ) - ) - log.timer_end('graph_build', 'graph build time') - dataset_finalize(dataset, metadata, out) +import os +from typing import List, Optional + +from sevenn.logger import Logger +from sevenn.train.dataset import AtomGraphDataset +from sevenn.util import unique_filepath + + +def build_sevennet_graph_dataset( + source: List[str], + cutoff: float, + num_cores: int, + out: str, + filename: str, + metadata: Optional[dict] = None, + **fmt_kwargs, +): + from sevenn.train.graph_dataset import SevenNetGraphDataset + + log = Logger() + if metadata is None: + metadata = {} + + log.timer_start('graph_build') + db = SevenNetGraphDataset( + cutoff=cutoff, + root=out, + files=source, + processed_name=filename, + process_num_cores=num_cores, + **fmt_kwargs, + ) + log.timer_end('graph_build', 'graph build time') + log.writeline(f'Graph saved: {db.processed_paths[0]}') + + log.bar() + for k, v in metadata.items(): + log.format_k_v(k, v, write=True) + log.bar() + + log.writeline('Distribution:') + log.statistic_write(db.statistics) + log.format_k_v('# atoms (node)', db.natoms, write=True) + log.format_k_v('# structures (graph)', len(db), write=True) + + +def dataset_finalize(dataset, metadata, out): + """ + Deprecated + """ + natoms = dataset.get_natoms() + species = dataset.get_species() + metadata = { + **metadata, + 'natoms': natoms, + 'species': species, + } + dataset.meta = metadata + + if os.path.isdir(out): + out = os.path.join(out, 'graph_built.sevenn_data') + elif out.endswith('.sevenn_data') is False: + out = out + '.sevenn_data' + out = unique_filepath(out) + + log = Logger() + log.writeline('The metadata of the dataset is...') + for k, v in metadata.items(): + log.format_k_v(k, v, write=True) + dataset.save(out) + log.writeline(f'dataset is saved to {out}') + + return dataset + + +def build_script( + source: List[str], + cutoff: float, + num_cores: int, + out: str, + metadata: Optional[dict] = None, + **fmt_kwargs, +): + """ + Deprecated + """ + from sevenn.train.dataload import file_to_dataset, match_reader + + if metadata is None: + metadata = {} + log = Logger() + + dataset = AtomGraphDataset({}, cutoff) + common_args = { + 'cutoff': cutoff, + 'cores': num_cores, + 'label': 'graph_build', + } + log.timer_start('graph_build') + for path in source: + if os.path.isdir(path): + continue + log.writeline(f'Read: {path}') + basename = os.path.basename(path) + if 'structure_list' in basename: + fmt = 'structure_list' + else: + fmt = 'ase' + reader, rmeta = match_reader(fmt, **fmt_kwargs) + metadata.update(**rmeta) + dataset.augment( + file_to_dataset( + file=path, + reader=reader, + **common_args, + ) + ) + log.timer_end('graph_build', 'graph build time') + dataset_finalize(dataset, metadata, out) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/inference.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/inference.py index 355cc25593663d8fdfc9ea9892528a5b15a8f060..93c998f4150b18695f672b79ff56297f8b671d12 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/inference.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/inference.py @@ -1,227 +1,227 @@ -import csv -import os -from typing import Iterable, List, Optional, Union - -import numpy as np -from torch_geometric.loader import DataLoader -from tqdm import tqdm - -import sevenn._keys as KEY -import sevenn.util as util -from sevenn.atom_graph_data import AtomGraphData -from sevenn.train.graph_dataset import SevenNetGraphDataset -from sevenn.train.modal_dataset import SevenNetMultiModalDataset - - -def write_inference_csv(output_list, out): - for i, output in enumerate(output_list): - output = output.fit_dimension() - output[KEY.STRESS] = output[KEY.STRESS] * 1602.1766208 - output[KEY.PRED_STRESS] = output[KEY.PRED_STRESS] * 1602.1766208 - output_list[i] = output.to_numpy_dict() - - per_graph_keys = [ - KEY.NUM_ATOMS, - KEY.USER_LABEL, - KEY.ENERGY, - KEY.PRED_TOTAL_ENERGY, - KEY.STRESS, - KEY.PRED_STRESS, - ] - - per_atom_keys = [ - KEY.ATOMIC_NUMBERS, - KEY.ATOMIC_ENERGY, - KEY.POS, - KEY.FORCE, - KEY.PRED_FORCE, - ] - - def unfold_dct_val(dct, keys, suffix_list=None): - res = {} - if suffix_list is None: - suffix_list = range(100) - for k in keys: - if k not in dct: - res[k] = '-' - elif isinstance(dct[k], np.ndarray) and dct[k].ndim != 0: - res.update( - {f'{k}_{suffix_list[i]}': v for i, v in enumerate(dct[k])} - ) - else: - res[k] = dct[k] - return res - - def per_atom_dct_list(dct, keys): - sfx_list = ['x', 'y', 'z'] - res = [] - natoms = dct[KEY.NUM_ATOMS] - extracted = {k: dct[k] for k in keys} - for i in range(natoms): - raw = {} - raw.update({k: v[i] for k, v in extracted.items()}) - per_atom_dct = unfold_dct_val(raw, keys, suffix_list=sfx_list) - res.append(per_atom_dct) - return res - - try: - with open(f'{out}/info.csv', 'w', newline='') as f: - header = output_list[0][KEY.INFO].keys() - writer = csv.DictWriter(f, fieldnames=header) - writer.writeheader() - for output in output_list: - writer.writerow(output[KEY.INFO]) - except (KeyError, TypeError, AttributeError, csv.Error) as e: - print(e) - print('failed to write meta data, info.csv is not written') - - with open(f'{out}/per_graph.csv', 'w', newline='') as f: - sfx_list = ['xx', 'yy', 'zz', 'xy', 'yz', 'zx'] # for stress - writer = None - for output in output_list: - cell_dct = {KEY.CELL: output[KEY.CELL]} - cell_dct = unfold_dct_val(cell_dct, [KEY.CELL], ['a', 'b', 'c']) - data = { - **unfold_dct_val(output, per_graph_keys, sfx_list), - **cell_dct, - } - if writer is None: - writer = csv.DictWriter(f, fieldnames=data.keys()) - writer.writeheader() - writer.writerow(data) - - with open(f'{out}/per_atom.csv', 'w', newline='') as f: - writer = None - for i, output in enumerate(output_list): - list_of_dct = per_atom_dct_list(output, per_atom_keys) - for j, dct in enumerate(list_of_dct): - idx_dct = {'stct_id': i, 'atom_id': j} - data = {**idx_dct, **dct} - if writer is None: - writer = csv.DictWriter(f, fieldnames=data.keys()) - writer.writeheader() - writer.writerow(data) - - -def _patch_data_info( - graph_list: Iterable[AtomGraphData], full_file_list: List[str] -) -> None: - keys = set() - for graph, path in zip(graph_list, full_file_list): - if KEY.INFO not in graph: - graph[KEY.INFO] = {} - graph[KEY.INFO].update({'file': os.path.abspath(path)}) - keys.update(graph[KEY.INFO].keys()) - - # save only safe subset of info (for batching) - for graph in graph_list: - info_dict = graph[KEY.INFO] - info_dict.update({k: '' for k in keys if k not in info_dict}) - - -def inference( - checkpoint: str, - targets: Union[str, List[str]], - output_dir: str, - num_workers: int = 1, - device: str = 'cpu', - batch_size: int = 4, - save_graph: bool = False, - allow_unlabeled: bool = False, - modal: Optional[str] = None, - **data_kwargs, -) -> None: - """ - Inference model on the target dataset, writes - per_graph, per_atom inference results in csv format - to the output_dir - If a given target doesn't have EFS key, it puts dummy - values. - - Args: - checkpoint: model checkpoint path, - target: path, or list of path to evaluate. Supports - ASE readable, sevenn_data/*.pt, .sevenn_data, and - structure_list - output_dir: directory to write results - num_workers: number of workers to build graph - device: device to evaluate, defaults to 'auto' - batch_size: batch size for inference - save_grpah: if True, save preprocessed graph to output dir - data_kwargs: keyword arguments used when reading targets, - for example, given index='-1', only the last snapshot - will be evaluated if it was ASE readable. - While this function can handle different types of targets - at once, it will not work smoothly with data_kwargs - - """ - model, _ = util.model_from_checkpoint(checkpoint) - cutoff = model.cutoff - - if modal: - if model.modal_map is None: - raise ValueError('Modality given, but model has no modal_map') - if modal not in model.modal_map: - _modals = list(model.modal_map.keys()) - raise ValueError(f'Unknown modal {modal} (not in {_modals})') - - if isinstance(targets, str): - targets = [targets] - - full_file_list = [] - if save_graph: - dataset = SevenNetGraphDataset( - cutoff=cutoff, - root=output_dir, - files=targets, - process_num_cores=num_workers, - processed_name='saved_graph.pt', - **data_kwargs, - ) - full_file_list = dataset.full_file_list # TODO: not used currently - else: - dataset = [] - for file in targets: - tmplist = SevenNetGraphDataset.file_to_graph_list( - file, - cutoff=cutoff, - num_cores=num_workers, - allow_unlabeled=allow_unlabeled, - **data_kwargs, - ) - dataset.extend(tmplist) - full_file_list.extend([os.path.abspath(file)] * len(tmplist)) - if ( - full_file_list is not None - and len(full_file_list) == len(dataset) - and not isinstance(dataset, SevenNetGraphDataset) - ): - _patch_data_info(dataset, full_file_list) # type: ignore - - if modal: - dataset = SevenNetMultiModalDataset({modal: dataset}) # type: ignore - - loader = DataLoader(dataset, batch_size, shuffle=False) # type: ignore - - model.to(device) - model.set_is_batch_data(True) - model.eval() - - rec = util.get_error_recorder() - output_list = [] - - for batch in tqdm(loader): - batch = batch.to(device) - output = model(batch).detach().cpu() - rec.update(output) - output_list.extend(util.to_atom_graph_list(output)) - - errors = rec.epoch_forward() - - if not os.path.exists(output_dir): - os.makedirs(output_dir) - with open(os.path.join(output_dir, 'errors.txt'), 'w', encoding='utf-8') as f: - for key, val in errors.items(): - f.write(f'{key}: {val}\n') - - write_inference_csv(output_list, output_dir) +import csv +import os +from typing import Iterable, List, Optional, Union + +import numpy as np +from torch_geometric.loader import DataLoader +from tqdm import tqdm + +import sevenn._keys as KEY +import sevenn.util as util +from sevenn.atom_graph_data import AtomGraphData +from sevenn.train.graph_dataset import SevenNetGraphDataset +from sevenn.train.modal_dataset import SevenNetMultiModalDataset + + +def write_inference_csv(output_list, out): + for i, output in enumerate(output_list): + output = output.fit_dimension() + output[KEY.STRESS] = output[KEY.STRESS] * 1602.1766208 + output[KEY.PRED_STRESS] = output[KEY.PRED_STRESS] * 1602.1766208 + output_list[i] = output.to_numpy_dict() + + per_graph_keys = [ + KEY.NUM_ATOMS, + KEY.USER_LABEL, + KEY.ENERGY, + KEY.PRED_TOTAL_ENERGY, + KEY.STRESS, + KEY.PRED_STRESS, + ] + + per_atom_keys = [ + KEY.ATOMIC_NUMBERS, + KEY.ATOMIC_ENERGY, + KEY.POS, + KEY.FORCE, + KEY.PRED_FORCE, + ] + + def unfold_dct_val(dct, keys, suffix_list=None): + res = {} + if suffix_list is None: + suffix_list = range(100) + for k in keys: + if k not in dct: + res[k] = '-' + elif isinstance(dct[k], np.ndarray) and dct[k].ndim != 0: + res.update( + {f'{k}_{suffix_list[i]}': v for i, v in enumerate(dct[k])} + ) + else: + res[k] = dct[k] + return res + + def per_atom_dct_list(dct, keys): + sfx_list = ['x', 'y', 'z'] + res = [] + natoms = dct[KEY.NUM_ATOMS] + extracted = {k: dct[k] for k in keys} + for i in range(natoms): + raw = {} + raw.update({k: v[i] for k, v in extracted.items()}) + per_atom_dct = unfold_dct_val(raw, keys, suffix_list=sfx_list) + res.append(per_atom_dct) + return res + + try: + with open(f'{out}/info.csv', 'w', newline='') as f: + header = output_list[0][KEY.INFO].keys() + writer = csv.DictWriter(f, fieldnames=header) + writer.writeheader() + for output in output_list: + writer.writerow(output[KEY.INFO]) + except (KeyError, TypeError, AttributeError, csv.Error) as e: + print(e) + print('failed to write meta data, info.csv is not written') + + with open(f'{out}/per_graph.csv', 'w', newline='') as f: + sfx_list = ['xx', 'yy', 'zz', 'xy', 'yz', 'zx'] # for stress + writer = None + for output in output_list: + cell_dct = {KEY.CELL: output[KEY.CELL]} + cell_dct = unfold_dct_val(cell_dct, [KEY.CELL], ['a', 'b', 'c']) + data = { + **unfold_dct_val(output, per_graph_keys, sfx_list), + **cell_dct, + } + if writer is None: + writer = csv.DictWriter(f, fieldnames=data.keys()) + writer.writeheader() + writer.writerow(data) + + with open(f'{out}/per_atom.csv', 'w', newline='') as f: + writer = None + for i, output in enumerate(output_list): + list_of_dct = per_atom_dct_list(output, per_atom_keys) + for j, dct in enumerate(list_of_dct): + idx_dct = {'stct_id': i, 'atom_id': j} + data = {**idx_dct, **dct} + if writer is None: + writer = csv.DictWriter(f, fieldnames=data.keys()) + writer.writeheader() + writer.writerow(data) + + +def _patch_data_info( + graph_list: Iterable[AtomGraphData], full_file_list: List[str] +) -> None: + keys = set() + for graph, path in zip(graph_list, full_file_list): + if KEY.INFO not in graph: + graph[KEY.INFO] = {} + graph[KEY.INFO].update({'file': os.path.abspath(path)}) + keys.update(graph[KEY.INFO].keys()) + + # save only safe subset of info (for batching) + for graph in graph_list: + info_dict = graph[KEY.INFO] + info_dict.update({k: '' for k in keys if k not in info_dict}) + + +def inference( + checkpoint: str, + targets: Union[str, List[str]], + output_dir: str, + num_workers: int = 1, + device: str = 'cpu', + batch_size: int = 4, + save_graph: bool = False, + allow_unlabeled: bool = False, + modal: Optional[str] = None, + **data_kwargs, +) -> None: + """ + Inference model on the target dataset, writes + per_graph, per_atom inference results in csv format + to the output_dir + If a given target doesn't have EFS key, it puts dummy + values. + + Args: + checkpoint: model checkpoint path, + target: path, or list of path to evaluate. Supports + ASE readable, sevenn_data/*.pt, .sevenn_data, and + structure_list + output_dir: directory to write results + num_workers: number of workers to build graph + device: device to evaluate, defaults to 'auto' + batch_size: batch size for inference + save_grpah: if True, save preprocessed graph to output dir + data_kwargs: keyword arguments used when reading targets, + for example, given index='-1', only the last snapshot + will be evaluated if it was ASE readable. + While this function can handle different types of targets + at once, it will not work smoothly with data_kwargs + + """ + model, _ = util.model_from_checkpoint(checkpoint) + cutoff = model.cutoff + + if modal: + if model.modal_map is None: + raise ValueError('Modality given, but model has no modal_map') + if modal not in model.modal_map: + _modals = list(model.modal_map.keys()) + raise ValueError(f'Unknown modal {modal} (not in {_modals})') + + if isinstance(targets, str): + targets = [targets] + + full_file_list = [] + if save_graph: + dataset = SevenNetGraphDataset( + cutoff=cutoff, + root=output_dir, + files=targets, + process_num_cores=num_workers, + processed_name='saved_graph.pt', + **data_kwargs, + ) + full_file_list = dataset.full_file_list # TODO: not used currently + else: + dataset = [] + for file in targets: + tmplist = SevenNetGraphDataset.file_to_graph_list( + file, + cutoff=cutoff, + num_cores=num_workers, + allow_unlabeled=allow_unlabeled, + **data_kwargs, + ) + dataset.extend(tmplist) + full_file_list.extend([os.path.abspath(file)] * len(tmplist)) + if ( + full_file_list is not None + and len(full_file_list) == len(dataset) + and not isinstance(dataset, SevenNetGraphDataset) + ): + _patch_data_info(dataset, full_file_list) # type: ignore + + if modal: + dataset = SevenNetMultiModalDataset({modal: dataset}) # type: ignore + + loader = DataLoader(dataset, batch_size, shuffle=False) # type: ignore + + model.to(device) + model.set_is_batch_data(True) + model.eval() + + rec = util.get_error_recorder() + output_list = [] + + for batch in tqdm(loader): + batch = batch.to(device) + output = model(batch).detach().cpu() + rec.update(output) + output_list.extend(util.to_atom_graph_list(output)) + + errors = rec.epoch_forward() + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, 'errors.txt'), 'w', encoding='utf-8') as f: + for key, val in errors.items(): + f.write(f'{key}: {val}\n') + + write_inference_csv(output_list, output_dir) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_continue.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_continue.py index 8ad9db19a14904bdb0cdac600a180b42864b5f1f..2537684eea413c016596505d32af1c2f35b723d6 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_continue.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_continue.py @@ -1,273 +1,273 @@ -import os -import warnings - -import torch - -import sevenn._keys as KEY -import sevenn.util as util -from sevenn.logger import Logger -from sevenn.scripts.convert_model_modality import ( - append_modality_to_model_dct, - get_single_modal_model_dct, -) - - -def processing_continue_v2(config): # simpler - """ - Replacement of processing_continue, - Skips model compatibility - """ - log = Logger() - continue_dct = config[KEY.CONTINUE] - log.write('\nContinue found, loading checkpoint\n') - - checkpoint = util.load_checkpoint(continue_dct[KEY.CHECKPOINT]) - model_cp = checkpoint.build_model() - config_cp = checkpoint.config - model_state_dict_cp = model_cp.state_dict() - - optimizer_state_dict_cp = ( - checkpoint.optimizer_state_dict - if not continue_dct[KEY.RESET_OPTIMIZER] - else None - ) - scheduler_state_dict_cp = ( - checkpoint.scheduler_state_dict - if not continue_dct[KEY.RESET_SCHEDULER] - else None - ) - - # use_statistic_value_of_checkpoint always True - # Overwrite config from model state dict, so graph_dataset.from_config - # will not put statistic values to shift, scale, and conv_denominator - config[KEY.SHIFT] = model_state_dict_cp['rescale_atomic_energy.shift'].tolist() - config[KEY.SCALE] = model_state_dict_cp['rescale_atomic_energy.scale'].tolist() - conv_denom = [] - for i in range(config_cp[KEY.NUM_CONVOLUTION]): - conv_denom.append(model_state_dict_cp[f'{i}_convolution.denominator'].item()) - config[KEY.CONV_DENOMINATOR] = conv_denom - log.writeline( - f'{KEY.SHIFT}, {KEY.SCALE}, and {KEY.CONV_DENOMINATOR} are ' - + 'overwritten by model_state_dict of checkpoint' - ) - - chem_keys = [ - KEY.TYPE_MAP, - KEY.NUM_SPECIES, - KEY.CHEMICAL_SPECIES, - KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, - ] - config.update({k: config_cp[k] for k in chem_keys}) - log.writeline( - 'chemical_species are overwritten by checkpoint. ' - + f'This model knows {config[KEY.NUM_SPECIES]} species' - ) - - if config_cp.get(KEY.USE_MODALITY, False) != config.get(KEY.USE_MODALITY): - raise ValueError('use_modality is not same. Check sevenn_cp') - - modal_map = config_cp.get(KEY.MODAL_MAP, None) # dict | None - if modal_map and len(modal_map) > 0: - modalities = list(modal_map.keys()) - log.writeline(f'Multimodal model found: {modalities}') - log.writeline('use_modality: True') - config[KEY.USE_MODALITY] = True - - from_epoch = checkpoint.epoch or 0 - log.writeline(f'Checkpoint previous epoch was: {from_epoch}') - epoch = 1 if continue_dct[KEY.RESET_EPOCH] else from_epoch + 1 - log.writeline(f'epoch start from {epoch}') - - log.writeline('checkpoint loading successful') - - state_dicts = [ - model_state_dict_cp, - optimizer_state_dict_cp, - scheduler_state_dict_cp, - ] - return state_dicts, epoch - - -def check_config_compatible(config, config_cp): - # TODO: check more - SHOULD_BE_SAME = [ - KEY.NODE_FEATURE_MULTIPLICITY, - KEY.LMAX, - KEY.IS_PARITY, - KEY.RADIAL_BASIS, - KEY.CUTOFF_FUNCTION, - KEY.CUTOFF, - KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS, - KEY.NUM_CONVOLUTION, - KEY.USE_BIAS_IN_LINEAR, - KEY.SELF_CONNECTION_TYPE, - ] - for sbs in SHOULD_BE_SAME: - if config[sbs] == config_cp[sbs]: - continue - if sbs == KEY.SELF_CONNECTION_TYPE and config_cp[sbs] == 'MACE': - warnings.warn( - 'We do not support this version of checkpoints to continue ' - "Please use self_connection_type='linear' in input.yaml " - 'and train from scratch', - UserWarning, - ) - raise ValueError( - f'Value of {sbs} should be same. {config[sbs]} != {config_cp[sbs]}' - ) - - try: - cntdct = config[KEY.CONTINUE] - except KeyError: - return - - TRAINABLE_CONFIGS = [KEY.TRAIN_DENOMINTAOR, KEY.TRAIN_SHIFT_SCALE] - if ( - any((not cntdct[KEY.RESET_SCHEDULER], not cntdct[KEY.RESET_OPTIMIZER])) - and all(config[k] == config_cp[k] for k in TRAINABLE_CONFIGS) is False - ): - raise ValueError( - 'reset optimizer and scheduler if you want to change ' - + 'trainable configs' - ) - - # TODO add conition for changed optim/scheduler but not reset - - -def processing_continue(config): - log = Logger() - continue_dct = config[KEY.CONTINUE] - log.write('\nContinue found, loading checkpoint\n') - - checkpoint = torch.load( - continue_dct[KEY.CHECKPOINT], map_location='cpu', weights_only=False - ) - config_cp = checkpoint['config'] - - model_cp, config_cp = util.model_from_checkpoint(checkpoint) - model_state_dict_cp = model_cp.state_dict() - - # it will raise error if not compatible - check_config_compatible(config, config_cp) - log.write('Checkpoint config is compatible\n') - - # for backward compat. - config.update({KEY._NORMALIZE_SPH: config_cp[KEY._NORMALIZE_SPH]}) - - from_epoch = checkpoint['epoch'] - optimizer_state_dict_cp = ( - checkpoint['optimizer_state_dict'] - if not continue_dct[KEY.RESET_OPTIMIZER] - else None - ) - scheduler_state_dict_cp = ( - checkpoint['scheduler_state_dict'] - if not continue_dct[KEY.RESET_SCHEDULER] - else None - ) - - # These could be changed based on given continue_input.yaml - # ex) adapt to statistics of fine-tuning dataset - shift_cp = model_state_dict_cp['rescale_atomic_energy.shift'].numpy() - del model_state_dict_cp['rescale_atomic_energy.shift'] - scale_cp = model_state_dict_cp['rescale_atomic_energy.scale'].numpy() - del model_state_dict_cp['rescale_atomic_energy.scale'] - conv_denominators = [] - for i in range(config_cp[KEY.NUM_CONVOLUTION]): - conv_denominators.append( - (model_state_dict_cp[f'{i}_convolution.denominator']).item() - ) - del model_state_dict_cp[f'{i}_convolution.denominator'] - - # Further handled by processing_dataset.py - config.update({ - KEY.SHIFT + '_cp': shift_cp, - KEY.SCALE + '_cp': scale_cp, - KEY.CONV_DENOMINATOR + '_cp': conv_denominators, - }) - - chem_keys = [ - KEY.TYPE_MAP, - KEY.NUM_SPECIES, - KEY.CHEMICAL_SPECIES, - KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, - ] - config.update({k: config_cp[k] for k in chem_keys}) - - if ( - KEY.USE_MODALITY in config_cp.keys() and config_cp[KEY.USE_MODALITY] - ): # checkpoint model is multimodal - config.update({ - KEY.MODAL_MAP + '_cp': config_cp[KEY.MODAL_MAP], - KEY.USE_MODALITY + '_cp': True, - KEY.NUM_MODALITIES + '_cp': len(config_cp[KEY.MODAL_MAP]), - }) - else: - config.update({ - KEY.MODAL_MAP + '_cp': {}, - KEY.USE_MODALITY + '_cp': False, - KEY.NUM_MODALITIES + '_cp': 0, - }) - - log.write(f'checkpoint previous epoch was: {from_epoch}\n') - - # decide start epoch - reset_epoch = continue_dct[KEY.RESET_EPOCH] - if reset_epoch: - start_epoch = 1 - log.write('epoch reset to 1\n') - else: - start_epoch = from_epoch + 1 - log.write(f'epoch start from {start_epoch}\n') - - # decide csv file to continue - init_csv = True - csv_fname = config_cp[KEY.CSV_LOG] - if os.path.isfile(csv_fname): - # I hope python compare dict well - if config_cp[KEY.ERROR_RECORD] == config[KEY.ERROR_RECORD]: - log.writeline('Same metric, csv file will be appended') - init_csv = False - else: - log.writeline(f'{csv_fname} file not found, new csv file will be created') - log.writeline('checkpoint loading was successful') - - state_dicts = [ - model_state_dict_cp, - optimizer_state_dict_cp, - scheduler_state_dict_cp, - ] - return state_dicts, start_epoch, init_csv - - -def convert_modality_of_checkpoint_state_dct(config, state_dicts): - # TODO: this requires updating model state dict after seeing dataset - model_state_dict_cp, optimizer_state_dict_cp, scheduler_state_dict_cp = ( - state_dicts - ) - - if config[KEY.USE_MODALITY]: # current model is multimodal - num_modalities_cp = len(config[KEY.MODAL_MAP + '_cp']) - append_modal_length = config[KEY.NUM_MODALITIES] - num_modalities_cp - - model_state_dict_cp = append_modality_to_model_dct( - model_state_dict_cp, config, num_modalities_cp, append_modal_length - ) - - else: # current model is single modal - if config[KEY.USE_MODALITY + '_cp']: # checkpoint model is multimodal - # change model state dict to single modal, default = "common" - model_state_dict_cp = get_single_modal_model_dct( - model_state_dict_cp, - config, - config[KEY.DEFAULT_MODAL], - from_processing_cp=True, - ) - - state_dicts = ( - model_state_dict_cp, - optimizer_state_dict_cp, - scheduler_state_dict_cp, - ) - - return state_dicts +import os +import warnings + +import torch + +import sevenn._keys as KEY +import sevenn.util as util +from sevenn.logger import Logger +from sevenn.scripts.convert_model_modality import ( + append_modality_to_model_dct, + get_single_modal_model_dct, +) + + +def processing_continue_v2(config): # simpler + """ + Replacement of processing_continue, + Skips model compatibility + """ + log = Logger() + continue_dct = config[KEY.CONTINUE] + log.write('\nContinue found, loading checkpoint\n') + + checkpoint = util.load_checkpoint(continue_dct[KEY.CHECKPOINT]) + model_cp = checkpoint.build_model() + config_cp = checkpoint.config + model_state_dict_cp = model_cp.state_dict() + + optimizer_state_dict_cp = ( + checkpoint.optimizer_state_dict + if not continue_dct[KEY.RESET_OPTIMIZER] + else None + ) + scheduler_state_dict_cp = ( + checkpoint.scheduler_state_dict + if not continue_dct[KEY.RESET_SCHEDULER] + else None + ) + + # use_statistic_value_of_checkpoint always True + # Overwrite config from model state dict, so graph_dataset.from_config + # will not put statistic values to shift, scale, and conv_denominator + config[KEY.SHIFT] = model_state_dict_cp['rescale_atomic_energy.shift'].tolist() + config[KEY.SCALE] = model_state_dict_cp['rescale_atomic_energy.scale'].tolist() + conv_denom = [] + for i in range(config_cp[KEY.NUM_CONVOLUTION]): + conv_denom.append(model_state_dict_cp[f'{i}_convolution.denominator'].item()) + config[KEY.CONV_DENOMINATOR] = conv_denom + log.writeline( + f'{KEY.SHIFT}, {KEY.SCALE}, and {KEY.CONV_DENOMINATOR} are ' + + 'overwritten by model_state_dict of checkpoint' + ) + + chem_keys = [ + KEY.TYPE_MAP, + KEY.NUM_SPECIES, + KEY.CHEMICAL_SPECIES, + KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, + ] + config.update({k: config_cp[k] for k in chem_keys}) + log.writeline( + 'chemical_species are overwritten by checkpoint. ' + + f'This model knows {config[KEY.NUM_SPECIES]} species' + ) + + if config_cp.get(KEY.USE_MODALITY, False) != config.get(KEY.USE_MODALITY): + raise ValueError('use_modality is not same. Check sevenn_cp') + + modal_map = config_cp.get(KEY.MODAL_MAP, None) # dict | None + if modal_map and len(modal_map) > 0: + modalities = list(modal_map.keys()) + log.writeline(f'Multimodal model found: {modalities}') + log.writeline('use_modality: True') + config[KEY.USE_MODALITY] = True + + from_epoch = checkpoint.epoch or 0 + log.writeline(f'Checkpoint previous epoch was: {from_epoch}') + epoch = 1 if continue_dct[KEY.RESET_EPOCH] else from_epoch + 1 + log.writeline(f'epoch start from {epoch}') + + log.writeline('checkpoint loading successful') + + state_dicts = [ + model_state_dict_cp, + optimizer_state_dict_cp, + scheduler_state_dict_cp, + ] + return state_dicts, epoch + + +def check_config_compatible(config, config_cp): + # TODO: check more + SHOULD_BE_SAME = [ + KEY.NODE_FEATURE_MULTIPLICITY, + KEY.LMAX, + KEY.IS_PARITY, + KEY.RADIAL_BASIS, + KEY.CUTOFF_FUNCTION, + KEY.CUTOFF, + KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS, + KEY.NUM_CONVOLUTION, + KEY.USE_BIAS_IN_LINEAR, + KEY.SELF_CONNECTION_TYPE, + ] + for sbs in SHOULD_BE_SAME: + if config[sbs] == config_cp[sbs]: + continue + if sbs == KEY.SELF_CONNECTION_TYPE and config_cp[sbs] == 'MACE': + warnings.warn( + 'We do not support this version of checkpoints to continue ' + "Please use self_connection_type='linear' in input.yaml " + 'and train from scratch', + UserWarning, + ) + raise ValueError( + f'Value of {sbs} should be same. {config[sbs]} != {config_cp[sbs]}' + ) + + try: + cntdct = config[KEY.CONTINUE] + except KeyError: + return + + TRAINABLE_CONFIGS = [KEY.TRAIN_DENOMINTAOR, KEY.TRAIN_SHIFT_SCALE] + if ( + any((not cntdct[KEY.RESET_SCHEDULER], not cntdct[KEY.RESET_OPTIMIZER])) + and all(config[k] == config_cp[k] for k in TRAINABLE_CONFIGS) is False + ): + raise ValueError( + 'reset optimizer and scheduler if you want to change ' + + 'trainable configs' + ) + + # TODO add conition for changed optim/scheduler but not reset + + +def processing_continue(config): + log = Logger() + continue_dct = config[KEY.CONTINUE] + log.write('\nContinue found, loading checkpoint\n') + + checkpoint = torch.load( + continue_dct[KEY.CHECKPOINT], map_location='cpu', weights_only=False + ) + config_cp = checkpoint['config'] + + model_cp, config_cp = util.model_from_checkpoint(checkpoint) + model_state_dict_cp = model_cp.state_dict() + + # it will raise error if not compatible + check_config_compatible(config, config_cp) + log.write('Checkpoint config is compatible\n') + + # for backward compat. + config.update({KEY._NORMALIZE_SPH: config_cp[KEY._NORMALIZE_SPH]}) + + from_epoch = checkpoint['epoch'] + optimizer_state_dict_cp = ( + checkpoint['optimizer_state_dict'] + if not continue_dct[KEY.RESET_OPTIMIZER] + else None + ) + scheduler_state_dict_cp = ( + checkpoint['scheduler_state_dict'] + if not continue_dct[KEY.RESET_SCHEDULER] + else None + ) + + # These could be changed based on given continue_input.yaml + # ex) adapt to statistics of fine-tuning dataset + shift_cp = model_state_dict_cp['rescale_atomic_energy.shift'].numpy() + del model_state_dict_cp['rescale_atomic_energy.shift'] + scale_cp = model_state_dict_cp['rescale_atomic_energy.scale'].numpy() + del model_state_dict_cp['rescale_atomic_energy.scale'] + conv_denominators = [] + for i in range(config_cp[KEY.NUM_CONVOLUTION]): + conv_denominators.append( + (model_state_dict_cp[f'{i}_convolution.denominator']).item() + ) + del model_state_dict_cp[f'{i}_convolution.denominator'] + + # Further handled by processing_dataset.py + config.update({ + KEY.SHIFT + '_cp': shift_cp, + KEY.SCALE + '_cp': scale_cp, + KEY.CONV_DENOMINATOR + '_cp': conv_denominators, + }) + + chem_keys = [ + KEY.TYPE_MAP, + KEY.NUM_SPECIES, + KEY.CHEMICAL_SPECIES, + KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, + ] + config.update({k: config_cp[k] for k in chem_keys}) + + if ( + KEY.USE_MODALITY in config_cp.keys() and config_cp[KEY.USE_MODALITY] + ): # checkpoint model is multimodal + config.update({ + KEY.MODAL_MAP + '_cp': config_cp[KEY.MODAL_MAP], + KEY.USE_MODALITY + '_cp': True, + KEY.NUM_MODALITIES + '_cp': len(config_cp[KEY.MODAL_MAP]), + }) + else: + config.update({ + KEY.MODAL_MAP + '_cp': {}, + KEY.USE_MODALITY + '_cp': False, + KEY.NUM_MODALITIES + '_cp': 0, + }) + + log.write(f'checkpoint previous epoch was: {from_epoch}\n') + + # decide start epoch + reset_epoch = continue_dct[KEY.RESET_EPOCH] + if reset_epoch: + start_epoch = 1 + log.write('epoch reset to 1\n') + else: + start_epoch = from_epoch + 1 + log.write(f'epoch start from {start_epoch}\n') + + # decide csv file to continue + init_csv = True + csv_fname = config_cp[KEY.CSV_LOG] + if os.path.isfile(csv_fname): + # I hope python compare dict well + if config_cp[KEY.ERROR_RECORD] == config[KEY.ERROR_RECORD]: + log.writeline('Same metric, csv file will be appended') + init_csv = False + else: + log.writeline(f'{csv_fname} file not found, new csv file will be created') + log.writeline('checkpoint loading was successful') + + state_dicts = [ + model_state_dict_cp, + optimizer_state_dict_cp, + scheduler_state_dict_cp, + ] + return state_dicts, start_epoch, init_csv + + +def convert_modality_of_checkpoint_state_dct(config, state_dicts): + # TODO: this requires updating model state dict after seeing dataset + model_state_dict_cp, optimizer_state_dict_cp, scheduler_state_dict_cp = ( + state_dicts + ) + + if config[KEY.USE_MODALITY]: # current model is multimodal + num_modalities_cp = len(config[KEY.MODAL_MAP + '_cp']) + append_modal_length = config[KEY.NUM_MODALITIES] - num_modalities_cp + + model_state_dict_cp = append_modality_to_model_dct( + model_state_dict_cp, config, num_modalities_cp, append_modal_length + ) + + else: # current model is single modal + if config[KEY.USE_MODALITY + '_cp']: # checkpoint model is multimodal + # change model state dict to single modal, default = "common" + model_state_dict_cp = get_single_modal_model_dct( + model_state_dict_cp, + config, + config[KEY.DEFAULT_MODAL], + from_processing_cp=True, + ) + + state_dicts = ( + model_state_dict_cp, + optimizer_state_dict_cp, + scheduler_state_dict_cp, + ) + + return state_dicts diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_dataset.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_dataset.py index fe1e1e432c8fff94589c8b93f53b549aecf966a5..64e79617664a092d6e1ca6e06fe10cd12f973572 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_dataset.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_dataset.py @@ -1,481 +1,481 @@ -import os - -import torch -import torch.distributed as dist - -import sevenn._const as CONST -import sevenn._keys as KEY -from sevenn.logger import Logger -from sevenn.train.dataload import file_to_dataset, match_reader -from sevenn.train.dataset import AtomGraphDataset -from sevenn.util import chemical_species_preprocess, onehot_to_chem - - -def dataset_load(file: str, config): - """ - Wrapping of dataload.file_to_dataset to suppert - graph prebuilt sevenn_data - """ - log = Logger() - log.write(f'Loading {file}\n') - log.timer_start('loading dataset') - - if file.endswith('.sevenn_data'): - dataset = torch.load(file, map_location='cpu', weights_only=False) - else: - reader, _ = match_reader( - config[KEY.DATA_FORMAT], **config[KEY.DATA_FORMAT_ARGS] - ) - dataset = file_to_dataset( - file, - config[KEY.CUTOFF], - config[KEY.PREPROCESS_NUM_CORES], - reader=reader, - use_modality=config[KEY.USE_MODALITY], - use_weight=config[KEY.USE_WEIGHT], - ) - log.format_k_v('loaded dataset size is', dataset.len(), write=True) - log.timer_end('loading dataset', 'data set loading time') - return dataset - - -def calculate_shift_or_scale_from_key( - train_set: AtomGraphDataset, key_given, n_chem -): - _expand = True - use_species_wise_shift_scale = False - if key_given == 'per_atom_energy_mean': - shift_or_scale = train_set.get_per_atom_energy_mean() - elif key_given == 'elemwise_reference_energies': - shift_or_scale = train_set.get_species_ref_energy_by_linear_comb(n_chem) - _expand = False - use_species_wise_shift_scale = True - - elif key_given == 'force_rms': - shift_or_scale = train_set.get_force_rms() - elif key_given == 'per_atom_energy_std': - shift_or_scale = train_set.get_statistics(KEY.PER_ATOM_ENERGY)['Total'][ - 'std' - ] - elif key_given == 'elemwise_force_rms': - shift_or_scale = train_set.get_species_wise_force_rms(n_chem) - _expand = False - use_species_wise_shift_scale = True - - return shift_or_scale, _expand, use_species_wise_shift_scale - - -def handle_shift_scale(config, train_set: AtomGraphDataset, checkpoint_given): - """ - Priority (first comes later to overwrite): - 1. Float given in yaml - 2. Use statistic values of checkpoint == True - 3. Plain options (provided as string) - """ - log = Logger() - shift, scale, conv_denominator = None, None, None - type_map = config[KEY.TYPE_MAP] - n_chem = len(type_map) - chem_strs = onehot_to_chem(list(range(n_chem)), type_map) - - log.writeline('\nCalculating statistic values from dataset') - - shift_given = config[KEY.SHIFT] - scale_given = config[KEY.SCALE] - _expand_shift = True - _expand_scale = True - use_species_wise_shift = False - use_species_wise_scale = False - - use_modal_wise_shift = config[KEY.USE_MODAL_WISE_SHIFT] - use_modal_wise_scale = config[KEY.USE_MODAL_WISE_SCALE] - - if shift_given in CONST.IMPLEMENTED_SHIFT: - shift, _expand_shift, use_species_wise_shift = ( - calculate_shift_or_scale_from_key(train_set, shift_given, n_chem) - ) - - if scale_given in CONST.IMPLEMENTED_SCALE: - scale, _expand_scale, use_species_wise_scale = ( - calculate_shift_or_scale_from_key(train_set, scale_given, n_chem) - ) - - if use_modal_wise_shift or use_modal_wise_scale: - atomdata_dict_sort_by_modal = train_set.get_dict_sort_by_modality() - modal_map = config[KEY.MODAL_MAP] - n_modal = len(modal_map) - cutoff = config[KEY.CUTOFF] - - if use_modal_wise_shift: - shift = torch.zeros((n_modal, n_chem)) - - if use_modal_wise_scale: - scale = torch.zeros((n_modal, n_chem)) - - for modal_key, data_list in atomdata_dict_sort_by_modal.items(): - modal_set = AtomGraphDataset(data_list, cutoff, x_is_one_hot_idx=True) - - if use_modal_wise_shift: - if shift_given == 'elemwise_reference_energies': - modal_shift, _expand_shift, use_species_wise_shift = ( - calculate_shift_or_scale_from_key( - modal_set, shift_given, n_chem - ) - ) - shift[modal_map[modal_key]] = torch.tensor( - modal_shift - ) # this is np.array - elif shift_given in CONST.IMPLEMENTED_SHIFT: - raise NotImplementedError( - 'Currently, modal-wise shift implemented for' - 'species-dependent case only.' - ) - - if use_modal_wise_scale: - if scale_given == 'elemwise_force_rms': - modal_scale, _expand_scale, use_species_wise_scale = ( - calculate_shift_or_scale_from_key( - modal_set, scale_given, n_chem - ) - ) - scale[modal_map[modal_key]] = modal_scale - elif scale_given in CONST.IMPLEMENTED_SCALE: - raise NotImplementedError( - 'Currently, modal-wise scale implemented for' - 'species-dependent case only.' - ) - - avg_num_neigh = train_set.get_avg_num_neigh() - log.format_k_v('Average # of neighbors', f'{avg_num_neigh:.6f}', write=True) - - if config[KEY.CONV_DENOMINATOR] == 'avg_num_neigh': - conv_denominator = avg_num_neigh - elif config[KEY.CONV_DENOMINATOR] == 'sqrt_avg_num_neigh': - conv_denominator = avg_num_neigh ** (0.5) - - if ( - checkpoint_given - and config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT] - ): - log.writeline( - 'Overwrite shift, scale, conv_denominator from model checkpoint' - ) - # TODO: This needs refactoring - conv_denominator = config[KEY.CONV_DENOMINATOR + '_cp'] - if not (use_modal_wise_shift or use_modal_wise_scale): - # Values extracted from checkpoint in processing_continue.py - if len(list(shift)) > 1: - use_species_wise_shift = True - use_species_wise_scale = True - _expand_shift = _expand_scale = False - else: - shift = shift.item() - scale = scale.item() - else: - # Case of modal wise shift scale - shift_cp = config[KEY.SHIFT + '_cp'] - scale_cp = config[KEY.SCALE + '_cp'] - if not use_modal_wise_shift: - shift = shift_cp - if not use_modal_wise_scale: - scale = scale_cp - modal_map = config[KEY.MODAL_MAP] - modal_map_cp = config[KEY.MODAL_MAP + '_cp'] - - # Extracting shift, scale for modal in checkpoint model. - if config[KEY.USE_MODALITY + '_cp']: # cp model is multimodal - for modal_key_cp, modal_idx_cp in modal_map_cp.items(): - modal_idx = modal_map[modal_key_cp] - if use_modal_wise_shift: - shift[modal_idx] = torch.tensor(shift_cp[modal_idx_cp]) - if use_modal_wise_scale: - scale[modal_idx] = torch.tensor(scale_cp[modal_idx_cp]) - - else: # cp model is single modal - try: - modal_idx = modal_map[config[KEY.DEFAULT_MODAL]] - except: - raise KeyError( - f'{config[KEY.DEFAULT_MODAL]} should be one of' - f' {modal_map.keys()}' - ) - if use_modal_wise_shift: - shift[modal_idx] = torch.tensor(shift_cp) - if use_modal_wise_scale: - scale[modal_idx] = torch.tensor(scale_cp) - - if not config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY]: - # Also overwrite values of new modal to reference value - # For multimodal, set reference modal with KEY.DEFAULT_MODAL - shift_ref = shift_cp - scale_ref = scale_cp - if config[KEY.USE_MODALITY + '_cp']: - try: - modal_idx_cp = modal_map_cp[config[KEY.DEFAULT_MODAL]] - except: - raise KeyError( - f'{config[KEY.DEFAULT_MODAL]} should be one of' - f' {modal_map_cp.keys()}' - ) - shift_ref = shift_cp[modal_idx_cp] - scale_ref = scale_cp[modal_idx_cp] - - for modal_key, modal_idx in modal_map.items(): - if modal_key not in modal_map_cp.keys(): - if use_modal_wise_shift: - shift[modal_idx] = shift_ref - if use_modal_wise_scale: - scale[modal_idx] = scale_ref - - # overwrite shift scale anyway if defined in yaml. - if type(shift_given) in [list, float]: - log.writeline('Overwrite shift to value(s) given in yaml') - _expand_shift = isinstance(shift_given, float) - shift = shift_given - if type(scale_given) in [list, float]: - log.writeline('Overwrite scale to value(s) given in yaml') - _expand_scale = isinstance(scale_given, float) - scale = scale_given - - if isinstance(config[KEY.CONV_DENOMINATOR], float): - log.writeline('Overwrite conv_denominator to value given in yaml') - conv_denominator = config[KEY.CONV_DENOMINATOR] - - if isinstance(conv_denominator, float): - conv_denominator = [conv_denominator] * config[KEY.NUM_CONVOLUTION] - - use_species_wise_shift_scale = use_species_wise_shift or use_species_wise_scale - if use_species_wise_shift_scale: - chem_strs = onehot_to_chem(list(range(n_chem)), type_map) - if _expand_shift: - if use_modal_wise_shift: - shift = torch.full((n_modal, n_chem), shift) - else: - shift = [shift] * n_chem - if _expand_scale: - if use_modal_wise_scale: - scale = torch.full((n_modal, n_chem), scale) - else: - scale = [scale] * n_chem - - Logger().write('Use element-wise shift, scale\n') - if use_modal_wise_shift or use_modal_wise_scale: - for modal_key, modal_idx in modal_map.items(): - Logger().writeline(f'For modal = {modal_key}') - print_shift = shift[modal_idx] if use_modal_wise_shift else shift - print_scale = scale[modal_idx] if use_modal_wise_scale else scale - for cstr, sh, sc in zip(chem_strs, print_shift, print_scale): - Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True) - else: - for cstr, sh, sc in zip(chem_strs, shift, scale): - Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True) - else: - log.write('Use global shift, scale\n') - log.format_k_v('shift, scale', f'{shift:.6f}, {scale:.6f}', write=True) - - assert isinstance(conv_denominator, list) and all( - isinstance(deno, float) for deno in conv_denominator - ) - log.format_k_v( - '(1st) conv_denominator is', f'{conv_denominator[0]:.6f}', write=True - ) - - config[KEY.USE_SPECIES_WISE_SHIFT_SCALE] = use_species_wise_shift_scale - return shift, scale, conv_denominator - - -# TODO: This is too long -def processing_dataset(config, working_dir): - log = Logger() - prefix = f'{os.path.abspath(working_dir)}/' - is_stress = config[KEY.IS_TRAIN_STRESS] - checkpoint_given = config[KEY.CONTINUE][KEY.CHECKPOINT] is not False - cutoff = config[KEY.CUTOFF] - - log.write('\nInitializing dataset...\n') - - dataset = AtomGraphDataset({}, cutoff) - load_dataset = config[KEY.LOAD_DATASET] - if type(load_dataset) is str: - load_dataset = [load_dataset] - for file in load_dataset: - dataset.augment(dataset_load(file, config)) - - dataset.group_by_key() # apply labels inside original datapoint - dataset.unify_dtypes() # unify dtypes of all data points - - # TODO: I think manual chemical species input is redundant - chem_in_db = dataset.get_species() - if config[KEY.CHEMICAL_SPECIES] == 'auto' and not checkpoint_given: - log.writeline('Auto detect chemical species from dataset') - config.update(chemical_species_preprocess(chem_in_db)) - elif config[KEY.CHEMICAL_SPECIES] == 'auto' and checkpoint_given: - pass # copied from checkpoint in processing_continue.py - elif config[KEY.CHEMICAL_SPECIES] != 'auto' and not checkpoint_given: - pass # processed in parse_input.py - else: # config[KEY.CHEMICAL_SPECIES] != "auto" and checkpoint_given - log.writeline('Ignore chemical species in yaml, use checkpoint') - # already processed in processing_continue.py - - # basic dataset compatibility check with previous model - if checkpoint_given: - chem_from_cp = config[KEY.CHEMICAL_SPECIES] - if not all(chem in chem_from_cp for chem in chem_in_db): - raise ValueError('Chemical species in checkpoint is not compatible') - - # check what modalities are used in dataset - if config[KEY.USE_MODALITY]: - modalities = dataset.get_modalities() - num_modalities = len(modalities) - if num_modalities < 2: - Logger().writeline('Only one modal is given, ignore modality') - config.uptate({KEY.USE_MODALITY: False}) - - else: - modal_map_cp = config[KEY.MODAL_MAP + '_cp'] if checkpoint_given else {} - modal_map = modal_map_cp.copy() - current_idx = len(modal_map_cp) - for modal_key in modalities: - if modal_key not in modal_map.keys(): - modal_map[modal_key] = current_idx - current_idx += 1 - - if config[KEY.IS_DDP]: - # Synchronize modal_map - torch.cuda.set_device(config[KEY.LOCAL_RANK]) - modal_map_bcast = [modal_map] - dist.broadcast_object_list(modal_map_bcast, src=0) - modal_map = modal_map_bcast[0] - - config.update( - { - KEY.NUM_MODALITIES: len(modal_map), - KEY.MODAL_MAP: modal_map, - KEY.MODAL_LIST: list(modal_map.keys()), - } - ) - - dataset.write_modal_attr( - modal_map, - config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE], - ) - - # --------------- save dataset regardless of train/valid--------------# - save_dataset = config[KEY.SAVE_DATASET] - save_by_label = config[KEY.SAVE_BY_LABEL] - if save_dataset: - if save_dataset.endswith('.sevenn_data') is False: - save_dataset += '.sevenn_data' - if (save_dataset.startswith('.') or save_dataset.startswith('/')) is False: - save_dataset = prefix + save_dataset # save_data set is plain file name - dataset.save(save_dataset) - log.format_k_v('Dataset saved to', save_dataset, write=True) - # log.write(f"Loaded full dataset saved to : {save_dataset}\n") - if save_by_label: - dataset.save(prefix, by_label=True) - log.format_k_v('Dataset saved by label', prefix, write=True) - # --------------------------------------------------------------------# - - # TODO: testset is not used - ignore_test = not config.get(KEY.USE_TESTSET, False) - if KEY.LOAD_VALIDSET in config and config[KEY.LOAD_VALIDSET]: - train_set = dataset - test_set = AtomGraphDataset([], config[KEY.CUTOFF]) - - log.write('Loading validset from load_validset\n') - valid_set = AtomGraphDataset({}, cutoff) - for file in config[KEY.LOAD_VALIDSET]: - valid_set.augment(dataset_load(file, config)) - valid_set.group_by_key() - valid_set.unify_dtypes() - - # condition: validset labels should be subset of trainset labels - valid_labels = valid_set.user_labels - train_labels = train_set.user_labels - if set(valid_labels).issubset(set(train_labels)) is False: - valid_set = AtomGraphDataset(valid_set.to_list(), cutoff) - valid_set.rewrite_labels_to_data() - train_set = AtomGraphDataset(train_set.to_list(), cutoff) - train_set.rewrite_labels_to_data() - Logger().write('WARNING! validset labels is not subset of trainset\n') - Logger().write('We overwrite all the train, valid labels to default.\n') - Logger().write('Please create validset by sevenn_graph_build with -l\n') - - Logger().write('the validset loaded, load_dataset is now train_set\n') - Logger().write('the ratio will be ignored\n') - - # condition: validset modalities should be subset of trainset modalities - if config[KEY.USE_MODALITY]: - config_modality = config[KEY.MODAL_LIST] - valid_modality = valid_set.get_modalities() - - if set(valid_modality).issubset(set(config_modality)) is False: - raise ValueError('validset modality is not subset of trainset') - - valid_set.write_modal_attr( - config[KEY.MODAL_MAP], - config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE], - ) - else: - train_set, valid_set, test_set = dataset.divide_dataset( - config[KEY.RATIO], ignore_test=ignore_test - ) - log.write(f'The dataset divided into train, valid by {KEY.RATIO}\n') - - log.format_k_v('\nloaded trainset size is', train_set.len(), write=True) - log.format_k_v('\nloaded validset size is', valid_set.len(), write=True) - - log.write('Dataset initialization was successful\n') - - log.write('\nNumber of atoms in the train_set:\n') - log.natoms_write(train_set.get_natoms(config[KEY.TYPE_MAP])) - - log.bar() - log.write('Per atom energy(eV/atom) distribution:\n') - log.statistic_write(train_set.get_statistics(KEY.PER_ATOM_ENERGY)) - log.bar() - log.write('Force(eV/Angstrom) distribution:\n') - log.statistic_write(train_set.get_statistics(KEY.FORCE)) - log.bar() - log.write('Stress(eV/Angstrom^3) distribution:\n') - try: - log.statistic_write(train_set.get_statistics(KEY.STRESS)) - except KeyError: - log.write('\n Stress is not included in the train_set\n') - if is_stress: - is_stress = False - log.write('Turn off stress training\n') - log.bar() - - # saved data must have atomic numbers as X not one hot idx - if config[KEY.SAVE_BY_TRAIN_VALID]: - train_set.save(prefix + 'train') - valid_set.save(prefix + 'valid') - log.format_k_v('Dataset saved by train, valid', prefix, write=True) - - # inconsistent .info dict give error when collate - _, _ = train_set.separate_info() - _, _ = valid_set.separate_info() - - if train_set.x_is_one_hot_idx is False: - train_set.x_to_one_hot_idx(config[KEY.TYPE_MAP]) - if valid_set.x_is_one_hot_idx is False: - valid_set.x_to_one_hot_idx(config[KEY.TYPE_MAP]) - - log.format_k_v('training_set size', train_set.len(), write=True) - log.format_k_v('validation_set size', valid_set.len(), write=True) - - shift, scale, conv_denominator = handle_shift_scale( - config, train_set, checkpoint_given - ) - config.update( - { - KEY.SHIFT: shift, - KEY.SCALE: scale, - KEY.CONV_DENOMINATOR: conv_denominator, - } - ) - - data_lists = (train_set.to_list(), valid_set.to_list(), test_set.to_list()) - - return data_lists +import os + +import torch +import torch.distributed as dist + +import sevenn._const as CONST +import sevenn._keys as KEY +from sevenn.logger import Logger +from sevenn.train.dataload import file_to_dataset, match_reader +from sevenn.train.dataset import AtomGraphDataset +from sevenn.util import chemical_species_preprocess, onehot_to_chem + + +def dataset_load(file: str, config): + """ + Wrapping of dataload.file_to_dataset to suppert + graph prebuilt sevenn_data + """ + log = Logger() + log.write(f'Loading {file}\n') + log.timer_start('loading dataset') + + if file.endswith('.sevenn_data'): + dataset = torch.load(file, map_location='cpu', weights_only=False) + else: + reader, _ = match_reader( + config[KEY.DATA_FORMAT], **config[KEY.DATA_FORMAT_ARGS] + ) + dataset = file_to_dataset( + file, + config[KEY.CUTOFF], + config[KEY.PREPROCESS_NUM_CORES], + reader=reader, + use_modality=config[KEY.USE_MODALITY], + use_weight=config[KEY.USE_WEIGHT], + ) + log.format_k_v('loaded dataset size is', dataset.len(), write=True) + log.timer_end('loading dataset', 'data set loading time') + return dataset + + +def calculate_shift_or_scale_from_key( + train_set: AtomGraphDataset, key_given, n_chem +): + _expand = True + use_species_wise_shift_scale = False + if key_given == 'per_atom_energy_mean': + shift_or_scale = train_set.get_per_atom_energy_mean() + elif key_given == 'elemwise_reference_energies': + shift_or_scale = train_set.get_species_ref_energy_by_linear_comb(n_chem) + _expand = False + use_species_wise_shift_scale = True + + elif key_given == 'force_rms': + shift_or_scale = train_set.get_force_rms() + elif key_given == 'per_atom_energy_std': + shift_or_scale = train_set.get_statistics(KEY.PER_ATOM_ENERGY)['Total'][ + 'std' + ] + elif key_given == 'elemwise_force_rms': + shift_or_scale = train_set.get_species_wise_force_rms(n_chem) + _expand = False + use_species_wise_shift_scale = True + + return shift_or_scale, _expand, use_species_wise_shift_scale + + +def handle_shift_scale(config, train_set: AtomGraphDataset, checkpoint_given): + """ + Priority (first comes later to overwrite): + 1. Float given in yaml + 2. Use statistic values of checkpoint == True + 3. Plain options (provided as string) + """ + log = Logger() + shift, scale, conv_denominator = None, None, None + type_map = config[KEY.TYPE_MAP] + n_chem = len(type_map) + chem_strs = onehot_to_chem(list(range(n_chem)), type_map) + + log.writeline('\nCalculating statistic values from dataset') + + shift_given = config[KEY.SHIFT] + scale_given = config[KEY.SCALE] + _expand_shift = True + _expand_scale = True + use_species_wise_shift = False + use_species_wise_scale = False + + use_modal_wise_shift = config[KEY.USE_MODAL_WISE_SHIFT] + use_modal_wise_scale = config[KEY.USE_MODAL_WISE_SCALE] + + if shift_given in CONST.IMPLEMENTED_SHIFT: + shift, _expand_shift, use_species_wise_shift = ( + calculate_shift_or_scale_from_key(train_set, shift_given, n_chem) + ) + + if scale_given in CONST.IMPLEMENTED_SCALE: + scale, _expand_scale, use_species_wise_scale = ( + calculate_shift_or_scale_from_key(train_set, scale_given, n_chem) + ) + + if use_modal_wise_shift or use_modal_wise_scale: + atomdata_dict_sort_by_modal = train_set.get_dict_sort_by_modality() + modal_map = config[KEY.MODAL_MAP] + n_modal = len(modal_map) + cutoff = config[KEY.CUTOFF] + + if use_modal_wise_shift: + shift = torch.zeros((n_modal, n_chem)) + + if use_modal_wise_scale: + scale = torch.zeros((n_modal, n_chem)) + + for modal_key, data_list in atomdata_dict_sort_by_modal.items(): + modal_set = AtomGraphDataset(data_list, cutoff, x_is_one_hot_idx=True) + + if use_modal_wise_shift: + if shift_given == 'elemwise_reference_energies': + modal_shift, _expand_shift, use_species_wise_shift = ( + calculate_shift_or_scale_from_key( + modal_set, shift_given, n_chem + ) + ) + shift[modal_map[modal_key]] = torch.tensor( + modal_shift + ) # this is np.array + elif shift_given in CONST.IMPLEMENTED_SHIFT: + raise NotImplementedError( + 'Currently, modal-wise shift implemented for' + 'species-dependent case only.' + ) + + if use_modal_wise_scale: + if scale_given == 'elemwise_force_rms': + modal_scale, _expand_scale, use_species_wise_scale = ( + calculate_shift_or_scale_from_key( + modal_set, scale_given, n_chem + ) + ) + scale[modal_map[modal_key]] = modal_scale + elif scale_given in CONST.IMPLEMENTED_SCALE: + raise NotImplementedError( + 'Currently, modal-wise scale implemented for' + 'species-dependent case only.' + ) + + avg_num_neigh = train_set.get_avg_num_neigh() + log.format_k_v('Average # of neighbors', f'{avg_num_neigh:.6f}', write=True) + + if config[KEY.CONV_DENOMINATOR] == 'avg_num_neigh': + conv_denominator = avg_num_neigh + elif config[KEY.CONV_DENOMINATOR] == 'sqrt_avg_num_neigh': + conv_denominator = avg_num_neigh ** (0.5) + + if ( + checkpoint_given + and config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT] + ): + log.writeline( + 'Overwrite shift, scale, conv_denominator from model checkpoint' + ) + # TODO: This needs refactoring + conv_denominator = config[KEY.CONV_DENOMINATOR + '_cp'] + if not (use_modal_wise_shift or use_modal_wise_scale): + # Values extracted from checkpoint in processing_continue.py + if len(list(shift)) > 1: + use_species_wise_shift = True + use_species_wise_scale = True + _expand_shift = _expand_scale = False + else: + shift = shift.item() + scale = scale.item() + else: + # Case of modal wise shift scale + shift_cp = config[KEY.SHIFT + '_cp'] + scale_cp = config[KEY.SCALE + '_cp'] + if not use_modal_wise_shift: + shift = shift_cp + if not use_modal_wise_scale: + scale = scale_cp + modal_map = config[KEY.MODAL_MAP] + modal_map_cp = config[KEY.MODAL_MAP + '_cp'] + + # Extracting shift, scale for modal in checkpoint model. + if config[KEY.USE_MODALITY + '_cp']: # cp model is multimodal + for modal_key_cp, modal_idx_cp in modal_map_cp.items(): + modal_idx = modal_map[modal_key_cp] + if use_modal_wise_shift: + shift[modal_idx] = torch.tensor(shift_cp[modal_idx_cp]) + if use_modal_wise_scale: + scale[modal_idx] = torch.tensor(scale_cp[modal_idx_cp]) + + else: # cp model is single modal + try: + modal_idx = modal_map[config[KEY.DEFAULT_MODAL]] + except: + raise KeyError( + f'{config[KEY.DEFAULT_MODAL]} should be one of' + f' {modal_map.keys()}' + ) + if use_modal_wise_shift: + shift[modal_idx] = torch.tensor(shift_cp) + if use_modal_wise_scale: + scale[modal_idx] = torch.tensor(scale_cp) + + if not config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY]: + # Also overwrite values of new modal to reference value + # For multimodal, set reference modal with KEY.DEFAULT_MODAL + shift_ref = shift_cp + scale_ref = scale_cp + if config[KEY.USE_MODALITY + '_cp']: + try: + modal_idx_cp = modal_map_cp[config[KEY.DEFAULT_MODAL]] + except: + raise KeyError( + f'{config[KEY.DEFAULT_MODAL]} should be one of' + f' {modal_map_cp.keys()}' + ) + shift_ref = shift_cp[modal_idx_cp] + scale_ref = scale_cp[modal_idx_cp] + + for modal_key, modal_idx in modal_map.items(): + if modal_key not in modal_map_cp.keys(): + if use_modal_wise_shift: + shift[modal_idx] = shift_ref + if use_modal_wise_scale: + scale[modal_idx] = scale_ref + + # overwrite shift scale anyway if defined in yaml. + if type(shift_given) in [list, float]: + log.writeline('Overwrite shift to value(s) given in yaml') + _expand_shift = isinstance(shift_given, float) + shift = shift_given + if type(scale_given) in [list, float]: + log.writeline('Overwrite scale to value(s) given in yaml') + _expand_scale = isinstance(scale_given, float) + scale = scale_given + + if isinstance(config[KEY.CONV_DENOMINATOR], float): + log.writeline('Overwrite conv_denominator to value given in yaml') + conv_denominator = config[KEY.CONV_DENOMINATOR] + + if isinstance(conv_denominator, float): + conv_denominator = [conv_denominator] * config[KEY.NUM_CONVOLUTION] + + use_species_wise_shift_scale = use_species_wise_shift or use_species_wise_scale + if use_species_wise_shift_scale: + chem_strs = onehot_to_chem(list(range(n_chem)), type_map) + if _expand_shift: + if use_modal_wise_shift: + shift = torch.full((n_modal, n_chem), shift) + else: + shift = [shift] * n_chem + if _expand_scale: + if use_modal_wise_scale: + scale = torch.full((n_modal, n_chem), scale) + else: + scale = [scale] * n_chem + + Logger().write('Use element-wise shift, scale\n') + if use_modal_wise_shift or use_modal_wise_scale: + for modal_key, modal_idx in modal_map.items(): + Logger().writeline(f'For modal = {modal_key}') + print_shift = shift[modal_idx] if use_modal_wise_shift else shift + print_scale = scale[modal_idx] if use_modal_wise_scale else scale + for cstr, sh, sc in zip(chem_strs, print_shift, print_scale): + Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True) + else: + for cstr, sh, sc in zip(chem_strs, shift, scale): + Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True) + else: + log.write('Use global shift, scale\n') + log.format_k_v('shift, scale', f'{shift:.6f}, {scale:.6f}', write=True) + + assert isinstance(conv_denominator, list) and all( + isinstance(deno, float) for deno in conv_denominator + ) + log.format_k_v( + '(1st) conv_denominator is', f'{conv_denominator[0]:.6f}', write=True + ) + + config[KEY.USE_SPECIES_WISE_SHIFT_SCALE] = use_species_wise_shift_scale + return shift, scale, conv_denominator + + +# TODO: This is too long +def processing_dataset(config, working_dir): + log = Logger() + prefix = f'{os.path.abspath(working_dir)}/' + is_stress = config[KEY.IS_TRAIN_STRESS] + checkpoint_given = config[KEY.CONTINUE][KEY.CHECKPOINT] is not False + cutoff = config[KEY.CUTOFF] + + log.write('\nInitializing dataset...\n') + + dataset = AtomGraphDataset({}, cutoff) + load_dataset = config[KEY.LOAD_DATASET] + if type(load_dataset) is str: + load_dataset = [load_dataset] + for file in load_dataset: + dataset.augment(dataset_load(file, config)) + + dataset.group_by_key() # apply labels inside original datapoint + dataset.unify_dtypes() # unify dtypes of all data points + + # TODO: I think manual chemical species input is redundant + chem_in_db = dataset.get_species() + if config[KEY.CHEMICAL_SPECIES] == 'auto' and not checkpoint_given: + log.writeline('Auto detect chemical species from dataset') + config.update(chemical_species_preprocess(chem_in_db)) + elif config[KEY.CHEMICAL_SPECIES] == 'auto' and checkpoint_given: + pass # copied from checkpoint in processing_continue.py + elif config[KEY.CHEMICAL_SPECIES] != 'auto' and not checkpoint_given: + pass # processed in parse_input.py + else: # config[KEY.CHEMICAL_SPECIES] != "auto" and checkpoint_given + log.writeline('Ignore chemical species in yaml, use checkpoint') + # already processed in processing_continue.py + + # basic dataset compatibility check with previous model + if checkpoint_given: + chem_from_cp = config[KEY.CHEMICAL_SPECIES] + if not all(chem in chem_from_cp for chem in chem_in_db): + raise ValueError('Chemical species in checkpoint is not compatible') + + # check what modalities are used in dataset + if config[KEY.USE_MODALITY]: + modalities = dataset.get_modalities() + num_modalities = len(modalities) + if num_modalities < 2: + Logger().writeline('Only one modal is given, ignore modality') + config.uptate({KEY.USE_MODALITY: False}) + + else: + modal_map_cp = config[KEY.MODAL_MAP + '_cp'] if checkpoint_given else {} + modal_map = modal_map_cp.copy() + current_idx = len(modal_map_cp) + for modal_key in modalities: + if modal_key not in modal_map.keys(): + modal_map[modal_key] = current_idx + current_idx += 1 + + if config[KEY.IS_DDP]: + # Synchronize modal_map + torch.cuda.set_device(config[KEY.LOCAL_RANK]) + modal_map_bcast = [modal_map] + dist.broadcast_object_list(modal_map_bcast, src=0) + modal_map = modal_map_bcast[0] + + config.update( + { + KEY.NUM_MODALITIES: len(modal_map), + KEY.MODAL_MAP: modal_map, + KEY.MODAL_LIST: list(modal_map.keys()), + } + ) + + dataset.write_modal_attr( + modal_map, + config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE], + ) + + # --------------- save dataset regardless of train/valid--------------# + save_dataset = config[KEY.SAVE_DATASET] + save_by_label = config[KEY.SAVE_BY_LABEL] + if save_dataset: + if save_dataset.endswith('.sevenn_data') is False: + save_dataset += '.sevenn_data' + if (save_dataset.startswith('.') or save_dataset.startswith('/')) is False: + save_dataset = prefix + save_dataset # save_data set is plain file name + dataset.save(save_dataset) + log.format_k_v('Dataset saved to', save_dataset, write=True) + # log.write(f"Loaded full dataset saved to : {save_dataset}\n") + if save_by_label: + dataset.save(prefix, by_label=True) + log.format_k_v('Dataset saved by label', prefix, write=True) + # --------------------------------------------------------------------# + + # TODO: testset is not used + ignore_test = not config.get(KEY.USE_TESTSET, False) + if KEY.LOAD_VALIDSET in config and config[KEY.LOAD_VALIDSET]: + train_set = dataset + test_set = AtomGraphDataset([], config[KEY.CUTOFF]) + + log.write('Loading validset from load_validset\n') + valid_set = AtomGraphDataset({}, cutoff) + for file in config[KEY.LOAD_VALIDSET]: + valid_set.augment(dataset_load(file, config)) + valid_set.group_by_key() + valid_set.unify_dtypes() + + # condition: validset labels should be subset of trainset labels + valid_labels = valid_set.user_labels + train_labels = train_set.user_labels + if set(valid_labels).issubset(set(train_labels)) is False: + valid_set = AtomGraphDataset(valid_set.to_list(), cutoff) + valid_set.rewrite_labels_to_data() + train_set = AtomGraphDataset(train_set.to_list(), cutoff) + train_set.rewrite_labels_to_data() + Logger().write('WARNING! validset labels is not subset of trainset\n') + Logger().write('We overwrite all the train, valid labels to default.\n') + Logger().write('Please create validset by sevenn_graph_build with -l\n') + + Logger().write('the validset loaded, load_dataset is now train_set\n') + Logger().write('the ratio will be ignored\n') + + # condition: validset modalities should be subset of trainset modalities + if config[KEY.USE_MODALITY]: + config_modality = config[KEY.MODAL_LIST] + valid_modality = valid_set.get_modalities() + + if set(valid_modality).issubset(set(config_modality)) is False: + raise ValueError('validset modality is not subset of trainset') + + valid_set.write_modal_attr( + config[KEY.MODAL_MAP], + config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE], + ) + else: + train_set, valid_set, test_set = dataset.divide_dataset( + config[KEY.RATIO], ignore_test=ignore_test + ) + log.write(f'The dataset divided into train, valid by {KEY.RATIO}\n') + + log.format_k_v('\nloaded trainset size is', train_set.len(), write=True) + log.format_k_v('\nloaded validset size is', valid_set.len(), write=True) + + log.write('Dataset initialization was successful\n') + + log.write('\nNumber of atoms in the train_set:\n') + log.natoms_write(train_set.get_natoms(config[KEY.TYPE_MAP])) + + log.bar() + log.write('Per atom energy(eV/atom) distribution:\n') + log.statistic_write(train_set.get_statistics(KEY.PER_ATOM_ENERGY)) + log.bar() + log.write('Force(eV/Angstrom) distribution:\n') + log.statistic_write(train_set.get_statistics(KEY.FORCE)) + log.bar() + log.write('Stress(eV/Angstrom^3) distribution:\n') + try: + log.statistic_write(train_set.get_statistics(KEY.STRESS)) + except KeyError: + log.write('\n Stress is not included in the train_set\n') + if is_stress: + is_stress = False + log.write('Turn off stress training\n') + log.bar() + + # saved data must have atomic numbers as X not one hot idx + if config[KEY.SAVE_BY_TRAIN_VALID]: + train_set.save(prefix + 'train') + valid_set.save(prefix + 'valid') + log.format_k_v('Dataset saved by train, valid', prefix, write=True) + + # inconsistent .info dict give error when collate + _, _ = train_set.separate_info() + _, _ = valid_set.separate_info() + + if train_set.x_is_one_hot_idx is False: + train_set.x_to_one_hot_idx(config[KEY.TYPE_MAP]) + if valid_set.x_is_one_hot_idx is False: + valid_set.x_to_one_hot_idx(config[KEY.TYPE_MAP]) + + log.format_k_v('training_set size', train_set.len(), write=True) + log.format_k_v('validation_set size', valid_set.len(), write=True) + + shift, scale, conv_denominator = handle_shift_scale( + config, train_set, checkpoint_given + ) + config.update( + { + KEY.SHIFT: shift, + KEY.SCALE: scale, + KEY.CONV_DENOMINATOR: conv_denominator, + } + ) + + data_lists = (train_set.to_list(), valid_set.to_list(), test_set.to_list()) + + return data_lists diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_epoch.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_epoch.py index 3a88669ff44002be88e7603d78d7a43cc822f6ee..ec2d7012301ea7a7f2ae9aadff51a616dbe4bcab 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_epoch.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_epoch.py @@ -1,182 +1,182 @@ -import os -from copy import deepcopy -from typing import Optional - -import torch -from torch.utils.data.distributed import DistributedSampler - -import sevenn._keys as KEY -from sevenn.error_recorder import ErrorRecorder -from sevenn.logger import Logger -from sevenn.train.trainer import Trainer - - -def processing_epoch_v2( - config: dict, - trainer: Trainer, - loaders: dict, # dict[str, Dataset] - start_epoch: int = 1, - train_loader_key: str = 'trainset', - error_recorder: Optional[ErrorRecorder] = None, - total_epoch: Optional[int] = None, - per_epoch: Optional[int] = None, - best_metric_loader_key: str = 'validset', - best_metric: Optional[str] = None, - write_csv: bool = True, - working_dir: Optional[str] = None, -): - from sevenn.util import unique_filepath - - log = Logger() - write_csv = write_csv and log.rank == 0 - working_dir = working_dir or os.getcwd() - prefix = f'{os.path.abspath(working_dir)}/' - - total_epoch = total_epoch or config[KEY.EPOCH] - per_epoch = per_epoch or config.get(KEY.PER_EPOCH, 10) - best_metric = best_metric or config.get(KEY.BEST_METRIC, 'TotalLoss') - recorder = error_recorder or ErrorRecorder.from_config( - config, trainer.loss_functions - ) - recorders = {k: deepcopy(recorder) for k in loaders} - - best_val = float('inf') - best_key = None - if best_metric_loader_key in recorders: - best_key = recorders[best_metric_loader_key].get_key_str(best_metric) - if best_key is None: - log.writeline( - f'Failed to get error recorder key: {best_metric} or ' - + f'{best_metric_loader_key} is missing. There will be no best ' - + 'checkpoint.' - ) - - csv_path = unique_filepath(f'{prefix}/lc.csv') - if write_csv: - head = ['epoch', 'lr'] - for k, rec in recorders.items(): - head.extend(list(rec.get_dct(prefix=k))) - with open(csv_path, 'w') as f: - f.write(','.join(head) + '\n') - - if start_epoch == 1: - path = f'{prefix}/checkpoint_0.pth' # save first epoch - trainer.write_checkpoint(path, config=config, epoch=0) - - for epoch in range(start_epoch, total_epoch + 1): # one indexing - log.timer_start('epoch') - lr = trainer.get_lr() - log.bar() - log.write(f'Epoch {epoch}/{total_epoch} lr: {lr:8f}\n') - log.bar() - - csv_dct = {'epoch': str(epoch), 'lr': f'{lr:8f}'} - errors = {} - for k, loader in loaders.items(): - is_train = k == train_loader_key - if ( - trainer.distributed - and isinstance(loader.sampler, DistributedSampler) - and is_train - and config.get('train_shuffle', True) - ): - loader.sampler.set_epoch(epoch) - - rec = recorders[k] - trainer.run_one_epoch(loader, is_train, rec) - csv_dct.update(rec.get_dct(prefix=k)) - errors[k] = rec.epoch_forward() - log.write_full_table(list(errors.values()), list(errors)) - trainer.scheduler_step(best_val) - - if write_csv: - with open(csv_path, 'a') as f: - f.write(','.join(list(csv_dct.values())) + '\n') - - if best_key and errors[best_metric_loader_key][best_key] < best_val: - path = f'{prefix}/checkpoint_best.pth' - trainer.write_checkpoint(path, config=config, epoch=epoch) - best_val = errors[best_metric_loader_key][best_key] - log.writeline('Best checkpoint written') - - if epoch % per_epoch == 0: - path = f'{prefix}/checkpoint_{epoch}.pth' - trainer.write_checkpoint(path, config=config, epoch=epoch) - - log.timer_end('epoch', message=f'Epoch {epoch} elapsed') - return trainer - - -def processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir): - log = Logger() - prefix = f'{os.path.abspath(working_dir)}/' - train_loader, valid_loader = loaders - - is_distributed = config[KEY.IS_DDP] - rank = config[KEY.RANK] - total_epoch = config[KEY.EPOCH] - per_epoch = config[KEY.PER_EPOCH] - train_recorder = ErrorRecorder.from_config(config) - valid_recorder = ErrorRecorder.from_config(config) - best_metric = config[KEY.BEST_METRIC] - csv_fname = f'{prefix}{config[KEY.CSV_LOG]}' - current_best = float('inf') - - if init_csv: - csv_header = ['Epoch', 'Learning_rate'] - # Assume train valid have the same metrics - for metric in train_recorder.get_metric_dict().keys(): - csv_header.append(f'Train_{metric}') - csv_header.append(f'Valid_{metric}') - log.init_csv(csv_fname, csv_header) - - def write_checkpoint(epoch, is_best=False): - if is_distributed and rank != 0: - return - suffix = '_best' if is_best else f'_{epoch}' - checkpoint = trainer.get_checkpoint_dict() - checkpoint.update({'config': config, 'epoch': epoch}) - torch.save(checkpoint, f'{prefix}/checkpoint{suffix}.pth') - - fin_epoch = total_epoch + start_epoch - for epoch in range(start_epoch, fin_epoch): - lr = trainer.get_lr() - log.timer_start('epoch') - log.bar() - log.write(f'Epoch {epoch}/{fin_epoch - 1} lr: {lr:8f}\n') - log.bar() - - trainer.run_one_epoch( - train_loader, is_train=True, error_recorder=train_recorder - ) - train_err = train_recorder.epoch_forward() - - trainer.run_one_epoch(valid_loader, error_recorder=valid_recorder) - valid_err = valid_recorder.epoch_forward() - - csv_values = [epoch, lr] - for metric in train_err: - csv_values.append(train_err[metric]) - csv_values.append(valid_err[metric]) - log.append_csv(csv_fname, csv_values) - - log.write_full_table([train_err, valid_err], ['Train', 'Valid']) - - val = None - for metric in valid_err: - # loose string comparison, - # e.g. "Energy" in "TotalEnergy" or "Energy_Loss" - if best_metric in metric: - val = valid_err[metric] - break - assert val is not None, f'Metric {best_metric} not found in {valid_err}' - trainer.scheduler_step(val) - - log.timer_end('epoch', message=f'Epoch {epoch} elapsed') - - if val < current_best: - current_best = val - write_checkpoint(epoch, is_best=True) - log.writeline('Best checkpoint written') - if epoch % per_epoch == 0: - write_checkpoint(epoch) +import os +from copy import deepcopy +from typing import Optional + +import torch +from torch.utils.data.distributed import DistributedSampler + +import sevenn._keys as KEY +from sevenn.error_recorder import ErrorRecorder +from sevenn.logger import Logger +from sevenn.train.trainer import Trainer + + +def processing_epoch_v2( + config: dict, + trainer: Trainer, + loaders: dict, # dict[str, Dataset] + start_epoch: int = 1, + train_loader_key: str = 'trainset', + error_recorder: Optional[ErrorRecorder] = None, + total_epoch: Optional[int] = None, + per_epoch: Optional[int] = None, + best_metric_loader_key: str = 'validset', + best_metric: Optional[str] = None, + write_csv: bool = True, + working_dir: Optional[str] = None, +): + from sevenn.util import unique_filepath + + log = Logger() + write_csv = write_csv and log.rank == 0 + working_dir = working_dir or os.getcwd() + prefix = f'{os.path.abspath(working_dir)}/' + + total_epoch = total_epoch or config[KEY.EPOCH] + per_epoch = per_epoch or config.get(KEY.PER_EPOCH, 10) + best_metric = best_metric or config.get(KEY.BEST_METRIC, 'TotalLoss') + recorder = error_recorder or ErrorRecorder.from_config( + config, trainer.loss_functions + ) + recorders = {k: deepcopy(recorder) for k in loaders} + + best_val = float('inf') + best_key = None + if best_metric_loader_key in recorders: + best_key = recorders[best_metric_loader_key].get_key_str(best_metric) + if best_key is None: + log.writeline( + f'Failed to get error recorder key: {best_metric} or ' + + f'{best_metric_loader_key} is missing. There will be no best ' + + 'checkpoint.' + ) + + csv_path = unique_filepath(f'{prefix}/lc.csv') + if write_csv: + head = ['epoch', 'lr'] + for k, rec in recorders.items(): + head.extend(list(rec.get_dct(prefix=k))) + with open(csv_path, 'w') as f: + f.write(','.join(head) + '\n') + + if start_epoch == 1: + path = f'{prefix}/checkpoint_0.pth' # save first epoch + trainer.write_checkpoint(path, config=config, epoch=0) + + for epoch in range(start_epoch, total_epoch + 1): # one indexing + log.timer_start('epoch') + lr = trainer.get_lr() + log.bar() + log.write(f'Epoch {epoch}/{total_epoch} lr: {lr:8f}\n') + log.bar() + + csv_dct = {'epoch': str(epoch), 'lr': f'{lr:8f}'} + errors = {} + for k, loader in loaders.items(): + is_train = k == train_loader_key + if ( + trainer.distributed + and isinstance(loader.sampler, DistributedSampler) + and is_train + and config.get('train_shuffle', True) + ): + loader.sampler.set_epoch(epoch) + + rec = recorders[k] + trainer.run_one_epoch(loader, is_train, rec) + csv_dct.update(rec.get_dct(prefix=k)) + errors[k] = rec.epoch_forward() + log.write_full_table(list(errors.values()), list(errors)) + trainer.scheduler_step(best_val) + + if write_csv: + with open(csv_path, 'a') as f: + f.write(','.join(list(csv_dct.values())) + '\n') + + if best_key and errors[best_metric_loader_key][best_key] < best_val: + path = f'{prefix}/checkpoint_best.pth' + trainer.write_checkpoint(path, config=config, epoch=epoch) + best_val = errors[best_metric_loader_key][best_key] + log.writeline('Best checkpoint written') + + if epoch % per_epoch == 0: + path = f'{prefix}/checkpoint_{epoch}.pth' + trainer.write_checkpoint(path, config=config, epoch=epoch) + + log.timer_end('epoch', message=f'Epoch {epoch} elapsed') + return trainer + + +def processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir): + log = Logger() + prefix = f'{os.path.abspath(working_dir)}/' + train_loader, valid_loader = loaders + + is_distributed = config[KEY.IS_DDP] + rank = config[KEY.RANK] + total_epoch = config[KEY.EPOCH] + per_epoch = config[KEY.PER_EPOCH] + train_recorder = ErrorRecorder.from_config(config) + valid_recorder = ErrorRecorder.from_config(config) + best_metric = config[KEY.BEST_METRIC] + csv_fname = f'{prefix}{config[KEY.CSV_LOG]}' + current_best = float('inf') + + if init_csv: + csv_header = ['Epoch', 'Learning_rate'] + # Assume train valid have the same metrics + for metric in train_recorder.get_metric_dict().keys(): + csv_header.append(f'Train_{metric}') + csv_header.append(f'Valid_{metric}') + log.init_csv(csv_fname, csv_header) + + def write_checkpoint(epoch, is_best=False): + if is_distributed and rank != 0: + return + suffix = '_best' if is_best else f'_{epoch}' + checkpoint = trainer.get_checkpoint_dict() + checkpoint.update({'config': config, 'epoch': epoch}) + torch.save(checkpoint, f'{prefix}/checkpoint{suffix}.pth') + + fin_epoch = total_epoch + start_epoch + for epoch in range(start_epoch, fin_epoch): + lr = trainer.get_lr() + log.timer_start('epoch') + log.bar() + log.write(f'Epoch {epoch}/{fin_epoch - 1} lr: {lr:8f}\n') + log.bar() + + trainer.run_one_epoch( + train_loader, is_train=True, error_recorder=train_recorder + ) + train_err = train_recorder.epoch_forward() + + trainer.run_one_epoch(valid_loader, error_recorder=valid_recorder) + valid_err = valid_recorder.epoch_forward() + + csv_values = [epoch, lr] + for metric in train_err: + csv_values.append(train_err[metric]) + csv_values.append(valid_err[metric]) + log.append_csv(csv_fname, csv_values) + + log.write_full_table([train_err, valid_err], ['Train', 'Valid']) + + val = None + for metric in valid_err: + # loose string comparison, + # e.g. "Energy" in "TotalEnergy" or "Energy_Loss" + if best_metric in metric: + val = valid_err[metric] + break + assert val is not None, f'Metric {best_metric} not found in {valid_err}' + trainer.scheduler_step(val) + + log.timer_end('epoch', message=f'Epoch {epoch} elapsed') + + if val < current_best: + current_best = val + write_checkpoint(epoch, is_best=True) + log.writeline('Best checkpoint written') + if epoch % per_epoch == 0: + write_checkpoint(epoch) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/train.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/train.py index 888469d461b1ac22bfd2f06360c7743646f07b95..b0dabab733444510085ef3d2b2c1a96d816ab1cc 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/train.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/train.py @@ -1,139 +1,139 @@ -from typing import List, Optional - -import torch.distributed as dist -from torch.utils.data.distributed import DistributedSampler -from torch_geometric.loader import DataLoader - -import sevenn._keys as KEY -from sevenn.logger import Logger -from sevenn.model_build import build_E3_equivariant_model -from sevenn.scripts.processing_continue import ( - convert_modality_of_checkpoint_state_dct, -) -from sevenn.train.trainer import Trainer - - -def loader_from_config(config, dataset, is_train=False): - batch_size = config[KEY.BATCH_SIZE] - shuffle = is_train and config[KEY.TRAIN_SHUFFLE] - sampler = None - loader_args = { - 'dataset': dataset, - 'batch_size': batch_size, - 'shuffle': shuffle - } - if KEY.NUM_WORKERS in config and config[KEY.NUM_WORKERS] > 0: - loader_args.update({'num_workers': config[KEY.NUM_WORKERS]}) - - if config[KEY.IS_DDP]: - dist.barrier() - sampler = DistributedSampler( - dataset, dist.get_world_size(), dist.get_rank(), shuffle=shuffle - ) - loader_args.update({'sampler': sampler}) - loader_args.pop('shuffle') # sampler is mutually exclusive with shuffle - return DataLoader(**loader_args) - - -def train_v2(config, working_dir: str): - """ - Main program flow, since v0.9.6 - """ - import sevenn.train.atoms_dataset as atoms_dataset - import sevenn.train.graph_dataset as graph_dataset - import sevenn.train.modal_dataset as modal_dataset - - from .processing_continue import processing_continue_v2 - from .processing_epoch import processing_epoch_v2 - - log = Logger() - log.timer_start('total') - - if KEY.LOAD_TRAINSET not in config and KEY.LOAD_DATASET in config: - log.writeline('***************************************************') - log.writeline('For train_v2, please use load_trainset_path instead') - log.writeline('I will assign load_trainset as load_dataset') - log.writeline('***************************************************') - config[KEY.LOAD_TRAINSET] = config.pop(KEY.LOAD_DATASET) - - # config updated - start_epoch = 1 - state_dicts: Optional[List[dict]] = None - if config[KEY.CONTINUE][KEY.CHECKPOINT]: - state_dicts, start_epoch = processing_continue_v2(config) - - if config.get(KEY.USE_MODALITY, False): - datasets = modal_dataset.from_config(config, working_dir) - elif config[KEY.DATASET_TYPE] == 'graph': - datasets = graph_dataset.from_config(config, working_dir) - elif config[KEY.DATASET_TYPE] == 'atoms': - datasets = atoms_dataset.from_config(config, working_dir) - else: - raise ValueError(f'Unknown dataset type: {config[KEY.DATASET_TYPE]}') - loaders = { - k: loader_from_config(config, v, is_train=(k == 'trainset')) - for k, v in datasets.items() - } - - log.write('\nModel building...\n') - model = build_E3_equivariant_model(config) - log.print_model_info(model, config) - - trainer = Trainer.from_config(model, config) - if state_dicts: - trainer.load_state_dicts(*state_dicts, strict=False) - - processing_epoch_v2( - config, trainer, loaders, start_epoch, working_dir=working_dir - ) - log.timer_end('total', message='Total wall time') - - -def train(config, working_dir: str): - """ - Main program flow, until v0.9.5 - """ - from .processing_continue import processing_continue - from .processing_dataset import processing_dataset - from .processing_epoch import processing_epoch - - log = Logger() - log.timer_start('total') - - # config updated - state_dicts: Optional[List[dict]] = None - if config[KEY.CONTINUE][KEY.CHECKPOINT]: - state_dicts, start_epoch, init_csv = processing_continue(config) - else: - start_epoch, init_csv = 1, True - - # config updated - train, valid, _ = processing_dataset(config, working_dir) - datasets = {'dataset': train, 'validset': valid} - loaders = { - k: loader_from_config(config, v, is_train=(k == 'dataset')) - for k, v in datasets.items() - } - loaders = list(loaders.values()) - - log.write('\nModel building...\n') - model = build_E3_equivariant_model(config) - - log.write('Model building was successful\n') - - trainer = Trainer.from_config(model, config) - if state_dicts: - state_dicts = convert_modality_of_checkpoint_state_dct( - config, state_dicts - ) - trainer.load_state_dicts(*state_dicts, strict=False) - - log.print_model_info(model, config) - - Logger().write('Trainer initialized, ready to training\n') - Logger().bar() - log.write('Trainer initialized, ready to training\n') - log.bar() - - processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir) - log.timer_end('total', message='Total wall time') +from typing import List, Optional + +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from torch_geometric.loader import DataLoader + +import sevenn._keys as KEY +from sevenn.logger import Logger +from sevenn.model_build import build_E3_equivariant_model +from sevenn.scripts.processing_continue import ( + convert_modality_of_checkpoint_state_dct, +) +from sevenn.train.trainer import Trainer + + +def loader_from_config(config, dataset, is_train=False): + batch_size = config[KEY.BATCH_SIZE] + shuffle = is_train and config[KEY.TRAIN_SHUFFLE] + sampler = None + loader_args = { + 'dataset': dataset, + 'batch_size': batch_size, + 'shuffle': shuffle + } + if KEY.NUM_WORKERS in config and config[KEY.NUM_WORKERS] > 0: + loader_args.update({'num_workers': config[KEY.NUM_WORKERS]}) + + if config[KEY.IS_DDP]: + dist.barrier() + sampler = DistributedSampler( + dataset, dist.get_world_size(), dist.get_rank(), shuffle=shuffle + ) + loader_args.update({'sampler': sampler}) + loader_args.pop('shuffle') # sampler is mutually exclusive with shuffle + return DataLoader(**loader_args) + + +def train_v2(config, working_dir: str): + """ + Main program flow, since v0.9.6 + """ + import sevenn.train.atoms_dataset as atoms_dataset + import sevenn.train.graph_dataset as graph_dataset + import sevenn.train.modal_dataset as modal_dataset + + from .processing_continue import processing_continue_v2 + from .processing_epoch import processing_epoch_v2 + + log = Logger() + log.timer_start('total') + + if KEY.LOAD_TRAINSET not in config and KEY.LOAD_DATASET in config: + log.writeline('***************************************************') + log.writeline('For train_v2, please use load_trainset_path instead') + log.writeline('I will assign load_trainset as load_dataset') + log.writeline('***************************************************') + config[KEY.LOAD_TRAINSET] = config.pop(KEY.LOAD_DATASET) + + # config updated + start_epoch = 1 + state_dicts: Optional[List[dict]] = None + if config[KEY.CONTINUE][KEY.CHECKPOINT]: + state_dicts, start_epoch = processing_continue_v2(config) + + if config.get(KEY.USE_MODALITY, False): + datasets = modal_dataset.from_config(config, working_dir) + elif config[KEY.DATASET_TYPE] == 'graph': + datasets = graph_dataset.from_config(config, working_dir) + elif config[KEY.DATASET_TYPE] == 'atoms': + datasets = atoms_dataset.from_config(config, working_dir) + else: + raise ValueError(f'Unknown dataset type: {config[KEY.DATASET_TYPE]}') + loaders = { + k: loader_from_config(config, v, is_train=(k == 'trainset')) + for k, v in datasets.items() + } + + log.write('\nModel building...\n') + model = build_E3_equivariant_model(config) + log.print_model_info(model, config) + + trainer = Trainer.from_config(model, config) + if state_dicts: + trainer.load_state_dicts(*state_dicts, strict=False) + + processing_epoch_v2( + config, trainer, loaders, start_epoch, working_dir=working_dir + ) + log.timer_end('total', message='Total wall time') + + +def train(config, working_dir: str): + """ + Main program flow, until v0.9.5 + """ + from .processing_continue import processing_continue + from .processing_dataset import processing_dataset + from .processing_epoch import processing_epoch + + log = Logger() + log.timer_start('total') + + # config updated + state_dicts: Optional[List[dict]] = None + if config[KEY.CONTINUE][KEY.CHECKPOINT]: + state_dicts, start_epoch, init_csv = processing_continue(config) + else: + start_epoch, init_csv = 1, True + + # config updated + train, valid, _ = processing_dataset(config, working_dir) + datasets = {'dataset': train, 'validset': valid} + loaders = { + k: loader_from_config(config, v, is_train=(k == 'dataset')) + for k, v in datasets.items() + } + loaders = list(loaders.values()) + + log.write('\nModel building...\n') + model = build_E3_equivariant_model(config) + + log.write('Model building was successful\n') + + trainer = Trainer.from_config(model, config) + if state_dicts: + state_dicts = convert_modality_of_checkpoint_state_dct( + config, state_dicts + ) + trainer.load_state_dicts(*state_dicts, strict=False) + + log.print_model_info(model, config) + + Logger().write('Trainer initialized, ready to training\n') + Logger().bar() + log.write('Trainer initialized, ready to training\n') + log.bar() + + processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir) + log.timer_end('total', message='Total wall time') diff --git a/mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py b/mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py index e3ee4904ff5448f495daf877c809b75f2f19e368..7efea4af1484224287a11625cfdbf7a8f2a7b762 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py @@ -1,6 +1,6 @@ -import warnings - -from .logger import * # noqa: F403 - -warnings.warn('Please use sevenn.logger instead of sevenn.sevenn_logger', - DeprecationWarning, stacklevel=2) +import warnings + +from .logger import * # noqa: F403 + +warnings.warn('Please use sevenn.logger instead of sevenn.sevenn_logger', + DeprecationWarning, stacklevel=2) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py b/mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py index f7e67252b6689ad032359b55901fe85e950b556c..ba3c78706eab27ac3925620f325fab25e4bdb09d 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py @@ -1,6 +1,6 @@ -import warnings - -from .calculator import * # noqa: F403 - -warnings.warn('Please use sevenn.calculator instead of sevenn.sevennet_calculator', - DeprecationWarning, stacklevel=2) +import warnings + +from .calculator import * # noqa: F403 + +warnings.warn('Please use sevenn.calculator instead of sevenn.sevennet_calculator', + DeprecationWarning, stacklevel=2) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 06affaeb2937478a01a3398a013f17cb8fcbecf9..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataload.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataload.cpython-310.pyc deleted file mode 100644 index d03c532352c5693d473df9bf8e4cff8a3a9bdbec..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataload.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataset.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataset.cpython-310.pyc deleted file mode 100644 index 55d4fb5c33d33ea6f7e3b2dbb20068e7847bc6a8..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataset.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/atoms_dataset.py b/mace-bench/3rdparty/SevenNet/sevenn/train/atoms_dataset.py index b12564af64ccc0eb30be606312b11c6e0808123e..178e8909628274e9a909b1147e75c03eb100a63f 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/atoms_dataset.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/atoms_dataset.py @@ -1,314 +1,314 @@ -import os -import random -import warnings -from collections import Counter -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import torch.utils.data -from ase.atoms import Atoms -from ase.data import chemical_symbols -from ase.io import write -from tqdm import tqdm - -import sevenn._keys as KEY -import sevenn.train.dataload as dataload -import sevenn.util as util -from sevenn._const import NUM_UNIV_ELEMENT -from sevenn.atom_graph_data import AtomGraphData - -_warn_avg_num_neigh = """SevenNetAtomsDataset does not provide correct avg_num_neigh -as it does not build graph. We will compute only random 10000 structures graph to -approximate this value. If you want more precise avg_num_neigh, -use SevenNetGraphDataset. If it is not viable due to memory limit, you -need online algorithm to do this , which is not yet implemented in the SevenNet""" - - -class SevenNetAtomsDataset(torch.utils.data.Dataset): - """ - Args: - cutoff: edge cutoff of given AtomGraphData - files: list of filenames or dict describing how to parse the file - ASE readable (with proper extension), structure_list, .sevenn_data, - dict containing file_list (see dict_reader of train/dataload.py) - info_dict_copy_keys: patch these keys from KEY.INFO to graph when accessing. - default is KEY.DATA_WEIGHT and KEY.DATA_MODALITY, which may accessed - while training. - **process_kwargs: keyword arguments that will be passed into ase.io.read - """ - - def __init__( - self, - cutoff: float, - files: Union[str, List[str]], - atoms_filter: Optional[Callable] = None, - atoms_transform: Optional[Callable] = None, - transform: Optional[Callable] = None, - use_data_weight: bool = False, - **process_kwargs, - ): - self.cutoff = cutoff - if isinstance(files, str): - files = [files] # user convenience - files = [os.path.abspath(file) for file in files] - self._files = files - self.atoms_filter = atoms_filter - self.atoms_transform = atoms_transform - self.transform = transform - self.use_data_weight = use_data_weight - self._scanned = False - self._avg_num_neigh_approx = None - self.statistics = {} - - atoms_list = [] - for file in files: - atoms_list.extend( - SevenNetAtomsDataset.file_to_atoms_list(file, **process_kwargs) - ) - self._atoms_list = atoms_list - - super().__init__() - - @staticmethod - def file_to_atoms_list(file: Union[str, dict], **kwargs) -> List[Atoms]: - if isinstance(file, dict): - atoms_list = dataload.dict_reader(file) - elif 'structure_list' in file: - atoms_dct = dataload.structure_list_reader(file) - atoms_list = [] - for lst in atoms_dct.values(): - atoms_list.extend(lst) - else: - atoms_list = dataload.ase_reader(file, **kwargs) - return atoms_list - - def save(self, path): - # Save atoms list as extxyz - write(path, self._atoms_list, format='extxyz') - - def _graph_build(self, atoms): - return dataload.atoms_to_graph( - atoms, self.cutoff, transfer_info=False, y_from_calc=False - ) - - def __len__(self): - return len(self._atoms_list) - - def __getitem__(self, index): - atoms = self._atoms_list[index] - if self.atoms_transform is not None: - atoms = self.atoms_transform(atoms) - - graph = self._graph_build(atoms) - if self.transform is not None: - graph = self.transform(graph) - - if self.use_data_weight: - weight = graph[KEY.INFO].pop( - KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0} - ) - graph[KEY.DATA_WEIGHT] = weight - - return AtomGraphData.from_numpy_dict(graph) - - @property - def species(self): - self.run_stat() - return [z for z in self.statistics['_natoms'].keys() if z != 'total'] - - @property - def natoms(self): - self.run_stat() - return self.statistics['_natoms'] - - @property - def per_atom_energy_mean(self): - self.run_stat() - return self.statistics[KEY.PER_ATOM_ENERGY]['mean'] - - @property - def elemwise_reference_energies(self): - from sklearn.linear_model import Ridge - - c = self.statistics['_composition'] - y = self.statistics[KEY.ENERGY]['_array'] - zero_indices = np.all(c == 0, axis=0) - c_reduced = c[:, ~zero_indices] - # will not 100% reproduce, as it is sorted by Z - # train/dataset.py was sorted by alphabets of chemical species - coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ - full_coeff = np.zeros(NUM_UNIV_ELEMENT) - full_coeff[~zero_indices] = coef_reduced - return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy - - @property - def force_rms(self): - self.run_stat() - mean = self.statistics[KEY.FORCE]['mean'] - std = self.statistics[KEY.FORCE]['std'] - return float((mean**2 + std**2) ** (0.5)) - - @property - def per_atom_energy_std(self): - self.run_stat() - return self.statistics['per_atom_energy']['std'] - - @property - def avg_num_neigh(self, n_sample=10000): - if self._avg_num_neigh_approx is None: - if len(self) > n_sample: - warnings.warn(_warn_avg_num_neigh) - n_sample = min(len(self), n_sample) - indices = random.sample(range(len(self)), n_sample) - n_neigh = [] - for i in indices: - graph = self[i] - _, nn = np.unique(graph[KEY.EDGE_IDX][0], return_counts=True) - n_neigh.append(nn) - n_neigh = np.concatenate(n_neigh) - self._avg_num_neigh_approx = np.mean(n_neigh) - return self._avg_num_neigh_approx - - @property - def sqrt_avg_num_neigh(self): - self.run_stat() - return self.avg_num_neigh**0.5 - - def run_stat(self): - """ - Loop over dataset and init any statistics might need - Unlink SevenNetGraphDataset, neighbors count is not computed as - it requires to build graph - """ - if self._scanned is True: - return # statistics already computed - y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS] - natoms_counter = Counter() - composition = np.zeros((len(self), NUM_UNIV_ELEMENT)) - stats: Dict[str, Dict[str, Any]] = {y: {'_array': []} for y in y_keys} - - for i, atoms in tqdm( - enumerate(self._atoms_list), desc='run_stat', total=len(self) - ): - z = atoms.get_atomic_numbers() - natoms_counter.update(z.tolist()) - composition[i] = np.bincount(z, minlength=NUM_UNIV_ELEMENT) - for y, dct in stats.items(): - if y == KEY.ENERGY: - dct['_array'].append(atoms.info['y_energy']) - elif y == KEY.PER_ATOM_ENERGY: - dct['_array'].append(atoms.info['y_energy'] / len(atoms)) - elif y == KEY.FORCE: - dct['_array'].append(atoms.arrays['y_force'].reshape(-1)) - elif y == KEY.STRESS: - dct['_array'].append(atoms.info['y_stress'].reshape(-1)) - - for y, dct in stats.items(): - if y == KEY.FORCE: - array = np.concatenate(dct['_array']) - else: - array = np.array(dct['_array']).reshape(-1) - dct.update( - { - 'mean': float(np.mean(array)), - 'std': float(np.std(array)), - 'median': float(np.quantile(array, q=0.5)), - 'max': float(np.max(array)), - 'min': float(np.min(array)), - '_array': array, - } - ) - - natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()} - natoms['total'] = sum(list(natoms.values())) - self.statistics.update( - { - '_composition': composition, - '_natoms': natoms, - **stats, - } - ) - self._scanned = True - - -# script, return dict of SevenNetAtomsDataset -def from_config( - config: Dict[str, Any], - working_dir: str = os.getcwd(), - dataset_keys: Optional[List[str]] = None, -): - from sevenn.logger import Logger - - log = Logger() - if dataset_keys is None: - dataset_keys = [] - for k in config: - if k.startswith('load_') and k.endswith('_path'): - dataset_keys.append(k) - - if KEY.LOAD_TRAINSET not in dataset_keys: - raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') - - # initialize arguments for loading dataset - dataset_args = { - 'cutoff': config[KEY.CUTOFF], - 'use_data_weight': config.get(KEY.USE_WEIGHT, False), - **config[KEY.DATA_FORMAT_ARGS], - } - - datasets = {} - for dk in dataset_keys: - if not (paths := config[dk]): - continue - if isinstance(paths, str): - paths = [paths] - name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) - dataset_args.update({'files': paths}) - datasets[name] = SevenNetAtomsDataset(**dataset_args) - - if not config[KEY.COMPUTE_STATISTICS]: - log.writeline( - ( - 'Computing statistics is skipped, note that if any of other' - 'configurations requires statistics (shift, scale, avg_num_neigh,' - 'chemical_species as auto), SevenNet eventually raise an error!' - ) - ) - return datasets - - train_set = datasets['trainset'] - - chem_species = set(train_set.species) - # print statistics of each dataset - for name, dataset in datasets.items(): - dataset.run_stat() - log.bar() - log.writeline(f'{name} distribution:') - log.statistic_write(dataset.statistics) - log.format_k_v('# atoms (node)', dataset.natoms, write=True) - log.format_k_v('# structures (graph)', len(dataset), write=True) - - chem_species.update(dataset.species) - log.bar() - - # initialize known species from dataset if 'auto' - # sorted to alphabetical order (which is same as before) - chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] - if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py - log.writeline('Known species are obtained from the dataset') - config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) - - # retrieve shift, scale, conv_denominaotrs from user input (keyword) - init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] - for k in init_from_stats: - input = config[k] # statistic key or numbers - # If it is not 'str', 1: It is 'continue' training - # 2: User manually inserted numbers - if isinstance(input, str) and hasattr(train_set, input): - var = getattr(train_set, input) - config.update({k: var}) - log.writeline(f'{k} is obtained from statistics') - elif isinstance(input, str) and not hasattr(train_set, input): - raise NotImplementedError(input) - - return datasets +import os +import random +import warnings +from collections import Counter +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch.utils.data +from ase.atoms import Atoms +from ase.data import chemical_symbols +from ase.io import write +from tqdm import tqdm + +import sevenn._keys as KEY +import sevenn.train.dataload as dataload +import sevenn.util as util +from sevenn._const import NUM_UNIV_ELEMENT +from sevenn.atom_graph_data import AtomGraphData + +_warn_avg_num_neigh = """SevenNetAtomsDataset does not provide correct avg_num_neigh +as it does not build graph. We will compute only random 10000 structures graph to +approximate this value. If you want more precise avg_num_neigh, +use SevenNetGraphDataset. If it is not viable due to memory limit, you +need online algorithm to do this , which is not yet implemented in the SevenNet""" + + +class SevenNetAtomsDataset(torch.utils.data.Dataset): + """ + Args: + cutoff: edge cutoff of given AtomGraphData + files: list of filenames or dict describing how to parse the file + ASE readable (with proper extension), structure_list, .sevenn_data, + dict containing file_list (see dict_reader of train/dataload.py) + info_dict_copy_keys: patch these keys from KEY.INFO to graph when accessing. + default is KEY.DATA_WEIGHT and KEY.DATA_MODALITY, which may accessed + while training. + **process_kwargs: keyword arguments that will be passed into ase.io.read + """ + + def __init__( + self, + cutoff: float, + files: Union[str, List[str]], + atoms_filter: Optional[Callable] = None, + atoms_transform: Optional[Callable] = None, + transform: Optional[Callable] = None, + use_data_weight: bool = False, + **process_kwargs, + ): + self.cutoff = cutoff + if isinstance(files, str): + files = [files] # user convenience + files = [os.path.abspath(file) for file in files] + self._files = files + self.atoms_filter = atoms_filter + self.atoms_transform = atoms_transform + self.transform = transform + self.use_data_weight = use_data_weight + self._scanned = False + self._avg_num_neigh_approx = None + self.statistics = {} + + atoms_list = [] + for file in files: + atoms_list.extend( + SevenNetAtomsDataset.file_to_atoms_list(file, **process_kwargs) + ) + self._atoms_list = atoms_list + + super().__init__() + + @staticmethod + def file_to_atoms_list(file: Union[str, dict], **kwargs) -> List[Atoms]: + if isinstance(file, dict): + atoms_list = dataload.dict_reader(file) + elif 'structure_list' in file: + atoms_dct = dataload.structure_list_reader(file) + atoms_list = [] + for lst in atoms_dct.values(): + atoms_list.extend(lst) + else: + atoms_list = dataload.ase_reader(file, **kwargs) + return atoms_list + + def save(self, path): + # Save atoms list as extxyz + write(path, self._atoms_list, format='extxyz') + + def _graph_build(self, atoms): + return dataload.atoms_to_graph( + atoms, self.cutoff, transfer_info=False, y_from_calc=False + ) + + def __len__(self): + return len(self._atoms_list) + + def __getitem__(self, index): + atoms = self._atoms_list[index] + if self.atoms_transform is not None: + atoms = self.atoms_transform(atoms) + + graph = self._graph_build(atoms) + if self.transform is not None: + graph = self.transform(graph) + + if self.use_data_weight: + weight = graph[KEY.INFO].pop( + KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0} + ) + graph[KEY.DATA_WEIGHT] = weight + + return AtomGraphData.from_numpy_dict(graph) + + @property + def species(self): + self.run_stat() + return [z for z in self.statistics['_natoms'].keys() if z != 'total'] + + @property + def natoms(self): + self.run_stat() + return self.statistics['_natoms'] + + @property + def per_atom_energy_mean(self): + self.run_stat() + return self.statistics[KEY.PER_ATOM_ENERGY]['mean'] + + @property + def elemwise_reference_energies(self): + from sklearn.linear_model import Ridge + + c = self.statistics['_composition'] + y = self.statistics[KEY.ENERGY]['_array'] + zero_indices = np.all(c == 0, axis=0) + c_reduced = c[:, ~zero_indices] + # will not 100% reproduce, as it is sorted by Z + # train/dataset.py was sorted by alphabets of chemical species + coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ + full_coeff = np.zeros(NUM_UNIV_ELEMENT) + full_coeff[~zero_indices] = coef_reduced + return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy + + @property + def force_rms(self): + self.run_stat() + mean = self.statistics[KEY.FORCE]['mean'] + std = self.statistics[KEY.FORCE]['std'] + return float((mean**2 + std**2) ** (0.5)) + + @property + def per_atom_energy_std(self): + self.run_stat() + return self.statistics['per_atom_energy']['std'] + + @property + def avg_num_neigh(self, n_sample=10000): + if self._avg_num_neigh_approx is None: + if len(self) > n_sample: + warnings.warn(_warn_avg_num_neigh) + n_sample = min(len(self), n_sample) + indices = random.sample(range(len(self)), n_sample) + n_neigh = [] + for i in indices: + graph = self[i] + _, nn = np.unique(graph[KEY.EDGE_IDX][0], return_counts=True) + n_neigh.append(nn) + n_neigh = np.concatenate(n_neigh) + self._avg_num_neigh_approx = np.mean(n_neigh) + return self._avg_num_neigh_approx + + @property + def sqrt_avg_num_neigh(self): + self.run_stat() + return self.avg_num_neigh**0.5 + + def run_stat(self): + """ + Loop over dataset and init any statistics might need + Unlink SevenNetGraphDataset, neighbors count is not computed as + it requires to build graph + """ + if self._scanned is True: + return # statistics already computed + y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS] + natoms_counter = Counter() + composition = np.zeros((len(self), NUM_UNIV_ELEMENT)) + stats: Dict[str, Dict[str, Any]] = {y: {'_array': []} for y in y_keys} + + for i, atoms in tqdm( + enumerate(self._atoms_list), desc='run_stat', total=len(self) + ): + z = atoms.get_atomic_numbers() + natoms_counter.update(z.tolist()) + composition[i] = np.bincount(z, minlength=NUM_UNIV_ELEMENT) + for y, dct in stats.items(): + if y == KEY.ENERGY: + dct['_array'].append(atoms.info['y_energy']) + elif y == KEY.PER_ATOM_ENERGY: + dct['_array'].append(atoms.info['y_energy'] / len(atoms)) + elif y == KEY.FORCE: + dct['_array'].append(atoms.arrays['y_force'].reshape(-1)) + elif y == KEY.STRESS: + dct['_array'].append(atoms.info['y_stress'].reshape(-1)) + + for y, dct in stats.items(): + if y == KEY.FORCE: + array = np.concatenate(dct['_array']) + else: + array = np.array(dct['_array']).reshape(-1) + dct.update( + { + 'mean': float(np.mean(array)), + 'std': float(np.std(array)), + 'median': float(np.quantile(array, q=0.5)), + 'max': float(np.max(array)), + 'min': float(np.min(array)), + '_array': array, + } + ) + + natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()} + natoms['total'] = sum(list(natoms.values())) + self.statistics.update( + { + '_composition': composition, + '_natoms': natoms, + **stats, + } + ) + self._scanned = True + + +# script, return dict of SevenNetAtomsDataset +def from_config( + config: Dict[str, Any], + working_dir: str = os.getcwd(), + dataset_keys: Optional[List[str]] = None, +): + from sevenn.logger import Logger + + log = Logger() + if dataset_keys is None: + dataset_keys = [] + for k in config: + if k.startswith('load_') and k.endswith('_path'): + dataset_keys.append(k) + + if KEY.LOAD_TRAINSET not in dataset_keys: + raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') + + # initialize arguments for loading dataset + dataset_args = { + 'cutoff': config[KEY.CUTOFF], + 'use_data_weight': config.get(KEY.USE_WEIGHT, False), + **config[KEY.DATA_FORMAT_ARGS], + } + + datasets = {} + for dk in dataset_keys: + if not (paths := config[dk]): + continue + if isinstance(paths, str): + paths = [paths] + name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) + dataset_args.update({'files': paths}) + datasets[name] = SevenNetAtomsDataset(**dataset_args) + + if not config[KEY.COMPUTE_STATISTICS]: + log.writeline( + ( + 'Computing statistics is skipped, note that if any of other' + 'configurations requires statistics (shift, scale, avg_num_neigh,' + 'chemical_species as auto), SevenNet eventually raise an error!' + ) + ) + return datasets + + train_set = datasets['trainset'] + + chem_species = set(train_set.species) + # print statistics of each dataset + for name, dataset in datasets.items(): + dataset.run_stat() + log.bar() + log.writeline(f'{name} distribution:') + log.statistic_write(dataset.statistics) + log.format_k_v('# atoms (node)', dataset.natoms, write=True) + log.format_k_v('# structures (graph)', len(dataset), write=True) + + chem_species.update(dataset.species) + log.bar() + + # initialize known species from dataset if 'auto' + # sorted to alphabetical order (which is same as before) + chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] + if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py + log.writeline('Known species are obtained from the dataset') + config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) + + # retrieve shift, scale, conv_denominaotrs from user input (keyword) + init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] + for k in init_from_stats: + input = config[k] # statistic key or numbers + # If it is not 'str', 1: It is 'continue' training + # 2: User manually inserted numbers + if isinstance(input, str) and hasattr(train_set, input): + var = getattr(train_set, input) + config.update({k: var}) + log.writeline(f'{k} is obtained from statistics') + elif isinstance(input, str) and not hasattr(train_set, input): + raise NotImplementedError(input) + + return datasets diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/collate.py b/mace-bench/3rdparty/SevenNet/sevenn/train/collate.py index e3c902a4cd6be2dbdca4b422ce58ed916d539763..3e1ede9ea229f3f0f8ddafcfca43a0f4fbc89787 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/collate.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/collate.py @@ -1,41 +1,41 @@ -from typing import Any, List, Optional, Sequence - -from ase.atoms import Atoms -from torch_geometric.loader.dataloader import Collater - -from sevenn.atom_graph_data import AtomGraphData - -from .dataload import atoms_to_graph - - -class AtomsToGraphCollater(Collater): - - def __init__( - self, - dataset: Sequence[Atoms], - cutoff: float, - transfer_info: bool = False, - follow_batch: Optional[List[str]] = None, - exclude_keys: Optional[List[str]] = None, - y_from_calc: bool = True, - ): - # quite original collator's type mismatch with [] - super().__init__([], follow_batch, exclude_keys) - self.dataset = dataset - self.cutoff = cutoff - self.transfer_info = transfer_info - self.y_from_calc = y_from_calc - - def __call__(self, batch: List[Any]) -> Any: - # build list of graph - graph_list = [] - for stct in batch: - graph = atoms_to_graph( - stct, - self.cutoff, - transfer_info=self.transfer_info, - y_from_calc=self.y_from_calc, - ) - graph = AtomGraphData.from_numpy_dict(graph) - graph_list.append(graph) - return super().__call__(graph_list) +from typing import Any, List, Optional, Sequence + +from ase.atoms import Atoms +from torch_geometric.loader.dataloader import Collater + +from sevenn.atom_graph_data import AtomGraphData + +from .dataload import atoms_to_graph + + +class AtomsToGraphCollater(Collater): + + def __init__( + self, + dataset: Sequence[Atoms], + cutoff: float, + transfer_info: bool = False, + follow_batch: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + y_from_calc: bool = True, + ): + # quite original collator's type mismatch with [] + super().__init__([], follow_batch, exclude_keys) + self.dataset = dataset + self.cutoff = cutoff + self.transfer_info = transfer_info + self.y_from_calc = y_from_calc + + def __call__(self, batch: List[Any]) -> Any: + # build list of graph + graph_list = [] + for stct in batch: + graph = atoms_to_graph( + stct, + self.cutoff, + transfer_info=self.transfer_info, + y_from_calc=self.y_from_calc, + ) + graph = AtomGraphData.from_numpy_dict(graph) + graph_list.append(graph) + return super().__call__(graph_list) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/dataload.py b/mace-bench/3rdparty/SevenNet/sevenn/train/dataload.py index cffa221a9480ea24f2d9ae5bd67ddd7b3d087b2f..48cdc8b3a667719fa0a235023bd64bb472fcca14 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/dataload.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/dataload.py @@ -1,609 +1,609 @@ -import copy -import os.path -from functools import partial -from itertools import chain, islice -from typing import Callable, Dict, List, Optional - -import ase -import ase.io -import numpy as np -import torch.multiprocessing as mp -from ase.io.vasp_parsers.vasp_outcar_parsers import ( - Cell, - DefaultParsersContainer, - Energy, - OutcarChunkParser, - PositionsAndForces, - Stress, - outcarchunks, -) -from ase.neighborlist import primitive_neighbor_list -from ase.utils import string2index -from braceexpand import braceexpand -from tqdm import tqdm - -import sevenn._keys as KEY -from sevenn._const import LossType -from sevenn.atom_graph_data import AtomGraphData - -from .dataset import AtomGraphDataset - - -def _graph_build_matscipy(cutoff: float, pbc, cell, pos): - pbc_x = pbc[0] - pbc_y = pbc[1] - pbc_z = pbc[2] - - identity = np.identity(3, dtype=float) - max_positions = np.max(np.absolute(pos)) + 1 - - # Extend cell in non-periodic directions - # For models with more than 5 layers, - # the multiplicative constant needs to be increased. - if not pbc_x: - cell[0, :] = max_positions * 5 * cutoff * identity[0, :] - if not pbc_y: - cell[1, :] = max_positions * 5 * cutoff * identity[1, :] - if not pbc_z: - cell[2, :] = max_positions * 5 * cutoff * identity[2, :] - # it does not have self-interaction - edge_src, edge_dst, edge_vec, shifts = neighbour_list( - quantities='ijDS', - pbc=pbc, - cell=cell, - positions=pos, - cutoff=cutoff, - ) - # dtype issue - edge_src = edge_src.astype(np.int64) - edge_dst = edge_dst.astype(np.int64) - - return edge_src, edge_dst, edge_vec, shifts - - -def _graph_build_ase(cutoff: float, pbc, cell, pos): - # building neighbor list - edge_src, edge_dst, edge_vec, shifts = primitive_neighbor_list( - 'ijDS', pbc, cell, pos, cutoff, self_interaction=True - ) - - is_zero_idx = np.all(edge_vec == 0, axis=1) - is_self_idx = edge_src == edge_dst - non_trivials = ~(is_zero_idx & is_self_idx) - shifts = np.array(shifts[non_trivials]) - - edge_vec = edge_vec[non_trivials] - edge_src = edge_src[non_trivials] - edge_dst = edge_dst[non_trivials] - - return edge_src, edge_dst, edge_vec, shifts - - -_graph_build_f = _graph_build_ase -try: - from matscipy.neighbours import neighbour_list - - _graph_build_f = _graph_build_matscipy -except ImportError: - pass - - -def _correct_scalar(v): - if isinstance(v, np.ndarray): - v = v.squeeze() - assert v.ndim == 0, f'given {v} is not a scalar' - return v - elif isinstance(v, (int, float, np.integer, np.floating)): - return np.array(v) - else: - assert False, f'{type(v)} is not expected' - - -def unlabeled_atoms_to_graph(atoms: ase.Atoms, cutoff: float): - pos = atoms.get_positions() - cell = np.array(atoms.get_cell()) - pbc = atoms.get_pbc() - - edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos) - - edge_idx = np.array([edge_src, edge_dst]) - - atomic_numbers = atoms.get_atomic_numbers() - - cell = np.array(cell) - vol = _correct_scalar(atoms.cell.volume) - if vol == 0: - vol = np.array(np.finfo(float).eps) - - data = { - KEY.NODE_FEATURE: atomic_numbers, - KEY.ATOMIC_NUMBERS: atomic_numbers, - KEY.POS: pos, - KEY.EDGE_IDX: edge_idx, - KEY.EDGE_VEC: edge_vec, - KEY.CELL: cell, - KEY.CELL_SHIFT: shifts, - KEY.CELL_VOLUME: vol, - KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)), - } - data[KEY.INFO] = {} - return data - - -def atoms_to_graph( - atoms: ase.Atoms, - cutoff: float, - transfer_info: bool = True, - y_from_calc: bool = False, - allow_unlabeled: bool = False, -): - """ - From ase atoms, return AtomGraphData as graph based on cutoff radius - Except for energy, force and stress labels must be numpy array type - as other cases are not tested. - Returns 'np.nan' with consistent shape for unlabeled data - (ex. stress of non-pbc system) - - Args: - atoms (Atoms): ase atoms - cutoff (float): cutoff radius - transfer_info (bool): if True, transfer ".info" from atoms to graph, - defaults to True - y_from_calc: if True, get ref values from calculator, defaults to False - Returns: - numpy dict that can be used to initialize AtomGraphData - by AtomGraphData(**atoms_to_graph(atoms, cutoff)) - , for scalar, its shape is (), and types are np.ndarray - Requires grad is handled by 'dataset' not here. - """ - if not y_from_calc: - y_energy = atoms.info['y_energy'] - y_force = atoms.arrays['y_force'] - y_stress = atoms.info.get('y_stress', np.full((6,), np.nan)) - if y_stress.shape == (3, 3): - y_stress = np.array( - [ - y_stress[0][0], - y_stress[1][1], - y_stress[2][2], - y_stress[0][1], - y_stress[1][2], - y_stress[2][0], - ] - ) - else: - y_stress = y_stress.squeeze() - else: - from_calc = _y_from_calc(atoms) - y_energy = from_calc['energy'] - y_force = from_calc['force'] - y_stress = from_calc['stress'] - assert y_stress.shape == (6,), 'If you see this, please raise a issue' - - if not allow_unlabeled and (np.isnan(y_energy) or np.isnan(y_force).any()): - raise ValueError('Unlabeled E or F found, set allow_unlabeled True') - - pos = atoms.get_positions() - cell = np.array(atoms.get_cell()) - pbc = atoms.get_pbc() - - edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos) - - edge_idx = np.array([edge_src, edge_dst]) - atomic_numbers = atoms.get_atomic_numbers() - - cell = np.array(cell) - vol = _correct_scalar(atoms.cell.volume) - if vol == 0: - vol = np.array(np.finfo(float).eps) - - data = { - KEY.NODE_FEATURE: atomic_numbers, - KEY.ATOMIC_NUMBERS: atomic_numbers, - KEY.POS: pos, - KEY.EDGE_IDX: edge_idx, - KEY.EDGE_VEC: edge_vec, - KEY.ENERGY: _correct_scalar(y_energy), - KEY.FORCE: y_force, - KEY.STRESS: y_stress.reshape(1, 6), # to make batch have (n_node, 6) - KEY.CELL: cell, - KEY.CELL_SHIFT: shifts, - KEY.CELL_VOLUME: vol, - KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)), - KEY.PER_ATOM_ENERGY: _correct_scalar(y_energy / len(pos)), - } - - if transfer_info and atoms.info is not None: - info = copy.deepcopy(atoms.info) - # save only metadata - info.pop('y_energy', None) - info.pop('y_force', None) - info.pop('y_stress', None) - data[KEY.INFO] = info - else: - data[KEY.INFO] = {} - - return data - - -def graph_build( - atoms_list: List, - cutoff: float, - num_cores: int = 1, - transfer_info: bool = True, - y_from_calc: bool = False, - allow_unlabeled: bool = False, -) -> List[AtomGraphData]: - """ - parallel version of graph_build - build graph from atoms_list and return list of AtomGraphData - Args: - atoms_list (List): list of ASE atoms - cutoff (float): cutoff radius of graph - num_cores (int): number of cores to use - transfer_info (bool): if True, copy info from atoms to graph, - defaults to True - y_from_calc (bool): Get reference y labels from calculator, defaults to False - Returns: - List[AtomGraphData]: list of AtomGraphData - """ - serial = num_cores == 1 - inputs = [ - (atoms, cutoff, transfer_info, y_from_calc, allow_unlabeled) - for atoms in atoms_list - ] - - if not serial: - pool = mp.Pool(num_cores) - graph_list = pool.starmap( - atoms_to_graph, - tqdm(inputs, total=len(atoms_list), desc=f'graph_build ({num_cores})'), - ) - pool.close() - pool.join() - else: - graph_list = [ - atoms_to_graph(*input_) - for input_ in tqdm(inputs, desc='graph_build (1)') - ] - - graph_list = [AtomGraphData.from_numpy_dict(g) for g in graph_list] - - return graph_list - - -def _y_from_calc(atoms: ase.Atoms): - ret = { - 'energy': np.nan, - 'force': np.full((len(atoms), 3), np.nan), - 'stress': np.full((6,), np.nan), - } - - if atoms.calc is None: - return ret - - try: - ret['energy'] = atoms.get_potential_energy(force_consistent=True) - except NotImplementedError: - ret['energy'] = atoms.get_potential_energy() - - try: - ret['force'] = atoms.get_forces(apply_constraint=False) - except NotImplementedError: - pass - - try: - y_stress = -1 * atoms.get_stress() # it ensures correct shape - ret['stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]]) - except RuntimeError: - pass - return ret - - -def _set_atoms_y( - atoms_list: List[ase.Atoms], - energy_key: Optional[str] = None, - force_key: Optional[str] = None, - stress_key: Optional[str] = None, -) -> List[ase.Atoms]: - """ - Define how SevenNet reads ASE.atoms object for its y label - If energy_key, force_key, or stress_key is given, the corresponding - label is obtained from .info dict of Atoms object. These values should - have eV, eV/Angstrom, and eV/Angstrom^3 for energy, force, and stress, - respectively. (stress in Voigt notation) - - Args: - atoms_list (list[ase.Atoms]): target atoms to set y_labels - energy_key (str, optional): key to get energy. Defaults to None. - force_key (str, optional): key to get force. Defaults to None. - stress_key (str, optional): key to get stress. Defaults to None. - - Returns: - list[ase.Atoms]: list of ase.Atoms - - Raises: - RuntimeError: if ase atoms are somewhat imperfect - - Use free_energy: atoms.get_potential_energy(force_consistent=True) - If it is not available, use atoms.get_potential_energy() - If stress is available, initialize stress tensor - Ignore constraints like selective dynamics - """ - for atoms in atoms_list: - from_calc = _y_from_calc(atoms) - if energy_key is not None: - atoms.info['y_energy'] = atoms.info.pop(energy_key) - else: - atoms.info['y_energy'] = from_calc['energy'] - - if force_key is not None: - atoms.arrays['y_force'] = atoms.arrays.pop(force_key) - else: - atoms.arrays['y_force'] = from_calc['force'] - - if stress_key is not None: - y_stress = -1 * atoms.info.pop(stress_key) - atoms.info['y_stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]]) - else: - atoms.info['y_stress'] = from_calc['stress'] - - return atoms_list - - -def ase_reader( - filename: str, - energy_key: Optional[str] = None, - force_key: Optional[str] = None, - stress_key: Optional[str] = None, - index: str = ':', - **kwargs, -) -> List[ase.Atoms]: - """ - Wrapper of ase.io.read - """ - atoms_list = ase.io.read(filename, index=index, **kwargs) - if not isinstance(atoms_list, list): - atoms_list = [atoms_list] - - return _set_atoms_y(atoms_list, energy_key, force_key, stress_key) - - -# Reader -def structure_list_reader(filename: str, format_outputs: Optional[str] = None): - """ - Read from structure_list using braceexpand and ASE - - Args: - fname : filename of structure_list - - Returns: - dictionary of lists of ASE structures. - key is title of training data (user-define) - """ - parsers = DefaultParsersContainer( - PositionsAndForces, Stress, Energy, Cell - ).make_parsers() - ocp = OutcarChunkParser(parsers=parsers) - - def parse_label(line): - line = line.strip() - if line.startswith('[') is False: - return False - elif line.endswith(']') is False: - raise ValueError('wrong structure_list title format') - return line[1:-1] - - def parse_fileline(line): - line = line.strip().split() - if len(line) == 1: - line.append(':') - elif len(line) != 2: - raise ValueError('wrong structure_list format') - return line[0], line[1] - - structure_list_file = open(filename, 'r') - lines = structure_list_file.readlines() - - raw_str_dict = {} - label = 'Default' - for line in lines: - if line.strip() == '': - continue - tmp_label = parse_label(line) - if tmp_label: - label = tmp_label - raw_str_dict[label] = [] - continue - elif label in raw_str_dict: - files_expr, index_expr = parse_fileline(line) - raw_str_dict[label].append((files_expr, index_expr)) - else: - raise ValueError('wrong structure_list format') - structure_list_file.close() - - structures_dict = {} - info_dct = {'data_from': 'user_OUTCAR'} - for title, file_lines in raw_str_dict.items(): - stct_lists = [] - for file_line in file_lines: - files_expr, index_expr = file_line - index = string2index(index_expr) - for expanded_filename in list(braceexpand(files_expr)): - f_stream = open(expanded_filename, 'r') - # generator of all outcar ionic steps - gen_all = outcarchunks(f_stream, ocp) - try: # TODO: index may not slice, it can be integer - it_atoms = islice(gen_all, index.start, index.stop, index.step) - except ValueError: - # TODO: support - # negative index - raise ValueError('Negative index is not supported yet') - - info_dct_f = { - **info_dct, - 'file': os.path.abspath(expanded_filename), - } - for idx, o in enumerate(it_atoms): - try: - it_atoms = islice( - gen_all, index.start, index.stop, index.step - ) - except ValueError: - # TODO: support - # negative index - raise ValueError('Negative index is not supported yet') - - info_dct_f = { - **info_dct, - 'file': os.path.abspath(expanded_filename), - } - for idx, o in enumerate(it_atoms): - try: - istep = index.start + idx * index.step # type: ignore - atoms = o.build() - atoms.info = {**info_dct_f, 'ionic_step': istep}.copy() - except TypeError: # it is not slice of ionic steps - atoms = o.build() - atoms.info = info_dct_f.copy() - stct_lists.append(atoms) - f_stream.close() - else: - stct_lists += ase.io.read( - expanded_filename, - index=index_expr, - parallel=False, - ) - structures_dict[title] = stct_lists - return {k: _set_atoms_y(v) for k, v in structures_dict.items()} - - -def dict_reader(data_dict: Dict): - data_dict_cp = copy.deepcopy(data_dict) - - ret = [] - file_list = data_dict_cp.pop('file_list', None) - if file_list is None: - raise KeyError('file_list is not found') - - data_weight_default = { - 'energy': 1.0, - 'force': 1.0, - 'stress': 1.0, - } - data_weight = data_weight_default.copy() - data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {})) - - for file_dct in file_list: - ftype = file_dct.pop('data_format', 'ase') - files = list(braceexpand(file_dct.pop('file'))) - if ftype == 'ase': - ret.extend(chain(*[ase_reader(f, **file_dct) for f in files])) - elif ftype == 'graph': - continue - else: - raise ValueError(f'{ftype} yet') - - for atoms in ret: - atoms.info.update(data_dict_cp) - atoms.info.update({KEY.DATA_WEIGHT: data_weight}) - return _set_atoms_y(ret) - - -def match_reader(reader_name: str, **kwargs): - reader = None - metadata = {} - if reader_name == 'structure_list': - reader = partial(structure_list_reader, **kwargs) - metadata.update({'origin': 'structure_list'}) - else: - reader = partial(ase_reader, **kwargs) - metadata.update({'origin': 'ase_reader'}) - return reader, metadata - - -def file_to_dataset( - file: str, - cutoff: float, - cores: int = 1, - reader: Callable = ase_reader, - label: Optional[str] = None, - transfer_info: bool = True, - use_weight: bool = False, - use_modality: bool = False, -): - """ - Deprecated - Read file by reader > get list of atoms or dict of atoms - """ - - # expect label: atoms_list dct or atoms or list of atoms - atoms = reader(file) - - if type(atoms) is list: - if label is None: - label = KEY.LABEL_NONE - atoms_dct = {label: atoms} - elif isinstance(atoms, ase.Atoms): - if label is None: - label = KEY.LABEL_NONE - atoms_dct = {label: [atoms]} - elif isinstance(atoms, dict): - atoms_dct = atoms - else: - raise TypeError('The return of reader is not list or dict') - - graph_dct = {} - for label, atoms_list in atoms_dct.items(): - graph_list = graph_build( - atoms_list=atoms_list, - cutoff=cutoff, - num_cores=cores, - transfer_info=transfer_info, - y_from_calc=False, - ) - - label_info = label.split(':') - for graph in graph_list: - graph[KEY.USER_LABEL] = label_info[0].strip() - if use_weight: - find_weight = False - for info in label_info[1:]: - if 'w=' in info.lower(): - weights = info.split('=')[1] - try: - if ',' in weights: - weight_list = list(map(float, weights.split(','))) - else: - weight_list = [float(weights)] * 3 - weight_dict = {} - for idx, loss_type in enumerate(LossType): - weight_dict[loss_type.value] = ( - weight_list[idx] if idx < len(weight_list) else 1 - ) - graph[KEY.DATA_WEIGHT] = weight_dict - find_weight = True - break - except: - raise ValueError( - 'Weight must be a real number, but' - f' {weights} is given for {label}' - ) - if not find_weight: - weight_dict = {} - for loss_type in LossType: - weight_dict[loss_type.value] = 1 - graph[KEY.DATA_WEIGHT] = weight_dict - if use_modality: - find_modality = False - for info in label_info[1:]: - if 'm=' in info.lower(): - graph[KEY.DATA_MODALITY] = (info.split('=')[1]).strip() - find_modality = True - break - if not find_modality: - raise ValueError(f'Modality not given for {label}') - - graph_dct[label_info[0].strip()] = graph_list - db = AtomGraphDataset(graph_dct, cutoff) - return db +import copy +import os.path +from functools import partial +from itertools import chain, islice +from typing import Callable, Dict, List, Optional + +import ase +import ase.io +import numpy as np +import torch.multiprocessing as mp +from ase.io.vasp_parsers.vasp_outcar_parsers import ( + Cell, + DefaultParsersContainer, + Energy, + OutcarChunkParser, + PositionsAndForces, + Stress, + outcarchunks, +) +from ase.neighborlist import primitive_neighbor_list +from ase.utils import string2index +from braceexpand import braceexpand +from tqdm import tqdm + +import sevenn._keys as KEY +from sevenn._const import LossType +from sevenn.atom_graph_data import AtomGraphData + +from .dataset import AtomGraphDataset + + +def _graph_build_matscipy(cutoff: float, pbc, cell, pos): + pbc_x = pbc[0] + pbc_y = pbc[1] + pbc_z = pbc[2] + + identity = np.identity(3, dtype=float) + max_positions = np.max(np.absolute(pos)) + 1 + + # Extend cell in non-periodic directions + # For models with more than 5 layers, + # the multiplicative constant needs to be increased. + if not pbc_x: + cell[0, :] = max_positions * 5 * cutoff * identity[0, :] + if not pbc_y: + cell[1, :] = max_positions * 5 * cutoff * identity[1, :] + if not pbc_z: + cell[2, :] = max_positions * 5 * cutoff * identity[2, :] + # it does not have self-interaction + edge_src, edge_dst, edge_vec, shifts = neighbour_list( + quantities='ijDS', + pbc=pbc, + cell=cell, + positions=pos, + cutoff=cutoff, + ) + # dtype issue + edge_src = edge_src.astype(np.int64) + edge_dst = edge_dst.astype(np.int64) + + return edge_src, edge_dst, edge_vec, shifts + + +def _graph_build_ase(cutoff: float, pbc, cell, pos): + # building neighbor list + edge_src, edge_dst, edge_vec, shifts = primitive_neighbor_list( + 'ijDS', pbc, cell, pos, cutoff, self_interaction=True + ) + + is_zero_idx = np.all(edge_vec == 0, axis=1) + is_self_idx = edge_src == edge_dst + non_trivials = ~(is_zero_idx & is_self_idx) + shifts = np.array(shifts[non_trivials]) + + edge_vec = edge_vec[non_trivials] + edge_src = edge_src[non_trivials] + edge_dst = edge_dst[non_trivials] + + return edge_src, edge_dst, edge_vec, shifts + + +_graph_build_f = _graph_build_ase +try: + from matscipy.neighbours import neighbour_list + + _graph_build_f = _graph_build_matscipy +except ImportError: + pass + + +def _correct_scalar(v): + if isinstance(v, np.ndarray): + v = v.squeeze() + assert v.ndim == 0, f'given {v} is not a scalar' + return v + elif isinstance(v, (int, float, np.integer, np.floating)): + return np.array(v) + else: + assert False, f'{type(v)} is not expected' + + +def unlabeled_atoms_to_graph(atoms: ase.Atoms, cutoff: float): + pos = atoms.get_positions() + cell = np.array(atoms.get_cell()) + pbc = atoms.get_pbc() + + edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos) + + edge_idx = np.array([edge_src, edge_dst]) + + atomic_numbers = atoms.get_atomic_numbers() + + cell = np.array(cell) + vol = _correct_scalar(atoms.cell.volume) + if vol == 0: + vol = np.array(np.finfo(float).eps) + + data = { + KEY.NODE_FEATURE: atomic_numbers, + KEY.ATOMIC_NUMBERS: atomic_numbers, + KEY.POS: pos, + KEY.EDGE_IDX: edge_idx, + KEY.EDGE_VEC: edge_vec, + KEY.CELL: cell, + KEY.CELL_SHIFT: shifts, + KEY.CELL_VOLUME: vol, + KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)), + } + data[KEY.INFO] = {} + return data + + +def atoms_to_graph( + atoms: ase.Atoms, + cutoff: float, + transfer_info: bool = True, + y_from_calc: bool = False, + allow_unlabeled: bool = False, +): + """ + From ase atoms, return AtomGraphData as graph based on cutoff radius + Except for energy, force and stress labels must be numpy array type + as other cases are not tested. + Returns 'np.nan' with consistent shape for unlabeled data + (ex. stress of non-pbc system) + + Args: + atoms (Atoms): ase atoms + cutoff (float): cutoff radius + transfer_info (bool): if True, transfer ".info" from atoms to graph, + defaults to True + y_from_calc: if True, get ref values from calculator, defaults to False + Returns: + numpy dict that can be used to initialize AtomGraphData + by AtomGraphData(**atoms_to_graph(atoms, cutoff)) + , for scalar, its shape is (), and types are np.ndarray + Requires grad is handled by 'dataset' not here. + """ + if not y_from_calc: + y_energy = atoms.info['y_energy'] + y_force = atoms.arrays['y_force'] + y_stress = atoms.info.get('y_stress', np.full((6,), np.nan)) + if y_stress.shape == (3, 3): + y_stress = np.array( + [ + y_stress[0][0], + y_stress[1][1], + y_stress[2][2], + y_stress[0][1], + y_stress[1][2], + y_stress[2][0], + ] + ) + else: + y_stress = y_stress.squeeze() + else: + from_calc = _y_from_calc(atoms) + y_energy = from_calc['energy'] + y_force = from_calc['force'] + y_stress = from_calc['stress'] + assert y_stress.shape == (6,), 'If you see this, please raise a issue' + + if not allow_unlabeled and (np.isnan(y_energy) or np.isnan(y_force).any()): + raise ValueError('Unlabeled E or F found, set allow_unlabeled True') + + pos = atoms.get_positions() + cell = np.array(atoms.get_cell()) + pbc = atoms.get_pbc() + + edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos) + + edge_idx = np.array([edge_src, edge_dst]) + atomic_numbers = atoms.get_atomic_numbers() + + cell = np.array(cell) + vol = _correct_scalar(atoms.cell.volume) + if vol == 0: + vol = np.array(np.finfo(float).eps) + + data = { + KEY.NODE_FEATURE: atomic_numbers, + KEY.ATOMIC_NUMBERS: atomic_numbers, + KEY.POS: pos, + KEY.EDGE_IDX: edge_idx, + KEY.EDGE_VEC: edge_vec, + KEY.ENERGY: _correct_scalar(y_energy), + KEY.FORCE: y_force, + KEY.STRESS: y_stress.reshape(1, 6), # to make batch have (n_node, 6) + KEY.CELL: cell, + KEY.CELL_SHIFT: shifts, + KEY.CELL_VOLUME: vol, + KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)), + KEY.PER_ATOM_ENERGY: _correct_scalar(y_energy / len(pos)), + } + + if transfer_info and atoms.info is not None: + info = copy.deepcopy(atoms.info) + # save only metadata + info.pop('y_energy', None) + info.pop('y_force', None) + info.pop('y_stress', None) + data[KEY.INFO] = info + else: + data[KEY.INFO] = {} + + return data + + +def graph_build( + atoms_list: List, + cutoff: float, + num_cores: int = 1, + transfer_info: bool = True, + y_from_calc: bool = False, + allow_unlabeled: bool = False, +) -> List[AtomGraphData]: + """ + parallel version of graph_build + build graph from atoms_list and return list of AtomGraphData + Args: + atoms_list (List): list of ASE atoms + cutoff (float): cutoff radius of graph + num_cores (int): number of cores to use + transfer_info (bool): if True, copy info from atoms to graph, + defaults to True + y_from_calc (bool): Get reference y labels from calculator, defaults to False + Returns: + List[AtomGraphData]: list of AtomGraphData + """ + serial = num_cores == 1 + inputs = [ + (atoms, cutoff, transfer_info, y_from_calc, allow_unlabeled) + for atoms in atoms_list + ] + + if not serial: + pool = mp.Pool(num_cores) + graph_list = pool.starmap( + atoms_to_graph, + tqdm(inputs, total=len(atoms_list), desc=f'graph_build ({num_cores})'), + ) + pool.close() + pool.join() + else: + graph_list = [ + atoms_to_graph(*input_) + for input_ in tqdm(inputs, desc='graph_build (1)') + ] + + graph_list = [AtomGraphData.from_numpy_dict(g) for g in graph_list] + + return graph_list + + +def _y_from_calc(atoms: ase.Atoms): + ret = { + 'energy': np.nan, + 'force': np.full((len(atoms), 3), np.nan), + 'stress': np.full((6,), np.nan), + } + + if atoms.calc is None: + return ret + + try: + ret['energy'] = atoms.get_potential_energy(force_consistent=True) + except NotImplementedError: + ret['energy'] = atoms.get_potential_energy() + + try: + ret['force'] = atoms.get_forces(apply_constraint=False) + except NotImplementedError: + pass + + try: + y_stress = -1 * atoms.get_stress() # it ensures correct shape + ret['stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]]) + except RuntimeError: + pass + return ret + + +def _set_atoms_y( + atoms_list: List[ase.Atoms], + energy_key: Optional[str] = None, + force_key: Optional[str] = None, + stress_key: Optional[str] = None, +) -> List[ase.Atoms]: + """ + Define how SevenNet reads ASE.atoms object for its y label + If energy_key, force_key, or stress_key is given, the corresponding + label is obtained from .info dict of Atoms object. These values should + have eV, eV/Angstrom, and eV/Angstrom^3 for energy, force, and stress, + respectively. (stress in Voigt notation) + + Args: + atoms_list (list[ase.Atoms]): target atoms to set y_labels + energy_key (str, optional): key to get energy. Defaults to None. + force_key (str, optional): key to get force. Defaults to None. + stress_key (str, optional): key to get stress. Defaults to None. + + Returns: + list[ase.Atoms]: list of ase.Atoms + + Raises: + RuntimeError: if ase atoms are somewhat imperfect + + Use free_energy: atoms.get_potential_energy(force_consistent=True) + If it is not available, use atoms.get_potential_energy() + If stress is available, initialize stress tensor + Ignore constraints like selective dynamics + """ + for atoms in atoms_list: + from_calc = _y_from_calc(atoms) + if energy_key is not None: + atoms.info['y_energy'] = atoms.info.pop(energy_key) + else: + atoms.info['y_energy'] = from_calc['energy'] + + if force_key is not None: + atoms.arrays['y_force'] = atoms.arrays.pop(force_key) + else: + atoms.arrays['y_force'] = from_calc['force'] + + if stress_key is not None: + y_stress = -1 * atoms.info.pop(stress_key) + atoms.info['y_stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]]) + else: + atoms.info['y_stress'] = from_calc['stress'] + + return atoms_list + + +def ase_reader( + filename: str, + energy_key: Optional[str] = None, + force_key: Optional[str] = None, + stress_key: Optional[str] = None, + index: str = ':', + **kwargs, +) -> List[ase.Atoms]: + """ + Wrapper of ase.io.read + """ + atoms_list = ase.io.read(filename, index=index, **kwargs) + if not isinstance(atoms_list, list): + atoms_list = [atoms_list] + + return _set_atoms_y(atoms_list, energy_key, force_key, stress_key) + + +# Reader +def structure_list_reader(filename: str, format_outputs: Optional[str] = None): + """ + Read from structure_list using braceexpand and ASE + + Args: + fname : filename of structure_list + + Returns: + dictionary of lists of ASE structures. + key is title of training data (user-define) + """ + parsers = DefaultParsersContainer( + PositionsAndForces, Stress, Energy, Cell + ).make_parsers() + ocp = OutcarChunkParser(parsers=parsers) + + def parse_label(line): + line = line.strip() + if line.startswith('[') is False: + return False + elif line.endswith(']') is False: + raise ValueError('wrong structure_list title format') + return line[1:-1] + + def parse_fileline(line): + line = line.strip().split() + if len(line) == 1: + line.append(':') + elif len(line) != 2: + raise ValueError('wrong structure_list format') + return line[0], line[1] + + structure_list_file = open(filename, 'r') + lines = structure_list_file.readlines() + + raw_str_dict = {} + label = 'Default' + for line in lines: + if line.strip() == '': + continue + tmp_label = parse_label(line) + if tmp_label: + label = tmp_label + raw_str_dict[label] = [] + continue + elif label in raw_str_dict: + files_expr, index_expr = parse_fileline(line) + raw_str_dict[label].append((files_expr, index_expr)) + else: + raise ValueError('wrong structure_list format') + structure_list_file.close() + + structures_dict = {} + info_dct = {'data_from': 'user_OUTCAR'} + for title, file_lines in raw_str_dict.items(): + stct_lists = [] + for file_line in file_lines: + files_expr, index_expr = file_line + index = string2index(index_expr) + for expanded_filename in list(braceexpand(files_expr)): + f_stream = open(expanded_filename, 'r') + # generator of all outcar ionic steps + gen_all = outcarchunks(f_stream, ocp) + try: # TODO: index may not slice, it can be integer + it_atoms = islice(gen_all, index.start, index.stop, index.step) + except ValueError: + # TODO: support + # negative index + raise ValueError('Negative index is not supported yet') + + info_dct_f = { + **info_dct, + 'file': os.path.abspath(expanded_filename), + } + for idx, o in enumerate(it_atoms): + try: + it_atoms = islice( + gen_all, index.start, index.stop, index.step + ) + except ValueError: + # TODO: support + # negative index + raise ValueError('Negative index is not supported yet') + + info_dct_f = { + **info_dct, + 'file': os.path.abspath(expanded_filename), + } + for idx, o in enumerate(it_atoms): + try: + istep = index.start + idx * index.step # type: ignore + atoms = o.build() + atoms.info = {**info_dct_f, 'ionic_step': istep}.copy() + except TypeError: # it is not slice of ionic steps + atoms = o.build() + atoms.info = info_dct_f.copy() + stct_lists.append(atoms) + f_stream.close() + else: + stct_lists += ase.io.read( + expanded_filename, + index=index_expr, + parallel=False, + ) + structures_dict[title] = stct_lists + return {k: _set_atoms_y(v) for k, v in structures_dict.items()} + + +def dict_reader(data_dict: Dict): + data_dict_cp = copy.deepcopy(data_dict) + + ret = [] + file_list = data_dict_cp.pop('file_list', None) + if file_list is None: + raise KeyError('file_list is not found') + + data_weight_default = { + 'energy': 1.0, + 'force': 1.0, + 'stress': 1.0, + } + data_weight = data_weight_default.copy() + data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {})) + + for file_dct in file_list: + ftype = file_dct.pop('data_format', 'ase') + files = list(braceexpand(file_dct.pop('file'))) + if ftype == 'ase': + ret.extend(chain(*[ase_reader(f, **file_dct) for f in files])) + elif ftype == 'graph': + continue + else: + raise ValueError(f'{ftype} yet') + + for atoms in ret: + atoms.info.update(data_dict_cp) + atoms.info.update({KEY.DATA_WEIGHT: data_weight}) + return _set_atoms_y(ret) + + +def match_reader(reader_name: str, **kwargs): + reader = None + metadata = {} + if reader_name == 'structure_list': + reader = partial(structure_list_reader, **kwargs) + metadata.update({'origin': 'structure_list'}) + else: + reader = partial(ase_reader, **kwargs) + metadata.update({'origin': 'ase_reader'}) + return reader, metadata + + +def file_to_dataset( + file: str, + cutoff: float, + cores: int = 1, + reader: Callable = ase_reader, + label: Optional[str] = None, + transfer_info: bool = True, + use_weight: bool = False, + use_modality: bool = False, +): + """ + Deprecated + Read file by reader > get list of atoms or dict of atoms + """ + + # expect label: atoms_list dct or atoms or list of atoms + atoms = reader(file) + + if type(atoms) is list: + if label is None: + label = KEY.LABEL_NONE + atoms_dct = {label: atoms} + elif isinstance(atoms, ase.Atoms): + if label is None: + label = KEY.LABEL_NONE + atoms_dct = {label: [atoms]} + elif isinstance(atoms, dict): + atoms_dct = atoms + else: + raise TypeError('The return of reader is not list or dict') + + graph_dct = {} + for label, atoms_list in atoms_dct.items(): + graph_list = graph_build( + atoms_list=atoms_list, + cutoff=cutoff, + num_cores=cores, + transfer_info=transfer_info, + y_from_calc=False, + ) + + label_info = label.split(':') + for graph in graph_list: + graph[KEY.USER_LABEL] = label_info[0].strip() + if use_weight: + find_weight = False + for info in label_info[1:]: + if 'w=' in info.lower(): + weights = info.split('=')[1] + try: + if ',' in weights: + weight_list = list(map(float, weights.split(','))) + else: + weight_list = [float(weights)] * 3 + weight_dict = {} + for idx, loss_type in enumerate(LossType): + weight_dict[loss_type.value] = ( + weight_list[idx] if idx < len(weight_list) else 1 + ) + graph[KEY.DATA_WEIGHT] = weight_dict + find_weight = True + break + except: + raise ValueError( + 'Weight must be a real number, but' + f' {weights} is given for {label}' + ) + if not find_weight: + weight_dict = {} + for loss_type in LossType: + weight_dict[loss_type.value] = 1 + graph[KEY.DATA_WEIGHT] = weight_dict + if use_modality: + find_modality = False + for info in label_info[1:]: + if 'm=' in info.lower(): + graph[KEY.DATA_MODALITY] = (info.split('=')[1]).strip() + find_modality = True + break + if not find_modality: + raise ValueError(f'Modality not given for {label}') + + graph_dct[label_info[0].strip()] = graph_list + db = AtomGraphDataset(graph_dct, cutoff) + return db diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/dataset.py b/mace-bench/3rdparty/SevenNet/sevenn/train/dataset.py index 3c31c55530ae4a37ab0bfd3cb699fce6e38ba3af..ccf04df9da9733a0fccfd62f98f3d6c07a18f431 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/dataset.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/dataset.py @@ -1,496 +1,496 @@ -import itertools -import random -from collections import Counter -from typing import Callable, Dict, List, Optional, Union - -import numpy as np -import torch -from ase.data import chemical_symbols -from sklearn.linear_model import Ridge - -import sevenn._keys as KEY -import sevenn.util as util - - -class AtomGraphDataset: - """ - Deprecated - - class representing dataset of AtomGraphData - the dataset is handled as dict, {label: data} - if given data is List, it stores data as {KEY_DEFAULT: data} - - cutoff is for metadata of the graphs not used for some calc - Every data expected to have one unique cutoff - No validity or check of the condition is done inside the object - - attribute: - dataset (Dict[str, List]): key is data label(str), value is list of data - user_labels (List[str]): list of user labels same as dataset.keys() - meta (Dict, Optional): metadata of dataset - for now, metadata 'might' have following keys: - KEY.CUTOFF (float), KEY.CHEMICAL_SPECIES (Dict) - """ - - DATA_KEY_X = ( - KEY.NODE_FEATURE - ) # atomic_number > one_hot_idx > one_hot_vector - DATA_KEY_ENERGY = KEY.ENERGY - DATA_KEY_FORCE = KEY.FORCE - KEY_DEFAULT = KEY.LABEL_NONE - - def __init__( - self, - dataset: Union[Dict[str, List], List], - cutoff: float, - metadata: Optional[Dict] = None, - x_is_one_hot_idx: bool = False, - ): - """ - Default constructor of AtomGraphDataset - Args: - dataset (Union[Dict[str, List], List]: dataset as dict or pure list - metadata (Dict, Optional): metadata of data - cutoff (float): cutoff radius of graphs inside the dataset - x_is_one_hot_idx (bool): if True, x is one_hot_idx, else 'Z' - - 'x' (node feature) of dataset can have 3 states, atomic_numbers, - one_hot_idx, or one_hot_vector. - - atomic_numbers is general but cannot directly used for input - one_hot_idx is can be input of the model but requires 'type_map' - """ - self.cutoff = cutoff - self.x_is_one_hot_idx = x_is_one_hot_idx - if metadata is None: - metadata = {KEY.CUTOFF: cutoff} - self.meta = metadata - if type(dataset) is list: - self.dataset = {self.KEY_DEFAULT: dataset} - else: - self.dataset = dataset - self.user_labels = list(self.dataset.keys()) - # group_by_key here? or not? - - def rewrite_labels_to_data(self): - """ - Based on self.dataset dict's keys - write data[KEY.USER_LABEL] to correspond to dict's keys - Most of times, it is already correctly written - But required to rewrite if someone rearrange dataset by their own way - """ - for label, data_list in self.dataset.items(): - for data in data_list: - data[KEY.USER_LABEL] = label - - def group_by_key(self, data_key: str = KEY.USER_LABEL): - """ - group dataset list by given key and save it as dict - and change in-place - Args: - data_key (str): data key to group by - - original use is USER_LABEL, but it can be used for other keys - if someone established it from data[KEY.INFO] - """ - data_list = self.to_list() - self.dataset = {} - for datum in data_list: - key = datum[data_key] - if key not in self.dataset: - self.dataset[key] = [] - self.dataset[key].append(datum) - self.user_labels = list(self.dataset.keys()) - - def separate_info(self, data_key: str = KEY.INFO): - """ - Separate info from data and save it as list of dict - to make it compatible with torch_geometric and later training - """ - data_list = self.to_list() - info_list = [] - for datum in data_list: - if data_key in datum is False: - continue - info_list.append(datum[data_key]) - del datum[data_key] # It does change the self.dataset - datum[data_key] = len(info_list) - 1 - self.info_list = info_list - - return (data_list, info_list) - - def get_species(self): - """ - You can also use get_natoms and extract keys from there instead of this - (And it is more efficient) - get chemical species of dataset - return list of SORTED chemical species (as str) - """ - if hasattr(self, 'type_map'): - natoms = self.get_natoms(self.type_map) - else: - natoms = self.get_natoms() - species = set() - for natom_dct in natoms.values(): - species.update(natom_dct.keys()) - species = sorted(list(species)) - return species - - def get_modalities(self): - modalities = set() - for data_list in self.dataset.values(): - datum = data_list[0].to_dict() - if KEY.DATA_MODALITY in datum.keys(): - modalities.add(datum[KEY.DATA_MODALITY]) - else: - return [] - return list(modalities) - - def write_modal_attr( - self, modal_type_mapper: dict, write_modal_type: bool = False - ): - num_modalities = len(modal_type_mapper) - for data_list in self.dataset.values(): - for data in data_list: - tmp_tensor = torch.zeros(num_modalities) - if data[KEY.DATA_MODALITY] != 'common': - modal_idx = modal_type_mapper[data[KEY.DATA_MODALITY]] - tmp_tensor[modal_idx] = 1.0 - if write_modal_type: - data[KEY.MODAL_TYPE] = modal_idx - data[KEY.MODAL_ATTR] = tmp_tensor - - def get_dict_sort_by_modality(self): - dict_sort_by_modality = {} - for data_list in self.dataset.values(): - try: - modal_key = data_list[0].to_dict()[KEY.DATA_MODALITY] - except: # Dataset is not modal - raise ValueError('This dataset has no modality.') - - if modal_key not in dict_sort_by_modality.keys(): - dict_sort_by_modality[modal_key] = [] - dict_sort_by_modality[modal_key].extend(data_list) - - return dict_sort_by_modality - - def len(self): - if ( - len(self.dataset.keys()) == 1 - and list(self.dataset.keys())[0] == AtomGraphDataset.KEY_DEFAULT - ): - return len(self.dataset[AtomGraphDataset.KEY_DEFAULT]) - else: - return {k: len(v) for k, v in self.dataset.items()} - - def get(self, idx: int, key: Optional[str] = None): - if key is None: - key = self.KEY_DEFAULT - return self.dataset[key][idx] - - def items(self): - return self.dataset.items() - - def to_dict(self): - dct_dataset = {} - for label, data_list in self.dataset.items(): - dct_dataset[label] = [datum.to_dict() for datum in data_list] - self.dataset = dct_dataset - return self - - def x_to_one_hot_idx(self, type_map: Dict[int, int]): - """ - type_map is dict of {atomic_number: one_hot_idx} - after this process, the dataset has dependency on type_map - or chemical species user want to consider - """ - assert self.x_is_one_hot_idx is False - for data_list in self.dataset.values(): - for datum in data_list: - datum[self.DATA_KEY_X] = torch.LongTensor( - [type_map[z.item()] for z in datum[self.DATA_KEY_X]] - ) - self.type_map = type_map - self.x_is_one_hot_idx = True - - def toggle_requires_grad_of_data( - self, key: str, requires_grad_value: bool - ): - """ - set requires_grad of specific key of data(pos, edge_vec, ...) - """ - for data_list in self.dataset.values(): - for datum in data_list: - datum[key].requires_grad_(requires_grad_value) - - def divide_dataset( - self, - ratio: float, - constant_ratio_btw_labels: bool = True, - ignore_test: bool = True - ): - """ - divide dataset into 1-2*ratio : ratio : ratio - return divided AtomGraphDataset - returned value lost its dict key and became {KEY_DEFAULT: datalist} - but KEY.USER_LABEL of each data is preserved - """ - - def divide(ratio: float, data_list: List, ignore_test=True): - if ratio > 0.5: - raise ValueError('Ratio must not exceed 0.5') - data_len = len(data_list) - random.shuffle(data_list) - n_validation = int(data_len * ratio) - if n_validation == 0: - raise ValueError( - '# of validation set is 0, increase your dataset' - ) - - if ignore_test: - test_list = [] - n_train = data_len - n_validation - train_list = data_list[0:n_train] - valid_list = data_list[n_train:] - else: - n_train = data_len - 2 * n_validation - train_list = data_list[0:n_train] - valid_list = data_list[n_train : n_train + n_validation] - test_list = data_list[n_train + n_validation : data_len] - return train_list, valid_list, test_list - - lists = ([], [], []) # train, valid, test - if constant_ratio_btw_labels: - for data_list in self.dataset.values(): - for store, divided in zip(lists, divide(ratio, data_list)): - store.extend(divided) - else: - lists = divide(ratio, self.to_list()) - - dbs = tuple( - AtomGraphDataset(data, self.cutoff, self.meta) for data in lists - ) - for db in dbs: - db.group_by_key() - return dbs - - def to_list(self): - return list(itertools.chain(*self.dataset.values())) - - def get_natoms(self, type_map: Optional[Dict[int, int]] = None): - """ - if x_is_one_hot_idx, type_map is required - type_map: Z->one_hot_index(node_feature) - return Dict{label: {symbol, natom}]} - """ - assert not (self.x_is_one_hot_idx is True and type_map is None) - natoms = {} - for label, data in self.dataset.items(): - natoms[label] = Counter() - for datum in data: - if self.x_is_one_hot_idx and type_map is not None: - Zs = util.onehot_to_chem(datum[self.DATA_KEY_X], type_map) - else: - Zs = [ - chemical_symbols[z] - for z in datum[self.DATA_KEY_X].tolist() - ] - cnt = Counter(Zs) - natoms[label] += cnt - natoms[label] = dict(natoms[label]) - return natoms - - def get_per_atom_mean(self, key: str, key_num_atoms: str = KEY.NUM_ATOMS): - """ - return per_atom mean of given data key - """ - eng_list = torch.Tensor( - [x[key] / x[key_num_atoms] for x in self.to_list()] - ) - return float(torch.mean(eng_list)) - - def get_per_atom_energy_mean(self): - """ - alias for get_per_atom_mean(KEY.ENERGY) - """ - return self.get_per_atom_mean(self.DATA_KEY_ENERGY) - - def get_species_ref_energy_by_linear_comb(self, num_chem_species: int): - """ - Total energy as y, composition as c_i, - solve linear regression of y = c_i*X - sklearn LinearRegression as solver - - x should be one-hot-indexed - give num_chem_species if possible - """ - assert self.x_is_one_hot_idx is True - data_list = self.to_list() - - c = torch.zeros((len(data_list), num_chem_species)) - for idx, datum in enumerate(data_list): - c[idx] = torch.bincount( - datum[self.DATA_KEY_X], minlength=num_chem_species - ) - y = torch.Tensor([x[self.DATA_KEY_ENERGY] for x in data_list]) - c = c.numpy() - y = y.numpy() - - # tweak to fine tune training from many-element to small element - zero_indices = np.all(c == 0, axis=0) - c_reduced = c[:, ~zero_indices] - full_coeff = np.zeros(num_chem_species) - coef_reduced = ( - Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ - ) - full_coeff[~zero_indices] = coef_reduced - - return full_coeff - - def get_force_rms(self): - force_list = [] - for x in self.to_list(): - force_list.extend( - x[self.DATA_KEY_FORCE] - .reshape( - -1, - ) - .tolist() - ) - force_list = torch.Tensor(force_list) - return float(torch.sqrt(torch.mean(torch.pow(force_list, 2)))) - - def get_species_wise_force_rms(self, num_chem_species: int): - """ - Return force rms for each species - Averaged by each components (x, y, z) - """ - assert self.x_is_one_hot_idx is True - data_list = self.to_list() - - atomx = torch.concat([d[self.DATA_KEY_X] for d in data_list]) - force = torch.concat([d[self.DATA_KEY_FORCE] for d in data_list]) - - index = atomx.repeat_interleave(3, 0).reshape(force.shape) - rms = torch.zeros( - (num_chem_species, 3), - dtype=force.dtype, - device=force.device - ) - rms.scatter_reduce_( - 0, index, force.square(), - reduce='mean', include_self=False - ) - return torch.sqrt(rms.mean(dim=1)) - - def get_avg_num_neigh(self): - n_neigh = [] - for _, data_list in self.dataset.items(): - for data in data_list: - n_neigh.extend( - np.unique(data[KEY.EDGE_IDX][0], return_counts=True)[1] - ) - - avg_num_neigh = np.average(n_neigh) - return avg_num_neigh - - def get_statistics(self, key: str): - """ - return dict of statistics of given key (energy, force, stress) - key of dict is its label and _total for total statistics - value of dict is dict of statistics (mean, std, median, max, min) - """ - - def _get_statistic_dict(tensor_list): - data_list = torch.cat( - [ - tensor.reshape( - -1, - ) - for tensor in tensor_list - ] - ) - data_list = data_list[~torch.isnan(data_list)] - return { - 'mean': float(torch.mean(data_list)), - 'std': float(torch.std(data_list)), - 'median': float(torch.median(data_list)), - 'max': ( - torch.nan - if data_list.numel() == 0 - else float(torch.max(data_list)) - ), - 'min': ( - torch.nan - if data_list.numel() == 0 - else float(torch.min(data_list)) - ), - } - - res = {} - for label, values in self.dataset.items(): - # flatten list of torch.Tensor (values) - tensor_list = [x[key] for x in values] - res[label] = _get_statistic_dict(tensor_list) - tensor_list = [x[key] for x in self.to_list()] - res['Total'] = _get_statistic_dict(tensor_list) - return res - - def augment(self, dataset, validator: Optional[Callable] = None): - """check meta compatibility here - dataset(AtomGraphDataset): data to augment - validator(Callable, Optional): function(self, dataset) -> bool - - if validator is None, by default it checks - whether cutoff & chemical_species are same before augment - - check consistent data type, float, double, long integer etc - """ - - def default_validator(db1, db2): - cut_consis = db1.cutoff == db2.cutoff - # compare unordered lists - x_is_not_onehot = (not db1.x_is_one_hot_idx) and ( - not db2.x_is_one_hot_idx - ) - return cut_consis and x_is_not_onehot - - if validator is None: - validator = default_validator - if not validator(self, dataset): - raise ValueError('given datasets are not compatible check cutoffs') - for key, val in dataset.items(): - if key in self.dataset: - self.dataset[key].extend(val) - else: - self.dataset.update({key: val}) - self.user_labels = list(self.dataset.keys()) - - def unify_dtypes( - self, - float_dtype: torch.dtype = torch.float32, - int_dtype: torch.dtype = torch.int64 - ): - data_list = self.to_list() - for datum in data_list: - for k, v in list(datum.items()): - datum[k] = util.dtype_correct(v, float_dtype, int_dtype) - - def delete_data_key(self, key: str): - for data in self.to_list(): - del data[key] - - # TODO: this by_label is not straightforward - def save(self, path: str, by_label: bool = False): - if by_label: - for label, data in self.dataset.items(): - torch.save( - AtomGraphDataset( - {label: data}, self.cutoff, metadata=self.meta - ), - f'{path}/{label}.sevenn_data', - ) - else: - if path.endswith('.sevenn_data') is False: - path += '.sevenn_data' - torch.save(self, path) +import itertools +import random +from collections import Counter +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from ase.data import chemical_symbols +from sklearn.linear_model import Ridge + +import sevenn._keys as KEY +import sevenn.util as util + + +class AtomGraphDataset: + """ + Deprecated + + class representing dataset of AtomGraphData + the dataset is handled as dict, {label: data} + if given data is List, it stores data as {KEY_DEFAULT: data} + + cutoff is for metadata of the graphs not used for some calc + Every data expected to have one unique cutoff + No validity or check of the condition is done inside the object + + attribute: + dataset (Dict[str, List]): key is data label(str), value is list of data + user_labels (List[str]): list of user labels same as dataset.keys() + meta (Dict, Optional): metadata of dataset + for now, metadata 'might' have following keys: + KEY.CUTOFF (float), KEY.CHEMICAL_SPECIES (Dict) + """ + + DATA_KEY_X = ( + KEY.NODE_FEATURE + ) # atomic_number > one_hot_idx > one_hot_vector + DATA_KEY_ENERGY = KEY.ENERGY + DATA_KEY_FORCE = KEY.FORCE + KEY_DEFAULT = KEY.LABEL_NONE + + def __init__( + self, + dataset: Union[Dict[str, List], List], + cutoff: float, + metadata: Optional[Dict] = None, + x_is_one_hot_idx: bool = False, + ): + """ + Default constructor of AtomGraphDataset + Args: + dataset (Union[Dict[str, List], List]: dataset as dict or pure list + metadata (Dict, Optional): metadata of data + cutoff (float): cutoff radius of graphs inside the dataset + x_is_one_hot_idx (bool): if True, x is one_hot_idx, else 'Z' + + 'x' (node feature) of dataset can have 3 states, atomic_numbers, + one_hot_idx, or one_hot_vector. + + atomic_numbers is general but cannot directly used for input + one_hot_idx is can be input of the model but requires 'type_map' + """ + self.cutoff = cutoff + self.x_is_one_hot_idx = x_is_one_hot_idx + if metadata is None: + metadata = {KEY.CUTOFF: cutoff} + self.meta = metadata + if type(dataset) is list: + self.dataset = {self.KEY_DEFAULT: dataset} + else: + self.dataset = dataset + self.user_labels = list(self.dataset.keys()) + # group_by_key here? or not? + + def rewrite_labels_to_data(self): + """ + Based on self.dataset dict's keys + write data[KEY.USER_LABEL] to correspond to dict's keys + Most of times, it is already correctly written + But required to rewrite if someone rearrange dataset by their own way + """ + for label, data_list in self.dataset.items(): + for data in data_list: + data[KEY.USER_LABEL] = label + + def group_by_key(self, data_key: str = KEY.USER_LABEL): + """ + group dataset list by given key and save it as dict + and change in-place + Args: + data_key (str): data key to group by + + original use is USER_LABEL, but it can be used for other keys + if someone established it from data[KEY.INFO] + """ + data_list = self.to_list() + self.dataset = {} + for datum in data_list: + key = datum[data_key] + if key not in self.dataset: + self.dataset[key] = [] + self.dataset[key].append(datum) + self.user_labels = list(self.dataset.keys()) + + def separate_info(self, data_key: str = KEY.INFO): + """ + Separate info from data and save it as list of dict + to make it compatible with torch_geometric and later training + """ + data_list = self.to_list() + info_list = [] + for datum in data_list: + if data_key in datum is False: + continue + info_list.append(datum[data_key]) + del datum[data_key] # It does change the self.dataset + datum[data_key] = len(info_list) - 1 + self.info_list = info_list + + return (data_list, info_list) + + def get_species(self): + """ + You can also use get_natoms and extract keys from there instead of this + (And it is more efficient) + get chemical species of dataset + return list of SORTED chemical species (as str) + """ + if hasattr(self, 'type_map'): + natoms = self.get_natoms(self.type_map) + else: + natoms = self.get_natoms() + species = set() + for natom_dct in natoms.values(): + species.update(natom_dct.keys()) + species = sorted(list(species)) + return species + + def get_modalities(self): + modalities = set() + for data_list in self.dataset.values(): + datum = data_list[0].to_dict() + if KEY.DATA_MODALITY in datum.keys(): + modalities.add(datum[KEY.DATA_MODALITY]) + else: + return [] + return list(modalities) + + def write_modal_attr( + self, modal_type_mapper: dict, write_modal_type: bool = False + ): + num_modalities = len(modal_type_mapper) + for data_list in self.dataset.values(): + for data in data_list: + tmp_tensor = torch.zeros(num_modalities) + if data[KEY.DATA_MODALITY] != 'common': + modal_idx = modal_type_mapper[data[KEY.DATA_MODALITY]] + tmp_tensor[modal_idx] = 1.0 + if write_modal_type: + data[KEY.MODAL_TYPE] = modal_idx + data[KEY.MODAL_ATTR] = tmp_tensor + + def get_dict_sort_by_modality(self): + dict_sort_by_modality = {} + for data_list in self.dataset.values(): + try: + modal_key = data_list[0].to_dict()[KEY.DATA_MODALITY] + except: # Dataset is not modal + raise ValueError('This dataset has no modality.') + + if modal_key not in dict_sort_by_modality.keys(): + dict_sort_by_modality[modal_key] = [] + dict_sort_by_modality[modal_key].extend(data_list) + + return dict_sort_by_modality + + def len(self): + if ( + len(self.dataset.keys()) == 1 + and list(self.dataset.keys())[0] == AtomGraphDataset.KEY_DEFAULT + ): + return len(self.dataset[AtomGraphDataset.KEY_DEFAULT]) + else: + return {k: len(v) for k, v in self.dataset.items()} + + def get(self, idx: int, key: Optional[str] = None): + if key is None: + key = self.KEY_DEFAULT + return self.dataset[key][idx] + + def items(self): + return self.dataset.items() + + def to_dict(self): + dct_dataset = {} + for label, data_list in self.dataset.items(): + dct_dataset[label] = [datum.to_dict() for datum in data_list] + self.dataset = dct_dataset + return self + + def x_to_one_hot_idx(self, type_map: Dict[int, int]): + """ + type_map is dict of {atomic_number: one_hot_idx} + after this process, the dataset has dependency on type_map + or chemical species user want to consider + """ + assert self.x_is_one_hot_idx is False + for data_list in self.dataset.values(): + for datum in data_list: + datum[self.DATA_KEY_X] = torch.LongTensor( + [type_map[z.item()] for z in datum[self.DATA_KEY_X]] + ) + self.type_map = type_map + self.x_is_one_hot_idx = True + + def toggle_requires_grad_of_data( + self, key: str, requires_grad_value: bool + ): + """ + set requires_grad of specific key of data(pos, edge_vec, ...) + """ + for data_list in self.dataset.values(): + for datum in data_list: + datum[key].requires_grad_(requires_grad_value) + + def divide_dataset( + self, + ratio: float, + constant_ratio_btw_labels: bool = True, + ignore_test: bool = True + ): + """ + divide dataset into 1-2*ratio : ratio : ratio + return divided AtomGraphDataset + returned value lost its dict key and became {KEY_DEFAULT: datalist} + but KEY.USER_LABEL of each data is preserved + """ + + def divide(ratio: float, data_list: List, ignore_test=True): + if ratio > 0.5: + raise ValueError('Ratio must not exceed 0.5') + data_len = len(data_list) + random.shuffle(data_list) + n_validation = int(data_len * ratio) + if n_validation == 0: + raise ValueError( + '# of validation set is 0, increase your dataset' + ) + + if ignore_test: + test_list = [] + n_train = data_len - n_validation + train_list = data_list[0:n_train] + valid_list = data_list[n_train:] + else: + n_train = data_len - 2 * n_validation + train_list = data_list[0:n_train] + valid_list = data_list[n_train : n_train + n_validation] + test_list = data_list[n_train + n_validation : data_len] + return train_list, valid_list, test_list + + lists = ([], [], []) # train, valid, test + if constant_ratio_btw_labels: + for data_list in self.dataset.values(): + for store, divided in zip(lists, divide(ratio, data_list)): + store.extend(divided) + else: + lists = divide(ratio, self.to_list()) + + dbs = tuple( + AtomGraphDataset(data, self.cutoff, self.meta) for data in lists + ) + for db in dbs: + db.group_by_key() + return dbs + + def to_list(self): + return list(itertools.chain(*self.dataset.values())) + + def get_natoms(self, type_map: Optional[Dict[int, int]] = None): + """ + if x_is_one_hot_idx, type_map is required + type_map: Z->one_hot_index(node_feature) + return Dict{label: {symbol, natom}]} + """ + assert not (self.x_is_one_hot_idx is True and type_map is None) + natoms = {} + for label, data in self.dataset.items(): + natoms[label] = Counter() + for datum in data: + if self.x_is_one_hot_idx and type_map is not None: + Zs = util.onehot_to_chem(datum[self.DATA_KEY_X], type_map) + else: + Zs = [ + chemical_symbols[z] + for z in datum[self.DATA_KEY_X].tolist() + ] + cnt = Counter(Zs) + natoms[label] += cnt + natoms[label] = dict(natoms[label]) + return natoms + + def get_per_atom_mean(self, key: str, key_num_atoms: str = KEY.NUM_ATOMS): + """ + return per_atom mean of given data key + """ + eng_list = torch.Tensor( + [x[key] / x[key_num_atoms] for x in self.to_list()] + ) + return float(torch.mean(eng_list)) + + def get_per_atom_energy_mean(self): + """ + alias for get_per_atom_mean(KEY.ENERGY) + """ + return self.get_per_atom_mean(self.DATA_KEY_ENERGY) + + def get_species_ref_energy_by_linear_comb(self, num_chem_species: int): + """ + Total energy as y, composition as c_i, + solve linear regression of y = c_i*X + sklearn LinearRegression as solver + + x should be one-hot-indexed + give num_chem_species if possible + """ + assert self.x_is_one_hot_idx is True + data_list = self.to_list() + + c = torch.zeros((len(data_list), num_chem_species)) + for idx, datum in enumerate(data_list): + c[idx] = torch.bincount( + datum[self.DATA_KEY_X], minlength=num_chem_species + ) + y = torch.Tensor([x[self.DATA_KEY_ENERGY] for x in data_list]) + c = c.numpy() + y = y.numpy() + + # tweak to fine tune training from many-element to small element + zero_indices = np.all(c == 0, axis=0) + c_reduced = c[:, ~zero_indices] + full_coeff = np.zeros(num_chem_species) + coef_reduced = ( + Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ + ) + full_coeff[~zero_indices] = coef_reduced + + return full_coeff + + def get_force_rms(self): + force_list = [] + for x in self.to_list(): + force_list.extend( + x[self.DATA_KEY_FORCE] + .reshape( + -1, + ) + .tolist() + ) + force_list = torch.Tensor(force_list) + return float(torch.sqrt(torch.mean(torch.pow(force_list, 2)))) + + def get_species_wise_force_rms(self, num_chem_species: int): + """ + Return force rms for each species + Averaged by each components (x, y, z) + """ + assert self.x_is_one_hot_idx is True + data_list = self.to_list() + + atomx = torch.concat([d[self.DATA_KEY_X] for d in data_list]) + force = torch.concat([d[self.DATA_KEY_FORCE] for d in data_list]) + + index = atomx.repeat_interleave(3, 0).reshape(force.shape) + rms = torch.zeros( + (num_chem_species, 3), + dtype=force.dtype, + device=force.device + ) + rms.scatter_reduce_( + 0, index, force.square(), + reduce='mean', include_self=False + ) + return torch.sqrt(rms.mean(dim=1)) + + def get_avg_num_neigh(self): + n_neigh = [] + for _, data_list in self.dataset.items(): + for data in data_list: + n_neigh.extend( + np.unique(data[KEY.EDGE_IDX][0], return_counts=True)[1] + ) + + avg_num_neigh = np.average(n_neigh) + return avg_num_neigh + + def get_statistics(self, key: str): + """ + return dict of statistics of given key (energy, force, stress) + key of dict is its label and _total for total statistics + value of dict is dict of statistics (mean, std, median, max, min) + """ + + def _get_statistic_dict(tensor_list): + data_list = torch.cat( + [ + tensor.reshape( + -1, + ) + for tensor in tensor_list + ] + ) + data_list = data_list[~torch.isnan(data_list)] + return { + 'mean': float(torch.mean(data_list)), + 'std': float(torch.std(data_list)), + 'median': float(torch.median(data_list)), + 'max': ( + torch.nan + if data_list.numel() == 0 + else float(torch.max(data_list)) + ), + 'min': ( + torch.nan + if data_list.numel() == 0 + else float(torch.min(data_list)) + ), + } + + res = {} + for label, values in self.dataset.items(): + # flatten list of torch.Tensor (values) + tensor_list = [x[key] for x in values] + res[label] = _get_statistic_dict(tensor_list) + tensor_list = [x[key] for x in self.to_list()] + res['Total'] = _get_statistic_dict(tensor_list) + return res + + def augment(self, dataset, validator: Optional[Callable] = None): + """check meta compatibility here + dataset(AtomGraphDataset): data to augment + validator(Callable, Optional): function(self, dataset) -> bool + + if validator is None, by default it checks + whether cutoff & chemical_species are same before augment + + check consistent data type, float, double, long integer etc + """ + + def default_validator(db1, db2): + cut_consis = db1.cutoff == db2.cutoff + # compare unordered lists + x_is_not_onehot = (not db1.x_is_one_hot_idx) and ( + not db2.x_is_one_hot_idx + ) + return cut_consis and x_is_not_onehot + + if validator is None: + validator = default_validator + if not validator(self, dataset): + raise ValueError('given datasets are not compatible check cutoffs') + for key, val in dataset.items(): + if key in self.dataset: + self.dataset[key].extend(val) + else: + self.dataset.update({key: val}) + self.user_labels = list(self.dataset.keys()) + + def unify_dtypes( + self, + float_dtype: torch.dtype = torch.float32, + int_dtype: torch.dtype = torch.int64 + ): + data_list = self.to_list() + for datum in data_list: + for k, v in list(datum.items()): + datum[k] = util.dtype_correct(v, float_dtype, int_dtype) + + def delete_data_key(self, key: str): + for data in self.to_list(): + del data[key] + + # TODO: this by_label is not straightforward + def save(self, path: str, by_label: bool = False): + if by_label: + for label, data in self.dataset.items(): + torch.save( + AtomGraphDataset( + {label: data}, self.cutoff, metadata=self.meta + ), + f'{path}/{label}.sevenn_data', + ) + else: + if path.endswith('.sevenn_data') is False: + path += '.sevenn_data' + torch.save(self, path) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/graph_dataset.py b/mace-bench/3rdparty/SevenNet/sevenn/train/graph_dataset.py index fd8d395787ee06efc4a58c315e3cf06b7a24997a..fc32d142de1d2a743a775b2679c0f6040052ffb8 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/graph_dataset.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/graph_dataset.py @@ -1,707 +1,707 @@ -import os -import warnings -from collections import Counter -from copy import deepcopy -from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.serialization -import torch.utils.data -import yaml -from ase.data import chemical_symbols -from torch_geometric.data import Data -from torch_geometric.data.in_memory_dataset import InMemoryDataset -from tqdm import tqdm - -import sevenn._keys as KEY -import sevenn.train.dataload as dataload -import sevenn.util as util -from sevenn import __version__ -from sevenn._const import NUM_UNIV_ELEMENT -from sevenn.atom_graph_data import AtomGraphData -from sevenn.logger import Logger - -if torch.__version__.split()[0] >= '2.4.0': - # load graph without error - torch.serialization.add_safe_globals([AtomGraphData]) - -# warning from PyG, for later torch versions -warnings.filterwarnings( - 'ignore', - message='You are using `torch.load` with `weights_only=False`', -) - - -def _tag_graphs(graph_list: List[AtomGraphData], tag: str): - """ - WIP: To be used - """ - for g in graph_list: - g[KEY.TAG] = tag - return graph_list - - -def pt_to_args(pt_filename: str): - """ - Return arg dict of root and processed_name from path to .pt - Usage: - dataset = SevenNetGraphDataset( - **pt_to_args({path}/sevenn_data/dataset.pt) - ) - """ - processed_dir, basename = os.path.split(pt_filename) - return { - 'root': os.path.dirname(processed_dir), - 'processed_name': os.path.basename(basename), - } - - -def _run_stat( - graph_list, - y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS], -) -> Dict[str, Any]: - """ - Loop over dataset and init any statistics might need - """ - n_neigh = [] - natoms_counter = Counter() - composition = torch.zeros((len(graph_list), NUM_UNIV_ELEMENT)) - stats: Dict[str, Any] = {y: {'_array': []} for y in y_keys} - - for i, graph in tqdm( - enumerate(graph_list), desc='run_stat', total=len(graph_list) - ): - z_tensor = graph[KEY.ATOMIC_NUMBERS] - natoms_counter.update(z_tensor.tolist()) - composition[i] = torch.bincount(z_tensor, minlength=NUM_UNIV_ELEMENT) - n_neigh.append(torch.unique(graph[KEY.EDGE_IDX][0], return_counts=True)[1]) - for y, dct in stats.items(): - dct['_array'].append( - graph[y].reshape( - -1, - ) - ) - - stats.update({'num_neighbor': {'_array': n_neigh}}) - for y, dct in stats.items(): - array = torch.cat(dct['_array']) - if array.dtype == torch.int64: # because of n_neigh - array = array.to(torch.float) - try: - median = torch.quantile(array, q=0.5) - except RuntimeError: - warnings.warn(f'skip median due to too large tensor size: {y}') - median = torch.nan - dct.update( - { - 'mean': float(torch.mean(array)), - 'std': float(torch.std(array, correction=0)), - 'median': float(median), - 'max': float(torch.max(array)), - 'min': float(torch.min(array)), - 'count': array.numel(), - '_array': array, - } - ) - - natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()} - natoms['total'] = sum(list(natoms.values())) - stats.update({'_composition': composition, 'natoms': natoms}) - return stats - - -def _elemwise_reference_energies(composition: np.ndarray, energies: np.ndarray): - from sklearn.linear_model import Ridge - - c = composition - y = energies - zero_indices = np.all(c == 0, axis=0) - c_reduced = c[:, ~zero_indices] - # will not 100% reproduce, as it is sorted by Z - # train/dataset.py was sorted by alphabets of chemical species - coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ - full_coeff = np.zeros(NUM_UNIV_ELEMENT) - full_coeff[~zero_indices] = coef_reduced - return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy - - -class SevenNetGraphDataset(InMemoryDataset): - """ - Replacement of AtomGraphDataset. (and .sevenn_data) - Extends InMemoryDataset of PyG. From given 'files', and 'cutoff', - build graphs for training SevenNet model. Preprocessed graphs are saved to - f'{root}/sevenn_data/{processed_name}.pt - - TODO: Save meta info (cutoff) by overriding .save and .load - TODO: 'tag' is not used yet, but initialized - 'tag' is replacement for 'label', and each datapoint has it as integer - 'tag' is usually parsed from if the structure_list of load_dataset - - Args: - root: path to save/load processed PyG dataset - cutoff: edge cutoff of given AtomGraphData - files: list of filenames or dict describing how to parse the file - ASE readable (with proper extension), structure_list, .sevenn_data, - dict containing file_list (see dict_reader of train/dataload.py) - process_num_cores: # of cpu cores to build graph - processed_name: save as {root}/sevenn_data/{processed_name}.pt - pre_transfrom: optional transform for each graph: def (graph) -> graph - pre_filter: optional filtering function for each graph: def (graph) -> graph - force_reload: if True, reload dataset from files even if there exist - {root}/sevenn_data/{processed_name} - **process_kwargs: keyword arguments that will be passed into ase.io.read - """ - - def __init__( - self, - cutoff: float, - root: Optional[str] = None, - files: Optional[Union[str, List[Any]]] = None, - process_num_cores: int = 1, - processed_name: str = 'graph.pt', - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None, - use_data_weight: bool = False, - log: bool = True, - force_reload: bool = False, - drop_info: bool = True, - **process_kwargs, - ): - self.cutoff = cutoff - if files is None: - files = [] - elif isinstance(files, str): - files = [files] # user convenience - - _files = [] - for f in files: - if isinstance(f, str): - f = os.path.abspath(f) - _files.append(f) - self._files = _files - - self._full_file_list = [] - if not processed_name.endswith('.pt'): - processed_name += '.pt' - self._processed_names = [ - processed_name, # {root}/sevenn_data/{name}.pt - processed_name.replace('.pt', '.yaml'), - ] - - root = root or './' - _pdir = os.path.join(root, 'sevenn_data') - _pt = os.path.join(_pdir, self._processed_names[0]) - if not os.path.exists(_pt) and len(self._files) == 0: - raise ValueError( - ( - f'{_pt} not found and no files to process. ' - + 'If you copied only .pt file, please copy ' - + 'whole sevenn_data dir without changing its name.' - + ' They all work together.' - ) - ) - - _yam = os.path.join(_pdir, self._processed_names[1]) - if not os.path.exists(_yam) and len(self._files) == 0: - raise ValueError(f'{_yam} not found and no files to process') - - self.process_num_cores = process_num_cores - self.process_kwargs = process_kwargs - self.use_data_weight = use_data_weight - self.drop_info = drop_info - - self.tag_map = {} - self.statistics = {} - self.finalized = False - - super().__init__( - root, - transform, - pre_transform, - pre_filter, - log=log, - force_reload=force_reload, - ) # Internally calls 'process' - self.load(self.processed_paths[0]) # load pt, saved after process - - def load(self, path: str, data_cls=Data) -> None: - super().load(path, data_cls) - - if len(self) == 0: - warnings.warn(f'No graphs found {self.processed_paths[0]}') - if len(self.statistics) == 0: - # dataset is loaded from existing pt file. - self._load_meta() - - def _load_meta(self) -> None: - with open(self.processed_paths[1], 'r') as f: - meta = yaml.safe_load(f) - - if meta['sevennet_version'] == '0.10.0': - self._save_meta(list(self)) - with open(self.processed_paths[1], 'r') as f: - meta = yaml.safe_load(f) - - cutoff = float(meta['cutoff']) - if float(meta['cutoff']) != self.cutoff: - warnings.warn( - ( - 'Loaded dataset is built with different cutoff length: ' - + f'{cutoff} != {self.cutoff}, dataset cutoff will be' - + f' overwritten to {cutoff}' - ) - ) - self.cutoff = cutoff - self._files = meta['files'] - self.statistics = meta['statistics'] - - def __getitem__(self, idx): - graph = super().__getitem__(idx) - if self.drop_info: - graph.pop(KEY.INFO, None) # type: ignore - return graph - - @property - def raw_file_names(self) -> List[Any]: - return self._files - - @property - def processed_file_names(self) -> List[str]: - return self._processed_names - - @property - def processed_dir(self) -> str: - return os.path.join(self.root, 'sevenn_data') - - @property - def full_file_list(self) -> Union[List[str], None]: - return self._full_file_list - - def process(self): - graph_list: List[AtomGraphData] = [] - for file in self.raw_file_names: - tmplist = SevenNetGraphDataset.file_to_graph_list( - file=file, - cutoff=self.cutoff, - num_cores=self.process_num_cores, - **self.process_kwargs, - ) - if isinstance(file, str) and self._full_file_list is not None: - self._full_file_list.extend([os.path.abspath(file)] * len(tmplist)) - else: - self._full_file_list = None - graph_list.extend(tmplist) - - processed_graph_list = [] - for data in graph_list: - if self.pre_filter is not None and not self.pre_filter(data): - continue - if self.pre_transform is not None: - data = self.pre_transform(data) - if self.use_data_weight: - # pop data weight from info, and assign to graph - weight = data[KEY.INFO].pop( - KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0} - ) - data[KEY.DATA_WEIGHT] = weight - processed_graph_list.append(data) - - if len(processed_graph_list) == 0: - # Can not save at all if there is no graph (error in PyG), raise an error - raise ValueError('Zero graph found after filtering') - - # save graphs, handled by torch_geometrics - self.save(processed_graph_list, self.processed_paths[0]) - self._save_meta(processed_graph_list) - if self.log: - Logger().writeline(f'Dataset is saved: {self.processed_paths[0]}') - - def _save_meta(self, graph_list) -> None: - stats = _run_stat(graph_list) - stats['elemwise_reference_energies'] = _elemwise_reference_energies( - stats['_composition'].numpy(), stats[KEY.ENERGY]['_array'].numpy() - ) - self.statistics = stats - - stats_save = {} - for label, dct in self.statistics.items(): - if label.startswith('_'): - continue - stats_save[label] = {} - if not isinstance(dct, dict): - stats_save[label] = dct - else: - for k, v in dct.items(): - if k.startswith('_'): - continue - stats_save[label][k] = v - - meta = { - 'sevennet_version': __version__, - 'cutoff': self.cutoff, - 'when': datetime.now().strftime('%Y-%m-%d %H:%M'), - 'files': self._files, - 'statistics': stats_save, - 'species': self.species, - 'num_graphs': self.statistics[KEY.ENERGY]['count'], - 'per_atom_energy_mean': self.per_atom_energy_mean, - 'force_rms': self.force_rms, - 'per_atom_energy_std': self.per_atom_energy_std, - 'avg_num_neigh': self.avg_num_neigh, - 'sqrt_avg_num_neigh': self.sqrt_avg_num_neigh, - } - - with open(self.processed_paths[1], 'w') as f: - yaml.dump(meta, f, default_flow_style=False) - - @property - def species(self): - return [z for z in self.statistics['natoms'].keys() if z != 'total'] - - @property - def natoms(self): - return self.statistics['natoms'] - - @property - def per_atom_energy_mean(self): - return self.statistics[KEY.PER_ATOM_ENERGY]['mean'] - - @property - def elemwise_reference_energies(self): - return self.statistics['elemwise_reference_energies'] - - @property - def force_rms(self): - mean = self.statistics[KEY.FORCE]['mean'] - std = self.statistics[KEY.FORCE]['std'] - return float((mean**2 + std**2) ** (0.5)) - - @property - def per_atom_energy_std(self): - return self.statistics['per_atom_energy']['std'] - - @property - def avg_num_neigh(self): - return self.statistics['num_neighbor']['mean'] - - @property - def sqrt_avg_num_neigh(self): - return self.avg_num_neigh**0.5 - - @staticmethod - def _read_sevenn_data(filename: str) -> Tuple[List[AtomGraphData], float]: - # backward compatibility - from sevenn.train.dataset import AtomGraphDataset - - dataset = torch.load(filename, map_location='cpu', weights_only=False) - if isinstance(dataset, AtomGraphDataset): - graph_list = [] - for _, graphs in dataset.dataset.items(): # type: ignore - # TODO: transfer label to tag (who gonna need this?) - graph_list.extend(graphs) - return graph_list, dataset.cutoff - else: - raise ValueError(f'Not sevenn_data type: {type(dataset)}') - - @staticmethod - def _read_structure_list( - filename: str, cutoff: float, num_cores: int = 1 - ) -> List[AtomGraphData]: - datadct = dataload.structure_list_reader(filename) - graph_list = [] - for tag, atoms_list in datadct.items(): - tmp = dataload.graph_build(atoms_list, cutoff, num_cores) - graph_list.extend(_tag_graphs(tmp, tag)) - return graph_list - - @staticmethod - def _read_ase_readable( - filename: str, - cutoff: float, - num_cores: int = 1, - tag: str = '', - transfer_info: bool = True, - allow_unlabeled: bool = False, - **ase_kwargs, - ) -> List[AtomGraphData]: - pbc_override = ase_kwargs.pop('pbc', None) - atoms_list = dataload.ase_reader(filename, **ase_kwargs) - for atoms in atoms_list: - if pbc_override is not None: - atoms.pbc = pbc_override - graph_list = dataload.graph_build( - atoms_list, - cutoff, - num_cores, - transfer_info=transfer_info, - allow_unlabeled=allow_unlabeled, - ) - if tag != '': - graph_list = _tag_graphs(graph_list, tag) - return graph_list - - @staticmethod - def _read_graph_dataset( - filename: str, cutoff: float, **kwargs - ) -> List[AtomGraphData]: - meta_f = filename.replace('.pt', '.yaml') - orig_cutoff = cutoff - if not os.path.exists(filename): - raise FileNotFoundError(f'No such file: {filename}') - if not os.path.exists(meta_f): - warnings.warn('No meta info found, beware of cutoff...') - else: - with open(meta_f, 'r') as f: - meta = yaml.safe_load(f) - orig_cutoff = float(meta['cutoff']) - if orig_cutoff != cutoff: - warnings.warn( - f'{filename} has different cutoff length: ' - + f'{cutoff} != {orig_cutoff}' - ) - ds_args: dict[str, Any] = dict({'cutoff': orig_cutoff}) - ds_args.update(pt_to_args(filename)) - ds_args.update(kwargs) - dataset = SevenNetGraphDataset(**ds_args) - # TODO: hard coded. consult with inference.py - glist = [g.fit_dimension() for g in dataset] # type: ignore - for g in glist: - if KEY.STRESS in g: - # (1, 6) is what we want - g[KEY.STRESS] = g[KEY.STRESS].unsqueeze(0) - return glist - - @staticmethod - def _read_dict( - data_dict: dict, - cutoff: float, - num_cores: int = 1, - ): - # logic same as the dataload dict_reader, but handles graphs - data_dict_cp = deepcopy(data_dict) - file_list = data_dict_cp.get('file_list', None) - if file_list is None: - raise KeyError('file_list is not found') - - data_weight_default = { - 'energy': 1.0, - 'force': 1.0, - 'stress': 1.0, - } - data_weight = data_weight_default.copy() - data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {})) - - graph_list = [] - for file_dct in file_list: - ftype = file_dct.pop('data_format', 'ase') - if ftype != 'graph': - continue - graph_list.extend( - SevenNetGraphDataset._read_graph_dataset( - file_dct.get('file'), cutoff=cutoff - ) - ) - for graph in graph_list: - if KEY.INFO not in graph: - graph[KEY.INFO] = {} - graph[KEY.INFO].update(data_dict_cp) - graph[KEY.INFO].update({KEY.DATA_WEIGHT: data_weight}) - - atoms_list = dataload.dict_reader(data_dict) - graph_list.extend(dataload.graph_build(atoms_list, cutoff, num_cores)) - return graph_list - - @staticmethod - def file_to_graph_list( - file: Union[str, dict], cutoff: float, num_cores: int = 1, **kwargs - ) -> List[AtomGraphData]: - """ - kwargs: if file is ase readable, passed to ase.io.read - """ - if isinstance(file, str) and not os.path.isfile(file): - raise ValueError(f'No such file: {file}') - graph_list: List[AtomGraphData] - if isinstance(file, dict): - graph_list = SevenNetGraphDataset._read_dict( - file, cutoff, num_cores, **kwargs - ) - elif file.endswith('.pt'): - graph_list = SevenNetGraphDataset._read_graph_dataset(file, cutoff) - elif file.endswith('.sevenn_data'): - graph_list, cutoff_other = SevenNetGraphDataset._read_sevenn_data(file) - if cutoff_other != cutoff: - warnings.warn(f'Given {file} has different {cutoff_other}!') - cutoff = cutoff_other - elif 'structure_list' in file: - graph_list = SevenNetGraphDataset._read_structure_list( - file, cutoff, num_cores - ) - else: - graph_list = SevenNetGraphDataset._read_ase_readable( - file, cutoff, num_cores, **kwargs - ) - return graph_list - - -def from_single_path( - path: Union[str, List], override_data_weight: bool = True, **dataset_kwargs -) -> Union[SevenNetGraphDataset, None]: - """ - Convenient routine for loading a single .pt dataset. - If given dict and it has data_weight, apply it using transform - """ - data_weight = {'energy': 1.0, 'force': 1.0, 'stress': 1.0} - spath = _extract_single_path(path) - if spath is None: - return None - - if isinstance(spath, str): - if not spath.endswith('.pt'): - return None - dataset_kwargs.update(pt_to_args(spath)) - elif isinstance(spath, dict): - file = _extract_file_from_dict(spath) - if file is None or not file.endswith('.pt'): - return None - dataset_kwargs.update(pt_to_args(file)) - data_weight_user = spath.get(KEY.DATA_WEIGHT, None) - if data_weight_user is not None: - data_weight.update(data_weight_user) - else: - return None - - if override_data_weight: - dataset_kwargs['transform'] = _chain_data_weight_override( - dataset_kwargs.get('transform'), data_weight - ) - - return SevenNetGraphDataset(**dataset_kwargs) - - -def _extract_single_path(path: Union[str, List]) -> Union[str, dict, None]: - """Extracts a single path from the input, - ensuring it's either a single string or list with one item.""" - if isinstance(path, list): - return path[0] if len(path) == 1 else None - return path if isinstance(path, (str, dict)) else None - - -def _extract_file_from_dict(path_dict: dict) -> Union[str, None]: - """Extracts a single file path from the dictionary, ensuring it's valid.""" - file_list = path_dict.get('file_list', None) - if file_list and len(file_list) == 1: - file = file_list[0].get('file', None) - return file if isinstance(file, str) else None - return None - - -def _chain_data_weight_override(transform_func, data_weight): - """Creates a transform function that overrides the data weight.""" - - def chained_transform(graph): - graph = transform_func(graph) if transform_func is not None else graph - graph[KEY.INFO].pop(KEY.DATA_WEIGHT, None) - graph[KEY.DATA_WEIGHT] = data_weight - return graph - - return chained_transform - - -# script, return dict of SevenNetGraphDataset -def from_config( - config: Dict[str, Any], - working_dir: str = os.getcwd(), - dataset_keys: Optional[List[str]] = None, -): - log = Logger() - if dataset_keys is None: - dataset_keys = [] - for k in config: - if k.startswith('load_') and k.endswith('_path'): - dataset_keys.append(k) - - if KEY.LOAD_TRAINSET not in dataset_keys: - raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') - - # initialize arguments for loading dataset - dataset_args = { - 'cutoff': config[KEY.CUTOFF], - 'root': working_dir, - 'process_num_cores': config.get(KEY.PREPROCESS_NUM_CORES, 1), - 'use_data_weight': config.get(KEY.USE_WEIGHT, False), - **config.get(KEY.DATA_FORMAT_ARGS, {}), - } - - datasets = {} - for dk in dataset_keys: - if not (paths := config[dk]): - continue - if isinstance(paths, str): - paths = [paths] - name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) - if (dataset := from_single_path(paths, **dataset_args)) is not None: - datasets[name] = dataset - else: - dataset_args.update({'files': paths, 'processed_name': name}) - dataset_path = os.path.join(working_dir, 'sevenn_data', f'{name}.pt') - if os.path.exists(dataset_path) and 'force_reload' not in dataset_args: - log.writeline( - f'Dataset will be loaded from {dataset_path}, without update. ' - + 'If you have changed your files to read, put force_reload=True' - + ' under the data_format_args key' - ) - datasets[name] = SevenNetGraphDataset(**dataset_args) - - train_set = datasets['trainset'] - - chem_species = set(train_set.species) - # print statistics of each dataset - for name, dataset in datasets.items(): - log.bar() - log.writeline(f'{name} distribution:') - log.statistic_write(dataset.statistics) - log.format_k_v('# structures (graph)', len(dataset), write=True) - - chem_species.update(dataset.species) - log.bar() - - # initialize known species from dataset if 'auto' - # sorted to alphabetical order (which is same as before) - chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] - if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py - log.writeline('Known species are obtained from the dataset') - config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) - - # retrieve shift, scale, conv_denominaotrs from user input (keyword) - init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] - for k in init_from_stats: - input = config[k] # statistic key or numbers - # If it is not 'str', 1: It is 'continue' training - # 2: User manually inserted numbers - if isinstance(input, str) and hasattr(train_set, input): - var = getattr(train_set, input) - config.update({k: var}) - log.writeline(f'{k} is obtained from statistics') - elif isinstance(input, str) and not hasattr(train_set, input): - raise NotImplementedError(input) - - if 'validset' not in datasets and config.get(KEY.RATIO, 0.0) > 0.0: - log.writeline('Use validation set as random split from the training set') - log.writeline( - 'Note that statistics, shift, scale, and conv_denominator are ' - + 'computed before random split.\n If you want these after random ' - + 'split, please preprocess dataset and set it as load_trainset_path ' - + 'and load_validset_path explicitly.' - ) - - ratio = float(config[KEY.RATIO]) - train, valid = torch.utils.data.random_split( - datasets['trainset'], (1.0 - ratio, ratio) - ) - datasets['trainset'] = train - datasets['validset'] = valid - - return datasets +import os +import warnings +from collections import Counter +from copy import deepcopy +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.serialization +import torch.utils.data +import yaml +from ase.data import chemical_symbols +from torch_geometric.data import Data +from torch_geometric.data.in_memory_dataset import InMemoryDataset +from tqdm import tqdm + +import sevenn._keys as KEY +import sevenn.train.dataload as dataload +import sevenn.util as util +from sevenn import __version__ +from sevenn._const import NUM_UNIV_ELEMENT +from sevenn.atom_graph_data import AtomGraphData +from sevenn.logger import Logger + +if torch.__version__.split()[0] >= '2.4.0': + # load graph without error + torch.serialization.add_safe_globals([AtomGraphData]) + +# warning from PyG, for later torch versions +warnings.filterwarnings( + 'ignore', + message='You are using `torch.load` with `weights_only=False`', +) + + +def _tag_graphs(graph_list: List[AtomGraphData], tag: str): + """ + WIP: To be used + """ + for g in graph_list: + g[KEY.TAG] = tag + return graph_list + + +def pt_to_args(pt_filename: str): + """ + Return arg dict of root and processed_name from path to .pt + Usage: + dataset = SevenNetGraphDataset( + **pt_to_args({path}/sevenn_data/dataset.pt) + ) + """ + processed_dir, basename = os.path.split(pt_filename) + return { + 'root': os.path.dirname(processed_dir), + 'processed_name': os.path.basename(basename), + } + + +def _run_stat( + graph_list, + y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS], +) -> Dict[str, Any]: + """ + Loop over dataset and init any statistics might need + """ + n_neigh = [] + natoms_counter = Counter() + composition = torch.zeros((len(graph_list), NUM_UNIV_ELEMENT)) + stats: Dict[str, Any] = {y: {'_array': []} for y in y_keys} + + for i, graph in tqdm( + enumerate(graph_list), desc='run_stat', total=len(graph_list) + ): + z_tensor = graph[KEY.ATOMIC_NUMBERS] + natoms_counter.update(z_tensor.tolist()) + composition[i] = torch.bincount(z_tensor, minlength=NUM_UNIV_ELEMENT) + n_neigh.append(torch.unique(graph[KEY.EDGE_IDX][0], return_counts=True)[1]) + for y, dct in stats.items(): + dct['_array'].append( + graph[y].reshape( + -1, + ) + ) + + stats.update({'num_neighbor': {'_array': n_neigh}}) + for y, dct in stats.items(): + array = torch.cat(dct['_array']) + if array.dtype == torch.int64: # because of n_neigh + array = array.to(torch.float) + try: + median = torch.quantile(array, q=0.5) + except RuntimeError: + warnings.warn(f'skip median due to too large tensor size: {y}') + median = torch.nan + dct.update( + { + 'mean': float(torch.mean(array)), + 'std': float(torch.std(array, correction=0)), + 'median': float(median), + 'max': float(torch.max(array)), + 'min': float(torch.min(array)), + 'count': array.numel(), + '_array': array, + } + ) + + natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()} + natoms['total'] = sum(list(natoms.values())) + stats.update({'_composition': composition, 'natoms': natoms}) + return stats + + +def _elemwise_reference_energies(composition: np.ndarray, energies: np.ndarray): + from sklearn.linear_model import Ridge + + c = composition + y = energies + zero_indices = np.all(c == 0, axis=0) + c_reduced = c[:, ~zero_indices] + # will not 100% reproduce, as it is sorted by Z + # train/dataset.py was sorted by alphabets of chemical species + coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ + full_coeff = np.zeros(NUM_UNIV_ELEMENT) + full_coeff[~zero_indices] = coef_reduced + return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy + + +class SevenNetGraphDataset(InMemoryDataset): + """ + Replacement of AtomGraphDataset. (and .sevenn_data) + Extends InMemoryDataset of PyG. From given 'files', and 'cutoff', + build graphs for training SevenNet model. Preprocessed graphs are saved to + f'{root}/sevenn_data/{processed_name}.pt + + TODO: Save meta info (cutoff) by overriding .save and .load + TODO: 'tag' is not used yet, but initialized + 'tag' is replacement for 'label', and each datapoint has it as integer + 'tag' is usually parsed from if the structure_list of load_dataset + + Args: + root: path to save/load processed PyG dataset + cutoff: edge cutoff of given AtomGraphData + files: list of filenames or dict describing how to parse the file + ASE readable (with proper extension), structure_list, .sevenn_data, + dict containing file_list (see dict_reader of train/dataload.py) + process_num_cores: # of cpu cores to build graph + processed_name: save as {root}/sevenn_data/{processed_name}.pt + pre_transfrom: optional transform for each graph: def (graph) -> graph + pre_filter: optional filtering function for each graph: def (graph) -> graph + force_reload: if True, reload dataset from files even if there exist + {root}/sevenn_data/{processed_name} + **process_kwargs: keyword arguments that will be passed into ase.io.read + """ + + def __init__( + self, + cutoff: float, + root: Optional[str] = None, + files: Optional[Union[str, List[Any]]] = None, + process_num_cores: int = 1, + processed_name: str = 'graph.pt', + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + use_data_weight: bool = False, + log: bool = True, + force_reload: bool = False, + drop_info: bool = True, + **process_kwargs, + ): + self.cutoff = cutoff + if files is None: + files = [] + elif isinstance(files, str): + files = [files] # user convenience + + _files = [] + for f in files: + if isinstance(f, str): + f = os.path.abspath(f) + _files.append(f) + self._files = _files + + self._full_file_list = [] + if not processed_name.endswith('.pt'): + processed_name += '.pt' + self._processed_names = [ + processed_name, # {root}/sevenn_data/{name}.pt + processed_name.replace('.pt', '.yaml'), + ] + + root = root or './' + _pdir = os.path.join(root, 'sevenn_data') + _pt = os.path.join(_pdir, self._processed_names[0]) + if not os.path.exists(_pt) and len(self._files) == 0: + raise ValueError( + ( + f'{_pt} not found and no files to process. ' + + 'If you copied only .pt file, please copy ' + + 'whole sevenn_data dir without changing its name.' + + ' They all work together.' + ) + ) + + _yam = os.path.join(_pdir, self._processed_names[1]) + if not os.path.exists(_yam) and len(self._files) == 0: + raise ValueError(f'{_yam} not found and no files to process') + + self.process_num_cores = process_num_cores + self.process_kwargs = process_kwargs + self.use_data_weight = use_data_weight + self.drop_info = drop_info + + self.tag_map = {} + self.statistics = {} + self.finalized = False + + super().__init__( + root, + transform, + pre_transform, + pre_filter, + log=log, + force_reload=force_reload, + ) # Internally calls 'process' + self.load(self.processed_paths[0]) # load pt, saved after process + + def load(self, path: str, data_cls=Data) -> None: + super().load(path, data_cls) + + if len(self) == 0: + warnings.warn(f'No graphs found {self.processed_paths[0]}') + if len(self.statistics) == 0: + # dataset is loaded from existing pt file. + self._load_meta() + + def _load_meta(self) -> None: + with open(self.processed_paths[1], 'r') as f: + meta = yaml.safe_load(f) + + if meta['sevennet_version'] == '0.10.0': + self._save_meta(list(self)) + with open(self.processed_paths[1], 'r') as f: + meta = yaml.safe_load(f) + + cutoff = float(meta['cutoff']) + if float(meta['cutoff']) != self.cutoff: + warnings.warn( + ( + 'Loaded dataset is built with different cutoff length: ' + + f'{cutoff} != {self.cutoff}, dataset cutoff will be' + + f' overwritten to {cutoff}' + ) + ) + self.cutoff = cutoff + self._files = meta['files'] + self.statistics = meta['statistics'] + + def __getitem__(self, idx): + graph = super().__getitem__(idx) + if self.drop_info: + graph.pop(KEY.INFO, None) # type: ignore + return graph + + @property + def raw_file_names(self) -> List[Any]: + return self._files + + @property + def processed_file_names(self) -> List[str]: + return self._processed_names + + @property + def processed_dir(self) -> str: + return os.path.join(self.root, 'sevenn_data') + + @property + def full_file_list(self) -> Union[List[str], None]: + return self._full_file_list + + def process(self): + graph_list: List[AtomGraphData] = [] + for file in self.raw_file_names: + tmplist = SevenNetGraphDataset.file_to_graph_list( + file=file, + cutoff=self.cutoff, + num_cores=self.process_num_cores, + **self.process_kwargs, + ) + if isinstance(file, str) and self._full_file_list is not None: + self._full_file_list.extend([os.path.abspath(file)] * len(tmplist)) + else: + self._full_file_list = None + graph_list.extend(tmplist) + + processed_graph_list = [] + for data in graph_list: + if self.pre_filter is not None and not self.pre_filter(data): + continue + if self.pre_transform is not None: + data = self.pre_transform(data) + if self.use_data_weight: + # pop data weight from info, and assign to graph + weight = data[KEY.INFO].pop( + KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0} + ) + data[KEY.DATA_WEIGHT] = weight + processed_graph_list.append(data) + + if len(processed_graph_list) == 0: + # Can not save at all if there is no graph (error in PyG), raise an error + raise ValueError('Zero graph found after filtering') + + # save graphs, handled by torch_geometrics + self.save(processed_graph_list, self.processed_paths[0]) + self._save_meta(processed_graph_list) + if self.log: + Logger().writeline(f'Dataset is saved: {self.processed_paths[0]}') + + def _save_meta(self, graph_list) -> None: + stats = _run_stat(graph_list) + stats['elemwise_reference_energies'] = _elemwise_reference_energies( + stats['_composition'].numpy(), stats[KEY.ENERGY]['_array'].numpy() + ) + self.statistics = stats + + stats_save = {} + for label, dct in self.statistics.items(): + if label.startswith('_'): + continue + stats_save[label] = {} + if not isinstance(dct, dict): + stats_save[label] = dct + else: + for k, v in dct.items(): + if k.startswith('_'): + continue + stats_save[label][k] = v + + meta = { + 'sevennet_version': __version__, + 'cutoff': self.cutoff, + 'when': datetime.now().strftime('%Y-%m-%d %H:%M'), + 'files': self._files, + 'statistics': stats_save, + 'species': self.species, + 'num_graphs': self.statistics[KEY.ENERGY]['count'], + 'per_atom_energy_mean': self.per_atom_energy_mean, + 'force_rms': self.force_rms, + 'per_atom_energy_std': self.per_atom_energy_std, + 'avg_num_neigh': self.avg_num_neigh, + 'sqrt_avg_num_neigh': self.sqrt_avg_num_neigh, + } + + with open(self.processed_paths[1], 'w') as f: + yaml.dump(meta, f, default_flow_style=False) + + @property + def species(self): + return [z for z in self.statistics['natoms'].keys() if z != 'total'] + + @property + def natoms(self): + return self.statistics['natoms'] + + @property + def per_atom_energy_mean(self): + return self.statistics[KEY.PER_ATOM_ENERGY]['mean'] + + @property + def elemwise_reference_energies(self): + return self.statistics['elemwise_reference_energies'] + + @property + def force_rms(self): + mean = self.statistics[KEY.FORCE]['mean'] + std = self.statistics[KEY.FORCE]['std'] + return float((mean**2 + std**2) ** (0.5)) + + @property + def per_atom_energy_std(self): + return self.statistics['per_atom_energy']['std'] + + @property + def avg_num_neigh(self): + return self.statistics['num_neighbor']['mean'] + + @property + def sqrt_avg_num_neigh(self): + return self.avg_num_neigh**0.5 + + @staticmethod + def _read_sevenn_data(filename: str) -> Tuple[List[AtomGraphData], float]: + # backward compatibility + from sevenn.train.dataset import AtomGraphDataset + + dataset = torch.load(filename, map_location='cpu', weights_only=False) + if isinstance(dataset, AtomGraphDataset): + graph_list = [] + for _, graphs in dataset.dataset.items(): # type: ignore + # TODO: transfer label to tag (who gonna need this?) + graph_list.extend(graphs) + return graph_list, dataset.cutoff + else: + raise ValueError(f'Not sevenn_data type: {type(dataset)}') + + @staticmethod + def _read_structure_list( + filename: str, cutoff: float, num_cores: int = 1 + ) -> List[AtomGraphData]: + datadct = dataload.structure_list_reader(filename) + graph_list = [] + for tag, atoms_list in datadct.items(): + tmp = dataload.graph_build(atoms_list, cutoff, num_cores) + graph_list.extend(_tag_graphs(tmp, tag)) + return graph_list + + @staticmethod + def _read_ase_readable( + filename: str, + cutoff: float, + num_cores: int = 1, + tag: str = '', + transfer_info: bool = True, + allow_unlabeled: bool = False, + **ase_kwargs, + ) -> List[AtomGraphData]: + pbc_override = ase_kwargs.pop('pbc', None) + atoms_list = dataload.ase_reader(filename, **ase_kwargs) + for atoms in atoms_list: + if pbc_override is not None: + atoms.pbc = pbc_override + graph_list = dataload.graph_build( + atoms_list, + cutoff, + num_cores, + transfer_info=transfer_info, + allow_unlabeled=allow_unlabeled, + ) + if tag != '': + graph_list = _tag_graphs(graph_list, tag) + return graph_list + + @staticmethod + def _read_graph_dataset( + filename: str, cutoff: float, **kwargs + ) -> List[AtomGraphData]: + meta_f = filename.replace('.pt', '.yaml') + orig_cutoff = cutoff + if not os.path.exists(filename): + raise FileNotFoundError(f'No such file: {filename}') + if not os.path.exists(meta_f): + warnings.warn('No meta info found, beware of cutoff...') + else: + with open(meta_f, 'r') as f: + meta = yaml.safe_load(f) + orig_cutoff = float(meta['cutoff']) + if orig_cutoff != cutoff: + warnings.warn( + f'{filename} has different cutoff length: ' + + f'{cutoff} != {orig_cutoff}' + ) + ds_args: dict[str, Any] = dict({'cutoff': orig_cutoff}) + ds_args.update(pt_to_args(filename)) + ds_args.update(kwargs) + dataset = SevenNetGraphDataset(**ds_args) + # TODO: hard coded. consult with inference.py + glist = [g.fit_dimension() for g in dataset] # type: ignore + for g in glist: + if KEY.STRESS in g: + # (1, 6) is what we want + g[KEY.STRESS] = g[KEY.STRESS].unsqueeze(0) + return glist + + @staticmethod + def _read_dict( + data_dict: dict, + cutoff: float, + num_cores: int = 1, + ): + # logic same as the dataload dict_reader, but handles graphs + data_dict_cp = deepcopy(data_dict) + file_list = data_dict_cp.get('file_list', None) + if file_list is None: + raise KeyError('file_list is not found') + + data_weight_default = { + 'energy': 1.0, + 'force': 1.0, + 'stress': 1.0, + } + data_weight = data_weight_default.copy() + data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {})) + + graph_list = [] + for file_dct in file_list: + ftype = file_dct.pop('data_format', 'ase') + if ftype != 'graph': + continue + graph_list.extend( + SevenNetGraphDataset._read_graph_dataset( + file_dct.get('file'), cutoff=cutoff + ) + ) + for graph in graph_list: + if KEY.INFO not in graph: + graph[KEY.INFO] = {} + graph[KEY.INFO].update(data_dict_cp) + graph[KEY.INFO].update({KEY.DATA_WEIGHT: data_weight}) + + atoms_list = dataload.dict_reader(data_dict) + graph_list.extend(dataload.graph_build(atoms_list, cutoff, num_cores)) + return graph_list + + @staticmethod + def file_to_graph_list( + file: Union[str, dict], cutoff: float, num_cores: int = 1, **kwargs + ) -> List[AtomGraphData]: + """ + kwargs: if file is ase readable, passed to ase.io.read + """ + if isinstance(file, str) and not os.path.isfile(file): + raise ValueError(f'No such file: {file}') + graph_list: List[AtomGraphData] + if isinstance(file, dict): + graph_list = SevenNetGraphDataset._read_dict( + file, cutoff, num_cores, **kwargs + ) + elif file.endswith('.pt'): + graph_list = SevenNetGraphDataset._read_graph_dataset(file, cutoff) + elif file.endswith('.sevenn_data'): + graph_list, cutoff_other = SevenNetGraphDataset._read_sevenn_data(file) + if cutoff_other != cutoff: + warnings.warn(f'Given {file} has different {cutoff_other}!') + cutoff = cutoff_other + elif 'structure_list' in file: + graph_list = SevenNetGraphDataset._read_structure_list( + file, cutoff, num_cores + ) + else: + graph_list = SevenNetGraphDataset._read_ase_readable( + file, cutoff, num_cores, **kwargs + ) + return graph_list + + +def from_single_path( + path: Union[str, List], override_data_weight: bool = True, **dataset_kwargs +) -> Union[SevenNetGraphDataset, None]: + """ + Convenient routine for loading a single .pt dataset. + If given dict and it has data_weight, apply it using transform + """ + data_weight = {'energy': 1.0, 'force': 1.0, 'stress': 1.0} + spath = _extract_single_path(path) + if spath is None: + return None + + if isinstance(spath, str): + if not spath.endswith('.pt'): + return None + dataset_kwargs.update(pt_to_args(spath)) + elif isinstance(spath, dict): + file = _extract_file_from_dict(spath) + if file is None or not file.endswith('.pt'): + return None + dataset_kwargs.update(pt_to_args(file)) + data_weight_user = spath.get(KEY.DATA_WEIGHT, None) + if data_weight_user is not None: + data_weight.update(data_weight_user) + else: + return None + + if override_data_weight: + dataset_kwargs['transform'] = _chain_data_weight_override( + dataset_kwargs.get('transform'), data_weight + ) + + return SevenNetGraphDataset(**dataset_kwargs) + + +def _extract_single_path(path: Union[str, List]) -> Union[str, dict, None]: + """Extracts a single path from the input, + ensuring it's either a single string or list with one item.""" + if isinstance(path, list): + return path[0] if len(path) == 1 else None + return path if isinstance(path, (str, dict)) else None + + +def _extract_file_from_dict(path_dict: dict) -> Union[str, None]: + """Extracts a single file path from the dictionary, ensuring it's valid.""" + file_list = path_dict.get('file_list', None) + if file_list and len(file_list) == 1: + file = file_list[0].get('file', None) + return file if isinstance(file, str) else None + return None + + +def _chain_data_weight_override(transform_func, data_weight): + """Creates a transform function that overrides the data weight.""" + + def chained_transform(graph): + graph = transform_func(graph) if transform_func is not None else graph + graph[KEY.INFO].pop(KEY.DATA_WEIGHT, None) + graph[KEY.DATA_WEIGHT] = data_weight + return graph + + return chained_transform + + +# script, return dict of SevenNetGraphDataset +def from_config( + config: Dict[str, Any], + working_dir: str = os.getcwd(), + dataset_keys: Optional[List[str]] = None, +): + log = Logger() + if dataset_keys is None: + dataset_keys = [] + for k in config: + if k.startswith('load_') and k.endswith('_path'): + dataset_keys.append(k) + + if KEY.LOAD_TRAINSET not in dataset_keys: + raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') + + # initialize arguments for loading dataset + dataset_args = { + 'cutoff': config[KEY.CUTOFF], + 'root': working_dir, + 'process_num_cores': config.get(KEY.PREPROCESS_NUM_CORES, 1), + 'use_data_weight': config.get(KEY.USE_WEIGHT, False), + **config.get(KEY.DATA_FORMAT_ARGS, {}), + } + + datasets = {} + for dk in dataset_keys: + if not (paths := config[dk]): + continue + if isinstance(paths, str): + paths = [paths] + name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) + if (dataset := from_single_path(paths, **dataset_args)) is not None: + datasets[name] = dataset + else: + dataset_args.update({'files': paths, 'processed_name': name}) + dataset_path = os.path.join(working_dir, 'sevenn_data', f'{name}.pt') + if os.path.exists(dataset_path) and 'force_reload' not in dataset_args: + log.writeline( + f'Dataset will be loaded from {dataset_path}, without update. ' + + 'If you have changed your files to read, put force_reload=True' + + ' under the data_format_args key' + ) + datasets[name] = SevenNetGraphDataset(**dataset_args) + + train_set = datasets['trainset'] + + chem_species = set(train_set.species) + # print statistics of each dataset + for name, dataset in datasets.items(): + log.bar() + log.writeline(f'{name} distribution:') + log.statistic_write(dataset.statistics) + log.format_k_v('# structures (graph)', len(dataset), write=True) + + chem_species.update(dataset.species) + log.bar() + + # initialize known species from dataset if 'auto' + # sorted to alphabetical order (which is same as before) + chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] + if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py + log.writeline('Known species are obtained from the dataset') + config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) + + # retrieve shift, scale, conv_denominaotrs from user input (keyword) + init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] + for k in init_from_stats: + input = config[k] # statistic key or numbers + # If it is not 'str', 1: It is 'continue' training + # 2: User manually inserted numbers + if isinstance(input, str) and hasattr(train_set, input): + var = getattr(train_set, input) + config.update({k: var}) + log.writeline(f'{k} is obtained from statistics') + elif isinstance(input, str) and not hasattr(train_set, input): + raise NotImplementedError(input) + + if 'validset' not in datasets and config.get(KEY.RATIO, 0.0) > 0.0: + log.writeline('Use validation set as random split from the training set') + log.writeline( + 'Note that statistics, shift, scale, and conv_denominator are ' + + 'computed before random split.\n If you want these after random ' + + 'split, please preprocess dataset and set it as load_trainset_path ' + + 'and load_validset_path explicitly.' + ) + + ratio = float(config[KEY.RATIO]) + train, valid = torch.utils.data.random_split( + datasets['trainset'], (1.0 - ratio, ratio) + ) + datasets['trainset'] = train + datasets['validset'] = valid + + return datasets diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/loss.py b/mace-bench/3rdparty/SevenNet/sevenn/train/loss.py index a223c26894d5e587aaa085ee153244b90f0c9a5e..7aae162c916480b31d7347e8e141fcc67e3c21db 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/loss.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/loss.py @@ -1,223 +1,223 @@ -from typing import Any, Callable, Dict, Optional, Tuple - -import torch - -import sevenn._keys as KEY - - -class LossDefinition: - """ - Base class for loss definition - weights are defined in outside of the class - """ - - def __init__( - self, - name: str, - unit: Optional[str] = None, - criterion: Optional[Callable] = None, - ref_key: Optional[str] = None, - pred_key: Optional[str] = None, - use_weight: bool = False, - ignore_unlabeled: bool = True, - ): - self.name = name - self.unit = unit - self.criterion = criterion - self.ref_key = ref_key - self.pred_key = pred_key - self.use_weight = use_weight - self.ignore_unlabeled = ignore_unlabeled - - def __repr__(self): - return self.name - - def assign_criteria(self, criterion: Callable): - if self.criterion is not None: - raise ValueError('Loss uses its own criterion.') - self.criterion = criterion - - def _preprocess( - self, batch_data: Dict[str, Any], model: Optional[Callable] = None - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - if self.pred_key is None or self.ref_key is None: - raise NotImplementedError('LossDefinition is not implemented.') - pred = torch.reshape(batch_data[self.pred_key], (-1,)) - ref = torch.reshape(batch_data[self.ref_key], (-1,)) - return pred, ref, None - - def _ignore_unlabeled(self, pred, ref, data_weights=None): - unlabeled = torch.isnan(ref) - pred = pred[~unlabeled] - ref = ref[~unlabeled] - if data_weights is not None: - data_weights = data_weights[~unlabeled] - return pred, ref, data_weights - - def get_loss(self, batch_data: Dict[str, Any], model: Optional[Callable] = None): - """ - Function that return scalar - """ - if self.criterion is None: - raise NotImplementedError('LossDefinition has no criterion.') - pred, ref, w_tensor = self._preprocess(batch_data, model) - - if self.ignore_unlabeled: - pred, ref, w_tensor = self._ignore_unlabeled(pred, ref, w_tensor) - - if len(pred) == 0: - assert self.ref_key is not None - return torch.zeros(1, device=batch_data[self.ref_key].device) - - loss = self.criterion(pred, ref) - if self.use_weight: - loss = torch.mean(loss * w_tensor) - return loss - - -class PerAtomEnergyLoss(LossDefinition): - """ - Loss for per atom energy - """ - - def __init__( - self, - name: str = 'Energy', - unit: str = 'eV/atom', - criterion: Optional[Callable] = None, - ref_key: str = KEY.ENERGY, - pred_key: str = KEY.PRED_TOTAL_ENERGY, - **kwargs, - ): - super().__init__( - name=name, - unit=unit, - criterion=criterion, - ref_key=ref_key, - pred_key=pred_key, - **kwargs, - ) - - def _preprocess( - self, batch_data: Dict[str, Any], model: Optional[Callable] = None - ): - num_atoms = batch_data[KEY.NUM_ATOMS] - assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) - pred = batch_data[self.pred_key] / num_atoms - ref = batch_data[self.ref_key] / num_atoms - w_tensor = None - - if self.use_weight: - loss_type = self.name.lower() - weight = batch_data[KEY.DATA_WEIGHT][loss_type] - w_tensor = torch.repeat_interleave(weight, 1) - - return pred, ref, w_tensor - - -class ForceLoss(LossDefinition): - """ - Loss for force - """ - - def __init__( - self, - name: str = 'Force', - unit: str = 'eV/A', - criterion: Optional[Callable] = None, - ref_key: str = KEY.FORCE, - pred_key: str = KEY.PRED_FORCE, - **kwargs, - ): - super().__init__( - name=name, - unit=unit, - criterion=criterion, - ref_key=ref_key, - pred_key=pred_key, - **kwargs, - ) - - def _preprocess( - self, batch_data: Dict[str, Any], model: Optional[Callable] = None - ): - assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) - pred = torch.reshape(batch_data[self.pred_key], (-1,)) - ref = torch.reshape(batch_data[self.ref_key], (-1,)) - w_tensor = None - - if self.use_weight: - loss_type = self.name.lower() - weight = batch_data[KEY.DATA_WEIGHT][loss_type] - w_tensor = weight[batch_data[KEY.BATCH]] - w_tensor = torch.repeat_interleave(w_tensor, 3) - - return pred, ref, w_tensor - - -class StressLoss(LossDefinition): - """ - Loss for stress this is kbar - """ - - def __init__( - self, - name: str = 'Stress', - unit: str = 'kbar', - criterion: Optional[Callable] = None, - ref_key: str = KEY.STRESS, - pred_key: str = KEY.PRED_STRESS, - **kwargs, - ): - super().__init__( - name=name, - unit=unit, - criterion=criterion, - ref_key=ref_key, - pred_key=pred_key, - **kwargs, - ) - self.TO_KB = 1602.1766208 # eV/A^3 to kbar - - def _preprocess( - self, batch_data: Dict[str, Any], model: Optional[Callable] = None - ): - assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) - - pred = torch.reshape(batch_data[self.pred_key] * self.TO_KB, (-1,)) - ref = torch.reshape(batch_data[self.ref_key] * self.TO_KB, (-1,)) - w_tensor = None - - if self.use_weight: - loss_type = self.name.lower() - weight = batch_data[KEY.DATA_WEIGHT][loss_type] - w_tensor = torch.repeat_interleave(weight, 6) - - return pred, ref, w_tensor - - -def get_loss_functions_from_config(config: Dict[str, Any]): - from sevenn.train.optim import loss_dict - - loss_functions = [] # list of tuples (loss_definition, weight) - - loss = loss_dict[config[KEY.LOSS].lower()] - loss_param = config.get(KEY.LOSS_PARAM, {}) - - use_weight = config.get(KEY.USE_WEIGHT, False) - if use_weight: - loss_param['reduction'] = 'none' - criterion = loss(**loss_param) - - commons = {'use_weight': use_weight} - - loss_functions.append((PerAtomEnergyLoss(**commons), 1.0)) - loss_functions.append((ForceLoss(**commons), config[KEY.FORCE_WEIGHT])) - if config[KEY.IS_TRAIN_STRESS]: - loss_functions.append((StressLoss(**commons), config[KEY.STRESS_WEIGHT])) - - for loss_function, _ in loss_functions: # why do these? - if loss_function.criterion is None: - loss_function.assign_criteria(criterion) - - return loss_functions +from typing import Any, Callable, Dict, Optional, Tuple + +import torch + +import sevenn._keys as KEY + + +class LossDefinition: + """ + Base class for loss definition + weights are defined in outside of the class + """ + + def __init__( + self, + name: str, + unit: Optional[str] = None, + criterion: Optional[Callable] = None, + ref_key: Optional[str] = None, + pred_key: Optional[str] = None, + use_weight: bool = False, + ignore_unlabeled: bool = True, + ): + self.name = name + self.unit = unit + self.criterion = criterion + self.ref_key = ref_key + self.pred_key = pred_key + self.use_weight = use_weight + self.ignore_unlabeled = ignore_unlabeled + + def __repr__(self): + return self.name + + def assign_criteria(self, criterion: Callable): + if self.criterion is not None: + raise ValueError('Loss uses its own criterion.') + self.criterion = criterion + + def _preprocess( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + if self.pred_key is None or self.ref_key is None: + raise NotImplementedError('LossDefinition is not implemented.') + pred = torch.reshape(batch_data[self.pred_key], (-1,)) + ref = torch.reshape(batch_data[self.ref_key], (-1,)) + return pred, ref, None + + def _ignore_unlabeled(self, pred, ref, data_weights=None): + unlabeled = torch.isnan(ref) + pred = pred[~unlabeled] + ref = ref[~unlabeled] + if data_weights is not None: + data_weights = data_weights[~unlabeled] + return pred, ref, data_weights + + def get_loss(self, batch_data: Dict[str, Any], model: Optional[Callable] = None): + """ + Function that return scalar + """ + if self.criterion is None: + raise NotImplementedError('LossDefinition has no criterion.') + pred, ref, w_tensor = self._preprocess(batch_data, model) + + if self.ignore_unlabeled: + pred, ref, w_tensor = self._ignore_unlabeled(pred, ref, w_tensor) + + if len(pred) == 0: + assert self.ref_key is not None + return torch.zeros(1, device=batch_data[self.ref_key].device) + + loss = self.criterion(pred, ref) + if self.use_weight: + loss = torch.mean(loss * w_tensor) + return loss + + +class PerAtomEnergyLoss(LossDefinition): + """ + Loss for per atom energy + """ + + def __init__( + self, + name: str = 'Energy', + unit: str = 'eV/atom', + criterion: Optional[Callable] = None, + ref_key: str = KEY.ENERGY, + pred_key: str = KEY.PRED_TOTAL_ENERGY, + **kwargs, + ): + super().__init__( + name=name, + unit=unit, + criterion=criterion, + ref_key=ref_key, + pred_key=pred_key, + **kwargs, + ) + + def _preprocess( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ): + num_atoms = batch_data[KEY.NUM_ATOMS] + assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) + pred = batch_data[self.pred_key] / num_atoms + ref = batch_data[self.ref_key] / num_atoms + w_tensor = None + + if self.use_weight: + loss_type = self.name.lower() + weight = batch_data[KEY.DATA_WEIGHT][loss_type] + w_tensor = torch.repeat_interleave(weight, 1) + + return pred, ref, w_tensor + + +class ForceLoss(LossDefinition): + """ + Loss for force + """ + + def __init__( + self, + name: str = 'Force', + unit: str = 'eV/A', + criterion: Optional[Callable] = None, + ref_key: str = KEY.FORCE, + pred_key: str = KEY.PRED_FORCE, + **kwargs, + ): + super().__init__( + name=name, + unit=unit, + criterion=criterion, + ref_key=ref_key, + pred_key=pred_key, + **kwargs, + ) + + def _preprocess( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ): + assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) + pred = torch.reshape(batch_data[self.pred_key], (-1,)) + ref = torch.reshape(batch_data[self.ref_key], (-1,)) + w_tensor = None + + if self.use_weight: + loss_type = self.name.lower() + weight = batch_data[KEY.DATA_WEIGHT][loss_type] + w_tensor = weight[batch_data[KEY.BATCH]] + w_tensor = torch.repeat_interleave(w_tensor, 3) + + return pred, ref, w_tensor + + +class StressLoss(LossDefinition): + """ + Loss for stress this is kbar + """ + + def __init__( + self, + name: str = 'Stress', + unit: str = 'kbar', + criterion: Optional[Callable] = None, + ref_key: str = KEY.STRESS, + pred_key: str = KEY.PRED_STRESS, + **kwargs, + ): + super().__init__( + name=name, + unit=unit, + criterion=criterion, + ref_key=ref_key, + pred_key=pred_key, + **kwargs, + ) + self.TO_KB = 1602.1766208 # eV/A^3 to kbar + + def _preprocess( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ): + assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) + + pred = torch.reshape(batch_data[self.pred_key] * self.TO_KB, (-1,)) + ref = torch.reshape(batch_data[self.ref_key] * self.TO_KB, (-1,)) + w_tensor = None + + if self.use_weight: + loss_type = self.name.lower() + weight = batch_data[KEY.DATA_WEIGHT][loss_type] + w_tensor = torch.repeat_interleave(weight, 6) + + return pred, ref, w_tensor + + +def get_loss_functions_from_config(config: Dict[str, Any]): + from sevenn.train.optim import loss_dict + + loss_functions = [] # list of tuples (loss_definition, weight) + + loss = loss_dict[config[KEY.LOSS].lower()] + loss_param = config.get(KEY.LOSS_PARAM, {}) + + use_weight = config.get(KEY.USE_WEIGHT, False) + if use_weight: + loss_param['reduction'] = 'none' + criterion = loss(**loss_param) + + commons = {'use_weight': use_weight} + + loss_functions.append((PerAtomEnergyLoss(**commons), 1.0)) + loss_functions.append((ForceLoss(**commons), config[KEY.FORCE_WEIGHT])) + if config[KEY.IS_TRAIN_STRESS]: + loss_functions.append((StressLoss(**commons), config[KEY.STRESS_WEIGHT])) + + for loss_function, _ in loss_functions: # why do these? + if loss_function.criterion is None: + loss_function.assign_criteria(criterion) + + return loss_functions diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/modal_dataset.py b/mace-bench/3rdparty/SevenNet/sevenn/train/modal_dataset.py index 8b749d2b2725c43c6d5be5b9a97349c5be1779a5..c5bcd91747a62f7d8e369b92c48d7363cef9fa25 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/modal_dataset.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/modal_dataset.py @@ -1,365 +1,365 @@ -import bisect -import os -from copy import deepcopy -from typing import Any, Dict, List, Optional - -import numpy as np -from torch.utils.data import ConcatDataset, Dataset - -import sevenn._keys as KEY -import sevenn.util as util -from sevenn.logger import Logger - - -def _arrange_paths_by_modality(paths: List[dict]): - modal_dct = {} - for path in paths: - if isinstance(path, dict): - if KEY.DATA_MODALITY not in path: - raise ValueError(f'{KEY.DATA_MODALITY} is missing') - modal = path.pop(KEY.DATA_MODALITY) - else: - raise TypeError(f'{path} is not dict or str') - if modal not in modal_dct: - modal_dct[modal] = [] - modal_dct[modal].append(path) - return modal_dct - - -def combined_variance( - means: np.ndarray, stds: np.ndarray, sample_sizes: np.ndarray, ddof: int = 0 -) -> float: - """ - Calculate the combined variance for multiple datasets. - """ - assert len(means) == len(stds) and len(stds) == len(sample_sizes) - # Total number of samples - total_samples = np.sum(sample_sizes) - - # Combined mean - combined_mean = np.sum(sample_sizes * means) / total_samples - - # Combined variance calculation - variance_terms = (sample_sizes - ddof) * (stds**2) - mean_diff_terms = sample_sizes * ((means - combined_mean) ** 2) - combined_variance = (np.sum(variance_terms) + np.sum(mean_diff_terms)) / ( - total_samples - ddof - ) - - return combined_variance - - -def combined_std( - means: List[float], stds: List[float], sample_sizes: List[int] -) -> float: - """ - Calculate the combined std for multiple datasets. - """ - assert len(means) == len(stds) and len(stds) == len(sample_sizes) - means_arr = np.array(means) - stds_arr = np.array(stds) - sample_sizes_arr = np.array(sample_sizes) - - cv = combined_variance(means_arr, stds_arr, sample_sizes_arr) - return np.sqrt(cv) - - -def combined_mean(means: List[float], sample_sizes: List[int]) -> float: - """ - Calculate the combined mean for multiple datasets. - """ - assert len(means) == len(sample_sizes) - means_arr = np.array(means) - sample_sizes_arr = np.array(sample_sizes) - - return np.sum(sample_sizes_arr * means_arr) / np.sum(sample_sizes_arr) - - -def combined_rms( - means: List[float], stds: List[float], sample_sizes: List[int] -) -> float: - """ - Calculate the combined RMS for multiple datasets. - """ - assert len(means) == len(stds) and len(stds) == len(sample_sizes) - means_arr = np.array(means) - stds_arr = np.array(stds) - sample_sizes_arr = np.array(sample_sizes) - - cm = combined_mean(means, sample_sizes) - cv = combined_variance(means_arr, stds_arr, sample_sizes_arr) - - # Combined RMS calculation - return np.sqrt(cm**2 + cv) - - -class SevenNetMultiModalDataset(ConcatDataset): - def __init__( - self, - modal_dataset_dict: Dict[str, Dataset], - ): - datasets = [] - modals = [] - for modal, dataset in modal_dataset_dict.items(): - modals.append(modal) - datasets.append(dataset) - self.modals = modals - super().__init__(datasets) - - def __getitem__(self, idx): - graph = super().__getitem__(idx) - dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) - modality = self.modals[dataset_idx] - graph[KEY.DATA_MODALITY] = modality - return graph - - def _modal_wise_property(self, attribute_name: str): - dct = {} - for modal, dataset in zip(self.modals, self.datasets): - try: - if hasattr(dataset, attribute_name): - dct[modal] = getattr(dataset, attribute_name) - except AttributeError: - dct[modal] = None - return dct - - @property - def dataset_dict(self): - arr = {} - for idx, modality in enumerate(self.modals): - arr[modality] = self.datasets[idx] - return arr - - @property - def species(self): - dct = self._modal_wise_property('species') - tot = set() - for sp in dct.values(): - tot.update(sp) - dct['total'] = list(tot) - return dct - - @property - def natoms(self): - return self._modal_wise_property('natoms') - - @property - def per_atom_energy_mean(self): - dct = self._modal_wise_property('per_atom_energy_mean') - try: - means = [] - sample_sizes = [] - for modality, mean in dct.items(): - means.append(mean) - sample_sizes.append( - self.statistics[modality][KEY.PER_ATOM_ENERGY]['count'] - ) - cm = combined_mean(means, sample_sizes) - dct['total'] = cm - except KeyError: - pass - return dct - - @property - def elemwise_reference_energies(self): - # total is not supported (it is expensive and complex, but useless) - return self._modal_wise_property('elemwise_reference_energies') - - @property - def force_rms(self): - dct = self._modal_wise_property('force_rms') - try: - means = [] - sample_sizes = [] - stds = [] - for modality in dct: - means.append(self.statistics[modality][KEY.FORCE]['mean']) - sample_sizes.append(self.statistics[modality][KEY.FORCE]['count']) - stds.append(self.statistics[modality][KEY.FORCE]['std']) - cm = combined_rms(means, stds, sample_sizes) - dct['total'] = cm - except KeyError: - pass - return dct - - @property - def per_atom_energy_std(self): - dct = self._modal_wise_property('per_atom_energy_std') - try: - means = [] - sample_sizes = [] - stds = [] - for modality in dct: - means.append(self.statistics[modality][KEY.PER_ATOM_ENERGY]['mean']) - sample_sizes.append( - self.statistics[modality][KEY.PER_ATOM_ENERGY]['count'] - ) - stds.append(self.statistics[modality][KEY.PER_ATOM_ENERGY]['std']) - cm = combined_std(means, stds, sample_sizes) - dct['total'] = cm - except KeyError: - pass - return dct - - @property - def avg_num_neigh(self): - dct = self._modal_wise_property('avg_num_neigh') - try: - means = [] - sample_sizes = [] - for modality, mean in dct.items(): - means.append(mean) - sample_sizes.append( - self.statistics[modality]['num_neighbor']['count'] - ) - cm = combined_mean(means, sample_sizes) - dct['total'] = cm - except KeyError: - pass - return dct - - @property - def sqrt_avg_num_neigh(self): - avg_nn = self.avg_num_neigh - return {k: v**0.5 for k, v in avg_nn.items()} - - @property - def statistics(self): - return self._modal_wise_property('statistics') - - @staticmethod - def as_graph_dataset( - paths: List[dict], - **graph_dataset_kwargs, - ): - import sevenn.train.graph_dataset as gd - - modal_paths = _arrange_paths_by_modality(paths) - dataset_dct = {} - for modality, paths in modal_paths.items(): - kwargs = deepcopy(graph_dataset_kwargs) - if (dataset := gd.from_single_path(paths, **kwargs)) is None: - pname = kwargs.pop('processed_name', 'graph').replace('.pt', '') - dataset = gd.SevenNetGraphDataset( - files=paths, - processed_name=f'{pname}_{modality}.pt', - **kwargs, - ) - dataset_dct[modality] = dataset - return SevenNetMultiModalDataset(dataset_dct) - - -def from_config( - config: Dict[str, Any], - working_dir: str = os.getcwd(), - dataset_keys: Optional[List[str]] = None, -): - log = Logger() - if dataset_keys is None: - dataset_keys = [ - k for k in config if (k.startswith('load_') and k.endswith('_path')) - ] - - if KEY.LOAD_TRAINSET not in dataset_keys: - raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') - - dataset_args = { - 'cutoff': config[KEY.CUTOFF], - 'root': working_dir, - 'process_num_cores': config.get(KEY.PREPROCESS_NUM_CORES, 1), - 'use_data_weight': config.get(KEY.USE_WEIGHT, False), - **config[KEY.DATA_FORMAT_ARGS], - } - - datasets = {} - for dk in dataset_keys: - if not (paths := config[dk]): - continue - if isinstance(paths, str): - paths = [paths] - name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) - dataset_args.update({'processed_name': name}) - datasets[name] = SevenNetMultiModalDataset.as_graph_dataset( - paths, # type: ignore - **dataset_args, - ) - - train_set = datasets['trainset'] - - modals_dataset = set() - chem_species = set() - # print statistics of each dataset - for name, dataset in datasets.items(): - for idx, modality in enumerate(dataset.modals): - log.bar() - log.writeline(f'{name} - {modality} distribution:') - log.statistic_write(dataset.statistics[modality]) - log.format_k_v( - '# structures (graph)', len(dataset.datasets[idx]), write=True - ) - modals_dataset.update([modality]) - chem_species.update(dataset.species['total']) - log.bar() - - if (modal_map := config.get(KEY.MODAL_MAP, None)) is None: - modals = sorted(list(modals_dataset)) - modal_map = {modal: i for i, modal in enumerate(modals)} - config[KEY.MODAL_MAP] = modal_map - - modals = list(modal_map.keys()) - if not modals_dataset.issubset(modal_map): - raise ValueError( - f'Found modalities in datasets: {modals_dataset} are not subset of' - + f' {modals}. Use sevenn_cp tool to append/assign modality' - ) - - log.writeline(f'Modalities of this model: {modals}') - - config[KEY.NUM_MODALITIES] = len(modal_map) - - # initialize known species from dataset if 'auto' - # sorted to alphabetical order (which is same as before) - chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] - if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py - log.writeline('Known species are obtained from the dataset') - config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) - - # retrieve shift, scale, conv_denominaotrs from user input (keyword) - init_from_stats_candid = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] - init_from_stats = [ - k for k in init_from_stats_candid if isinstance(config[k], str) - ] - - for k in init_from_stats: - input = config[k] - if not hasattr(train_set, input): - raise NotImplementedError(input) - modal_stat = getattr(train_set, input) - try: - if k == KEY.CONV_DENOMINATOR and 'total' in modal_stat: - # conv_denominator is not modal-wise - var = modal_stat['total'] - elif k == KEY.SHIFT and config[KEY.USE_MODAL_WISE_SHIFT]: - modal_stat.pop('total', None) - var = modal_stat - elif k == KEY.SHIFT and not config[KEY.USE_MODAL_WISE_SHIFT]: - var = modal_stat['total'] - elif k == KEY.SCALE and config[KEY.USE_MODAL_WISE_SCALE]: - modal_stat.pop('total', None) - var = modal_stat - elif k == KEY.SCALE and not config[KEY.USE_MODAL_WISE_SCALE]: - var = modal_stat['total'] - else: - raise NotImplementedError(f'Failed to init {k} from statistics') - except KeyError as e: - if e.args[0] == 'total': - raise NotImplementedError( - f'{k}: {input} does not support total statistics. ' - + f'Set use_modal_wise_{k} True or specify numbers manually' - ) - else: - raise e - config.update({k: var}) - log.writeline(f'{k} is obtained from statistics') - - return datasets +import bisect +import os +from copy import deepcopy +from typing import Any, Dict, List, Optional + +import numpy as np +from torch.utils.data import ConcatDataset, Dataset + +import sevenn._keys as KEY +import sevenn.util as util +from sevenn.logger import Logger + + +def _arrange_paths_by_modality(paths: List[dict]): + modal_dct = {} + for path in paths: + if isinstance(path, dict): + if KEY.DATA_MODALITY not in path: + raise ValueError(f'{KEY.DATA_MODALITY} is missing') + modal = path.pop(KEY.DATA_MODALITY) + else: + raise TypeError(f'{path} is not dict or str') + if modal not in modal_dct: + modal_dct[modal] = [] + modal_dct[modal].append(path) + return modal_dct + + +def combined_variance( + means: np.ndarray, stds: np.ndarray, sample_sizes: np.ndarray, ddof: int = 0 +) -> float: + """ + Calculate the combined variance for multiple datasets. + """ + assert len(means) == len(stds) and len(stds) == len(sample_sizes) + # Total number of samples + total_samples = np.sum(sample_sizes) + + # Combined mean + combined_mean = np.sum(sample_sizes * means) / total_samples + + # Combined variance calculation + variance_terms = (sample_sizes - ddof) * (stds**2) + mean_diff_terms = sample_sizes * ((means - combined_mean) ** 2) + combined_variance = (np.sum(variance_terms) + np.sum(mean_diff_terms)) / ( + total_samples - ddof + ) + + return combined_variance + + +def combined_std( + means: List[float], stds: List[float], sample_sizes: List[int] +) -> float: + """ + Calculate the combined std for multiple datasets. + """ + assert len(means) == len(stds) and len(stds) == len(sample_sizes) + means_arr = np.array(means) + stds_arr = np.array(stds) + sample_sizes_arr = np.array(sample_sizes) + + cv = combined_variance(means_arr, stds_arr, sample_sizes_arr) + return np.sqrt(cv) + + +def combined_mean(means: List[float], sample_sizes: List[int]) -> float: + """ + Calculate the combined mean for multiple datasets. + """ + assert len(means) == len(sample_sizes) + means_arr = np.array(means) + sample_sizes_arr = np.array(sample_sizes) + + return np.sum(sample_sizes_arr * means_arr) / np.sum(sample_sizes_arr) + + +def combined_rms( + means: List[float], stds: List[float], sample_sizes: List[int] +) -> float: + """ + Calculate the combined RMS for multiple datasets. + """ + assert len(means) == len(stds) and len(stds) == len(sample_sizes) + means_arr = np.array(means) + stds_arr = np.array(stds) + sample_sizes_arr = np.array(sample_sizes) + + cm = combined_mean(means, sample_sizes) + cv = combined_variance(means_arr, stds_arr, sample_sizes_arr) + + # Combined RMS calculation + return np.sqrt(cm**2 + cv) + + +class SevenNetMultiModalDataset(ConcatDataset): + def __init__( + self, + modal_dataset_dict: Dict[str, Dataset], + ): + datasets = [] + modals = [] + for modal, dataset in modal_dataset_dict.items(): + modals.append(modal) + datasets.append(dataset) + self.modals = modals + super().__init__(datasets) + + def __getitem__(self, idx): + graph = super().__getitem__(idx) + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + modality = self.modals[dataset_idx] + graph[KEY.DATA_MODALITY] = modality + return graph + + def _modal_wise_property(self, attribute_name: str): + dct = {} + for modal, dataset in zip(self.modals, self.datasets): + try: + if hasattr(dataset, attribute_name): + dct[modal] = getattr(dataset, attribute_name) + except AttributeError: + dct[modal] = None + return dct + + @property + def dataset_dict(self): + arr = {} + for idx, modality in enumerate(self.modals): + arr[modality] = self.datasets[idx] + return arr + + @property + def species(self): + dct = self._modal_wise_property('species') + tot = set() + for sp in dct.values(): + tot.update(sp) + dct['total'] = list(tot) + return dct + + @property + def natoms(self): + return self._modal_wise_property('natoms') + + @property + def per_atom_energy_mean(self): + dct = self._modal_wise_property('per_atom_energy_mean') + try: + means = [] + sample_sizes = [] + for modality, mean in dct.items(): + means.append(mean) + sample_sizes.append( + self.statistics[modality][KEY.PER_ATOM_ENERGY]['count'] + ) + cm = combined_mean(means, sample_sizes) + dct['total'] = cm + except KeyError: + pass + return dct + + @property + def elemwise_reference_energies(self): + # total is not supported (it is expensive and complex, but useless) + return self._modal_wise_property('elemwise_reference_energies') + + @property + def force_rms(self): + dct = self._modal_wise_property('force_rms') + try: + means = [] + sample_sizes = [] + stds = [] + for modality in dct: + means.append(self.statistics[modality][KEY.FORCE]['mean']) + sample_sizes.append(self.statistics[modality][KEY.FORCE]['count']) + stds.append(self.statistics[modality][KEY.FORCE]['std']) + cm = combined_rms(means, stds, sample_sizes) + dct['total'] = cm + except KeyError: + pass + return dct + + @property + def per_atom_energy_std(self): + dct = self._modal_wise_property('per_atom_energy_std') + try: + means = [] + sample_sizes = [] + stds = [] + for modality in dct: + means.append(self.statistics[modality][KEY.PER_ATOM_ENERGY]['mean']) + sample_sizes.append( + self.statistics[modality][KEY.PER_ATOM_ENERGY]['count'] + ) + stds.append(self.statistics[modality][KEY.PER_ATOM_ENERGY]['std']) + cm = combined_std(means, stds, sample_sizes) + dct['total'] = cm + except KeyError: + pass + return dct + + @property + def avg_num_neigh(self): + dct = self._modal_wise_property('avg_num_neigh') + try: + means = [] + sample_sizes = [] + for modality, mean in dct.items(): + means.append(mean) + sample_sizes.append( + self.statistics[modality]['num_neighbor']['count'] + ) + cm = combined_mean(means, sample_sizes) + dct['total'] = cm + except KeyError: + pass + return dct + + @property + def sqrt_avg_num_neigh(self): + avg_nn = self.avg_num_neigh + return {k: v**0.5 for k, v in avg_nn.items()} + + @property + def statistics(self): + return self._modal_wise_property('statistics') + + @staticmethod + def as_graph_dataset( + paths: List[dict], + **graph_dataset_kwargs, + ): + import sevenn.train.graph_dataset as gd + + modal_paths = _arrange_paths_by_modality(paths) + dataset_dct = {} + for modality, paths in modal_paths.items(): + kwargs = deepcopy(graph_dataset_kwargs) + if (dataset := gd.from_single_path(paths, **kwargs)) is None: + pname = kwargs.pop('processed_name', 'graph').replace('.pt', '') + dataset = gd.SevenNetGraphDataset( + files=paths, + processed_name=f'{pname}_{modality}.pt', + **kwargs, + ) + dataset_dct[modality] = dataset + return SevenNetMultiModalDataset(dataset_dct) + + +def from_config( + config: Dict[str, Any], + working_dir: str = os.getcwd(), + dataset_keys: Optional[List[str]] = None, +): + log = Logger() + if dataset_keys is None: + dataset_keys = [ + k for k in config if (k.startswith('load_') and k.endswith('_path')) + ] + + if KEY.LOAD_TRAINSET not in dataset_keys: + raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') + + dataset_args = { + 'cutoff': config[KEY.CUTOFF], + 'root': working_dir, + 'process_num_cores': config.get(KEY.PREPROCESS_NUM_CORES, 1), + 'use_data_weight': config.get(KEY.USE_WEIGHT, False), + **config[KEY.DATA_FORMAT_ARGS], + } + + datasets = {} + for dk in dataset_keys: + if not (paths := config[dk]): + continue + if isinstance(paths, str): + paths = [paths] + name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) + dataset_args.update({'processed_name': name}) + datasets[name] = SevenNetMultiModalDataset.as_graph_dataset( + paths, # type: ignore + **dataset_args, + ) + + train_set = datasets['trainset'] + + modals_dataset = set() + chem_species = set() + # print statistics of each dataset + for name, dataset in datasets.items(): + for idx, modality in enumerate(dataset.modals): + log.bar() + log.writeline(f'{name} - {modality} distribution:') + log.statistic_write(dataset.statistics[modality]) + log.format_k_v( + '# structures (graph)', len(dataset.datasets[idx]), write=True + ) + modals_dataset.update([modality]) + chem_species.update(dataset.species['total']) + log.bar() + + if (modal_map := config.get(KEY.MODAL_MAP, None)) is None: + modals = sorted(list(modals_dataset)) + modal_map = {modal: i for i, modal in enumerate(modals)} + config[KEY.MODAL_MAP] = modal_map + + modals = list(modal_map.keys()) + if not modals_dataset.issubset(modal_map): + raise ValueError( + f'Found modalities in datasets: {modals_dataset} are not subset of' + + f' {modals}. Use sevenn_cp tool to append/assign modality' + ) + + log.writeline(f'Modalities of this model: {modals}') + + config[KEY.NUM_MODALITIES] = len(modal_map) + + # initialize known species from dataset if 'auto' + # sorted to alphabetical order (which is same as before) + chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] + if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py + log.writeline('Known species are obtained from the dataset') + config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) + + # retrieve shift, scale, conv_denominaotrs from user input (keyword) + init_from_stats_candid = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] + init_from_stats = [ + k for k in init_from_stats_candid if isinstance(config[k], str) + ] + + for k in init_from_stats: + input = config[k] + if not hasattr(train_set, input): + raise NotImplementedError(input) + modal_stat = getattr(train_set, input) + try: + if k == KEY.CONV_DENOMINATOR and 'total' in modal_stat: + # conv_denominator is not modal-wise + var = modal_stat['total'] + elif k == KEY.SHIFT and config[KEY.USE_MODAL_WISE_SHIFT]: + modal_stat.pop('total', None) + var = modal_stat + elif k == KEY.SHIFT and not config[KEY.USE_MODAL_WISE_SHIFT]: + var = modal_stat['total'] + elif k == KEY.SCALE and config[KEY.USE_MODAL_WISE_SCALE]: + modal_stat.pop('total', None) + var = modal_stat + elif k == KEY.SCALE and not config[KEY.USE_MODAL_WISE_SCALE]: + var = modal_stat['total'] + else: + raise NotImplementedError(f'Failed to init {k} from statistics') + except KeyError as e: + if e.args[0] == 'total': + raise NotImplementedError( + f'{k}: {input} does not support total statistics. ' + + f'Set use_modal_wise_{k} True or specify numbers manually' + ) + else: + raise e + config.update({k: var}) + log.writeline(f'{k} is obtained from statistics') + + return datasets diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/optim.py b/mace-bench/3rdparty/SevenNet/sevenn/train/optim.py index 7f9894387347710d3a5903ef810104342bef3da4..10e757906dbf2b6a1afa90ae7de4fbeb9854540f 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/optim.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/optim.py @@ -1,23 +1,23 @@ -import torch.nn as nn -import torch.optim.lr_scheduler as scheduler -from torch.optim import adagrad, adam, adamw, radam, sgd - -optim_dict = { - 'sgd': sgd.SGD, - 'adagrad': adagrad.Adagrad, - 'adam': adam.Adam, - 'adamw': adamw.AdamW, - 'radam': radam.RAdam, -} - - -scheduler_dict = { - 'steplr': scheduler.StepLR, - 'multisteplr': scheduler.MultiStepLR, - 'exponentiallr': scheduler.ExponentialLR, - 'cosineannealinglr': scheduler.CosineAnnealingLR, - 'reducelronplateau': scheduler.ReduceLROnPlateau, - 'linearlr': scheduler.LinearLR, -} - -loss_dict = {'mse': nn.MSELoss, 'huber': nn.HuberLoss} +import torch.nn as nn +import torch.optim.lr_scheduler as scheduler +from torch.optim import adagrad, adam, adamw, radam, sgd + +optim_dict = { + 'sgd': sgd.SGD, + 'adagrad': adagrad.Adagrad, + 'adam': adam.Adam, + 'adamw': adamw.AdamW, + 'radam': radam.RAdam, +} + + +scheduler_dict = { + 'steplr': scheduler.StepLR, + 'multisteplr': scheduler.MultiStepLR, + 'exponentiallr': scheduler.ExponentialLR, + 'cosineannealinglr': scheduler.CosineAnnealingLR, + 'reducelronplateau': scheduler.ReduceLROnPlateau, + 'linearlr': scheduler.LinearLR, +} + +loss_dict = {'mse': nn.MSELoss, 'huber': nn.HuberLoss} diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/trainer.py b/mace-bench/3rdparty/SevenNet/sevenn/train/trainer.py index 962598e2cc7e89694f94a992a86f14655bd9bd4f..cb2eb91a744a12acbf337158ce3ed899a8285c78 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/trainer.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/trainer.py @@ -1,230 +1,230 @@ -import os -import uuid -from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn -from torch.nn.parallel import DistributedDataParallel as DDP -from tqdm import tqdm - -import sevenn._keys as KEY -from sevenn.error_recorder import ErrorRecorder -from sevenn.train.loss import LossDefinition - -from .loss import get_loss_functions_from_config -from .optim import optim_dict, scheduler_dict - - -class Trainer: - """ - Training routine specialized for this package. Depends on 'sevenn.train.loss' - - Args: - model: model to train - loss_functions: List of tuples of [LossDefinition, float]. 'float' is for - loss weight for each Loss function - optimizer_cls: torch optimizer class to initialize - optimizer_args: optimizer keyword argument except 'param' - scheduler_cls: torch scheduler class to initialize, can be None - optimizer_args: optimizer keyword argument except 'optimizer' - device: device to train model, defaults to 'auto' - distributed: whether this is distributed training - distributed_backend: torch DDP backend. Should be one of 'nccl', 'mpi' - """ - - def __init__( - self, - model: torch.nn.Module, - loss_functions: List[Tuple[LossDefinition, float]], - optimizer_cls, - optimizer_args: Optional[dict] = None, - scheduler_cls=None, - scheduler_args: Optional[dict] = None, - device: Union[torch.device, str] = 'auto', - distributed: bool = False, - distributed_backend: str = 'nccl', - ): - if device == 'auto': - device = 'cuda' if torch.cuda.is_available() else 'cpu' - if distributed_backend == 'mpi': - device = 'cpu' - - if distributed: - local_rank = int(os.environ['LOCAL_RANK']) - self.rank = local_rank - if distributed_backend == 'nccl': - device = torch.device('cuda', local_rank) - self.model = DDP(model.to(device), device_ids=[device]) - elif distributed_backend == 'mpi': - self.model = DDP(model.to(device)) - else: - raise ValueError(f'Unknown DDP backend: {distributed_backend}') - dist.barrier() - self.model.module.set_is_batch_data(True) - else: - self.model = model.to(device) - self.model.set_is_batch_data(True) - self.rank = 0 - - self.device = device - self.distributed = distributed - - param = [p for p in self.model.parameters() if p.requires_grad] - self.optimizer = optimizer_cls(param, **optimizer_args) - if scheduler_cls is not None: - self.scheduler = scheduler_cls(self.optimizer, **scheduler_args) - else: - self.scheduler = None - self.loss_functions = loss_functions - - @staticmethod - def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer': - trainer = Trainer( - model, - loss_functions=get_loss_functions_from_config(config), - optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()], - optimizer_args=config.get(KEY.OPTIM_PARAM, {}), - scheduler_cls=scheduler_dict[ - config.get(KEY.SCHEDULER, 'exponentiallr').lower() - ], - scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}), - device=config.get(KEY.DEVICE, 'auto'), - distributed=config.get(KEY.IS_DDP, False), - distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'), - ) - return trainer - - @staticmethod - def args_from_checkpoint(checkpoint: str) -> Tuple[Dict, Dict, Dict]: - """ - Usage: - trainer_args, optim_stct, scheduler_stct = args_from_checkpoint('7net-0') - # Do what you want to do here - trainer = Trainer(**trainer_args) - trainer.load_state_dict( - optimizer_state_dict=optim_stct, - scheduler_state_dict=scheduler_stct, - """ - from sevenn.util import load_checkpoint - - cp = load_checkpoint(checkpoint) - - model = cp.build_model() - config = cp.config - optimizer_cls = optim_dict[config[KEY.OPTIMIZER].lower()] - scheduler_cls = scheduler_dict[config[KEY.SCHEDULER].lower()] - loss_functions = get_loss_functions_from_config(config) - - return ( - { - 'model': model, - 'loss_functions': loss_functions, - 'optimizer_cls': optimizer_cls, - 'optimizer_args': config[KEY.OPTIM_PARAM], - 'scheduler_cls': scheduler_cls, - 'scheduler_args': config[KEY.SCHEDULER_PARAM], - }, - cp.optimizer_state_dict, - cp.scheduler_state_dict, - ) - - def run_one_epoch( - self, - loader: Iterable, - is_train: bool = False, - error_recorder: Optional[ErrorRecorder] = None, - wrap_tqdm: Union[bool, int] = False, - ) -> None: - """ - Run single epoch with given dataloader - Args: - loader: iterable yieds AtomGraphData - is_train: if true, do backward() and optimizer step - error_recorder: ErrorRecorder instance to compute errors (RMSEm MAE, ..) - wrap_tqdm: wrap given dataloader with tqdm for progress bar - """ - if is_train: - self.model.train() - else: - self.model.eval() - - if wrap_tqdm: - total_len = wrap_tqdm if isinstance(wrap_tqdm, int) else None - loader = tqdm(loader, total=total_len) - for _, batch in enumerate(loader): - if is_train: - self.optimizer.zero_grad() - batch = batch.to(self.device, non_blocking=True) - output = self.model(batch) - if error_recorder is not None: - error_recorder.update(output) - if is_train: - total_loss = torch.tensor([0.0], device=self.device) - for loss_def, w in self.loss_functions: - indv_loss = loss_def.get_loss(output, self.model) - if indv_loss is not None: - total_loss += (indv_loss * w) - total_loss.backward() - self.optimizer.step() - - if self.distributed and error_recorder is not None: - self.recorder_all_reduce(error_recorder) - - def scheduler_step(self, metric: Optional[float] = None) -> None: - if self.scheduler is None: - return - if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - assert isinstance(metric, float) - self.scheduler.step(metric) - else: - self.scheduler.step() - - def get_lr(self) -> float: - return float(self.optimizer.param_groups[0]['lr']) - - def recorder_all_reduce(self, recorder: ErrorRecorder) -> None: - for metric in recorder.metrics: - # metric.value._ddp_reduce(self.device) - metric.ddp_reduce(self.device) - - def get_checkpoint_dict(self) -> dict: - if self.distributed: - model_state_dct = self.model.module.state_dict() - else: - model_state_dct = self.model.state_dict() - return { - 'model_state_dict': model_state_dct, - 'optimizer_state_dict': self.optimizer.state_dict(), - 'scheduler_state_dict': self.scheduler.state_dict() - if self.scheduler is not None - else None, - 'time': datetime.now().strftime('%Y-%m-%d %H:%M'), - 'hash': uuid.uuid4().hex, - } - - def write_checkpoint(self, path: str, **extra) -> None: - if self.distributed and self.rank != 0: - return - cp = self.get_checkpoint_dict() - cp.update(**extra) - torch.save(cp, path) - - def load_state_dicts( - self, - model_state_dict: Optional[Dict] = None, - optimizer_state_dict: Optional[Dict] = None, - scheduler_state_dict: Optional[Dict] = None, - strict: bool = True, - ) -> None: - if model_state_dict is not None: - if self.distributed: - self.model.module.load_state_dict(model_state_dict, strict=strict) - else: - self.model.load_state_dict(model_state_dict, strict=strict) - - if optimizer_state_dict is not None: - self.optimizer.load_state_dict(optimizer_state_dict) - if scheduler_state_dict is not None and self.scheduler is not None: - self.scheduler.load_state_dict(scheduler_state_dict) +import os +import uuid +from datetime import datetime +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn +from torch.nn.parallel import DistributedDataParallel as DDP +from tqdm import tqdm + +import sevenn._keys as KEY +from sevenn.error_recorder import ErrorRecorder +from sevenn.train.loss import LossDefinition + +from .loss import get_loss_functions_from_config +from .optim import optim_dict, scheduler_dict + + +class Trainer: + """ + Training routine specialized for this package. Depends on 'sevenn.train.loss' + + Args: + model: model to train + loss_functions: List of tuples of [LossDefinition, float]. 'float' is for + loss weight for each Loss function + optimizer_cls: torch optimizer class to initialize + optimizer_args: optimizer keyword argument except 'param' + scheduler_cls: torch scheduler class to initialize, can be None + optimizer_args: optimizer keyword argument except 'optimizer' + device: device to train model, defaults to 'auto' + distributed: whether this is distributed training + distributed_backend: torch DDP backend. Should be one of 'nccl', 'mpi' + """ + + def __init__( + self, + model: torch.nn.Module, + loss_functions: List[Tuple[LossDefinition, float]], + optimizer_cls, + optimizer_args: Optional[dict] = None, + scheduler_cls=None, + scheduler_args: Optional[dict] = None, + device: Union[torch.device, str] = 'auto', + distributed: bool = False, + distributed_backend: str = 'nccl', + ): + if device == 'auto': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + if distributed_backend == 'mpi': + device = 'cpu' + + if distributed: + local_rank = int(os.environ['LOCAL_RANK']) + self.rank = local_rank + if distributed_backend == 'nccl': + device = torch.device('cuda', local_rank) + self.model = DDP(model.to(device), device_ids=[device]) + elif distributed_backend == 'mpi': + self.model = DDP(model.to(device)) + else: + raise ValueError(f'Unknown DDP backend: {distributed_backend}') + dist.barrier() + self.model.module.set_is_batch_data(True) + else: + self.model = model.to(device) + self.model.set_is_batch_data(True) + self.rank = 0 + + self.device = device + self.distributed = distributed + + param = [p for p in self.model.parameters() if p.requires_grad] + self.optimizer = optimizer_cls(param, **optimizer_args) + if scheduler_cls is not None: + self.scheduler = scheduler_cls(self.optimizer, **scheduler_args) + else: + self.scheduler = None + self.loss_functions = loss_functions + + @staticmethod + def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer': + trainer = Trainer( + model, + loss_functions=get_loss_functions_from_config(config), + optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()], + optimizer_args=config.get(KEY.OPTIM_PARAM, {}), + scheduler_cls=scheduler_dict[ + config.get(KEY.SCHEDULER, 'exponentiallr').lower() + ], + scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}), + device=config.get(KEY.DEVICE, 'auto'), + distributed=config.get(KEY.IS_DDP, False), + distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'), + ) + return trainer + + @staticmethod + def args_from_checkpoint(checkpoint: str) -> Tuple[Dict, Dict, Dict]: + """ + Usage: + trainer_args, optim_stct, scheduler_stct = args_from_checkpoint('7net-0') + # Do what you want to do here + trainer = Trainer(**trainer_args) + trainer.load_state_dict( + optimizer_state_dict=optim_stct, + scheduler_state_dict=scheduler_stct, + """ + from sevenn.util import load_checkpoint + + cp = load_checkpoint(checkpoint) + + model = cp.build_model() + config = cp.config + optimizer_cls = optim_dict[config[KEY.OPTIMIZER].lower()] + scheduler_cls = scheduler_dict[config[KEY.SCHEDULER].lower()] + loss_functions = get_loss_functions_from_config(config) + + return ( + { + 'model': model, + 'loss_functions': loss_functions, + 'optimizer_cls': optimizer_cls, + 'optimizer_args': config[KEY.OPTIM_PARAM], + 'scheduler_cls': scheduler_cls, + 'scheduler_args': config[KEY.SCHEDULER_PARAM], + }, + cp.optimizer_state_dict, + cp.scheduler_state_dict, + ) + + def run_one_epoch( + self, + loader: Iterable, + is_train: bool = False, + error_recorder: Optional[ErrorRecorder] = None, + wrap_tqdm: Union[bool, int] = False, + ) -> None: + """ + Run single epoch with given dataloader + Args: + loader: iterable yieds AtomGraphData + is_train: if true, do backward() and optimizer step + error_recorder: ErrorRecorder instance to compute errors (RMSEm MAE, ..) + wrap_tqdm: wrap given dataloader with tqdm for progress bar + """ + if is_train: + self.model.train() + else: + self.model.eval() + + if wrap_tqdm: + total_len = wrap_tqdm if isinstance(wrap_tqdm, int) else None + loader = tqdm(loader, total=total_len) + for _, batch in enumerate(loader): + if is_train: + self.optimizer.zero_grad() + batch = batch.to(self.device, non_blocking=True) + output = self.model(batch) + if error_recorder is not None: + error_recorder.update(output) + if is_train: + total_loss = torch.tensor([0.0], device=self.device) + for loss_def, w in self.loss_functions: + indv_loss = loss_def.get_loss(output, self.model) + if indv_loss is not None: + total_loss += (indv_loss * w) + total_loss.backward() + self.optimizer.step() + + if self.distributed and error_recorder is not None: + self.recorder_all_reduce(error_recorder) + + def scheduler_step(self, metric: Optional[float] = None) -> None: + if self.scheduler is None: + return + if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + assert isinstance(metric, float) + self.scheduler.step(metric) + else: + self.scheduler.step() + + def get_lr(self) -> float: + return float(self.optimizer.param_groups[0]['lr']) + + def recorder_all_reduce(self, recorder: ErrorRecorder) -> None: + for metric in recorder.metrics: + # metric.value._ddp_reduce(self.device) + metric.ddp_reduce(self.device) + + def get_checkpoint_dict(self) -> dict: + if self.distributed: + model_state_dct = self.model.module.state_dict() + else: + model_state_dct = self.model.state_dict() + return { + 'model_state_dict': model_state_dct, + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict() + if self.scheduler is not None + else None, + 'time': datetime.now().strftime('%Y-%m-%d %H:%M'), + 'hash': uuid.uuid4().hex, + } + + def write_checkpoint(self, path: str, **extra) -> None: + if self.distributed and self.rank != 0: + return + cp = self.get_checkpoint_dict() + cp.update(**extra) + torch.save(cp, path) + + def load_state_dicts( + self, + model_state_dict: Optional[Dict] = None, + optimizer_state_dict: Optional[Dict] = None, + scheduler_state_dict: Optional[Dict] = None, + strict: bool = True, + ) -> None: + if model_state_dict is not None: + if self.distributed: + self.model.module.load_state_dict(model_state_dict, strict=strict) + else: + self.model.load_state_dict(model_state_dict, strict=strict) + + if optimizer_state_dict is not None: + self.optimizer.load_state_dict(optimizer_state_dict) + if scheduler_state_dict is not None and self.scheduler is not None: + self.scheduler.load_state_dict(scheduler_state_dict) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/util.py b/mace-bench/3rdparty/SevenNet/sevenn/util.py index 7afeb4e899bf4df130481331d17a584c1d84313c..632a1d253316f1ee28b1a7c1c2f699ffd773a961 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/util.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/util.py @@ -1,330 +1,330 @@ -import os -import os.path as osp -import pathlib -import shutil -from typing import Dict, List, Tuple, Union - -import numpy as np -import requests -import torch -import torch.nn -from e3nn.o3 import FullTensorProduct, Irreps -from tqdm import tqdm - -import sevenn._const as _const -import sevenn._keys as KEY - - -def to_atom_graph_list(atom_graph_batch): - """ - torch_geometric batched data to separate list - original to_data_list() by PyG is not enough since - it doesn't handle inferred tensors - """ - is_stress = KEY.PRED_STRESS in atom_graph_batch - - data_list = atom_graph_batch.to_data_list() - - indices = atom_graph_batch[KEY.NUM_ATOMS].tolist() - - atomic_energy_list = torch.split(atom_graph_batch[KEY.ATOMIC_ENERGY], indices) - inferred_total_energy_list = torch.unbind( - atom_graph_batch[KEY.PRED_TOTAL_ENERGY] - ) - inferred_force_list = torch.split(atom_graph_batch[KEY.PRED_FORCE], indices) - - inferred_stress_list = None - if is_stress: - inferred_stress_list = torch.unbind(atom_graph_batch[KEY.PRED_STRESS]) - - for i, data in enumerate(data_list): - data[KEY.ATOMIC_ENERGY] = atomic_energy_list[i] - data[KEY.PRED_TOTAL_ENERGY] = inferred_total_energy_list[i] - data[KEY.PRED_FORCE] = inferred_force_list[i] - # To fit with KEY.STRESS (ref) format - if is_stress and inferred_stress_list is not None: - data[KEY.PRED_STRESS] = torch.unsqueeze(inferred_stress_list[i], 0) - return data_list - - -def error_recorder_from_loss_functions(loss_functions): - from .error_recorder import ErrorRecorder, MAError, RMSError, get_err_type - from .train.loss import ForceLoss, PerAtomEnergyLoss, StressLoss - - metrics = [] - for loss_function, _ in loss_functions: - ref_key = loss_function.ref_key - pred_key = loss_function.pred_key - # unit = loss_function.unit - criterion = loss_function.criterion - name = loss_function.name - base = None - if type(loss_function) is PerAtomEnergyLoss: - base = get_err_type('Energy') - elif type(loss_function) is ForceLoss: - base = get_err_type('Force') - elif type(loss_function) is StressLoss: - base = get_err_type('Stress') - else: - base = {} - base['name'] = name - base['ref_key'] = ref_key - base['pred_key'] = pred_key - if type(criterion) is torch.nn.MSELoss: - base['name'] = base['name'] + '_RMSE' - metrics.append(RMSError(**base)) - elif type(criterion) is torch.nn.L1Loss: - metrics.append(MAError(**base)) - return ErrorRecorder(metrics) - - -def onehot_to_chem(one_hot_indices: List[int], type_map: Dict[int, int]): - from ase.data import chemical_symbols - - type_map_rev = {v: k for k, v in type_map.items()} - return [chemical_symbols[type_map_rev[x]] for x in one_hot_indices] - - -def model_from_checkpoint( - checkpoint: str, -) -> Tuple[torch.nn.Module, Dict]: - cp = load_checkpoint(checkpoint) - model = cp.build_model() - - return model, cp.config - - -def model_from_checkpoint_with_backend( - checkpoint: str, - backend: str = 'e3nn', -) -> Tuple[torch.nn.Module, Dict]: - cp = load_checkpoint(checkpoint) - model = cp.build_model(backend) - - return model, cp.config - - -def unlabeled_atoms_to_input(atoms, cutoff: float, grad_key: str = KEY.EDGE_VEC): - from .atom_graph_data import AtomGraphData - from .train.dataload import unlabeled_atoms_to_graph - - atom_graph = AtomGraphData.from_numpy_dict( - unlabeled_atoms_to_graph(atoms, cutoff) - ) - atom_graph[grad_key].requires_grad_(True) - atom_graph[KEY.BATCH] = torch.zeros([0]) - return atom_graph - - -def chemical_species_preprocess(input_chem: List[str], universal: bool = False): - from ase.data import atomic_numbers, chemical_symbols - - from .nn.node_embedding import get_type_mapper_from_specie - - config = {} - if not universal: - input_chem = list(set(input_chem)) - chemical_specie = sorted([x.strip() for x in input_chem]) - config[KEY.CHEMICAL_SPECIES] = chemical_specie - config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = [ - atomic_numbers[x] for x in chemical_specie - ] - config[KEY.NUM_SPECIES] = len(chemical_specie) - config[KEY.TYPE_MAP] = get_type_mapper_from_specie(chemical_specie) - else: - config[KEY.CHEMICAL_SPECIES] = chemical_symbols - len_univ = len(chemical_symbols) - config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = list(range(len_univ)) - config[KEY.NUM_SPECIES] = len_univ - config[KEY.TYPE_MAP] = {z: z for z in range(len_univ)} - return config - - -def dtype_correct( - v: Union[np.ndarray, torch.Tensor, int, float], - float_dtype: torch.dtype = torch.float32, - int_dtype: torch.dtype = torch.int64, -): - if isinstance(v, np.ndarray): - if np.issubdtype(v.dtype, np.floating): - return torch.from_numpy(v).to(float_dtype) - elif np.issubdtype(v.dtype, np.integer): - return torch.from_numpy(v).to(int_dtype) - elif isinstance(v, torch.Tensor): - if v.dtype.is_floating_point: - return v.to(float_dtype) # convert to specified float dtype - else: # assuming non-floating point tensors are integers - return v.to(int_dtype) # convert to specified int dtype - else: # scalar values - if isinstance(v, int): - return torch.tensor(v, dtype=int_dtype) - elif isinstance(v, float): - return torch.tensor(v, dtype=float_dtype) - else: # Not numeric - return v - - -def infer_irreps_out( - irreps_x: Irreps, - irreps_operand: Irreps, - drop_l: Union[bool, int] = False, - parity_mode: str = 'full', - fix_multiplicity: Union[bool, int] = False, -): - assert parity_mode in ['full', 'even', 'sph'] - # (mul, (ir, p)) - irreps_out = FullTensorProduct(irreps_x, irreps_operand).irreps_out.simplify() - new_irreps_elem = [] - for mul, (l, p) in irreps_out: # noqa - elem = (mul, (l, p)) - if drop_l is not False and l > drop_l: - continue - if parity_mode == 'even' and p == -1: - continue - elif parity_mode == 'sph' and p != (-1) ** l: - continue - if fix_multiplicity: - elem = (fix_multiplicity, (l, p)) - new_irreps_elem.append(elem) - return Irreps(new_irreps_elem) - - -def download_checkpoint(path: str, url: str): - fname = osp.basename(path) - temp_path = path + '.partial' - try: - # raises permission error if fails - os.makedirs(osp.dirname(path), exist_ok=True) - response = requests.get(url, stream=True, timeout=30) - response.raise_for_status() # Raise exception for bad status codes - - total_size = int(response.headers.get('content-length', 0)) - block_size = 1024 # 1 KB chunks - - progress_bar = tqdm( - total=total_size, - unit='B', - unit_scale=True, - desc=f'Downloading {fname}', - ) - - with open(temp_path, 'wb') as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - - shutil.move(temp_path, path) - print(f'Checkpoint downloaded: {path}') - return path - except PermissionError: - raise - except Exception as e: - # Clean up partial downloads on failure - # May not work as errors handled internally by tqdm etc. - print(f'Download failed: {str(e)}') - if os.path.exists(temp_path): - print(f'Cleaning up partial download: {temp_path}') - os.remove(temp_path) - raise - - -def pretrained_name_to_path(name: str) -> str: - name = name.lower() - heads = ['sevennet', '7net'] - checkpoint_path = None - url = None - - if ( # TODO: regex - name in [f'{n}-0_11july2024' for n in heads] - or name in [f'{n}-0_11jul2024' for n in heads] - or name in ['sevennet-0', '7net-0'] - ): - checkpoint_path = _const.SEVENNET_0_11Jul2024 - elif name in [f'{n}-0_22may2024' for n in heads]: - checkpoint_path = _const.SEVENNET_0_22May2024 - elif name in [f'{n}-l3i5' for n in heads]: - checkpoint_path = _const.SEVENNET_l3i5 - elif name in [f'{n}-mf-0' for n in heads]: - checkpoint_path = _const.SEVENNET_MF_0 - elif name in [f'{n}-mf-ompa' for n in heads]: - checkpoint_path = _const.SEVENNET_MF_ompa - elif name in [f'{n}-omat' for n in heads]: - checkpoint_path = _const.SEVENNET_omat - else: - raise ValueError('Not a valid pretrained model name') - url = _const.CHECKPOINT_DOWNLOAD_LINKS.get(checkpoint_path) - - paths = [ - checkpoint_path, - checkpoint_path.replace(_const._prefix, osp.expanduser('~/.cache/sevennet')), - ] - - for path in paths: - if osp.exists(path): - return path - - # File not found check url and try download - if url is None: - raise FileNotFoundError(checkpoint_path) - - try: - return download_checkpoint(paths[0], url) # 7net package path - except PermissionError: - return download_checkpoint(paths[1], url) # ~/.cache - - -def load_checkpoint(checkpoint: Union[pathlib.Path, str]): - from sevenn.checkpoint import SevenNetCheckpoint - suggests = ['7net-0, 7net-l3i5, 7net-mf-ompa, 7net-omat'] - if osp.isfile(checkpoint): - checkpoint_path = checkpoint - else: - try: - checkpoint_path = pretrained_name_to_path(str(checkpoint)) - except ValueError: - raise ValueError( - f'Given {checkpoint} is not exists and not a pre-trained name.\n' - f'Valid pretrained model names: {suggests}' - ) - return SevenNetCheckpoint(checkpoint_path) - - -def unique_filepath(filepath: str) -> str: - if not os.path.isfile(filepath): - return filepath - else: - dirname = os.path.dirname(filepath) - fname = os.path.basename(filepath) - name, ext = os.path.splitext(fname) - cnt = 0 - new_name = f'{name}{cnt}{ext}' - new_path = os.path.join(dirname, new_name) - while os.path.exists(new_path): - cnt += 1 - new_name = f'{name}{cnt}{ext}' - new_path = os.path.join(dirname, new_name) - return new_path - - -def get_error_recorder( - recorder_tuples: List[Tuple[str, str]] = [ - ('Energy', 'RMSE'), - ('Force', 'RMSE'), - ('Stress', 'RMSE'), - ('Energy', 'MAE'), - ('Force', 'MAE'), - ('Stress', 'MAE'), - ], -): - # TODO add criterion argument and loss recorder selections - import sevenn.error_recorder as error_recorder - - config = recorder_tuples - err_metrics = [] - for err_type, metric_name in config: - metric_kwargs = error_recorder.get_err_type(err_type).copy() - metric_kwargs['name'] += f'_{metric_name}' - metric_cls = error_recorder.ErrorRecorder.METRIC_DICT[metric_name] - err_metrics.append(metric_cls(**metric_kwargs)) - return error_recorder.ErrorRecorder(err_metrics) +import os +import os.path as osp +import pathlib +import shutil +from typing import Dict, List, Tuple, Union + +import numpy as np +import requests +import torch +import torch.nn +from e3nn.o3 import FullTensorProduct, Irreps +from tqdm import tqdm + +import sevenn._const as _const +import sevenn._keys as KEY + + +def to_atom_graph_list(atom_graph_batch): + """ + torch_geometric batched data to separate list + original to_data_list() by PyG is not enough since + it doesn't handle inferred tensors + """ + is_stress = KEY.PRED_STRESS in atom_graph_batch + + data_list = atom_graph_batch.to_data_list() + + indices = atom_graph_batch[KEY.NUM_ATOMS].tolist() + + atomic_energy_list = torch.split(atom_graph_batch[KEY.ATOMIC_ENERGY], indices) + inferred_total_energy_list = torch.unbind( + atom_graph_batch[KEY.PRED_TOTAL_ENERGY] + ) + inferred_force_list = torch.split(atom_graph_batch[KEY.PRED_FORCE], indices) + + inferred_stress_list = None + if is_stress: + inferred_stress_list = torch.unbind(atom_graph_batch[KEY.PRED_STRESS]) + + for i, data in enumerate(data_list): + data[KEY.ATOMIC_ENERGY] = atomic_energy_list[i] + data[KEY.PRED_TOTAL_ENERGY] = inferred_total_energy_list[i] + data[KEY.PRED_FORCE] = inferred_force_list[i] + # To fit with KEY.STRESS (ref) format + if is_stress and inferred_stress_list is not None: + data[KEY.PRED_STRESS] = torch.unsqueeze(inferred_stress_list[i], 0) + return data_list + + +def error_recorder_from_loss_functions(loss_functions): + from .error_recorder import ErrorRecorder, MAError, RMSError, get_err_type + from .train.loss import ForceLoss, PerAtomEnergyLoss, StressLoss + + metrics = [] + for loss_function, _ in loss_functions: + ref_key = loss_function.ref_key + pred_key = loss_function.pred_key + # unit = loss_function.unit + criterion = loss_function.criterion + name = loss_function.name + base = None + if type(loss_function) is PerAtomEnergyLoss: + base = get_err_type('Energy') + elif type(loss_function) is ForceLoss: + base = get_err_type('Force') + elif type(loss_function) is StressLoss: + base = get_err_type('Stress') + else: + base = {} + base['name'] = name + base['ref_key'] = ref_key + base['pred_key'] = pred_key + if type(criterion) is torch.nn.MSELoss: + base['name'] = base['name'] + '_RMSE' + metrics.append(RMSError(**base)) + elif type(criterion) is torch.nn.L1Loss: + metrics.append(MAError(**base)) + return ErrorRecorder(metrics) + + +def onehot_to_chem(one_hot_indices: List[int], type_map: Dict[int, int]): + from ase.data import chemical_symbols + + type_map_rev = {v: k for k, v in type_map.items()} + return [chemical_symbols[type_map_rev[x]] for x in one_hot_indices] + + +def model_from_checkpoint( + checkpoint: str, +) -> Tuple[torch.nn.Module, Dict]: + cp = load_checkpoint(checkpoint) + model = cp.build_model() + + return model, cp.config + + +def model_from_checkpoint_with_backend( + checkpoint: str, + backend: str = 'e3nn', +) -> Tuple[torch.nn.Module, Dict]: + cp = load_checkpoint(checkpoint) + model = cp.build_model(backend) + + return model, cp.config + + +def unlabeled_atoms_to_input(atoms, cutoff: float, grad_key: str = KEY.EDGE_VEC): + from .atom_graph_data import AtomGraphData + from .train.dataload import unlabeled_atoms_to_graph + + atom_graph = AtomGraphData.from_numpy_dict( + unlabeled_atoms_to_graph(atoms, cutoff) + ) + atom_graph[grad_key].requires_grad_(True) + atom_graph[KEY.BATCH] = torch.zeros([0]) + return atom_graph + + +def chemical_species_preprocess(input_chem: List[str], universal: bool = False): + from ase.data import atomic_numbers, chemical_symbols + + from .nn.node_embedding import get_type_mapper_from_specie + + config = {} + if not universal: + input_chem = list(set(input_chem)) + chemical_specie = sorted([x.strip() for x in input_chem]) + config[KEY.CHEMICAL_SPECIES] = chemical_specie + config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = [ + atomic_numbers[x] for x in chemical_specie + ] + config[KEY.NUM_SPECIES] = len(chemical_specie) + config[KEY.TYPE_MAP] = get_type_mapper_from_specie(chemical_specie) + else: + config[KEY.CHEMICAL_SPECIES] = chemical_symbols + len_univ = len(chemical_symbols) + config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = list(range(len_univ)) + config[KEY.NUM_SPECIES] = len_univ + config[KEY.TYPE_MAP] = {z: z for z in range(len_univ)} + return config + + +def dtype_correct( + v: Union[np.ndarray, torch.Tensor, int, float], + float_dtype: torch.dtype = torch.float32, + int_dtype: torch.dtype = torch.int64, +): + if isinstance(v, np.ndarray): + if np.issubdtype(v.dtype, np.floating): + return torch.from_numpy(v).to(float_dtype) + elif np.issubdtype(v.dtype, np.integer): + return torch.from_numpy(v).to(int_dtype) + elif isinstance(v, torch.Tensor): + if v.dtype.is_floating_point: + return v.to(float_dtype) # convert to specified float dtype + else: # assuming non-floating point tensors are integers + return v.to(int_dtype) # convert to specified int dtype + else: # scalar values + if isinstance(v, int): + return torch.tensor(v, dtype=int_dtype) + elif isinstance(v, float): + return torch.tensor(v, dtype=float_dtype) + else: # Not numeric + return v + + +def infer_irreps_out( + irreps_x: Irreps, + irreps_operand: Irreps, + drop_l: Union[bool, int] = False, + parity_mode: str = 'full', + fix_multiplicity: Union[bool, int] = False, +): + assert parity_mode in ['full', 'even', 'sph'] + # (mul, (ir, p)) + irreps_out = FullTensorProduct(irreps_x, irreps_operand).irreps_out.simplify() + new_irreps_elem = [] + for mul, (l, p) in irreps_out: # noqa + elem = (mul, (l, p)) + if drop_l is not False and l > drop_l: + continue + if parity_mode == 'even' and p == -1: + continue + elif parity_mode == 'sph' and p != (-1) ** l: + continue + if fix_multiplicity: + elem = (fix_multiplicity, (l, p)) + new_irreps_elem.append(elem) + return Irreps(new_irreps_elem) + + +def download_checkpoint(path: str, url: str): + fname = osp.basename(path) + temp_path = path + '.partial' + try: + # raises permission error if fails + os.makedirs(osp.dirname(path), exist_ok=True) + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() # Raise exception for bad status codes + + total_size = int(response.headers.get('content-length', 0)) + block_size = 1024 # 1 KB chunks + + progress_bar = tqdm( + total=total_size, + unit='B', + unit_scale=True, + desc=f'Downloading {fname}', + ) + + with open(temp_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + + shutil.move(temp_path, path) + print(f'Checkpoint downloaded: {path}') + return path + except PermissionError: + raise + except Exception as e: + # Clean up partial downloads on failure + # May not work as errors handled internally by tqdm etc. + print(f'Download failed: {str(e)}') + if os.path.exists(temp_path): + print(f'Cleaning up partial download: {temp_path}') + os.remove(temp_path) + raise + + +def pretrained_name_to_path(name: str) -> str: + name = name.lower() + heads = ['sevennet', '7net'] + checkpoint_path = None + url = None + + if ( # TODO: regex + name in [f'{n}-0_11july2024' for n in heads] + or name in [f'{n}-0_11jul2024' for n in heads] + or name in ['sevennet-0', '7net-0'] + ): + checkpoint_path = _const.SEVENNET_0_11Jul2024 + elif name in [f'{n}-0_22may2024' for n in heads]: + checkpoint_path = _const.SEVENNET_0_22May2024 + elif name in [f'{n}-l3i5' for n in heads]: + checkpoint_path = _const.SEVENNET_l3i5 + elif name in [f'{n}-mf-0' for n in heads]: + checkpoint_path = _const.SEVENNET_MF_0 + elif name in [f'{n}-mf-ompa' for n in heads]: + checkpoint_path = _const.SEVENNET_MF_ompa + elif name in [f'{n}-omat' for n in heads]: + checkpoint_path = _const.SEVENNET_omat + else: + raise ValueError('Not a valid pretrained model name') + url = _const.CHECKPOINT_DOWNLOAD_LINKS.get(checkpoint_path) + + paths = [ + checkpoint_path, + checkpoint_path.replace(_const._prefix, osp.expanduser('~/.cache/sevennet')), + ] + + for path in paths: + if osp.exists(path): + return path + + # File not found check url and try download + if url is None: + raise FileNotFoundError(checkpoint_path) + + try: + return download_checkpoint(paths[0], url) # 7net package path + except PermissionError: + return download_checkpoint(paths[1], url) # ~/.cache + + +def load_checkpoint(checkpoint: Union[pathlib.Path, str]): + from sevenn.checkpoint import SevenNetCheckpoint + suggests = ['7net-0, 7net-l3i5, 7net-mf-ompa, 7net-omat'] + if osp.isfile(checkpoint): + checkpoint_path = checkpoint + else: + try: + checkpoint_path = pretrained_name_to_path(str(checkpoint)) + except ValueError: + raise ValueError( + f'Given {checkpoint} is not exists and not a pre-trained name.\n' + f'Valid pretrained model names: {suggests}' + ) + return SevenNetCheckpoint(checkpoint_path) + + +def unique_filepath(filepath: str) -> str: + if not os.path.isfile(filepath): + return filepath + else: + dirname = os.path.dirname(filepath) + fname = os.path.basename(filepath) + name, ext = os.path.splitext(fname) + cnt = 0 + new_name = f'{name}{cnt}{ext}' + new_path = os.path.join(dirname, new_name) + while os.path.exists(new_path): + cnt += 1 + new_name = f'{name}{cnt}{ext}' + new_path = os.path.join(dirname, new_name) + return new_path + + +def get_error_recorder( + recorder_tuples: List[Tuple[str, str]] = [ + ('Energy', 'RMSE'), + ('Force', 'RMSE'), + ('Stress', 'RMSE'), + ('Energy', 'MAE'), + ('Force', 'MAE'), + ('Stress', 'MAE'), + ], +): + # TODO add criterion argument and loss recorder selections + import sevenn.error_recorder as error_recorder + + config = recorder_tuples + err_metrics = [] + for err_type, metric_name in config: + metric_kwargs = error_recorder.get_err_type(err_type).copy() + metric_kwargs['name'] += f'_{metric_name}' + metric_cls = error_recorder.ErrorRecorder.METRIC_DICT[metric_name] + err_metrics.append(metric_cls(**metric_kwargs)) + return error_recorder.ErrorRecorder(err_metrics) diff --git a/mace-bench/3rdparty/SevenNet/tests/data/inferences/snet0_on_hfo2/errors.txt b/mace-bench/3rdparty/SevenNet/tests/data/inferences/snet0_on_hfo2/errors.txt index 81bd34c0016341d0c0c4f0a97eccb5c28a07b679..a7123abfdfccab38e7d90a381d9b732e7ab7f22d 100644 --- a/mace-bench/3rdparty/SevenNet/tests/data/inferences/snet0_on_hfo2/errors.txt +++ b/mace-bench/3rdparty/SevenNet/tests/data/inferences/snet0_on_hfo2/errors.txt @@ -1,6 +1,6 @@ -Energy_RMSE (eV/atom): 18.84848889028682 -Force_RMSE (eV/Å): 0.2622841142173583 -Stress_RMSE (kbar): 163.7362768581691 -Energy_MAE (eV/atom): 18.848487854003906 -Force_MAE (eV/Å): 0.116698424021403 -Stress_MAE (kbar): 47.33086649576823 +Energy_RMSE (eV/atom): 18.84848889028682 +Force_RMSE (eV/Å): 0.2622841142173583 +Stress_RMSE (kbar): 163.7362768581691 +Energy_MAE (eV/atom): 18.848487854003906 +Force_MAE (eV/Å): 0.116698424021403 +Stress_MAE (kbar): 47.33086649576823 diff --git a/mace-bench/3rdparty/SevenNet/tests/lammps_tests/conftest.py b/mace-bench/3rdparty/SevenNet/tests/lammps_tests/conftest.py index d4efed3210e0663dbda6621f9930daafdbd61224..8b7257de33c5d8aa6baedf1691a1bea6acbb6424 100644 --- a/mace-bench/3rdparty/SevenNet/tests/lammps_tests/conftest.py +++ b/mace-bench/3rdparty/SevenNet/tests/lammps_tests/conftest.py @@ -1,24 +1,24 @@ -import pytest - - -def pytest_addoption(parser): - parser.addoption('--lammps_cmd', default=None, help='Lammps binary to test') - parser.addoption( - '--mpirun_cmd', default=None, help='mpirun binary to test parallel' - ) - - -@pytest.fixture -def lammps_cmd(request): - bin = request.config.getoption('lammps_cmd') - if bin is None: - pytest.skip('No LAMMPS binary given, skipping test') - return bin - - -@pytest.fixture -def mpirun_cmd(request): - bin = request.config.getoption('mpirun_cmd') - if bin is None: - pytest.skip('No mpirun cmd given, skipping test') - return bin +import pytest + + +def pytest_addoption(parser): + parser.addoption('--lammps_cmd', default=None, help='Lammps binary to test') + parser.addoption( + '--mpirun_cmd', default=None, help='mpirun binary to test parallel' + ) + + +@pytest.fixture +def lammps_cmd(request): + bin = request.config.getoption('lammps_cmd') + if bin is None: + pytest.skip('No LAMMPS binary given, skipping test') + return bin + + +@pytest.fixture +def mpirun_cmd(request): + bin = request.config.getoption('mpirun_cmd') + if bin is None: + pytest.skip('No mpirun cmd given, skipping test') + return bin diff --git a/mace-bench/3rdparty/SevenNet/tests/lammps_tests/test_lammps.py b/mace-bench/3rdparty/SevenNet/tests/lammps_tests/test_lammps.py index 25b0614f994868f871c50e4a467e1e03ba2781fb..54f53a50e0fe6e7c4f53851fb9759020a6c365e3 100644 --- a/mace-bench/3rdparty/SevenNet/tests/lammps_tests/test_lammps.py +++ b/mace-bench/3rdparty/SevenNet/tests/lammps_tests/test_lammps.py @@ -1,467 +1,467 @@ -import copy -import logging -import pathlib -import subprocess - -import ase.calculators.lammps -import ase.io.lammpsdata -import numpy as np -import pytest -import torch -from ase.build import bulk, surface -from ase.calculators.singlepoint import SinglePointCalculator - -import sevenn -from sevenn.calculator import SevenNetCalculator -from sevenn.model_build import build_E3_equivariant_model -from sevenn.nn.cue_helper import is_cue_available -from sevenn.scripts.deploy import deploy, deploy_parallel -from sevenn.util import chemical_species_preprocess, pretrained_name_to_path - -logger = logging.getLogger('test_lammps') - -cutoff = 4.0 - -lmp_script_path = str( - (pathlib.Path(__file__).parent / 'scripts' / 'skel.lmp').resolve() -) - -data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() -cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') # knows Hf, O -cp_mf_path = pretrained_name_to_path('7net-mf-0') - - -@pytest.fixture(scope='module') -def serial_potential_path(tmp_path_factory): - tmp = tmp_path_factory.mktemp('serial_potential') - pot_path = str(tmp / 'deployed_serial.pt') - deploy(cp_0_path, pot_path) - return pot_path - - -@pytest.fixture(scope='module') -def parallel_potential_path(tmp_path_factory): - tmp = tmp_path_factory.mktemp('paralllel_potential') - pot_path = str(tmp / 'deployed_parallel') - deploy_parallel(cp_0_path, pot_path) - return ' '.join(['3', pot_path]) - - -@pytest.fixture(scope='module') -def serial_modal_potential_path(tmp_path_factory): - tmp = tmp_path_factory.mktemp('serial_modal_potential') - pot_path = str(tmp / 'deployed_serial.pt') - deploy(cp_mf_path, pot_path, 'PBE') - return pot_path - - -@pytest.fixture(scope='module') -def parallel_modal_potential_path(tmp_path_factory): - tmp = tmp_path_factory.mktemp('paralllel_modal_potential') - pot_path = str(tmp / 'deployed_parallel') - deploy_parallel(cp_mf_path, pot_path, 'PBE') - return ' '.join(['5', pot_path]) - - -@pytest.fixture(scope='module') -def ref_calculator(): - return SevenNetCalculator(cp_0_path) - - -@pytest.fixture(scope='module') -def ref_modal_calculator(): - return SevenNetCalculator(cp_mf_path, modal='PBE') - - -def get_model_config(): - config = { - 'cutoff': cutoff, - 'channel': 8, - 'lmax': 2, - 'is_parity': True, - 'num_convolution_layer': 3, - 'self_connection_type': 'linear', # not NequIp - 'interaction_type': 'nequip', - 'radial_basis': { - 'radial_basis_name': 'bessel', - }, - 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, - 'weight_nn_hidden_neurons': [64, 64], - 'act_radial': 'silu', - 'act_scalar': {'e': 'silu', 'o': 'tanh'}, - 'act_gate': {'e': 'silu', 'o': 'tanh'}, - 'conv_denominator': 30.0, - 'train_denominator': False, - 'shift': -10.0, - 'scale': 10.0, - 'train_shift_scale': False, - 'irreps_manual': False, - 'lmax_edge': -1, - 'lmax_node': -1, - 'readout_as_fcn': False, - 'use_bias_in_linear': False, - '_normalize_sph': True, - } - config.update(chemical_species_preprocess(['Hf', 'O'])) - return config - - -def get_model(config_overwrite=None, use_cueq=False, cueq_config=None): - cf = get_model_config() - if config_overwrite is not None: - cf.update(config_overwrite) - - cueq_config = cueq_config or {'cuequivariance_config': {'use': use_cueq}} - cf.update(cueq_config) - - model = build_E3_equivariant_model(cf, parallel=False) - assert not isinstance(model, list) - return model - - -def hfo2_bulk(replicate=(2, 2, 2), a=4.0): - atoms = bulk('HfO', 'rocksalt', a, orthorhombic=True) - atoms = atoms * replicate - atoms.rattle(stdev=0.10) - return atoms - - -def hf_surface(replicate=(3, 3, 1), layers=4, vacuum=0.5): - atoms = surface('Al', (1, 0, 0), layers=layers, vacuum=vacuum) - atoms.set_atomic_numbers([72] * len(atoms)) # Hf - atoms = atoms * replicate - atoms.rattle(stdev=0.10) - return atoms - - -def get_system(system_name, **kwargs): - if system_name == 'bulk': - return hfo2_bulk(**kwargs) - elif system_name == 'surface': - return hf_surface(**kwargs) - else: - raise ValueError() - - -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): - def acl(a, b, rtol=rtol, atol=atol): - return np.allclose(a, b, rtol=rtol, atol=atol) - - assert len(atoms1) == len(atoms2) - assert acl(atoms1.get_cell(), atoms2.get_cell()) - assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) - assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) - assert acl( - atoms1.get_stress(voigt=False), - atoms2.get_stress(voigt=False), - rtol * 10, - atol * 10, - ) - # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) - - -def _lammps_results_to_atoms(lammps_log, force_dump): - with open(lammps_log, 'r') as f: - lines = f.readlines() - lmp_log = None - for i, line in enumerate(lines): - if not line.startswith('Per MPI rank memory allocation'): - continue - lmp_log = { - k: eval(v) for k, v in zip(lines[i + 1].split(), lines[i + 2].split()) - } - break - - assert lmp_log is not None and 'PotEng' in lmp_log - - latoms_list = ase.io.read(force_dump, format='lammps-dump-text', index=':') - assert isinstance(latoms_list, list) - latoms = latoms_list[0] - assert latoms.calc is not None - latoms.calc.results['energy'] = lmp_log['PotEng'] - latoms.calc.results['free_energy'] = lmp_log['PotEng'] - latoms.info = { - 'data_from': 'lammps', - 'lmp_log': lmp_log, - 'lmp_dump': force_dump, - } - # atomic energy read - latoms.calc.results['energies'] = latoms.arrays['c_pa'][:, 0] - stress = np.array( - [ - [lmp_log['Pxx'], lmp_log['Pxy'], lmp_log['Pxz']], - [lmp_log['Pxy'], lmp_log['Pyy'], lmp_log['Pyz']], - [lmp_log['Pxz'], lmp_log['Pyz'], lmp_log['Pzz']], - ] - ) - stress = -1 * stress / 1602.1766208 / 1000 # convert bars to eV/A^3 - latoms.calc.results['stress'] = stress - - return latoms - - -def _run_lammps(atoms, pair_style, potential, wd, command, test_name): - wd = wd.resolve() - pbc = atoms.get_pbc() - pbc_str = ' '.join(['p' if x else 'f' for x in pbc]) - chem = list(set(atoms.get_chemical_symbols())) - # Way to ase handle lammps structure - - prism = ase.calculators.lammps.coordinatetransform.Prism( - atoms.get_cell(), pbc=pbc - ) - lmp_stct = wd / 'lammps_structure' - ase.io.lammpsdata.write_lammps_data( - lmp_stct, atoms, prismobj=prism, specorder=chem - ) - - with open(lmp_script_path, 'r') as f: - cont = f.read() - - lammps_log = str(wd / 'log.lammps') - force_dump = str(wd / 'force.dump') - - var_dct = {} - var_dct['__ELEMENT__'] = ' '.join(chem) - var_dct['__LMP_STCT__'] = str(lmp_stct.resolve()) - var_dct['__PAIR_STYLE__'] = pair_style - var_dct['__POTENTIALS__'] = potential - var_dct['__BOUNDARY__'] = pbc_str - var_dct['__FORCE_DUMP_PATH__'] = force_dump - for key, val in var_dct.items(): - cont = cont.replace(key, val) - - input_script_path = str(wd / 'in.lmp') - with open(input_script_path, 'w') as f: - f.write(cont) - - command = f'{command} -in {input_script_path} -log {lammps_log}' - subprocess_routine(command.split(), test_name) - - lmp_atoms = _lammps_results_to_atoms(lammps_log, force_dump) - assert lmp_atoms.calc is not None - - rot_mat = prism.rot_mat - results = copy.deepcopy(lmp_atoms.calc.results) - r_force = np.dot(results['forces'], rot_mat.T) - results['forces'] = r_force - if 'stress' in results: - # see ase.calculators.lammpsrun.py - stress_tensor = results['stress'] - stress_atoms = np.dot(np.dot(rot_mat, stress_tensor), rot_mat.T) - results['stress'] = stress_atoms - r_cell = lmp_atoms.get_cell() @ rot_mat.T - lmp_atoms.set_cell(r_cell, scale_atoms=True) - lmp_atoms = SinglePointCalculator(lmp_atoms, **results).get_atoms() - - return lmp_atoms - - -def serial_lammps_run(atoms, potential, wd, test_name, lammps_cmd): - command = lammps_cmd - return _run_lammps(atoms, 'e3gnn', potential, wd, command, test_name) - - -def parallel_lammps_run( - atoms, potential, wd, test_name, ncores, lammps_cmd, mpirun_cmd -): - command = f'{mpirun_cmd} -np {ncores} {lammps_cmd}' - return _run_lammps(atoms, 'e3gnn/parallel', potential, wd, command, test_name) - - -def subprocess_routine(cmd, name): - res = subprocess.run(cmd, capture_output=True, timeout=30) - if res.returncode != 0: - logger.error(f'Subprocess {name} failed return code: {res.returncode}') - logger.error(res.stderr.decode('utf-8')) - raise RuntimeError(f'{name} failed') - - logger.info(f'stdout of {name}:') - logger.info(res.stdout.decode('utf-8')) - - -@pytest.mark.parametrize( - 'system', - ['bulk', 'surface'], -) -def test_serial(system, serial_potential_path, ref_calculator, lammps_cmd, tmp_path): - atoms = get_system(system) - atoms_lammps = serial_lammps_run( - atoms=atoms, - potential=serial_potential_path, - wd=tmp_path, - test_name='serial lmp test', - lammps_cmd=lammps_cmd, - ) - atoms.calc = ref_calculator - assert_atoms(atoms, atoms_lammps) - - -@pytest.mark.parametrize( - 'system,ncores', - [ - ('bulk', 1), - ('bulk', 2), - ('bulk', 4), - ('surface', 1), - ('surface', 2), - ('surface', 3), - ('surface', 4), - ], -) -def test_parallel( - system, - ncores, - parallel_potential_path, - ref_calculator, - lammps_cmd, - mpirun_cmd, - tmp_path, -): - if system == 'bulk': - rep = (6, 6, 3) - elif system == 'surface': - rep = (4, 4, 1) - else: - assert False - atoms = get_system(system, replicate=rep) - atoms_lammps = parallel_lammps_run( - atoms=atoms, - potential=parallel_potential_path, - wd=tmp_path, - test_name='parallel lmp test', - lammps_cmd=lammps_cmd, - mpirun_cmd=mpirun_cmd, - ncores=ncores, - ) - atoms.calc = ref_calculator - assert_atoms(atoms, atoms_lammps) - - -@pytest.mark.parametrize( - 'system', - ['bulk', 'surface'], -) -def test_modal_serial( - system, serial_modal_potential_path, ref_modal_calculator, lammps_cmd, tmp_path -): - atoms = get_system(system) - atoms_lammps = serial_lammps_run( - atoms=atoms, - potential=serial_modal_potential_path, - wd=tmp_path, - test_name='serial lmp test', - lammps_cmd=lammps_cmd, - ) - atoms.calc = ref_modal_calculator - assert_atoms(atoms, atoms_lammps) - - -@pytest.mark.parametrize( - 'system,ncores', - [ - ('bulk', 2), - ('surface', 2), - ], -) -def test_modal_parallel( - system, - ncores, - parallel_modal_potential_path, - ref_modal_calculator, - lammps_cmd, - mpirun_cmd, - tmp_path, -): - if system == 'bulk': - rep = (6, 6, 3) - elif system == 'surface': - rep = (4, 4, 1) - else: - assert False - atoms = get_system(system, replicate=rep) - atoms_lammps = parallel_lammps_run( - atoms=atoms, - potential=parallel_modal_potential_path, - wd=tmp_path, - test_name='parallel lmp test', - lammps_cmd=lammps_cmd, - mpirun_cmd=mpirun_cmd, - ncores=ncores, - ) - atoms.calc = ref_modal_calculator - assert_atoms(atoms, atoms_lammps) - - -@pytest.mark.filterwarnings('ignore:.*is not found from.*') -@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') -def test_cueq_serial(lammps_cmd, tmp_path): - """ - TODO: Use already saved cueq enabled checkpoint after cueq becomes stable - """ - cueq = True - model = get_model(use_cueq=cueq) - ref_calc = SevenNetCalculator(model, file_type='model_instance') - atoms = get_system('bulk') - - cfg = get_model_config() - cfg.update( - {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} - ) - - cp_path = str(tmp_path / 'cp.pth') - torch.save( - {'model_state_dict': model.state_dict(), 'config': cfg}, - cp_path, - ) - - pot_path = str(tmp_path / 'deployed_from_cueq_serial.pt') - deploy(cp_path, pot_path) - - atoms_lammps = serial_lammps_run( - atoms=atoms, - potential=pot_path, - wd=tmp_path, - test_name='cueq checkpoint serial lmp run test', - lammps_cmd=lammps_cmd, - ) - atoms.calc = ref_calc - assert_atoms(atoms, atoms_lammps) - - -@pytest.mark.filterwarnings('ignore:.*is not found from.*') -@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') -def test_cueq_parallel(lammps_cmd, mpirun_cmd, tmp_path): - """ - TODO: Use already saved cueq enabled checkpoint after cueq becomes stable - """ - cueq = True - model = get_model(use_cueq=cueq) - ref_calc = SevenNetCalculator(model, file_type='model_instance') - atoms = get_system('surface', replicate=(4, 4, 1)) - - cfg = get_model_config() - cfg.update( - {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} - ) - - cp_path = str(tmp_path / 'cp.pth') - torch.save( - {'model_state_dict': model.state_dict(), 'config': cfg}, - cp_path, - ) - - pot_path = str(tmp_path / 'deployed_from_cueq_parallel') - deploy_parallel(cp_path, pot_path) - - atoms_lammps = parallel_lammps_run( - atoms=atoms, - potential=' '.join([str(cfg['num_convolution_layer']), pot_path]), - wd=tmp_path, - test_name='cueq checkpoint parallel lmp run test', - lammps_cmd=lammps_cmd, - mpirun_cmd=mpirun_cmd, - ncores=2, - ) - atoms.calc = ref_calc - assert_atoms(atoms, atoms_lammps) +import copy +import logging +import pathlib +import subprocess + +import ase.calculators.lammps +import ase.io.lammpsdata +import numpy as np +import pytest +import torch +from ase.build import bulk, surface +from ase.calculators.singlepoint import SinglePointCalculator + +import sevenn +from sevenn.calculator import SevenNetCalculator +from sevenn.model_build import build_E3_equivariant_model +from sevenn.nn.cue_helper import is_cue_available +from sevenn.scripts.deploy import deploy, deploy_parallel +from sevenn.util import chemical_species_preprocess, pretrained_name_to_path + +logger = logging.getLogger('test_lammps') + +cutoff = 4.0 + +lmp_script_path = str( + (pathlib.Path(__file__).parent / 'scripts' / 'skel.lmp').resolve() +) + +data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() +cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') # knows Hf, O +cp_mf_path = pretrained_name_to_path('7net-mf-0') + + +@pytest.fixture(scope='module') +def serial_potential_path(tmp_path_factory): + tmp = tmp_path_factory.mktemp('serial_potential') + pot_path = str(tmp / 'deployed_serial.pt') + deploy(cp_0_path, pot_path) + return pot_path + + +@pytest.fixture(scope='module') +def parallel_potential_path(tmp_path_factory): + tmp = tmp_path_factory.mktemp('paralllel_potential') + pot_path = str(tmp / 'deployed_parallel') + deploy_parallel(cp_0_path, pot_path) + return ' '.join(['3', pot_path]) + + +@pytest.fixture(scope='module') +def serial_modal_potential_path(tmp_path_factory): + tmp = tmp_path_factory.mktemp('serial_modal_potential') + pot_path = str(tmp / 'deployed_serial.pt') + deploy(cp_mf_path, pot_path, 'PBE') + return pot_path + + +@pytest.fixture(scope='module') +def parallel_modal_potential_path(tmp_path_factory): + tmp = tmp_path_factory.mktemp('paralllel_modal_potential') + pot_path = str(tmp / 'deployed_parallel') + deploy_parallel(cp_mf_path, pot_path, 'PBE') + return ' '.join(['5', pot_path]) + + +@pytest.fixture(scope='module') +def ref_calculator(): + return SevenNetCalculator(cp_0_path) + + +@pytest.fixture(scope='module') +def ref_modal_calculator(): + return SevenNetCalculator(cp_mf_path, modal='PBE') + + +def get_model_config(): + config = { + 'cutoff': cutoff, + 'channel': 8, + 'lmax': 2, + 'is_parity': True, + 'num_convolution_layer': 3, + 'self_connection_type': 'linear', # not NequIp + 'interaction_type': 'nequip', + 'radial_basis': { + 'radial_basis_name': 'bessel', + }, + 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, + 'weight_nn_hidden_neurons': [64, 64], + 'act_radial': 'silu', + 'act_scalar': {'e': 'silu', 'o': 'tanh'}, + 'act_gate': {'e': 'silu', 'o': 'tanh'}, + 'conv_denominator': 30.0, + 'train_denominator': False, + 'shift': -10.0, + 'scale': 10.0, + 'train_shift_scale': False, + 'irreps_manual': False, + 'lmax_edge': -1, + 'lmax_node': -1, + 'readout_as_fcn': False, + 'use_bias_in_linear': False, + '_normalize_sph': True, + } + config.update(chemical_species_preprocess(['Hf', 'O'])) + return config + + +def get_model(config_overwrite=None, use_cueq=False, cueq_config=None): + cf = get_model_config() + if config_overwrite is not None: + cf.update(config_overwrite) + + cueq_config = cueq_config or {'cuequivariance_config': {'use': use_cueq}} + cf.update(cueq_config) + + model = build_E3_equivariant_model(cf, parallel=False) + assert not isinstance(model, list) + return model + + +def hfo2_bulk(replicate=(2, 2, 2), a=4.0): + atoms = bulk('HfO', 'rocksalt', a, orthorhombic=True) + atoms = atoms * replicate + atoms.rattle(stdev=0.10) + return atoms + + +def hf_surface(replicate=(3, 3, 1), layers=4, vacuum=0.5): + atoms = surface('Al', (1, 0, 0), layers=layers, vacuum=vacuum) + atoms.set_atomic_numbers([72] * len(atoms)) # Hf + atoms = atoms * replicate + atoms.rattle(stdev=0.10) + return atoms + + +def get_system(system_name, **kwargs): + if system_name == 'bulk': + return hfo2_bulk(**kwargs) + elif system_name == 'surface': + return hf_surface(**kwargs) + else: + raise ValueError() + + +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): + def acl(a, b, rtol=rtol, atol=atol): + return np.allclose(a, b, rtol=rtol, atol=atol) + + assert len(atoms1) == len(atoms2) + assert acl(atoms1.get_cell(), atoms2.get_cell()) + assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) + assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) + assert acl( + atoms1.get_stress(voigt=False), + atoms2.get_stress(voigt=False), + rtol * 10, + atol * 10, + ) + # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) + + +def _lammps_results_to_atoms(lammps_log, force_dump): + with open(lammps_log, 'r') as f: + lines = f.readlines() + lmp_log = None + for i, line in enumerate(lines): + if not line.startswith('Per MPI rank memory allocation'): + continue + lmp_log = { + k: eval(v) for k, v in zip(lines[i + 1].split(), lines[i + 2].split()) + } + break + + assert lmp_log is not None and 'PotEng' in lmp_log + + latoms_list = ase.io.read(force_dump, format='lammps-dump-text', index=':') + assert isinstance(latoms_list, list) + latoms = latoms_list[0] + assert latoms.calc is not None + latoms.calc.results['energy'] = lmp_log['PotEng'] + latoms.calc.results['free_energy'] = lmp_log['PotEng'] + latoms.info = { + 'data_from': 'lammps', + 'lmp_log': lmp_log, + 'lmp_dump': force_dump, + } + # atomic energy read + latoms.calc.results['energies'] = latoms.arrays['c_pa'][:, 0] + stress = np.array( + [ + [lmp_log['Pxx'], lmp_log['Pxy'], lmp_log['Pxz']], + [lmp_log['Pxy'], lmp_log['Pyy'], lmp_log['Pyz']], + [lmp_log['Pxz'], lmp_log['Pyz'], lmp_log['Pzz']], + ] + ) + stress = -1 * stress / 1602.1766208 / 1000 # convert bars to eV/A^3 + latoms.calc.results['stress'] = stress + + return latoms + + +def _run_lammps(atoms, pair_style, potential, wd, command, test_name): + wd = wd.resolve() + pbc = atoms.get_pbc() + pbc_str = ' '.join(['p' if x else 'f' for x in pbc]) + chem = list(set(atoms.get_chemical_symbols())) + # Way to ase handle lammps structure + + prism = ase.calculators.lammps.coordinatetransform.Prism( + atoms.get_cell(), pbc=pbc + ) + lmp_stct = wd / 'lammps_structure' + ase.io.lammpsdata.write_lammps_data( + lmp_stct, atoms, prismobj=prism, specorder=chem + ) + + with open(lmp_script_path, 'r') as f: + cont = f.read() + + lammps_log = str(wd / 'log.lammps') + force_dump = str(wd / 'force.dump') + + var_dct = {} + var_dct['__ELEMENT__'] = ' '.join(chem) + var_dct['__LMP_STCT__'] = str(lmp_stct.resolve()) + var_dct['__PAIR_STYLE__'] = pair_style + var_dct['__POTENTIALS__'] = potential + var_dct['__BOUNDARY__'] = pbc_str + var_dct['__FORCE_DUMP_PATH__'] = force_dump + for key, val in var_dct.items(): + cont = cont.replace(key, val) + + input_script_path = str(wd / 'in.lmp') + with open(input_script_path, 'w') as f: + f.write(cont) + + command = f'{command} -in {input_script_path} -log {lammps_log}' + subprocess_routine(command.split(), test_name) + + lmp_atoms = _lammps_results_to_atoms(lammps_log, force_dump) + assert lmp_atoms.calc is not None + + rot_mat = prism.rot_mat + results = copy.deepcopy(lmp_atoms.calc.results) + r_force = np.dot(results['forces'], rot_mat.T) + results['forces'] = r_force + if 'stress' in results: + # see ase.calculators.lammpsrun.py + stress_tensor = results['stress'] + stress_atoms = np.dot(np.dot(rot_mat, stress_tensor), rot_mat.T) + results['stress'] = stress_atoms + r_cell = lmp_atoms.get_cell() @ rot_mat.T + lmp_atoms.set_cell(r_cell, scale_atoms=True) + lmp_atoms = SinglePointCalculator(lmp_atoms, **results).get_atoms() + + return lmp_atoms + + +def serial_lammps_run(atoms, potential, wd, test_name, lammps_cmd): + command = lammps_cmd + return _run_lammps(atoms, 'e3gnn', potential, wd, command, test_name) + + +def parallel_lammps_run( + atoms, potential, wd, test_name, ncores, lammps_cmd, mpirun_cmd +): + command = f'{mpirun_cmd} -np {ncores} {lammps_cmd}' + return _run_lammps(atoms, 'e3gnn/parallel', potential, wd, command, test_name) + + +def subprocess_routine(cmd, name): + res = subprocess.run(cmd, capture_output=True, timeout=30) + if res.returncode != 0: + logger.error(f'Subprocess {name} failed return code: {res.returncode}') + logger.error(res.stderr.decode('utf-8')) + raise RuntimeError(f'{name} failed') + + logger.info(f'stdout of {name}:') + logger.info(res.stdout.decode('utf-8')) + + +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial(system, serial_potential_path, ref_calculator, lammps_cmd, tmp_path): + atoms = get_system(system) + atoms_lammps = serial_lammps_run( + atoms=atoms, + potential=serial_potential_path, + wd=tmp_path, + test_name='serial lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_calculator + assert_atoms(atoms, atoms_lammps) + + +@pytest.mark.parametrize( + 'system,ncores', + [ + ('bulk', 1), + ('bulk', 2), + ('bulk', 4), + ('surface', 1), + ('surface', 2), + ('surface', 3), + ('surface', 4), + ], +) +def test_parallel( + system, + ncores, + parallel_potential_path, + ref_calculator, + lammps_cmd, + mpirun_cmd, + tmp_path, +): + if system == 'bulk': + rep = (6, 6, 3) + elif system == 'surface': + rep = (4, 4, 1) + else: + assert False + atoms = get_system(system, replicate=rep) + atoms_lammps = parallel_lammps_run( + atoms=atoms, + potential=parallel_potential_path, + wd=tmp_path, + test_name='parallel lmp test', + lammps_cmd=lammps_cmd, + mpirun_cmd=mpirun_cmd, + ncores=ncores, + ) + atoms.calc = ref_calculator + assert_atoms(atoms, atoms_lammps) + + +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_modal_serial( + system, serial_modal_potential_path, ref_modal_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_lammps_run( + atoms=atoms, + potential=serial_modal_potential_path, + wd=tmp_path, + test_name='serial lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_modal_calculator + assert_atoms(atoms, atoms_lammps) + + +@pytest.mark.parametrize( + 'system,ncores', + [ + ('bulk', 2), + ('surface', 2), + ], +) +def test_modal_parallel( + system, + ncores, + parallel_modal_potential_path, + ref_modal_calculator, + lammps_cmd, + mpirun_cmd, + tmp_path, +): + if system == 'bulk': + rep = (6, 6, 3) + elif system == 'surface': + rep = (4, 4, 1) + else: + assert False + atoms = get_system(system, replicate=rep) + atoms_lammps = parallel_lammps_run( + atoms=atoms, + potential=parallel_modal_potential_path, + wd=tmp_path, + test_name='parallel lmp test', + lammps_cmd=lammps_cmd, + mpirun_cmd=mpirun_cmd, + ncores=ncores, + ) + atoms.calc = ref_modal_calculator + assert_atoms(atoms, atoms_lammps) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') +def test_cueq_serial(lammps_cmd, tmp_path): + """ + TODO: Use already saved cueq enabled checkpoint after cueq becomes stable + """ + cueq = True + model = get_model(use_cueq=cueq) + ref_calc = SevenNetCalculator(model, file_type='model_instance') + atoms = get_system('bulk') + + cfg = get_model_config() + cfg.update( + {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} + ) + + cp_path = str(tmp_path / 'cp.pth') + torch.save( + {'model_state_dict': model.state_dict(), 'config': cfg}, + cp_path, + ) + + pot_path = str(tmp_path / 'deployed_from_cueq_serial.pt') + deploy(cp_path, pot_path) + + atoms_lammps = serial_lammps_run( + atoms=atoms, + potential=pot_path, + wd=tmp_path, + test_name='cueq checkpoint serial lmp run test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_calc + assert_atoms(atoms, atoms_lammps) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') +def test_cueq_parallel(lammps_cmd, mpirun_cmd, tmp_path): + """ + TODO: Use already saved cueq enabled checkpoint after cueq becomes stable + """ + cueq = True + model = get_model(use_cueq=cueq) + ref_calc = SevenNetCalculator(model, file_type='model_instance') + atoms = get_system('surface', replicate=(4, 4, 1)) + + cfg = get_model_config() + cfg.update( + {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} + ) + + cp_path = str(tmp_path / 'cp.pth') + torch.save( + {'model_state_dict': model.state_dict(), 'config': cfg}, + cp_path, + ) + + pot_path = str(tmp_path / 'deployed_from_cueq_parallel') + deploy_parallel(cp_path, pot_path) + + atoms_lammps = parallel_lammps_run( + atoms=atoms, + potential=' '.join([str(cfg['num_convolution_layer']), pot_path]), + wd=tmp_path, + test_name='cueq checkpoint parallel lmp run test', + lammps_cmd=lammps_cmd, + mpirun_cmd=mpirun_cmd, + ncores=2, + ) + atoms.calc = ref_calc + assert_atoms(atoms, atoms_lammps) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_calculator.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_calculator.py index f2e6f8e5adbdee7d5451f9eea49c570ce3ad3efb..c2494927674355bd5a3f488a1ae421b397fe3ac7 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_calculator.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_calculator.py @@ -1,217 +1,217 @@ -import copy - -import numpy as np -import pytest -from ase.build import bulk, molecule - -from sevenn.calculator import D3Calculator, SevenNetCalculator -from sevenn.nn.cue_helper import is_cue_available -from sevenn.scripts.deploy import deploy -from sevenn.util import ( - model_from_checkpoint, - model_from_checkpoint_with_backend, - pretrained_name_to_path, -) - - -@pytest.fixture -def atoms_pbc(): - atoms1 = bulk('NaCl', 'rocksalt', a=5.63) - atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) - atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) - return atoms1 - - -@pytest.fixture -def atoms_mol(): - atoms2 = molecule('H2O') - atoms2.set_positions([[0.0, 0.2, 0.12], [0.0, 0.76, -0.48], [0.0, -0.76, -0.48]]) - return atoms2 - - -@pytest.fixture(scope='module') -def sevennet_0_cal(): - return SevenNetCalculator('7net-0_11July2024') - - -@pytest.fixture(scope='module') -def sevennet_0_cueq_cal(): - cpp = pretrained_name_to_path('7net-0_11July2024') - model, _ = model_from_checkpoint_with_backend(cpp, 'cueq') - return SevenNetCalculator(model) - - -@pytest.fixture(scope='module') -def d3_cal(): - try: - return D3Calculator() - except NotImplementedError as e: - pytest.skip(f'{e}') - - -def test_sevennet_0_cal_pbc(atoms_pbc, sevennet_0_cal): - atoms1_ref = { - 'energy': -3.779199, - 'energies': [-1.8493923, -1.9298072], - 'force': [ - [12.666697, 0.04726403, 0.04775861], - [-12.666697, -0.04726403, -0.04775861], - ], - 'stress': [ - [ - -0.6439122, - -0.03643947, - -0.03643981, - 0.00599139, - 0.04544507, - 0.04543639, - ] - ], - } - - atoms_pbc.calc = sevennet_0_cal - assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) - assert np.allclose( - atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] - ) - assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) - assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) - assert np.allclose(atoms_pbc.get_potential_energies(), atoms1_ref['energies']) - - -def test_sevennet_0_cal_mol(atoms_mol, sevennet_0_cal): - atoms2_ref = { - 'energy': -12.782808303833008, - 'energies': [-6.2493525, -3.141562, -3.3918958], - 'force': [ - [0.0, -1.3619621e01, 7.5937047e00], - [0.0, 9.3918495e00, -1.0172190e01], - [0.0, 4.2277718e00, 2.5784855e00], - ], - } - atoms_mol.calc = sevennet_0_cal - assert np.allclose(atoms_mol.get_potential_energy(), atoms2_ref['energy']) - assert np.allclose( - atoms_mol.get_potential_energy(force_consistent=True), atoms2_ref['energy'] - ) - assert np.allclose(atoms_mol.get_forces(), atoms2_ref['force']) - assert np.allclose(atoms_mol.get_potential_energies(), atoms2_ref['energies']) - - -def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc): - fname = str(tmp_path / '7net_0.pt') - deploy(pretrained_name_to_path('7net-0_11July2024'), fname) - - calc_script = SevenNetCalculator(fname, file_type='torchscript') - calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024')) - - atoms_pbc.calc = calc_cp - atoms_pbc.get_potential_energy() - res_cp = copy.copy(atoms_pbc.calc.results) - - atoms_pbc.calc = calc_script - atoms_pbc.get_potential_energy() - res_script = copy.copy(atoms_pbc.calc.results) - - for k in res_cp: - assert np.allclose(res_cp[k], res_script[k]) - - -def test_sevennet_0_cal_as_instance_consistency(atoms_pbc): - model, _ = model_from_checkpoint( - pretrained_name_to_path('7net-0_11July2024') - ) - - calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024')) - calc_instance = SevenNetCalculator(model, file_type='model_instance') - - atoms_pbc.calc = calc_cp - atoms_pbc.get_potential_energy() - res_cp = copy.copy(atoms_pbc.calc.results) - - atoms_pbc.calc = calc_instance - atoms_pbc.get_potential_energy() - res_script = copy.copy(atoms_pbc.calc.results) - - for k in res_cp: - assert np.allclose(res_cp[k], res_script[k]) - - -@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') -def test_sevennet_0_cal_cueq(atoms_pbc, sevennet_0_cueq_cal): - atoms1_ref = { - 'energy': -3.779199, - 'energies': [-1.8493923, -1.9298072], - 'force': [ - [12.666697, 0.04726403, 0.04775861], - [-12.666697, -0.04726403, -0.04775861], - ], - 'stress': [ - [ - -0.6439122, - -0.03643947, - -0.03643981, - 0.00599139, - 0.04544507, - 0.04543639, - ] - ], - } - - atoms_pbc.calc = sevennet_0_cueq_cal - - assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) - assert np.allclose( - atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] - ) - assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) - assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) - assert np.allclose(atoms_pbc.get_potential_energies(), atoms1_ref['energies']) - - -def test_d3_cal_pbc(atoms_pbc, d3_cal): - atoms1_ref = { - 'energy': -0.531393751583389, - 'force': [ - [-0.00570205, 0.00107457, 0.00107459], - [0.00570205, -0.00107457, -0.00107459], - ], - 'stress': [ - [ - 1.52403705e-02, - 1.50417333e-02, - 1.50417321e-02, - -3.22684163e-05, - -5.05532863e-05, - -5.05586994e-05, - ] - ], - } - - atoms_pbc.calc = d3_cal - - assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) - assert np.allclose( - atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] - ) - assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) - assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) - - -def test_d3_cal_mol(atoms_mol, d3_cal): - atoms2_ref = { - 'energy': -0.009889134535170716, - 'force': [ - [0.0, 2.04263840e-03, 1.27477674e-03], - [0.0, -9.90038901e-05, 1.18046682e-06], - [0.0, -1.94363451e-03, -1.27595721e-03], - ], - } - - atoms_mol.calc = d3_cal - - assert np.allclose(atoms_mol.get_potential_energy(), atoms2_ref['energy']) - assert np.allclose( - atoms_mol.get_potential_energy(force_consistent=True), atoms2_ref['energy'] - ) - assert np.allclose(atoms_mol.get_forces(), atoms2_ref['force']) +import copy + +import numpy as np +import pytest +from ase.build import bulk, molecule + +from sevenn.calculator import D3Calculator, SevenNetCalculator +from sevenn.nn.cue_helper import is_cue_available +from sevenn.scripts.deploy import deploy +from sevenn.util import ( + model_from_checkpoint, + model_from_checkpoint_with_backend, + pretrained_name_to_path, +) + + +@pytest.fixture +def atoms_pbc(): + atoms1 = bulk('NaCl', 'rocksalt', a=5.63) + atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) + atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) + return atoms1 + + +@pytest.fixture +def atoms_mol(): + atoms2 = molecule('H2O') + atoms2.set_positions([[0.0, 0.2, 0.12], [0.0, 0.76, -0.48], [0.0, -0.76, -0.48]]) + return atoms2 + + +@pytest.fixture(scope='module') +def sevennet_0_cal(): + return SevenNetCalculator('7net-0_11July2024') + + +@pytest.fixture(scope='module') +def sevennet_0_cueq_cal(): + cpp = pretrained_name_to_path('7net-0_11July2024') + model, _ = model_from_checkpoint_with_backend(cpp, 'cueq') + return SevenNetCalculator(model) + + +@pytest.fixture(scope='module') +def d3_cal(): + try: + return D3Calculator() + except NotImplementedError as e: + pytest.skip(f'{e}') + + +def test_sevennet_0_cal_pbc(atoms_pbc, sevennet_0_cal): + atoms1_ref = { + 'energy': -3.779199, + 'energies': [-1.8493923, -1.9298072], + 'force': [ + [12.666697, 0.04726403, 0.04775861], + [-12.666697, -0.04726403, -0.04775861], + ], + 'stress': [ + [ + -0.6439122, + -0.03643947, + -0.03643981, + 0.00599139, + 0.04544507, + 0.04543639, + ] + ], + } + + atoms_pbc.calc = sevennet_0_cal + assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) + assert np.allclose( + atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] + ) + assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) + assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) + assert np.allclose(atoms_pbc.get_potential_energies(), atoms1_ref['energies']) + + +def test_sevennet_0_cal_mol(atoms_mol, sevennet_0_cal): + atoms2_ref = { + 'energy': -12.782808303833008, + 'energies': [-6.2493525, -3.141562, -3.3918958], + 'force': [ + [0.0, -1.3619621e01, 7.5937047e00], + [0.0, 9.3918495e00, -1.0172190e01], + [0.0, 4.2277718e00, 2.5784855e00], + ], + } + atoms_mol.calc = sevennet_0_cal + assert np.allclose(atoms_mol.get_potential_energy(), atoms2_ref['energy']) + assert np.allclose( + atoms_mol.get_potential_energy(force_consistent=True), atoms2_ref['energy'] + ) + assert np.allclose(atoms_mol.get_forces(), atoms2_ref['force']) + assert np.allclose(atoms_mol.get_potential_energies(), atoms2_ref['energies']) + + +def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc): + fname = str(tmp_path / '7net_0.pt') + deploy(pretrained_name_to_path('7net-0_11July2024'), fname) + + calc_script = SevenNetCalculator(fname, file_type='torchscript') + calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024')) + + atoms_pbc.calc = calc_cp + atoms_pbc.get_potential_energy() + res_cp = copy.copy(atoms_pbc.calc.results) + + atoms_pbc.calc = calc_script + atoms_pbc.get_potential_energy() + res_script = copy.copy(atoms_pbc.calc.results) + + for k in res_cp: + assert np.allclose(res_cp[k], res_script[k]) + + +def test_sevennet_0_cal_as_instance_consistency(atoms_pbc): + model, _ = model_from_checkpoint( + pretrained_name_to_path('7net-0_11July2024') + ) + + calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024')) + calc_instance = SevenNetCalculator(model, file_type='model_instance') + + atoms_pbc.calc = calc_cp + atoms_pbc.get_potential_energy() + res_cp = copy.copy(atoms_pbc.calc.results) + + atoms_pbc.calc = calc_instance + atoms_pbc.get_potential_energy() + res_script = copy.copy(atoms_pbc.calc.results) + + for k in res_cp: + assert np.allclose(res_cp[k], res_script[k]) + + +@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') +def test_sevennet_0_cal_cueq(atoms_pbc, sevennet_0_cueq_cal): + atoms1_ref = { + 'energy': -3.779199, + 'energies': [-1.8493923, -1.9298072], + 'force': [ + [12.666697, 0.04726403, 0.04775861], + [-12.666697, -0.04726403, -0.04775861], + ], + 'stress': [ + [ + -0.6439122, + -0.03643947, + -0.03643981, + 0.00599139, + 0.04544507, + 0.04543639, + ] + ], + } + + atoms_pbc.calc = sevennet_0_cueq_cal + + assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) + assert np.allclose( + atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] + ) + assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) + assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) + assert np.allclose(atoms_pbc.get_potential_energies(), atoms1_ref['energies']) + + +def test_d3_cal_pbc(atoms_pbc, d3_cal): + atoms1_ref = { + 'energy': -0.531393751583389, + 'force': [ + [-0.00570205, 0.00107457, 0.00107459], + [0.00570205, -0.00107457, -0.00107459], + ], + 'stress': [ + [ + 1.52403705e-02, + 1.50417333e-02, + 1.50417321e-02, + -3.22684163e-05, + -5.05532863e-05, + -5.05586994e-05, + ] + ], + } + + atoms_pbc.calc = d3_cal + + assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) + assert np.allclose( + atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] + ) + assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) + assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) + + +def test_d3_cal_mol(atoms_mol, d3_cal): + atoms2_ref = { + 'energy': -0.009889134535170716, + 'force': [ + [0.0, 2.04263840e-03, 1.27477674e-03], + [0.0, -9.90038901e-05, 1.18046682e-06], + [0.0, -1.94363451e-03, -1.27595721e-03], + ], + } + + atoms_mol.calc = d3_cal + + assert np.allclose(atoms_mol.get_potential_energy(), atoms2_ref['energy']) + assert np.allclose( + atoms_mol.get_potential_energy(force_consistent=True), atoms2_ref['energy'] + ) + assert np.allclose(atoms_mol.get_forces(), atoms2_ref['force']) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cli.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cli.py index ad0e80cdddda34d28565a67e9ce4c45c58320b25..0bdeeabf4757e7c2bcf5a09ca6183576f7138af9 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cli.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cli.py @@ -1,233 +1,233 @@ -import csv -import os -import pathlib -from unittest import mock - -import ase.io -import numpy as np -import pytest -import yaml -from ase.build import bulk - -from sevenn.calculator import SevenNetCalculator -from sevenn.logger import Logger -from sevenn.main.sevenn import main as sevenn_main -from sevenn.main.sevenn_get_model import main as get_model_main -from sevenn.main.sevenn_graph_build import main as graph_build_main -from sevenn.main.sevenn_inference import main as inference_main -from sevenn.util import pretrained_name_to_path - -main = os.path.abspath(f'{os.path.dirname(__file__)}/../../sevenn/main/') -preset = os.path.abspath(f'{os.path.dirname(__file__)}/../../sevenn/presets/') -file_path = pathlib.Path(__file__).parent.resolve() - -data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() -hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') -hfo2_7net_0_inference_path = data_root / 'inferences' / 'snet0_on_hfo2' -cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') - -Logger() # init - - -@pytest.fixture -def atoms_hfo(): - atoms1 = bulk('HfO', 'rocksalt', a=5.63) - atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) - atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) - return atoms1 - - -@pytest.fixture(scope='module') -def sevennet_0_cal(): - return SevenNetCalculator('7net-0_11July2024') - - -def test_get_model_serial(tmp_path, capsys): - output_file = tmp_path / 'mypot.pt' - cp = pretrained_name_to_path('7net-0') - cli_args = ['-o', str(output_file), cp] - with mock.patch('sys.argv', [f'{main}/sevenn_get_model.py'] + cli_args): - get_model_main() - _ = capsys.readouterr() # not used - assert output_file.is_file(), '.pt file is not written' - - -def test_get_model_parallel(tmp_path, capsys): - output_dir = tmp_path / 'my_parallel' - cp = pretrained_name_to_path('7net-0') - expected_file_cnt = 5 # 5 interaction layers - cli_args = ['-o', str(output_dir), '-p', cp] - with mock.patch('sys.argv', [f'{main}/sevenn_get_model.py'] + cli_args): - # with pytest.raises(SystemExit): - get_model_main() - _ = capsys.readouterr() # not used - assert output_dir.is_dir(), 'parallel model directory not exist' - for i in range(expected_file_cnt): - assert (output_dir / f'deployed_parallel_{i}.pt').is_file() - - -@pytest.mark.parametrize('source', [(hfo2_path)]) -def test_graph_build(source, tmp_path): - output_dir = tmp_path / 'sevenn_data' - output_f = output_dir / 'my_graph.pt' - output_yml = output_dir / 'my_graph.yaml' - cli_args = ['-o', str(tmp_path), '-f', 'my_graph.pt', source, '4.0'] - with mock.patch('sys.argv', [f'{main}/sevenn_graph_build.py'] + cli_args): - graph_build_main() - - assert output_dir.is_dir() - assert output_f.is_file() - assert output_yml.is_file() - - -@pytest.mark.parametrize( - 'batch,device,save_graph', - [ - (1, 'cpu', False), - (2, 'cpu', False), - (1, 'cpu', True), - ], -) -def test_inference(batch, device, save_graph, tmp_path): - checkpoint = '7net-0' - target = hfo2_path - ref_path = hfo2_7net_0_inference_path - - output_dir = tmp_path / 'inference_results' - files = ['info.csv', 'per_graph.csv', 'per_atom.csv', 'errors.txt'] - cli_args = [ - '--output', - str(output_dir), - '--device', - device, - '--batch', - str(batch), - checkpoint, - target, - ] - if save_graph: - cli_args.append('--save_graph') - with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): - inference_main() - - assert output_dir.is_dir() - for f in files: - assert (output_dir / f).is_file() - with open(output_dir / 'errors.txt', 'r', encoding='utf-8') as f: - errors = [float(ll.split(':')[-1].strip()) for ll in f.readlines()] - with open(ref_path / 'errors.txt', 'r', encoding='utf-8') as f: - errors_ref = [float(ll.split(':')[-1].strip()) for ll in f.readlines()] - assert np.allclose(np.array(errors), np.array(errors_ref)) - - """ - # TODO: commented out as currently SevenNetGraphDataset can't do this - with open(output_dir / 'info.csv', 'r') as f: - reader = csv.DictReader(f) - for dct in reader: - assert dct['file'] == hfo2_path - assert reader.line_num == 3 - """ - - if save_graph: - assert (output_dir / 'sevenn_data').is_dir() - assert (output_dir / 'sevenn_data' / 'saved_graph.pt').is_file() - assert (output_dir / 'sevenn_data' / 'saved_graph.yaml').is_file() - - -def test_inference_unlabeled(atoms_hfo, tmp_path): - labeled = str(hfo2_path) - unlabeled = str(tmp_path / 'unlabeled.xyz') - ase.io.write(unlabeled, atoms_hfo) - - output_dir = tmp_path / 'inference_results' - cli_args = [ - '--output', - str(output_dir), - '--allow_unlabeled', - cp_0_path, - labeled, - unlabeled, - ] - with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): - inference_main() - - with open(output_dir / 'info.csv', 'r') as f: - reader = csv.DictReader(f) - for dct in reader: - assert dct['file'] in [labeled, unlabeled] - assert reader.line_num == 4 - - -def test_inference_labeled_w_kwargs(atoms_hfo, tmp_path): - atoms_hfo.info['my_energy'] = 1.0 - atoms_hfo.arrays['my_force'] = np.full((len(atoms_hfo), 3), 7.7) - # this should be considered as Voigt, xx, yy, zz, yz, zx, xy - atoms_hfo.info['my_stress'] = np.array([1, 2, 3, 4, 5, 6]) - - unlabeled = str(tmp_path / 'unlabeled.xyz') - ase.io.write(unlabeled, atoms_hfo) - - output_dir = tmp_path / 'inference_results' - cli_args = [ - '--output', - str(output_dir), - cp_0_path, - unlabeled, - '--kwargs', - 'energy_key=my_energy', - 'force_key=my_force', - 'stress_key=my_stress', - ] - with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): - inference_main() - - per_graph = None - with open(output_dir / 'per_graph.csv', 'r') as f: - reader = csv.DictReader(f) - for dct in reader: - per_graph = dct - assert reader.line_num == 2 - assert per_graph is not None - - stress_coeff = -1602.1766208 - assert np.allclose(float(per_graph['stress_yy']), 2 * stress_coeff) - assert np.allclose(float(per_graph['stress_yz']), 4 * stress_coeff) - assert np.allclose(float(per_graph['stress_zx']), 5 * stress_coeff) - assert np.allclose(float(per_graph['stress_xy']), 6 * stress_coeff) - - -@pytest.mark.parametrize( - 'preset_name,mode,data_path', - [ - ('fine_tune', 'train_v2', hfo2_path), - ('base', 'train_v2', hfo2_path), - ('sevennet-0', 'train_v1', hfo2_path), - ], -) -def test_sevenn_preset(preset_name, mode, data_path, tmp_path): - preset_path = os.path.join(preset, preset_name + '.yaml') - with open(preset_path, 'r') as f: - cfg = yaml.safe_load(f) - - cfg['train']['epoch'] = 1 - if mode == 'train_v2': - cfg['data']['load_trainset_path'] = data_path - cfg['data'].pop('load_testset_path', None) - elif mode == 'train_v1': - cfg['data']['load_dataset_path'] = data_path - else: - assert False - cfg['data']['load_validset_path'] = data_path - - input_yam = str(tmp_path / 'input.yaml') - with open(input_yam, 'w') as f: - yaml.dump(cfg, f) - - Logger().switch_file(str(tmp_path / 'log.sevenn')) - cli_args = ['train', '-w', str(tmp_path), '-m', mode, input_yam] - with mock.patch('sys.argv', [f'{main}/sevenn.py'] + cli_args): - sevenn_main() - - assert (tmp_path / 'lc.csv').is_file() or (tmp_path / 'log.csv').is_file() - assert (tmp_path / 'log.sevenn').is_file() - assert (tmp_path / 'checkpoint_best.pth').is_file() +import csv +import os +import pathlib +from unittest import mock + +import ase.io +import numpy as np +import pytest +import yaml +from ase.build import bulk + +from sevenn.calculator import SevenNetCalculator +from sevenn.logger import Logger +from sevenn.main.sevenn import main as sevenn_main +from sevenn.main.sevenn_get_model import main as get_model_main +from sevenn.main.sevenn_graph_build import main as graph_build_main +from sevenn.main.sevenn_inference import main as inference_main +from sevenn.util import pretrained_name_to_path + +main = os.path.abspath(f'{os.path.dirname(__file__)}/../../sevenn/main/') +preset = os.path.abspath(f'{os.path.dirname(__file__)}/../../sevenn/presets/') +file_path = pathlib.Path(__file__).parent.resolve() + +data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() +hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') +hfo2_7net_0_inference_path = data_root / 'inferences' / 'snet0_on_hfo2' +cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') + +Logger() # init + + +@pytest.fixture +def atoms_hfo(): + atoms1 = bulk('HfO', 'rocksalt', a=5.63) + atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) + atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) + return atoms1 + + +@pytest.fixture(scope='module') +def sevennet_0_cal(): + return SevenNetCalculator('7net-0_11July2024') + + +def test_get_model_serial(tmp_path, capsys): + output_file = tmp_path / 'mypot.pt' + cp = pretrained_name_to_path('7net-0') + cli_args = ['-o', str(output_file), cp] + with mock.patch('sys.argv', [f'{main}/sevenn_get_model.py'] + cli_args): + get_model_main() + _ = capsys.readouterr() # not used + assert output_file.is_file(), '.pt file is not written' + + +def test_get_model_parallel(tmp_path, capsys): + output_dir = tmp_path / 'my_parallel' + cp = pretrained_name_to_path('7net-0') + expected_file_cnt = 5 # 5 interaction layers + cli_args = ['-o', str(output_dir), '-p', cp] + with mock.patch('sys.argv', [f'{main}/sevenn_get_model.py'] + cli_args): + # with pytest.raises(SystemExit): + get_model_main() + _ = capsys.readouterr() # not used + assert output_dir.is_dir(), 'parallel model directory not exist' + for i in range(expected_file_cnt): + assert (output_dir / f'deployed_parallel_{i}.pt').is_file() + + +@pytest.mark.parametrize('source', [(hfo2_path)]) +def test_graph_build(source, tmp_path): + output_dir = tmp_path / 'sevenn_data' + output_f = output_dir / 'my_graph.pt' + output_yml = output_dir / 'my_graph.yaml' + cli_args = ['-o', str(tmp_path), '-f', 'my_graph.pt', source, '4.0'] + with mock.patch('sys.argv', [f'{main}/sevenn_graph_build.py'] + cli_args): + graph_build_main() + + assert output_dir.is_dir() + assert output_f.is_file() + assert output_yml.is_file() + + +@pytest.mark.parametrize( + 'batch,device,save_graph', + [ + (1, 'cpu', False), + (2, 'cpu', False), + (1, 'cpu', True), + ], +) +def test_inference(batch, device, save_graph, tmp_path): + checkpoint = '7net-0' + target = hfo2_path + ref_path = hfo2_7net_0_inference_path + + output_dir = tmp_path / 'inference_results' + files = ['info.csv', 'per_graph.csv', 'per_atom.csv', 'errors.txt'] + cli_args = [ + '--output', + str(output_dir), + '--device', + device, + '--batch', + str(batch), + checkpoint, + target, + ] + if save_graph: + cli_args.append('--save_graph') + with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): + inference_main() + + assert output_dir.is_dir() + for f in files: + assert (output_dir / f).is_file() + with open(output_dir / 'errors.txt', 'r', encoding='utf-8') as f: + errors = [float(ll.split(':')[-1].strip()) for ll in f.readlines()] + with open(ref_path / 'errors.txt', 'r', encoding='utf-8') as f: + errors_ref = [float(ll.split(':')[-1].strip()) for ll in f.readlines()] + assert np.allclose(np.array(errors), np.array(errors_ref)) + + """ + # TODO: commented out as currently SevenNetGraphDataset can't do this + with open(output_dir / 'info.csv', 'r') as f: + reader = csv.DictReader(f) + for dct in reader: + assert dct['file'] == hfo2_path + assert reader.line_num == 3 + """ + + if save_graph: + assert (output_dir / 'sevenn_data').is_dir() + assert (output_dir / 'sevenn_data' / 'saved_graph.pt').is_file() + assert (output_dir / 'sevenn_data' / 'saved_graph.yaml').is_file() + + +def test_inference_unlabeled(atoms_hfo, tmp_path): + labeled = str(hfo2_path) + unlabeled = str(tmp_path / 'unlabeled.xyz') + ase.io.write(unlabeled, atoms_hfo) + + output_dir = tmp_path / 'inference_results' + cli_args = [ + '--output', + str(output_dir), + '--allow_unlabeled', + cp_0_path, + labeled, + unlabeled, + ] + with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): + inference_main() + + with open(output_dir / 'info.csv', 'r') as f: + reader = csv.DictReader(f) + for dct in reader: + assert dct['file'] in [labeled, unlabeled] + assert reader.line_num == 4 + + +def test_inference_labeled_w_kwargs(atoms_hfo, tmp_path): + atoms_hfo.info['my_energy'] = 1.0 + atoms_hfo.arrays['my_force'] = np.full((len(atoms_hfo), 3), 7.7) + # this should be considered as Voigt, xx, yy, zz, yz, zx, xy + atoms_hfo.info['my_stress'] = np.array([1, 2, 3, 4, 5, 6]) + + unlabeled = str(tmp_path / 'unlabeled.xyz') + ase.io.write(unlabeled, atoms_hfo) + + output_dir = tmp_path / 'inference_results' + cli_args = [ + '--output', + str(output_dir), + cp_0_path, + unlabeled, + '--kwargs', + 'energy_key=my_energy', + 'force_key=my_force', + 'stress_key=my_stress', + ] + with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): + inference_main() + + per_graph = None + with open(output_dir / 'per_graph.csv', 'r') as f: + reader = csv.DictReader(f) + for dct in reader: + per_graph = dct + assert reader.line_num == 2 + assert per_graph is not None + + stress_coeff = -1602.1766208 + assert np.allclose(float(per_graph['stress_yy']), 2 * stress_coeff) + assert np.allclose(float(per_graph['stress_yz']), 4 * stress_coeff) + assert np.allclose(float(per_graph['stress_zx']), 5 * stress_coeff) + assert np.allclose(float(per_graph['stress_xy']), 6 * stress_coeff) + + +@pytest.mark.parametrize( + 'preset_name,mode,data_path', + [ + ('fine_tune', 'train_v2', hfo2_path), + ('base', 'train_v2', hfo2_path), + ('sevennet-0', 'train_v1', hfo2_path), + ], +) +def test_sevenn_preset(preset_name, mode, data_path, tmp_path): + preset_path = os.path.join(preset, preset_name + '.yaml') + with open(preset_path, 'r') as f: + cfg = yaml.safe_load(f) + + cfg['train']['epoch'] = 1 + if mode == 'train_v2': + cfg['data']['load_trainset_path'] = data_path + cfg['data'].pop('load_testset_path', None) + elif mode == 'train_v1': + cfg['data']['load_dataset_path'] = data_path + else: + assert False + cfg['data']['load_validset_path'] = data_path + + input_yam = str(tmp_path / 'input.yaml') + with open(input_yam, 'w') as f: + yaml.dump(cfg, f) + + Logger().switch_file(str(tmp_path / 'log.sevenn')) + cli_args = ['train', '-w', str(tmp_path), '-m', mode, input_yam] + with mock.patch('sys.argv', [f'{main}/sevenn.py'] + cli_args): + sevenn_main() + + assert (tmp_path / 'lc.csv').is_file() or (tmp_path / 'log.csv').is_file() + assert (tmp_path / 'log.sevenn').is_file() + assert (tmp_path / 'checkpoint_best.pth').is_file() diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cueq.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cueq.py index a9c3b157c4369b1999de22cc69e4b0d6d4cad25c..4f1d5bd041d2ba394f9664219821f3d6720ab35c 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cueq.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cueq.py @@ -1,282 +1,282 @@ -# TODO: add gradient test from total loss after double precision. -# so far, it is empirically checked by seeing learning curves -import copy - -import numpy as np -import pytest -import torch -from ase.build import bulk -from torch_geometric.loader.dataloader import Collater - -import sevenn -import sevenn.train.dataload as dl -from sevenn.atom_graph_data import AtomGraphData -from sevenn.calculator import SevenNetCalculator -from sevenn.model_build import build_E3_equivariant_model -from sevenn.nn.cue_helper import is_cue_available -from sevenn.nn.sequential import AtomGraphSequential -from sevenn.util import ( - chemical_species_preprocess, - model_from_checkpoint_with_backend, -) - -cutoff = 4.0 - -_atoms = bulk('NaCl', 'rocksalt', a=4.00) * (2, 2, 2) -_avg_num_neigh = 30.0 -_atoms.rattle() - -_graph = AtomGraphData.from_numpy_dict(dl.unlabeled_atoms_to_graph(_atoms, cutoff)) - - -def get_graphs(batched): - # batch size 2 - cloned = [_graph.clone().to('cuda'), _graph.clone().to('cuda')] - if not batched: - return cloned - else: - return Collater(cloned)(cloned) - - -def get_model_config(): - config = { - 'cutoff': cutoff, - 'channel': 32, - 'lmax': 2, - 'is_parity': True, - 'num_convolution_layer': 3, - 'self_connection_type': 'nequip', # not NequIp - 'interaction_type': 'nequip', - 'radial_basis': { - 'radial_basis_name': 'bessel', - }, - 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, - 'weight_nn_hidden_neurons': [64, 64], - 'act_radial': 'silu', - 'act_scalar': {'e': 'silu', 'o': 'tanh'}, - 'act_gate': {'e': 'silu', 'o': 'tanh'}, - 'conv_denominator': _avg_num_neigh, - 'train_denominator': False, - 'shift': -10.0, - 'scale': 10.0, - 'train_shift_scale': False, - 'irreps_manual': False, - 'lmax_edge': -1, - 'lmax_node': -1, - 'readout_as_fcn': False, - 'use_bias_in_linear': False, - '_normalize_sph': True, - } - chems = set() - chems.update(_atoms.get_chemical_symbols()) - config.update(**chemical_species_preprocess(list(chems))) - return config - - -def get_model(config_overwrite=None, use_cueq=False, cueq_config=None): - cf = get_model_config() - if config_overwrite is not None: - cf.update(config_overwrite) - - cueq_config = cueq_config or {'cuequivariance_config': {'use': use_cueq}} - cf.update(cueq_config) - - model = build_E3_equivariant_model(cf, parallel=False) - assert isinstance(model, AtomGraphSequential) - model.to('cuda') - return model - - -@pytest.mark.skipif( - not is_cue_available() or not torch.cuda.is_available(), - reason='cueq or gpu is not available', -) -@pytest.mark.parametrize( - 'cf', - [ - ({}), - ({'self_connection_type': 'linear'}), - ({'is_parity': False}), - ({'channel': 8}), - ({'lmax': 3}), - ({'num_interaction_layer': 2}), - ({'num_interaction_layer': 4}), - ], -) -def test_model_output(cf): - torch.manual_seed(777) - model_e3nn = get_model(cf) - torch.manual_seed(777) - model_cueq = get_model(cf, use_cueq=True) - - model_e3nn.set_is_batch_data(True) - model_cueq.set_is_batch_data(True) - - e3nn_out = model_e3nn._preprocess(get_graphs(batched=True)) - cueq_out = model_cueq._preprocess(get_graphs(batched=True)) - - for k, e3nn_f in model_e3nn._modules.items(): - cueq_f = model_cueq._modules[k] - e3nn_out = e3nn_f(e3nn_out) # type: ignore - cueq_out = cueq_f(cueq_out) # type: ignore - assert torch.allclose(e3nn_out.x, cueq_out.x, atol=1e-6), ( - f'{k} \n\n {e3nn_f} \n\n {cueq_f}' - ) - - assert torch.allclose( - e3nn_out.inferred_total_energy, cueq_out.inferred_total_energy - ) - assert torch.allclose(e3nn_out.atomic_energy, cueq_out.atomic_energy) - assert torch.allclose( - e3nn_out.inferred_force, cueq_out.inferred_force, atol=1e-5 - ) - assert torch.allclose( - e3nn_out.inferred_stress, cueq_out.inferred_stress, atol=1e-5 - ) - - -@pytest.mark.filterwarnings('ignore:.*is not found from.*') -@pytest.mark.skipif( - not is_cue_available() or not torch.cuda.is_available(), - reason='cueq or gpu is not available', -) -@pytest.mark.parametrize( - 'start_from_cueq', - [ - (True), - (False), - ], -) -def test_checkpoint_convert(tmp_path, start_from_cueq): - torch.manual_seed(123) - model_from = get_model(use_cueq=start_from_cueq) - - cfg = get_model_config() - cfg.update( - { - 'cuequivariance_config': {'use': start_from_cueq}, - 'version': sevenn.__version__, - } - ) - torch.save( - {'model_state_dict': model_from.state_dict(), 'config': cfg}, - tmp_path / 'cp_from.pth', - ) - - backend = 'e3nn' if start_from_cueq else 'cueq' - model_to, _ = model_from_checkpoint_with_backend( - str(tmp_path / 'cp_from.pth'), backend - ) - model_to.to('cuda') - - model_from.set_is_batch_data(True) - model_to.set_is_batch_data(True) - - from_out = model_from(get_graphs(batched=True)) - to_out = model_to(get_graphs(batched=True)) - - assert torch.allclose( - from_out.inferred_total_energy, to_out.inferred_total_energy - ) - assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy) - assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5) - assert torch.allclose( - from_out.inferred_stress, to_out.inferred_stress, atol=1e-5 - ) - - -@pytest.mark.filterwarnings('ignore:.*is not found from.*') -@pytest.mark.skipif( - not is_cue_available() or not torch.cuda.is_available(), - reason='cueq or gpu is not available', -) -@pytest.mark.parametrize( - 'start_from_cueq', - [ - (True), - (False), - ], -) -def test_checkpoint_convert_no_batch(tmp_path, start_from_cueq): - torch.manual_seed(123) - model_from = get_model(use_cueq=start_from_cueq) - - cfg = get_model_config() - cfg.update( - { - 'cuequivariance_config': {'use': start_from_cueq}, - 'version': sevenn.__version__, - } - ) - torch.save( - {'model_state_dict': model_from.state_dict(), 'config': cfg}, - tmp_path / 'cp_from.pth', - ) - - backend = 'e3nn' if start_from_cueq else 'cueq' - model_to, _ = model_from_checkpoint_with_backend( - str(tmp_path / 'cp_from.pth'), backend - ) - model_to.to('cuda') - - model_from.set_is_batch_data(False) - model_to.set_is_batch_data(False) - - from_out = model_from(get_graphs(batched=False)[0]) - to_out = model_to(get_graphs(batched=False)[0]) - - assert torch.allclose( - from_out.inferred_total_energy, to_out.inferred_total_energy - ) - assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy) - assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5) - assert torch.allclose( - from_out.inferred_stress, to_out.inferred_stress, atol=1e-5 - ) - - -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): - def acl(a, b, rtol=rtol, atol=atol): - return np.allclose(a, b, rtol=rtol, atol=atol) - - assert len(atoms1) == len(atoms2) - assert acl(atoms1.get_cell(), atoms2.get_cell()) - assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) - assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) - assert acl( - atoms1.get_stress(voigt=False), - atoms2.get_stress(voigt=False), - rtol * 10, - atol * 10, - ) - # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) - - -@pytest.mark.filterwarnings('ignore:.*is not found from.*') -@pytest.mark.skipif( - not is_cue_available() or not torch.cuda.is_available(), - reason='cueq or gpu is not available', -) -def test_calculator(tmp_path): - cueq = True - model = get_model(use_cueq=cueq) - ref_calc = SevenNetCalculator(model, file_type='model_instance') - atoms = copy.deepcopy(_atoms) - atoms.calc = ref_calc - - cfg = get_model_config() - cfg.update( - {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} - ) - - cp_path = str(tmp_path / 'cp.pth') - torch.save( - {'model_state_dict': model.state_dict(), 'config': cfg}, - cp_path, - ) - - calc2 = SevenNetCalculator(cp_path, enable_cueq=False) - atoms2 = copy.deepcopy(_atoms) - atoms2.calc = calc2 - - assert_atoms(atoms, atoms2) +# TODO: add gradient test from total loss after double precision. +# so far, it is empirically checked by seeing learning curves +import copy + +import numpy as np +import pytest +import torch +from ase.build import bulk +from torch_geometric.loader.dataloader import Collater + +import sevenn +import sevenn.train.dataload as dl +from sevenn.atom_graph_data import AtomGraphData +from sevenn.calculator import SevenNetCalculator +from sevenn.model_build import build_E3_equivariant_model +from sevenn.nn.cue_helper import is_cue_available +from sevenn.nn.sequential import AtomGraphSequential +from sevenn.util import ( + chemical_species_preprocess, + model_from_checkpoint_with_backend, +) + +cutoff = 4.0 + +_atoms = bulk('NaCl', 'rocksalt', a=4.00) * (2, 2, 2) +_avg_num_neigh = 30.0 +_atoms.rattle() + +_graph = AtomGraphData.from_numpy_dict(dl.unlabeled_atoms_to_graph(_atoms, cutoff)) + + +def get_graphs(batched): + # batch size 2 + cloned = [_graph.clone().to('cuda'), _graph.clone().to('cuda')] + if not batched: + return cloned + else: + return Collater(cloned)(cloned) + + +def get_model_config(): + config = { + 'cutoff': cutoff, + 'channel': 32, + 'lmax': 2, + 'is_parity': True, + 'num_convolution_layer': 3, + 'self_connection_type': 'nequip', # not NequIp + 'interaction_type': 'nequip', + 'radial_basis': { + 'radial_basis_name': 'bessel', + }, + 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, + 'weight_nn_hidden_neurons': [64, 64], + 'act_radial': 'silu', + 'act_scalar': {'e': 'silu', 'o': 'tanh'}, + 'act_gate': {'e': 'silu', 'o': 'tanh'}, + 'conv_denominator': _avg_num_neigh, + 'train_denominator': False, + 'shift': -10.0, + 'scale': 10.0, + 'train_shift_scale': False, + 'irreps_manual': False, + 'lmax_edge': -1, + 'lmax_node': -1, + 'readout_as_fcn': False, + 'use_bias_in_linear': False, + '_normalize_sph': True, + } + chems = set() + chems.update(_atoms.get_chemical_symbols()) + config.update(**chemical_species_preprocess(list(chems))) + return config + + +def get_model(config_overwrite=None, use_cueq=False, cueq_config=None): + cf = get_model_config() + if config_overwrite is not None: + cf.update(config_overwrite) + + cueq_config = cueq_config or {'cuequivariance_config': {'use': use_cueq}} + cf.update(cueq_config) + + model = build_E3_equivariant_model(cf, parallel=False) + assert isinstance(model, AtomGraphSequential) + model.to('cuda') + return model + + +@pytest.mark.skipif( + not is_cue_available() or not torch.cuda.is_available(), + reason='cueq or gpu is not available', +) +@pytest.mark.parametrize( + 'cf', + [ + ({}), + ({'self_connection_type': 'linear'}), + ({'is_parity': False}), + ({'channel': 8}), + ({'lmax': 3}), + ({'num_interaction_layer': 2}), + ({'num_interaction_layer': 4}), + ], +) +def test_model_output(cf): + torch.manual_seed(777) + model_e3nn = get_model(cf) + torch.manual_seed(777) + model_cueq = get_model(cf, use_cueq=True) + + model_e3nn.set_is_batch_data(True) + model_cueq.set_is_batch_data(True) + + e3nn_out = model_e3nn._preprocess(get_graphs(batched=True)) + cueq_out = model_cueq._preprocess(get_graphs(batched=True)) + + for k, e3nn_f in model_e3nn._modules.items(): + cueq_f = model_cueq._modules[k] + e3nn_out = e3nn_f(e3nn_out) # type: ignore + cueq_out = cueq_f(cueq_out) # type: ignore + assert torch.allclose(e3nn_out.x, cueq_out.x, atol=1e-6), ( + f'{k} \n\n {e3nn_f} \n\n {cueq_f}' + ) + + assert torch.allclose( + e3nn_out.inferred_total_energy, cueq_out.inferred_total_energy + ) + assert torch.allclose(e3nn_out.atomic_energy, cueq_out.atomic_energy) + assert torch.allclose( + e3nn_out.inferred_force, cueq_out.inferred_force, atol=1e-5 + ) + assert torch.allclose( + e3nn_out.inferred_stress, cueq_out.inferred_stress, atol=1e-5 + ) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif( + not is_cue_available() or not torch.cuda.is_available(), + reason='cueq or gpu is not available', +) +@pytest.mark.parametrize( + 'start_from_cueq', + [ + (True), + (False), + ], +) +def test_checkpoint_convert(tmp_path, start_from_cueq): + torch.manual_seed(123) + model_from = get_model(use_cueq=start_from_cueq) + + cfg = get_model_config() + cfg.update( + { + 'cuequivariance_config': {'use': start_from_cueq}, + 'version': sevenn.__version__, + } + ) + torch.save( + {'model_state_dict': model_from.state_dict(), 'config': cfg}, + tmp_path / 'cp_from.pth', + ) + + backend = 'e3nn' if start_from_cueq else 'cueq' + model_to, _ = model_from_checkpoint_with_backend( + str(tmp_path / 'cp_from.pth'), backend + ) + model_to.to('cuda') + + model_from.set_is_batch_data(True) + model_to.set_is_batch_data(True) + + from_out = model_from(get_graphs(batched=True)) + to_out = model_to(get_graphs(batched=True)) + + assert torch.allclose( + from_out.inferred_total_energy, to_out.inferred_total_energy + ) + assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy) + assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5) + assert torch.allclose( + from_out.inferred_stress, to_out.inferred_stress, atol=1e-5 + ) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif( + not is_cue_available() or not torch.cuda.is_available(), + reason='cueq or gpu is not available', +) +@pytest.mark.parametrize( + 'start_from_cueq', + [ + (True), + (False), + ], +) +def test_checkpoint_convert_no_batch(tmp_path, start_from_cueq): + torch.manual_seed(123) + model_from = get_model(use_cueq=start_from_cueq) + + cfg = get_model_config() + cfg.update( + { + 'cuequivariance_config': {'use': start_from_cueq}, + 'version': sevenn.__version__, + } + ) + torch.save( + {'model_state_dict': model_from.state_dict(), 'config': cfg}, + tmp_path / 'cp_from.pth', + ) + + backend = 'e3nn' if start_from_cueq else 'cueq' + model_to, _ = model_from_checkpoint_with_backend( + str(tmp_path / 'cp_from.pth'), backend + ) + model_to.to('cuda') + + model_from.set_is_batch_data(False) + model_to.set_is_batch_data(False) + + from_out = model_from(get_graphs(batched=False)[0]) + to_out = model_to(get_graphs(batched=False)[0]) + + assert torch.allclose( + from_out.inferred_total_energy, to_out.inferred_total_energy + ) + assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy) + assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5) + assert torch.allclose( + from_out.inferred_stress, to_out.inferred_stress, atol=1e-5 + ) + + +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): + def acl(a, b, rtol=rtol, atol=atol): + return np.allclose(a, b, rtol=rtol, atol=atol) + + assert len(atoms1) == len(atoms2) + assert acl(atoms1.get_cell(), atoms2.get_cell()) + assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) + assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) + assert acl( + atoms1.get_stress(voigt=False), + atoms2.get_stress(voigt=False), + rtol * 10, + atol * 10, + ) + # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif( + not is_cue_available() or not torch.cuda.is_available(), + reason='cueq or gpu is not available', +) +def test_calculator(tmp_path): + cueq = True + model = get_model(use_cueq=cueq) + ref_calc = SevenNetCalculator(model, file_type='model_instance') + atoms = copy.deepcopy(_atoms) + atoms.calc = ref_calc + + cfg = get_model_config() + cfg.update( + {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} + ) + + cp_path = str(tmp_path / 'cp.pth') + torch.save( + {'model_state_dict': model.state_dict(), 'config': cfg}, + cp_path, + ) + + calc2 = SevenNetCalculator(cp_path, enable_cueq=False) + atoms2 = copy.deepcopy(_atoms) + atoms2.calc = calc2 + + assert_atoms(atoms, atoms2) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_data.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_data.py index b966b2a1938d873c10ea7044cc6a01c3a0506251..fda1effd491232a7b95c94b5c0bb2987f46f819c 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_data.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_data.py @@ -1,521 +1,521 @@ -import logging -import os -import os.path as osp -import uuid -from collections import Counter -from copy import deepcopy -from typing import Literal - -import ase.calculators.singlepoint as singlepoint -import ase.io -import numpy as np -import pytest -import torch -from ase import Atoms -from ase.build import bulk, molecule -from torch_geometric.loader import DataLoader - -import sevenn._keys as KEY -import sevenn.train.dataload as dl -import sevenn.train.graph_dataset as ds -import sevenn.train.modal_dataset as modal_dataset -from sevenn._const import NUM_UNIV_ELEMENT -from sevenn.atom_graph_data import AtomGraphData -from sevenn.util import model_from_checkpoint, pretrained_name_to_path - -cutoff = 4.0 -lattice_constant = 3.35 - -_samples = { - 'bulk': bulk('NaCl', 'rocksalt', a=5.63), - 'mol': molecule('H2O'), - 'isolated': molecule('H'), - 'small_bulk': Atoms( - symbols='Cu', - positions=[ - (0, 0, 0), # Atom at the corner of the cube - ], - cell=[ - [lattice_constant, 0, 0], - [0, lattice_constant, 0], - [0, 0, lattice_constant], - ], - pbc=True, # Periodic boundary conditions - ), -} - - -_nedges_c4 = {'bulk': 36, 'mol': 6, 'isolated': 0, 'small_bulk': 18} - - -def get_atoms( - atoms_type: Literal['bulk', 'mol', 'isolated', 'small_bulk'], - init_y_as: Literal['calc', 'info', 'none'], -): - """ - Return atoms w, w/o reference values with its - # of edges for 4.0 cutoff length - """ - assert atoms_type in _samples - atoms = deepcopy(_samples[atoms_type]) - natoms = len(atoms) - if init_y_as == 'calc': - results = { - 'energy': np.random.rand(1), - 'forces': np.random.rand(natoms, 3), - 'stress': np.random.rand(6), - } - if not atoms.pbc.all(): - del results['stress'] - calc = singlepoint.SinglePointCalculator(atoms, **results) - atoms = calc.get_atoms() - elif init_y_as == 'info': - atoms.info['y_energy'] = np.random.rand(1) - atoms.arrays['y_force'] = np.random.rand(natoms, 3) - atoms.info['y_stress'] = np.random.rand(6) - if not atoms.pbc.all(): - del atoms.info['y_stress'] - return atoms, _nedges_c4[atoms_type] - - -@pytest.mark.parametrize('init_y_as', ['calc', 'info']) -@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) -def test_atoms_to_graph(atoms_type, init_y_as): - atoms, nedges = get_atoms(atoms_type, init_y_as) - is_stress = atoms.pbc.all() - y_from_calc = init_y_as == 'calc' - - graph = dl.atoms_to_graph(atoms, cutoff=cutoff, y_from_calc=y_from_calc) - - essential = { - 'atomic_numbers': ((len(atoms),), int), - 'pos': ((len(atoms), 3), float), - 'edge_index': ((2, nedges), int), - 'edge_vec': ((nedges, 3), float), - 'total_energy': ((), float), - 'force_of_atoms': ((len(atoms), 3), float), - 'cell_volume': ((), float), - 'num_atoms': ((), int), - 'per_atom_energy': ((), float), - 'stress': ((1, 6), float), - } - - for k, (shape, dtype) in essential.items(): - assert k in graph, f'{k} missing in graph' - assert isinstance( - graph[k], np.ndarray - ), f'{k}: {type(graph[k])} is not np.ndarray' - assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' - if not is_stress and k == 'stress': - assert np.isnan(graph[k]).all() - else: - assert graph[k].dtype == dtype, f'{k} dtype {graph[k].dtype} != {dtype}' - - assert graph['per_atom_energy'] == (graph['total_energy'] / len(atoms)) - assert graph['num_atoms'] == len(atoms) - if not is_stress: - assert graph['cell_volume'] == np.finfo(float).eps - - -@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) -def test_unlabeled_atoms_to_graph(atoms_type): - atoms, nedges = get_atoms(atoms_type, 'none') - - graph = dl.unlabeled_atoms_to_graph(atoms, cutoff=cutoff) - - essential = { - 'atomic_numbers': ((len(atoms),), int), - 'pos': ((len(atoms), 3), float), - 'edge_index': ((2, nedges), int), - 'edge_vec': ((nedges, 3), float), - 'cell_volume': ((), float), - 'num_atoms': ((), int), - } - - for k, (shape, dtype) in essential.items(): - assert k in graph, f'{k} missing in graph' - assert isinstance( - graph[k], np.ndarray - ), f'{k}: {type(graph[k])} is not np.ndarray' - assert graph[k].dtype == dtype, f'{k} dtype {graph[k].dtype} != {dtype}' - assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' - - assert graph['num_atoms'] == len(atoms) - if not atoms.pbc.all(): - assert graph['cell_volume'] == np.finfo(float).eps - - -@pytest.mark.parametrize('init_y_as', ['calc', 'info']) -@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) -def test_atom_graph_data(atoms_type, init_y_as): - atoms, nedges = get_atoms(atoms_type, init_y_as) - y_from_calc = init_y_as == 'calc' - is_stress = atoms.pbc.all() - np_graph = dl.atoms_to_graph(atoms, cutoff=cutoff, y_from_calc=y_from_calc) - graph = AtomGraphData.from_numpy_dict(np_graph) - - essential = { - 'atomic_numbers': ((len(atoms),), int), - 'edge_index': ((2, nedges), int), - 'edge_vec': ((nedges, 3), float), - } - auxilaray = { - 'x': ((len(atoms),), int), - 'pos': ((len(atoms), 3), float), - 'num_atoms': ((), int), - 'cell_volume': ((), float), - 'total_energy': ((), float), - 'per_atom_energy': ((), float), - 'force_of_atoms': ((len(atoms), 3), float), - 'stress': ((1, 6), float), - } - - for k, (shape, dtype) in essential.items(): - assert k in graph, f'{k} missing in graph' - assert isinstance( - graph[k], torch.Tensor - ), f'{k}: {type(graph[k])} is not an tensor' - assert graph[k].is_floating_point() == (dtype is float) - assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' - - for k, (shape, dtype) in auxilaray.items(): - if k not in graph: - continue - assert isinstance( - graph[k], torch.Tensor - ), f'{k}: {type(graph[k])} is not an tensor' - assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' - if not is_stress and k == 'stress': - assert torch.isnan(graph[k]).all() - else: - assert graph[k].is_floating_point() == (dtype is float) - - -def test_graph_build(): - """ - Compare parallel implementation, should preserve order - """ - atoms_list = [ - get_atoms(t, 'calc')[0] # type: ignore - for t in list(_samples.keys()) - ] - one_core = dl.graph_build(atoms_list, cutoff, num_cores=1, y_from_calc=True) - two_core = dl.graph_build(atoms_list, cutoff, num_cores=2, y_from_calc=True) - - assert len(one_core) == len(two_core) - for g1, g2 in zip(one_core, two_core): - assert set(g1.keys()) == set(g2.keys()) - for k in g1.keys(): - if not isinstance(g1[k], torch.Tensor): - continue - if k == 'stress': # TODO: robust way to test it - assert torch.allclose(g1[k], g2[k]) or ( - torch.isnan(g1[k]).all() == torch.isnan(g2[k]).all() - ) - else: - assert torch.allclose(g1[k], g2[k]) - - -@pytest.fixture(scope='module') -def graph_dataset_tuple(): - tmpdir = os.getenv('TMPDIR', '/tmp') - randstr = uuid.uuid4().hex - assert os.access(tmpdir, os.W_OK), f'{tmpdir} is not writable' - - root = tmpdir - files = f'{root}/{randstr}.extxyz' - atoms_list = [ - get_atoms(atype, 'calc')[0] # type: ignore - for atype in ['bulk', 'mol', 'isolated'] - ] - ase.io.write(files, atoms_list, 'extxyz') - - dataset = ds.SevenNetGraphDataset( - cutoff=cutoff, - root=root, - files=files, - processed_name=f'{randstr}.pt', - ) - assert os.path.isfile(f'{root}/sevenn_data/{randstr}.pt'), 'dataset not written' - return dataset, atoms_list - - -def test_sevenn_graph_dataset_properties(graph_dataset_tuple): - dataset, atoms_list = graph_dataset_tuple - - species = set() - natoms = Counter() - elist = [] - e_per_list = [] - flist = [] - slist = [] - for at in atoms_list: - chems = at.get_chemical_symbols() - species.update(chems) - natoms.update(chems) - elist.append(at.get_potential_energy()) - e_per_list.append(at.get_potential_energy() / len(at)) - flist.extend(at.get_forces()) - try: - slist.append(at.get_stress()) - except NotImplementedError: - slist.append(np.full(6, np.nan)) - - elist = np.array(elist) - e_per_list = np.array(e_per_list) - flist = np.array(flist) - slist = np.array(slist) - - natoms['total'] = sum([cnt for cnt in list(natoms.values())]) - - assert set(dataset.species) == species - assert dataset.natoms == natoms - assert np.allclose(dataset.per_atom_energy_mean, e_per_list.mean()) - assert np.allclose(dataset.force_rms, np.sqrt((flist**2).mean())) - - -def test_sevenn_graph_dataset_elemwise_energies(graph_dataset_tuple): - logger = logging.getLogger(__name__) - - dataset, atoms_list = graph_dataset_tuple - - ref_e = dataset.elemwise_reference_energies - assert len(ref_e) == NUM_UNIV_ELEMENT - z_set = set() - for atoms in atoms_list: - inferred_e = 0 - atomic_numbers = atoms.get_atomic_numbers() - z_set.update(atomic_numbers) - for z in atomic_numbers: - inferred_e += ref_e[z] - # it never be same, but should be similar - logger.info('elemwise energy should be similar:') - logger.info(f'{inferred_e:4f} {atoms.get_potential_energy()[0]:4f}') - - for z in range(NUM_UNIV_ELEMENT): - if z not in z_set: - assert ref_e[z] == 0 - - -def test_sevenn_graph_dataset_statistics(graph_dataset_tuple): - dataset, atoms_list = graph_dataset_tuple - - elist = [] - e_per_list = [] - flist = [] - slist = [] - for at in atoms_list: - elist.append(at.get_potential_energy()) - e_per_list.append(at.get_potential_energy() / len(at)) - flist.extend(at.get_forces()) - try: - slist.append(at.get_stress()) - except NotImplementedError: - slist.append(np.full(6, np.nan)) - - dct = { - 'total_energy': np.array(elist), - 'per_atom_energy': np.array(e_per_list), - 'force_of_atoms': np.array(flist).flatten(), - # 'stress': np.array(slist), # TODO: it may have nan - } - - for key in dct: - assert np.allclose(dataset.statistics[key]['mean'], dct[key].mean()), key - assert np.allclose(dataset.statistics[key]['std'], dct[key].std(ddof=0)), key - assert np.allclose( - dataset.statistics[key]['median'], np.median(dct[key]) - ), key - assert np.allclose(dataset.statistics[key]['max'], dct[key].max()), key - assert np.allclose(dataset.statistics[key]['min'], dct[key].min()), key - - -def test_sevenn_mm_dataset_statistics(tmp_path): - - files = osp.join(tmp_path, 'gd_one.extxyz') - atoms_list1 = [ - get_atoms(atype, 'calc')[0] # type: ignore - for atype in ['bulk', 'bulk', 'bulk', 'bulk'] - ] - ase.io.write(files, atoms_list1, 'extxyz') - - gd1 = ds.SevenNetGraphDataset( - cutoff=cutoff, - root=tmp_path, - files=files, - processed_name='gd_one.pt', - ) - - files = osp.join(tmp_path, 'gd_two.extxyz') - atoms_list2 = [ - get_atoms(atype, 'calc')[0] # type: ignore - for atype in ['mol', 'mol', 'bulk'] - ] - ase.io.write(files, atoms_list2, 'extxyz') - - gd2 = ds.SevenNetGraphDataset( - cutoff=cutoff, - root=tmp_path, - files=files, - processed_name='gd_two.pt', - ) - - ref = ds.SevenNetGraphDataset( - cutoff=cutoff, - root=tmp_path, - files=[gd1.processed_paths[0], gd2.processed_paths[0]], - processed_name='combined.pt', - ) - - mm = modal_dataset.SevenNetMultiModalDataset( - {'modal1': gd1, 'modal2': gd2} - ) - - assert np.allclose(ref.per_atom_energy_mean, mm.per_atom_energy_mean['total']) - assert np.allclose(ref.avg_num_neigh, mm.avg_num_neigh['total']) - assert np.allclose(ref.force_rms, mm.force_rms['total']) - assert set(ref.species) == set(mm.species['total']) - - -@pytest.mark.parametrize( - 'a_types,init_ys', [(['bulk', 'mol', 'isolated'], ['calc', 'calc', 'calc'])] -) -def test_7net_graph_dataset_batch_shape(a_types, init_ys, tmp_path): - assert len(a_types) == len(init_ys) - n_graph = len(a_types) - atoms_list = [] - tot_edges = 0 - tot_atoms = 0 - for a_type, init_y in zip(a_types, init_ys): - atoms, n_edge = get_atoms(a_type, init_y) - tot_edges += n_edge - tot_atoms += len(atoms) - atoms_list.append(atoms) - ase.io.write(tmp_path / 'tmp', atoms_list, format='extxyz') - dataset = ds.SevenNetGraphDataset(cutoff, tmp_path, str(tmp_path / 'tmp')) - loader = DataLoader(dataset, batch_size=n_graph) - graph = next(iter(loader)) - - essential = { - 'x': ((tot_atoms,), int), - 'atomic_numbers': ((tot_atoms,), int), - 'pos': ((tot_atoms, 3), float), - 'edge_index': ((2, tot_edges), int), - 'edge_vec': ((tot_edges, 3), float), - 'total_energy': ((n_graph,), float), - 'force_of_atoms': ((tot_atoms, 3), float), - 'cell_volume': ((n_graph,), float), - 'num_atoms': ((n_graph,), int), - 'per_atom_energy': ((n_graph,), float), - 'stress': ((n_graph, 6), float), - 'batch': ((tot_atoms,), int), # from PyG - } - - for k, (shape, dtype) in essential.items(): - assert k in graph, f'{k} missing in graph' - assert isinstance( - graph[k], torch.Tensor - ), f'{k}: {type(graph[k])} is not an tensor' - assert graph[k].is_floating_point() == (dtype is float) - assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' - - -@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated', 'small_bulk']) -def test_graph_build_ase_and_matscipy(atoms_type): - atoms, _ = get_atoms(atoms_type, 'calc') - atoms.rattle() - pos = atoms.get_positions() - cell = np.array(atoms.get_cell()) - pbc = atoms.get_pbc() - - # graph build check - # ase graph build - edge_src_ase, edge_dst_ase, edge_vec_ase, shifts_ase = dl._graph_build_ase( - cutoff, pbc, cell, pos - ) - # matscipy graph build - edge_src_matsci, edge_dst_matsci, edge_vec_matsci, shifts_matsci = ( - dl._graph_build_matscipy(cutoff, pbc, cell, pos) - ) - - # sort the graph - sorted_indices_ase = np.lexsort( - (edge_vec_ase[:, 2], edge_vec_ase[:, 1], edge_vec_ase[:, 0]) - ) - sorted_indices_matsci = np.lexsort( - (edge_vec_matsci[:, 2], edge_vec_matsci[:, 1], edge_vec_matsci[:, 0]) - ) - sorted_vec_ase = edge_vec_ase[sorted_indices_ase] - sorted_vec_matsci = edge_vec_matsci[sorted_indices_matsci] - sorted_src_ase = edge_src_ase[sorted_indices_ase] - sorted_dst_ase = edge_dst_ase[sorted_indices_ase] - sorted_src_matsci = edge_src_matsci[sorted_indices_matsci] - sorted_dst_matsci = edge_dst_matsci[sorted_indices_matsci] - sorted_shift_ase = shifts_ase[sorted_indices_ase] - sorted_shift_matsci = shifts_matsci[sorted_indices_matsci] - - # compare the result - assert np.allclose(sorted_vec_ase, sorted_vec_matsci) - assert np.array_equal(sorted_src_ase, sorted_src_matsci) - assert np.array_equal(sorted_dst_ase, sorted_dst_matsci) - assert np.array_equal(sorted_shift_ase, sorted_shift_matsci) - - # energy test - model, _ = model_from_checkpoint(pretrained_name_to_path('7net-0_11July2024')) - model.eval() - model.set_is_batch_data(False) - - # for ase energy - edge_idx_ase = np.array([edge_src_ase, edge_dst_ase]) - atomic_numbers = atoms.get_atomic_numbers() - cell = np.array(cell) - vol = dl._correct_scalar(atoms.cell.volume) - if vol == 0: - vol = np.array(np.finfo(float).eps) - - data_ase = { - KEY.NODE_FEATURE: atomic_numbers, - KEY.ATOMIC_NUMBERS: atomic_numbers, - KEY.POS: pos, - KEY.EDGE_IDX: edge_idx_ase, - KEY.EDGE_VEC: edge_vec_ase, - KEY.CELL: cell, - KEY.CELL_SHIFT: shifts_ase, - KEY.CELL_VOLUME: vol, - KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)), - } - data_ase[KEY.INFO] = {} - atom_graph_data_ase = AtomGraphData.from_numpy_dict(data_ase) - output_ase = model(atom_graph_data_ase) - ase_pred_energy = output_ase[KEY.PRED_TOTAL_ENERGY] - ase_pred_force = output_ase[KEY.PRED_FORCE] - ase_pred_stress = output_ase[KEY.PRED_STRESS] - - # for matsci energy - edge_idx_matsci = np.array([edge_src_matsci, edge_dst_matsci]) - atomic_numbers = atoms.get_atomic_numbers() - cell = np.array(cell) - vol = dl._correct_scalar(atoms.cell.volume) - if vol == 0: - vol = np.array(np.finfo(float).eps) - - data_matsci = { - KEY.NODE_FEATURE: atomic_numbers, - KEY.ATOMIC_NUMBERS: atomic_numbers, - KEY.POS: pos, - KEY.EDGE_IDX: edge_idx_matsci, - KEY.EDGE_VEC: edge_vec_matsci, - KEY.CELL: cell, - KEY.CELL_SHIFT: shifts_matsci, - KEY.CELL_VOLUME: vol, - KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)), - } - data_matsci[KEY.INFO] = {} - atom_graph_data_matsci = AtomGraphData.from_numpy_dict(data_matsci) - output_matsci = model(atom_graph_data_matsci) - matsci_pred_energy = output_matsci[KEY.PRED_TOTAL_ENERGY] - matsci_pred_force = output_matsci[KEY.PRED_FORCE] - matsci_pred_stress = output_matsci[KEY.PRED_STRESS] - assert torch.equal(ase_pred_energy, matsci_pred_energy) - assert torch.allclose(ase_pred_force, matsci_pred_force, atol=1e-06) - assert torch.allclose(ase_pred_stress, matsci_pred_stress) +import logging +import os +import os.path as osp +import uuid +from collections import Counter +from copy import deepcopy +from typing import Literal + +import ase.calculators.singlepoint as singlepoint +import ase.io +import numpy as np +import pytest +import torch +from ase import Atoms +from ase.build import bulk, molecule +from torch_geometric.loader import DataLoader + +import sevenn._keys as KEY +import sevenn.train.dataload as dl +import sevenn.train.graph_dataset as ds +import sevenn.train.modal_dataset as modal_dataset +from sevenn._const import NUM_UNIV_ELEMENT +from sevenn.atom_graph_data import AtomGraphData +from sevenn.util import model_from_checkpoint, pretrained_name_to_path + +cutoff = 4.0 +lattice_constant = 3.35 + +_samples = { + 'bulk': bulk('NaCl', 'rocksalt', a=5.63), + 'mol': molecule('H2O'), + 'isolated': molecule('H'), + 'small_bulk': Atoms( + symbols='Cu', + positions=[ + (0, 0, 0), # Atom at the corner of the cube + ], + cell=[ + [lattice_constant, 0, 0], + [0, lattice_constant, 0], + [0, 0, lattice_constant], + ], + pbc=True, # Periodic boundary conditions + ), +} + + +_nedges_c4 = {'bulk': 36, 'mol': 6, 'isolated': 0, 'small_bulk': 18} + + +def get_atoms( + atoms_type: Literal['bulk', 'mol', 'isolated', 'small_bulk'], + init_y_as: Literal['calc', 'info', 'none'], +): + """ + Return atoms w, w/o reference values with its + # of edges for 4.0 cutoff length + """ + assert atoms_type in _samples + atoms = deepcopy(_samples[atoms_type]) + natoms = len(atoms) + if init_y_as == 'calc': + results = { + 'energy': np.random.rand(1), + 'forces': np.random.rand(natoms, 3), + 'stress': np.random.rand(6), + } + if not atoms.pbc.all(): + del results['stress'] + calc = singlepoint.SinglePointCalculator(atoms, **results) + atoms = calc.get_atoms() + elif init_y_as == 'info': + atoms.info['y_energy'] = np.random.rand(1) + atoms.arrays['y_force'] = np.random.rand(natoms, 3) + atoms.info['y_stress'] = np.random.rand(6) + if not atoms.pbc.all(): + del atoms.info['y_stress'] + return atoms, _nedges_c4[atoms_type] + + +@pytest.mark.parametrize('init_y_as', ['calc', 'info']) +@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) +def test_atoms_to_graph(atoms_type, init_y_as): + atoms, nedges = get_atoms(atoms_type, init_y_as) + is_stress = atoms.pbc.all() + y_from_calc = init_y_as == 'calc' + + graph = dl.atoms_to_graph(atoms, cutoff=cutoff, y_from_calc=y_from_calc) + + essential = { + 'atomic_numbers': ((len(atoms),), int), + 'pos': ((len(atoms), 3), float), + 'edge_index': ((2, nedges), int), + 'edge_vec': ((nedges, 3), float), + 'total_energy': ((), float), + 'force_of_atoms': ((len(atoms), 3), float), + 'cell_volume': ((), float), + 'num_atoms': ((), int), + 'per_atom_energy': ((), float), + 'stress': ((1, 6), float), + } + + for k, (shape, dtype) in essential.items(): + assert k in graph, f'{k} missing in graph' + assert isinstance( + graph[k], np.ndarray + ), f'{k}: {type(graph[k])} is not np.ndarray' + assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' + if not is_stress and k == 'stress': + assert np.isnan(graph[k]).all() + else: + assert graph[k].dtype == dtype, f'{k} dtype {graph[k].dtype} != {dtype}' + + assert graph['per_atom_energy'] == (graph['total_energy'] / len(atoms)) + assert graph['num_atoms'] == len(atoms) + if not is_stress: + assert graph['cell_volume'] == np.finfo(float).eps + + +@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) +def test_unlabeled_atoms_to_graph(atoms_type): + atoms, nedges = get_atoms(atoms_type, 'none') + + graph = dl.unlabeled_atoms_to_graph(atoms, cutoff=cutoff) + + essential = { + 'atomic_numbers': ((len(atoms),), int), + 'pos': ((len(atoms), 3), float), + 'edge_index': ((2, nedges), int), + 'edge_vec': ((nedges, 3), float), + 'cell_volume': ((), float), + 'num_atoms': ((), int), + } + + for k, (shape, dtype) in essential.items(): + assert k in graph, f'{k} missing in graph' + assert isinstance( + graph[k], np.ndarray + ), f'{k}: {type(graph[k])} is not np.ndarray' + assert graph[k].dtype == dtype, f'{k} dtype {graph[k].dtype} != {dtype}' + assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' + + assert graph['num_atoms'] == len(atoms) + if not atoms.pbc.all(): + assert graph['cell_volume'] == np.finfo(float).eps + + +@pytest.mark.parametrize('init_y_as', ['calc', 'info']) +@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) +def test_atom_graph_data(atoms_type, init_y_as): + atoms, nedges = get_atoms(atoms_type, init_y_as) + y_from_calc = init_y_as == 'calc' + is_stress = atoms.pbc.all() + np_graph = dl.atoms_to_graph(atoms, cutoff=cutoff, y_from_calc=y_from_calc) + graph = AtomGraphData.from_numpy_dict(np_graph) + + essential = { + 'atomic_numbers': ((len(atoms),), int), + 'edge_index': ((2, nedges), int), + 'edge_vec': ((nedges, 3), float), + } + auxilaray = { + 'x': ((len(atoms),), int), + 'pos': ((len(atoms), 3), float), + 'num_atoms': ((), int), + 'cell_volume': ((), float), + 'total_energy': ((), float), + 'per_atom_energy': ((), float), + 'force_of_atoms': ((len(atoms), 3), float), + 'stress': ((1, 6), float), + } + + for k, (shape, dtype) in essential.items(): + assert k in graph, f'{k} missing in graph' + assert isinstance( + graph[k], torch.Tensor + ), f'{k}: {type(graph[k])} is not an tensor' + assert graph[k].is_floating_point() == (dtype is float) + assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' + + for k, (shape, dtype) in auxilaray.items(): + if k not in graph: + continue + assert isinstance( + graph[k], torch.Tensor + ), f'{k}: {type(graph[k])} is not an tensor' + assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' + if not is_stress and k == 'stress': + assert torch.isnan(graph[k]).all() + else: + assert graph[k].is_floating_point() == (dtype is float) + + +def test_graph_build(): + """ + Compare parallel implementation, should preserve order + """ + atoms_list = [ + get_atoms(t, 'calc')[0] # type: ignore + for t in list(_samples.keys()) + ] + one_core = dl.graph_build(atoms_list, cutoff, num_cores=1, y_from_calc=True) + two_core = dl.graph_build(atoms_list, cutoff, num_cores=2, y_from_calc=True) + + assert len(one_core) == len(two_core) + for g1, g2 in zip(one_core, two_core): + assert set(g1.keys()) == set(g2.keys()) + for k in g1.keys(): + if not isinstance(g1[k], torch.Tensor): + continue + if k == 'stress': # TODO: robust way to test it + assert torch.allclose(g1[k], g2[k]) or ( + torch.isnan(g1[k]).all() == torch.isnan(g2[k]).all() + ) + else: + assert torch.allclose(g1[k], g2[k]) + + +@pytest.fixture(scope='module') +def graph_dataset_tuple(): + tmpdir = os.getenv('TMPDIR', '/tmp') + randstr = uuid.uuid4().hex + assert os.access(tmpdir, os.W_OK), f'{tmpdir} is not writable' + + root = tmpdir + files = f'{root}/{randstr}.extxyz' + atoms_list = [ + get_atoms(atype, 'calc')[0] # type: ignore + for atype in ['bulk', 'mol', 'isolated'] + ] + ase.io.write(files, atoms_list, 'extxyz') + + dataset = ds.SevenNetGraphDataset( + cutoff=cutoff, + root=root, + files=files, + processed_name=f'{randstr}.pt', + ) + assert os.path.isfile(f'{root}/sevenn_data/{randstr}.pt'), 'dataset not written' + return dataset, atoms_list + + +def test_sevenn_graph_dataset_properties(graph_dataset_tuple): + dataset, atoms_list = graph_dataset_tuple + + species = set() + natoms = Counter() + elist = [] + e_per_list = [] + flist = [] + slist = [] + for at in atoms_list: + chems = at.get_chemical_symbols() + species.update(chems) + natoms.update(chems) + elist.append(at.get_potential_energy()) + e_per_list.append(at.get_potential_energy() / len(at)) + flist.extend(at.get_forces()) + try: + slist.append(at.get_stress()) + except NotImplementedError: + slist.append(np.full(6, np.nan)) + + elist = np.array(elist) + e_per_list = np.array(e_per_list) + flist = np.array(flist) + slist = np.array(slist) + + natoms['total'] = sum([cnt for cnt in list(natoms.values())]) + + assert set(dataset.species) == species + assert dataset.natoms == natoms + assert np.allclose(dataset.per_atom_energy_mean, e_per_list.mean()) + assert np.allclose(dataset.force_rms, np.sqrt((flist**2).mean())) + + +def test_sevenn_graph_dataset_elemwise_energies(graph_dataset_tuple): + logger = logging.getLogger(__name__) + + dataset, atoms_list = graph_dataset_tuple + + ref_e = dataset.elemwise_reference_energies + assert len(ref_e) == NUM_UNIV_ELEMENT + z_set = set() + for atoms in atoms_list: + inferred_e = 0 + atomic_numbers = atoms.get_atomic_numbers() + z_set.update(atomic_numbers) + for z in atomic_numbers: + inferred_e += ref_e[z] + # it never be same, but should be similar + logger.info('elemwise energy should be similar:') + logger.info(f'{inferred_e:4f} {atoms.get_potential_energy()[0]:4f}') + + for z in range(NUM_UNIV_ELEMENT): + if z not in z_set: + assert ref_e[z] == 0 + + +def test_sevenn_graph_dataset_statistics(graph_dataset_tuple): + dataset, atoms_list = graph_dataset_tuple + + elist = [] + e_per_list = [] + flist = [] + slist = [] + for at in atoms_list: + elist.append(at.get_potential_energy()) + e_per_list.append(at.get_potential_energy() / len(at)) + flist.extend(at.get_forces()) + try: + slist.append(at.get_stress()) + except NotImplementedError: + slist.append(np.full(6, np.nan)) + + dct = { + 'total_energy': np.array(elist), + 'per_atom_energy': np.array(e_per_list), + 'force_of_atoms': np.array(flist).flatten(), + # 'stress': np.array(slist), # TODO: it may have nan + } + + for key in dct: + assert np.allclose(dataset.statistics[key]['mean'], dct[key].mean()), key + assert np.allclose(dataset.statistics[key]['std'], dct[key].std(ddof=0)), key + assert np.allclose( + dataset.statistics[key]['median'], np.median(dct[key]) + ), key + assert np.allclose(dataset.statistics[key]['max'], dct[key].max()), key + assert np.allclose(dataset.statistics[key]['min'], dct[key].min()), key + + +def test_sevenn_mm_dataset_statistics(tmp_path): + + files = osp.join(tmp_path, 'gd_one.extxyz') + atoms_list1 = [ + get_atoms(atype, 'calc')[0] # type: ignore + for atype in ['bulk', 'bulk', 'bulk', 'bulk'] + ] + ase.io.write(files, atoms_list1, 'extxyz') + + gd1 = ds.SevenNetGraphDataset( + cutoff=cutoff, + root=tmp_path, + files=files, + processed_name='gd_one.pt', + ) + + files = osp.join(tmp_path, 'gd_two.extxyz') + atoms_list2 = [ + get_atoms(atype, 'calc')[0] # type: ignore + for atype in ['mol', 'mol', 'bulk'] + ] + ase.io.write(files, atoms_list2, 'extxyz') + + gd2 = ds.SevenNetGraphDataset( + cutoff=cutoff, + root=tmp_path, + files=files, + processed_name='gd_two.pt', + ) + + ref = ds.SevenNetGraphDataset( + cutoff=cutoff, + root=tmp_path, + files=[gd1.processed_paths[0], gd2.processed_paths[0]], + processed_name='combined.pt', + ) + + mm = modal_dataset.SevenNetMultiModalDataset( + {'modal1': gd1, 'modal2': gd2} + ) + + assert np.allclose(ref.per_atom_energy_mean, mm.per_atom_energy_mean['total']) + assert np.allclose(ref.avg_num_neigh, mm.avg_num_neigh['total']) + assert np.allclose(ref.force_rms, mm.force_rms['total']) + assert set(ref.species) == set(mm.species['total']) + + +@pytest.mark.parametrize( + 'a_types,init_ys', [(['bulk', 'mol', 'isolated'], ['calc', 'calc', 'calc'])] +) +def test_7net_graph_dataset_batch_shape(a_types, init_ys, tmp_path): + assert len(a_types) == len(init_ys) + n_graph = len(a_types) + atoms_list = [] + tot_edges = 0 + tot_atoms = 0 + for a_type, init_y in zip(a_types, init_ys): + atoms, n_edge = get_atoms(a_type, init_y) + tot_edges += n_edge + tot_atoms += len(atoms) + atoms_list.append(atoms) + ase.io.write(tmp_path / 'tmp', atoms_list, format='extxyz') + dataset = ds.SevenNetGraphDataset(cutoff, tmp_path, str(tmp_path / 'tmp')) + loader = DataLoader(dataset, batch_size=n_graph) + graph = next(iter(loader)) + + essential = { + 'x': ((tot_atoms,), int), + 'atomic_numbers': ((tot_atoms,), int), + 'pos': ((tot_atoms, 3), float), + 'edge_index': ((2, tot_edges), int), + 'edge_vec': ((tot_edges, 3), float), + 'total_energy': ((n_graph,), float), + 'force_of_atoms': ((tot_atoms, 3), float), + 'cell_volume': ((n_graph,), float), + 'num_atoms': ((n_graph,), int), + 'per_atom_energy': ((n_graph,), float), + 'stress': ((n_graph, 6), float), + 'batch': ((tot_atoms,), int), # from PyG + } + + for k, (shape, dtype) in essential.items(): + assert k in graph, f'{k} missing in graph' + assert isinstance( + graph[k], torch.Tensor + ), f'{k}: {type(graph[k])} is not an tensor' + assert graph[k].is_floating_point() == (dtype is float) + assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' + + +@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated', 'small_bulk']) +def test_graph_build_ase_and_matscipy(atoms_type): + atoms, _ = get_atoms(atoms_type, 'calc') + atoms.rattle() + pos = atoms.get_positions() + cell = np.array(atoms.get_cell()) + pbc = atoms.get_pbc() + + # graph build check + # ase graph build + edge_src_ase, edge_dst_ase, edge_vec_ase, shifts_ase = dl._graph_build_ase( + cutoff, pbc, cell, pos + ) + # matscipy graph build + edge_src_matsci, edge_dst_matsci, edge_vec_matsci, shifts_matsci = ( + dl._graph_build_matscipy(cutoff, pbc, cell, pos) + ) + + # sort the graph + sorted_indices_ase = np.lexsort( + (edge_vec_ase[:, 2], edge_vec_ase[:, 1], edge_vec_ase[:, 0]) + ) + sorted_indices_matsci = np.lexsort( + (edge_vec_matsci[:, 2], edge_vec_matsci[:, 1], edge_vec_matsci[:, 0]) + ) + sorted_vec_ase = edge_vec_ase[sorted_indices_ase] + sorted_vec_matsci = edge_vec_matsci[sorted_indices_matsci] + sorted_src_ase = edge_src_ase[sorted_indices_ase] + sorted_dst_ase = edge_dst_ase[sorted_indices_ase] + sorted_src_matsci = edge_src_matsci[sorted_indices_matsci] + sorted_dst_matsci = edge_dst_matsci[sorted_indices_matsci] + sorted_shift_ase = shifts_ase[sorted_indices_ase] + sorted_shift_matsci = shifts_matsci[sorted_indices_matsci] + + # compare the result + assert np.allclose(sorted_vec_ase, sorted_vec_matsci) + assert np.array_equal(sorted_src_ase, sorted_src_matsci) + assert np.array_equal(sorted_dst_ase, sorted_dst_matsci) + assert np.array_equal(sorted_shift_ase, sorted_shift_matsci) + + # energy test + model, _ = model_from_checkpoint(pretrained_name_to_path('7net-0_11July2024')) + model.eval() + model.set_is_batch_data(False) + + # for ase energy + edge_idx_ase = np.array([edge_src_ase, edge_dst_ase]) + atomic_numbers = atoms.get_atomic_numbers() + cell = np.array(cell) + vol = dl._correct_scalar(atoms.cell.volume) + if vol == 0: + vol = np.array(np.finfo(float).eps) + + data_ase = { + KEY.NODE_FEATURE: atomic_numbers, + KEY.ATOMIC_NUMBERS: atomic_numbers, + KEY.POS: pos, + KEY.EDGE_IDX: edge_idx_ase, + KEY.EDGE_VEC: edge_vec_ase, + KEY.CELL: cell, + KEY.CELL_SHIFT: shifts_ase, + KEY.CELL_VOLUME: vol, + KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)), + } + data_ase[KEY.INFO] = {} + atom_graph_data_ase = AtomGraphData.from_numpy_dict(data_ase) + output_ase = model(atom_graph_data_ase) + ase_pred_energy = output_ase[KEY.PRED_TOTAL_ENERGY] + ase_pred_force = output_ase[KEY.PRED_FORCE] + ase_pred_stress = output_ase[KEY.PRED_STRESS] + + # for matsci energy + edge_idx_matsci = np.array([edge_src_matsci, edge_dst_matsci]) + atomic_numbers = atoms.get_atomic_numbers() + cell = np.array(cell) + vol = dl._correct_scalar(atoms.cell.volume) + if vol == 0: + vol = np.array(np.finfo(float).eps) + + data_matsci = { + KEY.NODE_FEATURE: atomic_numbers, + KEY.ATOMIC_NUMBERS: atomic_numbers, + KEY.POS: pos, + KEY.EDGE_IDX: edge_idx_matsci, + KEY.EDGE_VEC: edge_vec_matsci, + KEY.CELL: cell, + KEY.CELL_SHIFT: shifts_matsci, + KEY.CELL_VOLUME: vol, + KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)), + } + data_matsci[KEY.INFO] = {} + atom_graph_data_matsci = AtomGraphData.from_numpy_dict(data_matsci) + output_matsci = model(atom_graph_data_matsci) + matsci_pred_energy = output_matsci[KEY.PRED_TOTAL_ENERGY] + matsci_pred_force = output_matsci[KEY.PRED_FORCE] + matsci_pred_stress = output_matsci[KEY.PRED_STRESS] + assert torch.equal(ase_pred_energy, matsci_pred_energy) + assert torch.allclose(ase_pred_force, matsci_pred_force, atol=1e-06) + assert torch.allclose(ase_pred_stress, matsci_pred_stress) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_errors.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_errors.py index 455fd762da41d5614559821a9213c9d17b6a70fe..36024a845358625ade699db0041723fadf41c356 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_errors.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_errors.py @@ -1,285 +1,285 @@ -# test_errors: error recorder.py, loss.py -from copy import deepcopy - -import numpy as np -import pytest -import torch -import torch.nn -from torch import tensor - -import sevenn.error_recorder as erc -import sevenn.train.loss as loss -from sevenn.atom_graph_data import AtomGraphData -from sevenn.train.optim import loss_dict - -_default_config = { - 'loss': 'mse', - 'loss_param': {}, - 'error_record': [ - ('Energy', 'RMSE'), - ('Force', 'RMSE'), - ('Stress', 'RMSE'), - ('Energy', 'MAE'), - ('Force', 'MAE'), - ('Stress', 'MAE'), - ('TotalLoss', 'None'), - ], - 'is_train_stress': True, - 'force_loss_weight': 1.0, - 'stress_loss_weight': 0.001, -} - -_erc_test_params = [ - ('TotalEnergy', 4, 3), - ('Energy', 4, 3), - ('Force', 4, 3), - ('Stress', 4, 3), - ('Stress_GPa', 4, 3), - ('Energy', 4, 1), - ('Energy', 1, 1), - ('Force', 1, 3), - ('Stress', 1, 3), -] - - -def acl(a, b): - return torch.allclose(a, b, atol=1e-6) - - -def config(**overwrite): # to make it read-only - cf = deepcopy(_default_config) - for k, v in overwrite.items(): - cf[k] = v - return cf - - -def test_per_atom_energy_loss(): - loss_f = loss.PerAtomEnergyLoss(criterion=torch.nn.MSELoss()) - ref = torch.rand(2) - pred = torch.rand(2) - natoms = torch.randint(1, 10, (2,)) - tmp = AtomGraphData( - total_energy=ref, - inferred_total_energy=pred, - num_atoms=natoms, - ).to_dict() - ret = loss_f.get_loss(tmp) - assert loss_f.criterion is not None - assert torch.allclose(loss_f.criterion((ref / natoms), (pred / natoms)), ret) - - -def test_force_loss(): - loss_f = loss.ForceLoss(criterion=torch.nn.MSELoss()) - ref = torch.rand((4, 3)) - pred = torch.rand((4, 3)) - batch = tensor([0, 0, 0, 1]) - tmp = AtomGraphData( - force_of_atoms=ref, - inferred_force=pred, - batch=batch, - ).to_dict() - ret = loss_f.get_loss(tmp) - assert loss_f.criterion is not None - assert torch.allclose(loss_f.criterion(ref.reshape(-1), pred.reshape(-1)), ret) - - -def test_stress_loss(): - loss_f = loss.StressLoss(criterion=torch.nn.MSELoss()) - ref = torch.rand((2, 6)) - pred = torch.rand((2, 6)) - tmp = AtomGraphData( - stress=ref, - inferred_stress=pred, - ).to_dict() - ret = loss_f.get_loss(tmp) - KB = 1602.1766208 - assert loss_f.criterion is not None - assert torch.allclose( - loss_f.criterion(ref.reshape(-1) * KB, pred.reshape(-1) * KB), ret - ) - - -@pytest.mark.parametrize('conf', [config(), config(is_train_stress=False)]) -def test_loss_from_config(conf): - loss_functions = loss.get_loss_functions_from_config(conf) - - if conf['is_train_stress']: - assert len(loss_functions) == 3 - else: - assert len(loss_functions) == 2 - - for loss_def, w in loss_functions: - assert isinstance(loss_def, loss.LossDefinition) - if isinstance(loss_def, loss.PerAtomEnergyLoss): - assert w == 1.0 - elif isinstance(loss_def, loss.ForceLoss): - assert w == conf['force_loss_weight'] - elif isinstance(loss_def, loss.StressLoss): - assert w == conf['stress_loss_weight'] - else: - raise ValueError(f'Unexpected loss function: {loss_def}') - - -@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) -def test_rms_error(err_type, ndata, natoms): - err_dct = erc.get_err_type(err_type) - err = erc.RMSError(**err_dct) - ref = torch.rand((ndata, err.vdim)).squeeze(1) - pred = torch.rand((ndata, err.vdim)).squeeze(1) - natoms = torch.tensor([natoms] * ndata) - _data = { - err_dct['ref_key']: ref, - err_dct['pred_key']: pred, - 'num_atoms': natoms, - } - - tmp = AtomGraphData(**_data) - err.update(tmp) - - _ref = ref * err.coeff - _pred = pred * err.coeff - if 'per_atom' in err_dct and err_dct['per_atom']: - # natoms = natoms.unsqueeze(-1) - _ref = _ref / natoms - _pred = _pred / natoms - val = torch.sqrt(((_ref - _pred) ** 2).sum() / ndata) # not ndata*natoms - assert np.allclose(err.get(), val.item()) - err.update(tmp) - assert np.allclose(err.get(), val.item()) - - -@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) -def test_mae_error(err_type, ndata, natoms): - err_dct = erc.get_err_type(err_type) - vdim = err_dct['vdim'] - err = erc.MAError(**err_dct) - ref = torch.rand((ndata, vdim)).squeeze(1) - pred = torch.rand((ndata, vdim)).squeeze(1) - natoms = torch.tensor([natoms] * ndata) - _data = { - err_dct['ref_key']: ref, - err_dct['pred_key']: pred, - 'num_atoms': natoms, - } - - tmp = AtomGraphData(**_data) - err.update(tmp) - - _ref = ref * err.coeff - _pred = pred * err.coeff - if 'per_atom' in err_dct and err_dct['per_atom']: - _ref /= natoms - _pred /= natoms - - val = abs(_ref - _pred).sum() / (ndata * vdim) - assert np.allclose(err.get(), val.item()) - err.update(tmp) - assert np.allclose(err.get(), val.item()) - - -# TODO: test_component_rms_error - - -@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) -def test_custom_error(err_type, ndata, natoms): - def func(a, b): - return a * b - - err_dct = erc.get_err_type(err_type) - vdim = err_dct['vdim'] - err = erc.CustomError(func, **err_dct) - ref = torch.rand((ndata, vdim)).squeeze(1) - pred = torch.rand((ndata, vdim)).squeeze(1) - natoms = torch.tensor([natoms] * ndata) - _data = { - err_dct['ref_key']: ref, - err_dct['pred_key']: pred, - 'num_atoms': natoms, - } - - _ref = ref * err.coeff - _pred = pred * err.coeff - if 'per_atom' in err_dct and err_dct['per_atom']: - _ref /= natoms - _pred /= natoms - - tmp = AtomGraphData(**_data) - err.update(tmp) - val = func(_ref, _pred).mean() - assert np.allclose(err.get(), val.item()) - err.update(tmp) - assert np.allclose(err.get(), val.item()) - - -@pytest.mark.parametrize('conf', [config(), config(is_train_stress=False)]) -def test_total_loss_metric_from_config(conf): - def func(a, b): - return a * b - - err = erc.ErrorRecorder.init_total_loss_metric(conf, func) - ndata = 3 - natoms = 4 - - e1, e2 = torch.rand(ndata), torch.rand(ndata) - f1, f2 = torch.rand(ndata * natoms, 3), torch.rand(ndata * natoms, 3) - s1, s2 = torch.rand((ndata, 6)), torch.rand((ndata, 6)) - _data = { - 'total_energy': e1, - 'inferred_total_energy': e2, - 'force_of_atoms': f1, - 'inferred_force': f2, - 'stress': s1, - 'inferred_stress': s2, - 'num_atoms': torch.tensor([natoms] * ndata), - } - - tmp = AtomGraphData(**_data) - err.update(tmp) - - val = (func(e1 / natoms, e2 / natoms)).mean() + conf['force_loss_weight'] * func( - f1, f2 - ).mean() - if conf['is_train_stress']: - KB = 1602.1766208 - val += conf['stress_loss_weight'] * func(s1 * KB, s2 * KB).mean() - - assert np.allclose(err.get(), val.item()) - err.update(tmp) - assert np.allclose(err.get(), val.item()) - - -@pytest.mark.parametrize( - 'conf', [config(), config(is_train_stress=False), config(loss='huber')] -) -def test_error_recorder_from_config(conf): - recorder = erc.ErrorRecorder.from_config(conf) - - total_loss_flag = False - for metric in recorder.metrics: - if conf['is_train_stress'] is False: - assert 'stress' not in metric.name - if metric.name == 'TotalLoss': - total_loss_flag = True - for loss_metric, _ in metric.metrics: # type: ignore - assert isinstance(loss_metric.func, loss_dict[conf['loss']]) - assert total_loss_flag - - -@pytest.mark.parametrize( - 'conf', [config(), config(is_train_stress=False), config(loss='huber')] -) -def test_error_recorder_from_config_and_loss_functions(conf): - loss_functions = loss.get_loss_functions_from_config(conf) - recorder = erc.ErrorRecorder.from_config(conf, loss_functions) - - total_loss_flag = False - for metric in recorder.metrics: - if conf['is_train_stress'] is False: - assert 'stress' not in metric.name - if metric.name == 'TotalLoss': - total_loss_flag = True - for loss_metric, _ in metric.metrics: # type: ignore - assert isinstance( - loss_metric.loss_def.criterion, loss_dict[conf['loss']] - ) - assert total_loss_flag +# test_errors: error recorder.py, loss.py +from copy import deepcopy + +import numpy as np +import pytest +import torch +import torch.nn +from torch import tensor + +import sevenn.error_recorder as erc +import sevenn.train.loss as loss +from sevenn.atom_graph_data import AtomGraphData +from sevenn.train.optim import loss_dict + +_default_config = { + 'loss': 'mse', + 'loss_param': {}, + 'error_record': [ + ('Energy', 'RMSE'), + ('Force', 'RMSE'), + ('Stress', 'RMSE'), + ('Energy', 'MAE'), + ('Force', 'MAE'), + ('Stress', 'MAE'), + ('TotalLoss', 'None'), + ], + 'is_train_stress': True, + 'force_loss_weight': 1.0, + 'stress_loss_weight': 0.001, +} + +_erc_test_params = [ + ('TotalEnergy', 4, 3), + ('Energy', 4, 3), + ('Force', 4, 3), + ('Stress', 4, 3), + ('Stress_GPa', 4, 3), + ('Energy', 4, 1), + ('Energy', 1, 1), + ('Force', 1, 3), + ('Stress', 1, 3), +] + + +def acl(a, b): + return torch.allclose(a, b, atol=1e-6) + + +def config(**overwrite): # to make it read-only + cf = deepcopy(_default_config) + for k, v in overwrite.items(): + cf[k] = v + return cf + + +def test_per_atom_energy_loss(): + loss_f = loss.PerAtomEnergyLoss(criterion=torch.nn.MSELoss()) + ref = torch.rand(2) + pred = torch.rand(2) + natoms = torch.randint(1, 10, (2,)) + tmp = AtomGraphData( + total_energy=ref, + inferred_total_energy=pred, + num_atoms=natoms, + ).to_dict() + ret = loss_f.get_loss(tmp) + assert loss_f.criterion is not None + assert torch.allclose(loss_f.criterion((ref / natoms), (pred / natoms)), ret) + + +def test_force_loss(): + loss_f = loss.ForceLoss(criterion=torch.nn.MSELoss()) + ref = torch.rand((4, 3)) + pred = torch.rand((4, 3)) + batch = tensor([0, 0, 0, 1]) + tmp = AtomGraphData( + force_of_atoms=ref, + inferred_force=pred, + batch=batch, + ).to_dict() + ret = loss_f.get_loss(tmp) + assert loss_f.criterion is not None + assert torch.allclose(loss_f.criterion(ref.reshape(-1), pred.reshape(-1)), ret) + + +def test_stress_loss(): + loss_f = loss.StressLoss(criterion=torch.nn.MSELoss()) + ref = torch.rand((2, 6)) + pred = torch.rand((2, 6)) + tmp = AtomGraphData( + stress=ref, + inferred_stress=pred, + ).to_dict() + ret = loss_f.get_loss(tmp) + KB = 1602.1766208 + assert loss_f.criterion is not None + assert torch.allclose( + loss_f.criterion(ref.reshape(-1) * KB, pred.reshape(-1) * KB), ret + ) + + +@pytest.mark.parametrize('conf', [config(), config(is_train_stress=False)]) +def test_loss_from_config(conf): + loss_functions = loss.get_loss_functions_from_config(conf) + + if conf['is_train_stress']: + assert len(loss_functions) == 3 + else: + assert len(loss_functions) == 2 + + for loss_def, w in loss_functions: + assert isinstance(loss_def, loss.LossDefinition) + if isinstance(loss_def, loss.PerAtomEnergyLoss): + assert w == 1.0 + elif isinstance(loss_def, loss.ForceLoss): + assert w == conf['force_loss_weight'] + elif isinstance(loss_def, loss.StressLoss): + assert w == conf['stress_loss_weight'] + else: + raise ValueError(f'Unexpected loss function: {loss_def}') + + +@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) +def test_rms_error(err_type, ndata, natoms): + err_dct = erc.get_err_type(err_type) + err = erc.RMSError(**err_dct) + ref = torch.rand((ndata, err.vdim)).squeeze(1) + pred = torch.rand((ndata, err.vdim)).squeeze(1) + natoms = torch.tensor([natoms] * ndata) + _data = { + err_dct['ref_key']: ref, + err_dct['pred_key']: pred, + 'num_atoms': natoms, + } + + tmp = AtomGraphData(**_data) + err.update(tmp) + + _ref = ref * err.coeff + _pred = pred * err.coeff + if 'per_atom' in err_dct and err_dct['per_atom']: + # natoms = natoms.unsqueeze(-1) + _ref = _ref / natoms + _pred = _pred / natoms + val = torch.sqrt(((_ref - _pred) ** 2).sum() / ndata) # not ndata*natoms + assert np.allclose(err.get(), val.item()) + err.update(tmp) + assert np.allclose(err.get(), val.item()) + + +@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) +def test_mae_error(err_type, ndata, natoms): + err_dct = erc.get_err_type(err_type) + vdim = err_dct['vdim'] + err = erc.MAError(**err_dct) + ref = torch.rand((ndata, vdim)).squeeze(1) + pred = torch.rand((ndata, vdim)).squeeze(1) + natoms = torch.tensor([natoms] * ndata) + _data = { + err_dct['ref_key']: ref, + err_dct['pred_key']: pred, + 'num_atoms': natoms, + } + + tmp = AtomGraphData(**_data) + err.update(tmp) + + _ref = ref * err.coeff + _pred = pred * err.coeff + if 'per_atom' in err_dct and err_dct['per_atom']: + _ref /= natoms + _pred /= natoms + + val = abs(_ref - _pred).sum() / (ndata * vdim) + assert np.allclose(err.get(), val.item()) + err.update(tmp) + assert np.allclose(err.get(), val.item()) + + +# TODO: test_component_rms_error + + +@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) +def test_custom_error(err_type, ndata, natoms): + def func(a, b): + return a * b + + err_dct = erc.get_err_type(err_type) + vdim = err_dct['vdim'] + err = erc.CustomError(func, **err_dct) + ref = torch.rand((ndata, vdim)).squeeze(1) + pred = torch.rand((ndata, vdim)).squeeze(1) + natoms = torch.tensor([natoms] * ndata) + _data = { + err_dct['ref_key']: ref, + err_dct['pred_key']: pred, + 'num_atoms': natoms, + } + + _ref = ref * err.coeff + _pred = pred * err.coeff + if 'per_atom' in err_dct and err_dct['per_atom']: + _ref /= natoms + _pred /= natoms + + tmp = AtomGraphData(**_data) + err.update(tmp) + val = func(_ref, _pred).mean() + assert np.allclose(err.get(), val.item()) + err.update(tmp) + assert np.allclose(err.get(), val.item()) + + +@pytest.mark.parametrize('conf', [config(), config(is_train_stress=False)]) +def test_total_loss_metric_from_config(conf): + def func(a, b): + return a * b + + err = erc.ErrorRecorder.init_total_loss_metric(conf, func) + ndata = 3 + natoms = 4 + + e1, e2 = torch.rand(ndata), torch.rand(ndata) + f1, f2 = torch.rand(ndata * natoms, 3), torch.rand(ndata * natoms, 3) + s1, s2 = torch.rand((ndata, 6)), torch.rand((ndata, 6)) + _data = { + 'total_energy': e1, + 'inferred_total_energy': e2, + 'force_of_atoms': f1, + 'inferred_force': f2, + 'stress': s1, + 'inferred_stress': s2, + 'num_atoms': torch.tensor([natoms] * ndata), + } + + tmp = AtomGraphData(**_data) + err.update(tmp) + + val = (func(e1 / natoms, e2 / natoms)).mean() + conf['force_loss_weight'] * func( + f1, f2 + ).mean() + if conf['is_train_stress']: + KB = 1602.1766208 + val += conf['stress_loss_weight'] * func(s1 * KB, s2 * KB).mean() + + assert np.allclose(err.get(), val.item()) + err.update(tmp) + assert np.allclose(err.get(), val.item()) + + +@pytest.mark.parametrize( + 'conf', [config(), config(is_train_stress=False), config(loss='huber')] +) +def test_error_recorder_from_config(conf): + recorder = erc.ErrorRecorder.from_config(conf) + + total_loss_flag = False + for metric in recorder.metrics: + if conf['is_train_stress'] is False: + assert 'stress' not in metric.name + if metric.name == 'TotalLoss': + total_loss_flag = True + for loss_metric, _ in metric.metrics: # type: ignore + assert isinstance(loss_metric.func, loss_dict[conf['loss']]) + assert total_loss_flag + + +@pytest.mark.parametrize( + 'conf', [config(), config(is_train_stress=False), config(loss='huber')] +) +def test_error_recorder_from_config_and_loss_functions(conf): + loss_functions = loss.get_loss_functions_from_config(conf) + recorder = erc.ErrorRecorder.from_config(conf, loss_functions) + + total_loss_flag = False + for metric in recorder.metrics: + if conf['is_train_stress'] is False: + assert 'stress' not in metric.name + if metric.name == 'TotalLoss': + total_loss_flag = True + for loss_metric, _ in metric.metrics: # type: ignore + assert isinstance( + loss_metric.loss_def.criterion, loss_dict[conf['loss']] + ) + assert total_loss_flag diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_modal.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_modal.py index 3b384bd2582e48993babc2e419e13d630bb719f2..000476f0f75ff24cfe4375d3cc2695a8182fd8ba 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_modal.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_modal.py @@ -1,136 +1,136 @@ -# # deploy is test on lammps -# test append modality -# from no modality model to modality yes model -# from modality model to more modality model -# different shift scale settings -# test modality options (check num param) -# calculators with modality - -import copy -# + modal checkpoint continue and test_train -# + sevenn_cp test things in test_cli -import pathlib - -import pytest -from ase.build import bulk - -import sevenn.train.graph_dataset as graph_ds -import sevenn.util as util -from sevenn.calculator import SevenNetCalculator -from sevenn.model_build import build_E3_equivariant_model - -cutoff = 5.0 -data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() -hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') -sevennet_0_path = util.pretrained_name_to_path('7net-0_11July2024') - - -@pytest.fixture(scope='module') -def graph_dataset_path(tmp_path_factory): - gd_path = tmp_path_factory.mktemp('gd') - ds = graph_ds.SevenNetGraphDataset( - cutoff=cutoff, root=str(gd_path), files=[hfo2_path], processed_name='tmp.pt' - ) - return ds.processed_paths[0] - - -_modal_cfg = { - 'use_modal_node_embedding': False, - 'use_modal_self_inter_intro': True, - 'use_modal_self_inter_outro': True, - 'use_modal_output_block': True, - 'use_modality': True, - 'use_modal_wise_shift': True, # T/F should be tested - 'use_modal_wise_scale': False, # T/F should be tested - 'load_trainset_path': [ - { - 'data_modality': 'modal_new', - 'file_list': [{'file': hfo2_path}], - } - ], -} - - -@pytest.fixture(scope='module') -def snet_0_cp(): - return util.load_checkpoint(sevennet_0_path) - - -@pytest.fixture(scope='module') -def snet_0_calc(): - return SevenNetCalculator() - - -@pytest.fixture() -def bulk_atoms(): - atoms = bulk('Si') * 3 - atoms.rattle() - return atoms - - -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): - import numpy as np - - def acl(a, b, rtol=rtol, atol=atol): - return np.allclose(a, b, rtol=rtol, atol=atol) - - assert len(atoms1) == len(atoms2) - assert acl(atoms1.get_cell(), atoms2.get_cell()) - assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) - assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) - assert acl( - atoms1.get_stress(voigt=False), - atoms2.get_stress(voigt=False), - rtol * 10, - atol * 10, - ) - # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) - - -def get_modal_cfg(overwrite=None): - modal_cfg = copy.deepcopy(_modal_cfg).copy() - if overwrite: - modal_cfg.update(overwrite) - return modal_cfg - - -@pytest.mark.parametrize( - 'cfg_overwrite', - [ - ({}), - ({'use_modal_wise_scale': True}), - ({'use_modal_wise_shift': False}), - ({'use_modal_self_inter_intro': False}), - ], -) -def test_append_modal_sevennet_0( - cfg_overwrite, - snet_0_cp, - snet_0_calc, - bulk_atoms, - graph_dataset_path, - tmp_path, -): - modal_cfg = snet_0_cp.config - modal_cfg.pop('load_dataset_path') - modal_cfg.pop('load_validset_path') - modal_cfg.update(get_modal_cfg(cfg_overwrite)) - modal_cfg['shift'] = 'elemwise_reference_energies' - modal_cfg['scale'] = 'per_atom_energy_std' - modal_cfg['load_trainset_path'][0]['file_list'] = [{'file': graph_dataset_path}] - - new_state_dict = snet_0_cp.append_modal( - modal_cfg, original_modal_name='pbe', working_dir=tmp_path - ) - sevennet_0_w_modal = build_E3_equivariant_model(modal_cfg) - sevennet_0_w_modal.load_state_dict(new_state_dict, strict=True) - - atoms1 = bulk_atoms - atoms2 = copy.deepcopy(atoms1) - - atoms1.calc = snet_0_calc - atoms2.calc = SevenNetCalculator( - model=sevennet_0_w_modal, file_type='model_instance', modal='pbe' - ) - - assert_atoms(atoms1, atoms2) +# # deploy is test on lammps +# test append modality +# from no modality model to modality yes model +# from modality model to more modality model +# different shift scale settings +# test modality options (check num param) +# calculators with modality + +import copy +# + modal checkpoint continue and test_train +# + sevenn_cp test things in test_cli +import pathlib + +import pytest +from ase.build import bulk + +import sevenn.train.graph_dataset as graph_ds +import sevenn.util as util +from sevenn.calculator import SevenNetCalculator +from sevenn.model_build import build_E3_equivariant_model + +cutoff = 5.0 +data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() +hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') +sevennet_0_path = util.pretrained_name_to_path('7net-0_11July2024') + + +@pytest.fixture(scope='module') +def graph_dataset_path(tmp_path_factory): + gd_path = tmp_path_factory.mktemp('gd') + ds = graph_ds.SevenNetGraphDataset( + cutoff=cutoff, root=str(gd_path), files=[hfo2_path], processed_name='tmp.pt' + ) + return ds.processed_paths[0] + + +_modal_cfg = { + 'use_modal_node_embedding': False, + 'use_modal_self_inter_intro': True, + 'use_modal_self_inter_outro': True, + 'use_modal_output_block': True, + 'use_modality': True, + 'use_modal_wise_shift': True, # T/F should be tested + 'use_modal_wise_scale': False, # T/F should be tested + 'load_trainset_path': [ + { + 'data_modality': 'modal_new', + 'file_list': [{'file': hfo2_path}], + } + ], +} + + +@pytest.fixture(scope='module') +def snet_0_cp(): + return util.load_checkpoint(sevennet_0_path) + + +@pytest.fixture(scope='module') +def snet_0_calc(): + return SevenNetCalculator() + + +@pytest.fixture() +def bulk_atoms(): + atoms = bulk('Si') * 3 + atoms.rattle() + return atoms + + +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): + import numpy as np + + def acl(a, b, rtol=rtol, atol=atol): + return np.allclose(a, b, rtol=rtol, atol=atol) + + assert len(atoms1) == len(atoms2) + assert acl(atoms1.get_cell(), atoms2.get_cell()) + assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) + assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) + assert acl( + atoms1.get_stress(voigt=False), + atoms2.get_stress(voigt=False), + rtol * 10, + atol * 10, + ) + # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) + + +def get_modal_cfg(overwrite=None): + modal_cfg = copy.deepcopy(_modal_cfg).copy() + if overwrite: + modal_cfg.update(overwrite) + return modal_cfg + + +@pytest.mark.parametrize( + 'cfg_overwrite', + [ + ({}), + ({'use_modal_wise_scale': True}), + ({'use_modal_wise_shift': False}), + ({'use_modal_self_inter_intro': False}), + ], +) +def test_append_modal_sevennet_0( + cfg_overwrite, + snet_0_cp, + snet_0_calc, + bulk_atoms, + graph_dataset_path, + tmp_path, +): + modal_cfg = snet_0_cp.config + modal_cfg.pop('load_dataset_path') + modal_cfg.pop('load_validset_path') + modal_cfg.update(get_modal_cfg(cfg_overwrite)) + modal_cfg['shift'] = 'elemwise_reference_energies' + modal_cfg['scale'] = 'per_atom_energy_std' + modal_cfg['load_trainset_path'][0]['file_list'] = [{'file': graph_dataset_path}] + + new_state_dict = snet_0_cp.append_modal( + modal_cfg, original_modal_name='pbe', working_dir=tmp_path + ) + sevennet_0_w_modal = build_E3_equivariant_model(modal_cfg) + sevennet_0_w_modal.load_state_dict(new_state_dict, strict=True) + + atoms1 = bulk_atoms + atoms2 = copy.deepcopy(atoms1) + + atoms1.calc = snet_0_calc + atoms2.calc = SevenNetCalculator( + model=sevennet_0_w_modal, file_type='model_instance', modal='pbe' + ) + + assert_atoms(atoms1, atoms2) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_model.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_model.py index 843a3550311bdc74b57fd36cb0254871fd49c615..d75976f8cdc3f5761796d703edbbf7dd849a81b6 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_model.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_model.py @@ -1,213 +1,213 @@ -import pytest -import torch -from ase.build import bulk, molecule -from ase.data import chemical_symbols -from torch_geometric.loader.dataloader import Collater - -import sevenn.train.dataload as dl -from sevenn.atom_graph_data import AtomGraphData -from sevenn.model_build import build_E3_equivariant_model -from sevenn.nn.sequential import AtomGraphSequential -from sevenn.util import chemical_species_preprocess - -cutoff = 4.0 - - -_samples = { - 'bulk': bulk('NaCl', 'rocksalt', a=5.63), - 'mol': molecule('H2O'), - 'isolated': molecule('H'), -} -n_samples = len(_samples) -n_atoms_total = sum([len(at) for at in _samples.values()]) - -_graph_list = [ - AtomGraphData.from_numpy_dict(dl.unlabeled_atoms_to_graph(at, cutoff)) - for at in list(_samples.values()) -] - - -def test_chemical_species_preprocess(): - chems = ['He', 'H', 'Be', 'H'] - cf = chemical_species_preprocess(chems, universal=False) - assert cf['chemical_species'] == ['Be', 'H', 'He'] - assert cf['_number_of_species'] == 3 - assert cf['_type_map'] == {4: 0, 1: 1, 2: 2} - - cf = chemical_species_preprocess(chems, universal=True) - assert cf['chemical_species'] == chemical_symbols - assert cf['_number_of_species'] == len(chemical_symbols) - assert len(cf['_type_map']) == len(chemical_symbols) - for z, node_idx in cf['_type_map'].items(): - assert z == node_idx - - -def get_graphs(batched): - cloned = [g.clone() for g in _graph_list] - if not batched: - return cloned - else: - return Collater(cloned)(cloned) - - -def get_model_config(): - config = { - 'cutoff': cutoff, - 'channel': 4, - 'radial_basis': { - 'radial_basis_name': 'bessel', - }, - 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, - 'interaction_type': 'nequip', - 'lmax': 2, - 'is_parity': True, - 'num_convolution_layer': 3, - 'weight_nn_hidden_neurons': [64, 64], - 'act_radial': 'silu', - 'act_scalar': {'e': 'silu', 'o': 'tanh'}, - 'act_gate': {'e': 'silu', 'o': 'tanh'}, - 'conv_denominator': 30.0, - 'train_denominator': False, - 'self_connection_type': 'nequip', - 'shift': -10.0, - 'scale': 10.0, - 'train_shift_scale': False, - 'irreps_manual': False, - 'lmax_edge': -1, - 'lmax_node': -1, - 'readout_as_fcn': False, - 'use_bias_in_linear': False, - '_normalize_sph': True, - } - chems = set() - for at in list(_samples.values()): - chems.update(at.get_chemical_symbols()) - config.update(**chemical_species_preprocess(list(chems))) - return config - - -def get_model(config_overwrite={}): - cf = get_model_config() - cf.update(**config_overwrite) - model = build_E3_equivariant_model(cf, parallel=False) - assert isinstance(model, AtomGraphSequential) - return model - - -@pytest.mark.parametrize('batched', [False, True]) -@pytest.mark.parametrize('cf', [{}]) -def test_shape(cf, batched): - model = get_model(cf) - model.set_is_batch_data(batched) - - graph = get_graphs(batched) - if not batched: - output_shapes = { - 'inferred_total_energy': (), - 'inferred_stress': (6,), - } - for g in graph: - natoms = g['num_atoms'] - output_shapes.update( - { - 'atomic_energy': (natoms, 1), # intended - 'inferred_force': (natoms, 3), - } - ) - output = model(g) - for k, shape in output_shapes.items(): - assert output[k].shape == shape, f'{k}: {output[k].shape} != {shape}' - else: - output_shapes = { - 'inferred_total_energy': (n_samples,), - 'atomic_energy': (n_atoms_total, 1), # intended - 'inferred_force': (n_atoms_total, 3), - 'inferred_stress': (n_samples, 6), - } - output = model(graph) - for k, shape in output_shapes.items(): - assert output[k].shape == shape, f'{k}: {output[k].shape} != {shape}' - - -def test_batch(): - model = get_model() - model.set_is_batch_data(False) - - graph_list = get_graphs(batched=False) - output_list = [model(g) for g in graph_list] - - model.set_is_batch_data(True) - graph_batch = get_graphs(batched=True) - output_batched = model(graph_batch) - - e_concat = torch.concat( - [g['inferred_total_energy'].unsqueeze(-1) for g in output_list] - ) - ae_concat = torch.concat([g['atomic_energy'].squeeze(1) for g in output_list]) - f_concat = torch.concat([g['inferred_force'] for g in output_list]) - s_concat = torch.stack([g['inferred_stress'] for g in output_list]) - - assert torch.allclose(e_concat, output_batched['inferred_total_energy']) - assert torch.allclose(ae_concat, output_batched['atomic_energy'].squeeze(1)) - assert torch.allclose( - torch.round(f_concat, decimals=5), - torch.round(output_batched['inferred_force'], decimals=5), - atol=1e-5, - ) - - assert torch.allclose( # TODO, hard-coded, assumes the first structure is bulk - torch.round(s_concat[0], decimals=5), - torch.round(output_batched['inferred_stress'][0], decimals=5), - ) - - -_n_param_tests = [ - ({}, 20642), - ({'train_denominator': True}, 20642 + 3), - ({'train_shift_scale': True}, 20642 + 2), - ({'shift': [1.0] * 4}, 20642), - ({'scale': [1.0] * 4, 'train_shift_scale': True}, 20642 + 8), - ({'num_convolution_layer': 4}, 33458), - ({'lmax': 3}, 26866), - ({'channel': 2}, 16883), - ({'is_parity': False}, 20386), - ({'self_connection_type': 'linear'}, 20114), -] - - -@pytest.mark.parametrize('cf,ref', _n_param_tests) -def test_num_params(cf, ref): - model = get_model(cf) - param = sum([p.numel() for p in model.parameters() if p.requires_grad]) - assert param == ref, f'ref: {ref} != given: {param}' - - -_n_modal_param_tests = [ - ({}, 20642), - ({'use_modal_node_embedding': True}, 20642 + 8), - ({'use_modal_self_inter_intro': True}, 20642 + 2 * 4 * 3), - ({'use_modal_self_inter_outro': True}, 20642 + 2 * (12 + 20 + 4)), - ({'use_modal_output_block': True}, 20642 + 2 * 4 / 2), -] - - -@pytest.mark.parametrize('cf,ref', _n_modal_param_tests) -def test_modal_num_params(cf, ref): - modal_cfg = { - 'use_modality': True, - '_number_of_modalities': 2, - '_modal_map': {'x1': 0, 'x2': 1}, - 'use_modal_node_embedding': False, - 'use_modal_self_inter_intro': False, - 'use_modal_self_inter_outro': False, - 'use_modal_output_block': False, - 'use_modal_wise_shift': False, - 'use_modal_wise_scale': False, - } - modal_cfg.update(cf) - model = get_model(modal_cfg) - param = sum([p.numel() for p in model.parameters() if p.requires_grad]) - assert param == ref, f'ref: {ref} != given: {param}' - - -# TODO: test_irreps, test_gard, test_equivariance +import pytest +import torch +from ase.build import bulk, molecule +from ase.data import chemical_symbols +from torch_geometric.loader.dataloader import Collater + +import sevenn.train.dataload as dl +from sevenn.atom_graph_data import AtomGraphData +from sevenn.model_build import build_E3_equivariant_model +from sevenn.nn.sequential import AtomGraphSequential +from sevenn.util import chemical_species_preprocess + +cutoff = 4.0 + + +_samples = { + 'bulk': bulk('NaCl', 'rocksalt', a=5.63), + 'mol': molecule('H2O'), + 'isolated': molecule('H'), +} +n_samples = len(_samples) +n_atoms_total = sum([len(at) for at in _samples.values()]) + +_graph_list = [ + AtomGraphData.from_numpy_dict(dl.unlabeled_atoms_to_graph(at, cutoff)) + for at in list(_samples.values()) +] + + +def test_chemical_species_preprocess(): + chems = ['He', 'H', 'Be', 'H'] + cf = chemical_species_preprocess(chems, universal=False) + assert cf['chemical_species'] == ['Be', 'H', 'He'] + assert cf['_number_of_species'] == 3 + assert cf['_type_map'] == {4: 0, 1: 1, 2: 2} + + cf = chemical_species_preprocess(chems, universal=True) + assert cf['chemical_species'] == chemical_symbols + assert cf['_number_of_species'] == len(chemical_symbols) + assert len(cf['_type_map']) == len(chemical_symbols) + for z, node_idx in cf['_type_map'].items(): + assert z == node_idx + + +def get_graphs(batched): + cloned = [g.clone() for g in _graph_list] + if not batched: + return cloned + else: + return Collater(cloned)(cloned) + + +def get_model_config(): + config = { + 'cutoff': cutoff, + 'channel': 4, + 'radial_basis': { + 'radial_basis_name': 'bessel', + }, + 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, + 'interaction_type': 'nequip', + 'lmax': 2, + 'is_parity': True, + 'num_convolution_layer': 3, + 'weight_nn_hidden_neurons': [64, 64], + 'act_radial': 'silu', + 'act_scalar': {'e': 'silu', 'o': 'tanh'}, + 'act_gate': {'e': 'silu', 'o': 'tanh'}, + 'conv_denominator': 30.0, + 'train_denominator': False, + 'self_connection_type': 'nequip', + 'shift': -10.0, + 'scale': 10.0, + 'train_shift_scale': False, + 'irreps_manual': False, + 'lmax_edge': -1, + 'lmax_node': -1, + 'readout_as_fcn': False, + 'use_bias_in_linear': False, + '_normalize_sph': True, + } + chems = set() + for at in list(_samples.values()): + chems.update(at.get_chemical_symbols()) + config.update(**chemical_species_preprocess(list(chems))) + return config + + +def get_model(config_overwrite={}): + cf = get_model_config() + cf.update(**config_overwrite) + model = build_E3_equivariant_model(cf, parallel=False) + assert isinstance(model, AtomGraphSequential) + return model + + +@pytest.mark.parametrize('batched', [False, True]) +@pytest.mark.parametrize('cf', [{}]) +def test_shape(cf, batched): + model = get_model(cf) + model.set_is_batch_data(batched) + + graph = get_graphs(batched) + if not batched: + output_shapes = { + 'inferred_total_energy': (), + 'inferred_stress': (6,), + } + for g in graph: + natoms = g['num_atoms'] + output_shapes.update( + { + 'atomic_energy': (natoms, 1), # intended + 'inferred_force': (natoms, 3), + } + ) + output = model(g) + for k, shape in output_shapes.items(): + assert output[k].shape == shape, f'{k}: {output[k].shape} != {shape}' + else: + output_shapes = { + 'inferred_total_energy': (n_samples,), + 'atomic_energy': (n_atoms_total, 1), # intended + 'inferred_force': (n_atoms_total, 3), + 'inferred_stress': (n_samples, 6), + } + output = model(graph) + for k, shape in output_shapes.items(): + assert output[k].shape == shape, f'{k}: {output[k].shape} != {shape}' + + +def test_batch(): + model = get_model() + model.set_is_batch_data(False) + + graph_list = get_graphs(batched=False) + output_list = [model(g) for g in graph_list] + + model.set_is_batch_data(True) + graph_batch = get_graphs(batched=True) + output_batched = model(graph_batch) + + e_concat = torch.concat( + [g['inferred_total_energy'].unsqueeze(-1) for g in output_list] + ) + ae_concat = torch.concat([g['atomic_energy'].squeeze(1) for g in output_list]) + f_concat = torch.concat([g['inferred_force'] for g in output_list]) + s_concat = torch.stack([g['inferred_stress'] for g in output_list]) + + assert torch.allclose(e_concat, output_batched['inferred_total_energy']) + assert torch.allclose(ae_concat, output_batched['atomic_energy'].squeeze(1)) + assert torch.allclose( + torch.round(f_concat, decimals=5), + torch.round(output_batched['inferred_force'], decimals=5), + atol=1e-5, + ) + + assert torch.allclose( # TODO, hard-coded, assumes the first structure is bulk + torch.round(s_concat[0], decimals=5), + torch.round(output_batched['inferred_stress'][0], decimals=5), + ) + + +_n_param_tests = [ + ({}, 20642), + ({'train_denominator': True}, 20642 + 3), + ({'train_shift_scale': True}, 20642 + 2), + ({'shift': [1.0] * 4}, 20642), + ({'scale': [1.0] * 4, 'train_shift_scale': True}, 20642 + 8), + ({'num_convolution_layer': 4}, 33458), + ({'lmax': 3}, 26866), + ({'channel': 2}, 16883), + ({'is_parity': False}, 20386), + ({'self_connection_type': 'linear'}, 20114), +] + + +@pytest.mark.parametrize('cf,ref', _n_param_tests) +def test_num_params(cf, ref): + model = get_model(cf) + param = sum([p.numel() for p in model.parameters() if p.requires_grad]) + assert param == ref, f'ref: {ref} != given: {param}' + + +_n_modal_param_tests = [ + ({}, 20642), + ({'use_modal_node_embedding': True}, 20642 + 8), + ({'use_modal_self_inter_intro': True}, 20642 + 2 * 4 * 3), + ({'use_modal_self_inter_outro': True}, 20642 + 2 * (12 + 20 + 4)), + ({'use_modal_output_block': True}, 20642 + 2 * 4 / 2), +] + + +@pytest.mark.parametrize('cf,ref', _n_modal_param_tests) +def test_modal_num_params(cf, ref): + modal_cfg = { + 'use_modality': True, + '_number_of_modalities': 2, + '_modal_map': {'x1': 0, 'x2': 1}, + 'use_modal_node_embedding': False, + 'use_modal_self_inter_intro': False, + 'use_modal_self_inter_outro': False, + 'use_modal_output_block': False, + 'use_modal_wise_shift': False, + 'use_modal_wise_scale': False, + } + modal_cfg.update(cf) + model = get_model(modal_cfg) + param = sum([p.numel() for p in model.parameters() if p.requires_grad]) + assert param == ref, f'ref: {ref} != given: {param}' + + +# TODO: test_irreps, test_gard, test_equivariance diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_pretrained.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_pretrained.py index d82d7ab0ff4f243b7ce400061446485f8a7113f9..4676ca47e11a209c7ada6b9992ebfa2a252c1912 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_pretrained.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_pretrained.py @@ -1,344 +1,344 @@ -# test_pretrained: output consistency for pretrained models - -import pytest -import torch -from ase.build import bulk, molecule - -import sevenn._keys as KEY -from sevenn.atom_graph_data import AtomGraphData -from sevenn.train.dataload import unlabeled_atoms_to_graph -from sevenn.util import model_from_checkpoint, pretrained_name_to_path - - -def acl(a, b, atol=1e-6): - return torch.allclose(a, b, atol=atol) - - -@pytest.fixture -def atoms_pbc(): - atoms1 = bulk('NaCl', 'rocksalt', a=5.63) - atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) - atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) - return atoms1 - - -@pytest.fixture -def atoms_mol(): - atoms2 = molecule('H2O') - atoms2.set_positions([[0.0, 0.2, 0.12], [0.0, 0.76, -0.48], [0.0, -0.76, -0.48]]) - return atoms2 - - -def test_7net0_22May2024(atoms_pbc, atoms_mol): - """ - Reference from v0.9.3.post1 with SevenNetCalculator - """ - cp_path = pretrained_name_to_path('7net-0_22May2024') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - g1_ref_e = torch.tensor([-3.4140868186950684]) - g1_ref_f = torch.tensor( - [ - [1.2628037e01, 7.5093508e-03, 1.3480943e-02], - [-1.2628037e01, -7.5093508e-03, -1.3480917e-02], - ] - ) - g1_ref_s = -1 * torch.tensor( - [-0.65014917, -0.01990843, -0.02000658, 0.03286226, 0.00589222, 0.03291973] - ) - - g2_ref_e = torch.tensor([-12.808363914489746]) - g2_ref_f = torch.tensor( - [ - [9.31322575e-10, -1.30241165e01, 6.93116236e00], - [-1.39698386e-09, 9.28001022e00, -9.51867390e00], - [5.23868948e-10, 3.74410582e00, 2.58751225e00], - ] - ) - - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net0_11July2024(atoms_pbc, atoms_mol): - """ - Reference from v0.9.3.post1 with SevenNetCalculator - """ - cp_path = pretrained_name_to_path('7net-0_11July2024') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-3.779199]) - g1_ref_f = torch.tensor( - [ - [12.666697, 0.04726403, 0.04775861], - [-12.666697, -0.04726403, -0.04775861], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.6439122, -0.03643947, -0.03643981, 0.04543639, 0.00599139, 0.04544507] - ) - - g2_ref_e = torch.tensor([-12.782808303833008]) - g2_ref_f = torch.tensor( - [ - [0.0, -1.3619621e01, 7.5937047e00], - [0.0, 9.3918495e00, -1.0172190e01], - [0.0, 4.2277718e00, 2.5784855e00], - ] - ) - - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net_l3i5(atoms_pbc, atoms_mol): - """ - Reference from v0.9.3.post1 with SevenNetCalculator - """ - cp_path = pretrained_name_to_path('7net-l3i5') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-3.611131191253662]) - g1_ref_f = torch.tensor( - [ - [13.430887, 0.08655541, 0.08754013], - [-13.430886, -0.08655544, -0.08754011], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.6818918, -0.04104544, -0.04107663, 0.04794561, 0.00565416, 0.04793138] - ) - - g2_ref_e = torch.tensor([-12.700481414794922]) - g2_ref_f = torch.tensor( - [ - [0.0, -1.4547814e01, 8.1347866], - [0.0, 1.0308369e01, -1.0880318e01], - [0.0, 4.2394452, 2.7455316], - ] - ) - - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f, 1e-5) - assert acl(g1.inferred_stress, g1_ref_s, 1e-5) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net_mf_0(atoms_pbc, atoms_mol): - cp_path = pretrained_name_to_path('7net-mf-0') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - g1[KEY.DATA_MODALITY] = 'R2SCAN' - g2[KEY.DATA_MODALITY] = 'R2SCAN' - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-11.607587814331055]) - g1_ref_f = torch.tensor( - [ - [8.512259, 0.07307914, 0.06676716], - [-8.512257, -0.07307915, -0.06676716], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.4516204, -0.02483013, -0.02485001, 0.03247492, 0.00259375, 0.03250402] - ) - - g2_ref_e = torch.tensor([-14.172412872314453]) - g2_ref_f = torch.tensor( - [ - [4.6566129e-10, -1.3429364e01, 6.9344816e00], - [2.3283064e-09, 8.9132404e00, -9.6807365e00], - [-2.7939677e-09, 4.5161238e00, 2.7462559e00], - ] - ) - - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net_mf_ompa_mpa(atoms_pbc, atoms_mol): - cp_path = pretrained_name_to_path('7net-mf-ompa') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - # mpa - g1[KEY.DATA_MODALITY] = 'mpa' - g2[KEY.DATA_MODALITY] = 'mpa' - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-3.490943193435669]) - g1_ref_f = torch.tensor( - [ - [1.2680445e01, -2.7985498e-04, -2.7979910e-04], - [-1.2680446e01, 2.7984008e-04, 2.7981028e-04], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.6481662, -0.02462837, -0.02462837, 0.02693467, 0.00459635, 0.02693467] - ) - - g2_ref_e = torch.tensor([-12.597525596618652]) - g2_ref_f = torch.tensor( - [ - [0.0, -12.245223, 7.26795], - [0.0, 8.816763, -9.423925], - [0.0, 3.4284601, 2.1559749], - ] - ) - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net_mf_ompa_omat(atoms_pbc, atoms_mol): - cp_path = pretrained_name_to_path('7net-mf-ompa') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - # mpa - g1[KEY.DATA_MODALITY] = 'omat24' - g2[KEY.DATA_MODALITY] = 'omat24' - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-3.5094668865203857]) - g1_ref_f = torch.tensor( - [ - [1.2562084e01, -1.4219694e-03, -1.4219843e-03], - [-1.2562084e01, 1.4219508e-03, 1.4219955e-03], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.6430905, -0.0254128, -0.02541281, 0.0268343, 0.00460021, 0.0268343] - ) - - g2_ref_e = torch.tensor([-12.6202974319458]) - g2_ref_f = torch.tensor( - [ - [0.0, -12.205926, 7.2050343], - [0.0, 8.790399, -9.368677], - [0.0, 3.4155273, 2.163643], - ] - ) - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net_omat(atoms_pbc, atoms_mol): - cp_path = pretrained_name_to_path('7net-omat') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-3.5033323764801025]) - g1_ref_f = torch.tensor( - [ - [12.533154, 0.02358698, 0.02358694], - [-12.533153, -0.02358699, -0.02358697], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.6420925, -0.02781446, -0.02781446, 0.02575445, 0.00381664, 0.02575445] - ) - - g2_ref_e = torch.tensor([-12.403768539428711]) - g2_ref_f = torch.tensor( - [ - [0, -12.848297, 7.11432], - [0.0, 9.265477, -9.564951], - [0.0, 3.58282, 2.4506311], - ] - ) - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) +# test_pretrained: output consistency for pretrained models + +import pytest +import torch +from ase.build import bulk, molecule + +import sevenn._keys as KEY +from sevenn.atom_graph_data import AtomGraphData +from sevenn.train.dataload import unlabeled_atoms_to_graph +from sevenn.util import model_from_checkpoint, pretrained_name_to_path + + +def acl(a, b, atol=1e-6): + return torch.allclose(a, b, atol=atol) + + +@pytest.fixture +def atoms_pbc(): + atoms1 = bulk('NaCl', 'rocksalt', a=5.63) + atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) + atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) + return atoms1 + + +@pytest.fixture +def atoms_mol(): + atoms2 = molecule('H2O') + atoms2.set_positions([[0.0, 0.2, 0.12], [0.0, 0.76, -0.48], [0.0, -0.76, -0.48]]) + return atoms2 + + +def test_7net0_22May2024(atoms_pbc, atoms_mol): + """ + Reference from v0.9.3.post1 with SevenNetCalculator + """ + cp_path = pretrained_name_to_path('7net-0_22May2024') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + g1_ref_e = torch.tensor([-3.4140868186950684]) + g1_ref_f = torch.tensor( + [ + [1.2628037e01, 7.5093508e-03, 1.3480943e-02], + [-1.2628037e01, -7.5093508e-03, -1.3480917e-02], + ] + ) + g1_ref_s = -1 * torch.tensor( + [-0.65014917, -0.01990843, -0.02000658, 0.03286226, 0.00589222, 0.03291973] + ) + + g2_ref_e = torch.tensor([-12.808363914489746]) + g2_ref_f = torch.tensor( + [ + [9.31322575e-10, -1.30241165e01, 6.93116236e00], + [-1.39698386e-09, 9.28001022e00, -9.51867390e00], + [5.23868948e-10, 3.74410582e00, 2.58751225e00], + ] + ) + + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net0_11July2024(atoms_pbc, atoms_mol): + """ + Reference from v0.9.3.post1 with SevenNetCalculator + """ + cp_path = pretrained_name_to_path('7net-0_11July2024') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.779199]) + g1_ref_f = torch.tensor( + [ + [12.666697, 0.04726403, 0.04775861], + [-12.666697, -0.04726403, -0.04775861], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6439122, -0.03643947, -0.03643981, 0.04543639, 0.00599139, 0.04544507] + ) + + g2_ref_e = torch.tensor([-12.782808303833008]) + g2_ref_f = torch.tensor( + [ + [0.0, -1.3619621e01, 7.5937047e00], + [0.0, 9.3918495e00, -1.0172190e01], + [0.0, 4.2277718e00, 2.5784855e00], + ] + ) + + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_l3i5(atoms_pbc, atoms_mol): + """ + Reference from v0.9.3.post1 with SevenNetCalculator + """ + cp_path = pretrained_name_to_path('7net-l3i5') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.611131191253662]) + g1_ref_f = torch.tensor( + [ + [13.430887, 0.08655541, 0.08754013], + [-13.430886, -0.08655544, -0.08754011], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6818918, -0.04104544, -0.04107663, 0.04794561, 0.00565416, 0.04793138] + ) + + g2_ref_e = torch.tensor([-12.700481414794922]) + g2_ref_f = torch.tensor( + [ + [0.0, -1.4547814e01, 8.1347866], + [0.0, 1.0308369e01, -1.0880318e01], + [0.0, 4.2394452, 2.7455316], + ] + ) + + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f, 1e-5) + assert acl(g1.inferred_stress, g1_ref_s, 1e-5) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_mf_0(atoms_pbc, atoms_mol): + cp_path = pretrained_name_to_path('7net-mf-0') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + g1[KEY.DATA_MODALITY] = 'R2SCAN' + g2[KEY.DATA_MODALITY] = 'R2SCAN' + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-11.607587814331055]) + g1_ref_f = torch.tensor( + [ + [8.512259, 0.07307914, 0.06676716], + [-8.512257, -0.07307915, -0.06676716], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.4516204, -0.02483013, -0.02485001, 0.03247492, 0.00259375, 0.03250402] + ) + + g2_ref_e = torch.tensor([-14.172412872314453]) + g2_ref_f = torch.tensor( + [ + [4.6566129e-10, -1.3429364e01, 6.9344816e00], + [2.3283064e-09, 8.9132404e00, -9.6807365e00], + [-2.7939677e-09, 4.5161238e00, 2.7462559e00], + ] + ) + + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_mf_ompa_mpa(atoms_pbc, atoms_mol): + cp_path = pretrained_name_to_path('7net-mf-ompa') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + # mpa + g1[KEY.DATA_MODALITY] = 'mpa' + g2[KEY.DATA_MODALITY] = 'mpa' + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.490943193435669]) + g1_ref_f = torch.tensor( + [ + [1.2680445e01, -2.7985498e-04, -2.7979910e-04], + [-1.2680446e01, 2.7984008e-04, 2.7981028e-04], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6481662, -0.02462837, -0.02462837, 0.02693467, 0.00459635, 0.02693467] + ) + + g2_ref_e = torch.tensor([-12.597525596618652]) + g2_ref_f = torch.tensor( + [ + [0.0, -12.245223, 7.26795], + [0.0, 8.816763, -9.423925], + [0.0, 3.4284601, 2.1559749], + ] + ) + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_mf_ompa_omat(atoms_pbc, atoms_mol): + cp_path = pretrained_name_to_path('7net-mf-ompa') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + # mpa + g1[KEY.DATA_MODALITY] = 'omat24' + g2[KEY.DATA_MODALITY] = 'omat24' + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.5094668865203857]) + g1_ref_f = torch.tensor( + [ + [1.2562084e01, -1.4219694e-03, -1.4219843e-03], + [-1.2562084e01, 1.4219508e-03, 1.4219955e-03], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6430905, -0.0254128, -0.02541281, 0.0268343, 0.00460021, 0.0268343] + ) + + g2_ref_e = torch.tensor([-12.6202974319458]) + g2_ref_f = torch.tensor( + [ + [0.0, -12.205926, 7.2050343], + [0.0, 8.790399, -9.368677], + [0.0, 3.4155273, 2.163643], + ] + ) + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_omat(atoms_pbc, atoms_mol): + cp_path = pretrained_name_to_path('7net-omat') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.5033323764801025]) + g1_ref_f = torch.tensor( + [ + [12.533154, 0.02358698, 0.02358694], + [-12.533153, -0.02358699, -0.02358697], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6420925, -0.02781446, -0.02781446, 0.02575445, 0.00381664, 0.02575445] + ) + + g2_ref_e = torch.tensor([-12.403768539428711]) + g2_ref_f = torch.tensor( + [ + [0, -12.848297, 7.11432], + [0.0, 9.265477, -9.564951], + [0.0, 3.58282, 2.4506311], + ] + ) + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_shift_scale.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_shift_scale.py index 2fe3245061ead2bad8eab8c206fe522c8cab9f6c..3d5a4eba46902e4c67600400b6fb38b0a81a25a1 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_shift_scale.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_shift_scale.py @@ -1,494 +1,494 @@ -import pytest -import torch - -import sevenn._keys as KEY -from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType -from sevenn.nn.scale import ( - ModalWiseRescale, - Rescale, - SpeciesWiseRescale, - get_resolved_shift_scale, -) - -################################################################################ -# Tests for Rescale # -################################################################################ - - -@pytest.mark.parametrize('shift,scale', [(0.0, 1.0), (1.0, 2.0), (-5.0, 10.0)]) -def test_rescale_init(shift, scale): - """ - Test that Rescale can be initialized properly without errors - and that parameters are set correctly. - """ - module = Rescale(shift=shift, scale=scale) - assert module.shift.item() == shift - assert module.scale.item() == scale - assert module.key_input == KEY.SCALED_ATOMIC_ENERGY - assert module.key_output == KEY.ATOMIC_ENERGY - - -def test_rescale_forward(): - """ - Test that Rescale forward pass correctly applies: - output = input * scale + shift - """ - # Setup - shift, scale = 1.0, 2.0 - module = Rescale(shift=shift, scale=scale) - # Make some fake data - input_data = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) - data: AtomGraphDataType = {KEY.SCALED_ATOMIC_ENERGY: input_data.clone()} - - # Forward - out_data = module(data) - - # Check correctness - expected_output = input_data * scale + shift - assert torch.allclose(out_data[KEY.ATOMIC_ENERGY], expected_output) - - -def test_rescale_get_shift_and_scale(): - """ - Test get_shift() and get_scale() methods in Rescale. - """ - module = Rescale(shift=1.5, scale=3.5) - assert module.get_shift() == pytest.approx(1.5) - assert module.get_scale() == pytest.approx(3.5) - - -################################################################################ -# Tests for SpeciesWiseRescale # -################################################################################ - - -def test_specieswise_rescale_init_float(): - """ - Test SpeciesWiseRescale when both shift and scale are floats - (should expand to same length lists). - """ - module = SpeciesWiseRescale(shift=[1.0, -1.0], scale=2.0) - # Expect a parameter of length = 1 in this scenario, but can differ - # if we raise an error for "Both shift and scale is not a list". - # Usually, you'd specify a known number of species or do from_mappers. - # The code as-is throws ValueError if both are float. Let's do from_mappers: - # We'll do direct init if your code allows it. If not, use from_mappers. - assert module.shift.shape == module.scale.shape - # They must be single-parameter (or expanded) if not from mappers. - - -def test_specieswise_rescale_init_list(): - """ - Test initialization with list-based shift/scale of same length. - """ - shift = [1.0, 2.0, 3.0] - scale = [2.0, 3.0, 4.0] - module = SpeciesWiseRescale(shift=shift, scale=scale) - assert len(module.shift) == 3 - assert len(module.scale) == 3 - assert torch.allclose(module.shift, torch.tensor([1.0, 2.0, 3.0])) - assert torch.allclose(module.scale, torch.tensor([2.0, 3.0, 4.0])) - - -def test_specieswise_rescale_forward(): - """ - Test that SpeciesWiseRescale forward pass applies: - output[i] = input[i]*scale[atom_type[i]] + shift[atom_type[i]] - """ - # Suppose we have two species types: - # 0 -> shift=1, scale=2, 1 -> shift=5, scale=10 - # (we'll pass them as lists in the correct order) - shift = [1.0, 5.0] - scale = [2.0, 10.0] - module = SpeciesWiseRescale( - shift=shift, - scale=scale, - data_key_in='in', - data_key_out='out', - data_key_indices='z', - ) - - # Create mock data - # Suppose we have three atoms: species => [0, 1, 0] - # input => [ [1.], [1.], [3.] ] - data: AtomGraphDataType = { - 'z': torch.tensor([0, 1, 0], dtype=torch.long), - 'in': torch.tensor([[1.0], [1.0], [3.0]], dtype=torch.float), - } - - out = module(data) - # Now let's manually compute expected: - # For atom 0: scale=2, shift=1, input=1 => 1*2+1=3 - # For atom 1: scale=10, shift=5, input=1 => 1*10+5=15 - # For atom 2: scale=2, shift=1, input=3 => 3*2+1=7 - expected = torch.tensor([[3.0], [15.0], [7.0]]) - - assert torch.allclose(out['out'], expected) - - -def test_specieswise_rescale_get_shift_scale(): - """ - Test get_shift() and get_scale() with/without type_map. - """ - shift = [1.0, 2.0] - scale = [3.0, 4.0] - module = SpeciesWiseRescale(shift=shift, scale=scale) - - # Without type_map - # Should return the raw parameter values (list form). - s = module.get_shift() - sc = module.get_scale() - assert s == [1.0, 2.0] - assert sc == [3.0, 4.0] - - # With a type_map (example: atomic_number 1 -> 0, 8 -> 1) - type_map = {1: 0, 8: 1} # hydrogen, oxygen - s_univ = module.get_shift(type_map) - sc_univ = module.get_scale(type_map) - # In this small example with NUM_UNIV_ELEMENT = 2, the _as_univ will produce - # a list of length = NUM_UNIV_ELEMENT. If your real NUM_UNIV_ELEMENT is bigger, - # the rest would be padded with default values. - # For demonstration let's assume it returns [1.0, 2.0]. - # Check at least the known mapped portion: - assert len(s_univ) == NUM_UNIV_ELEMENT - assert len(sc_univ) == NUM_UNIV_ELEMENT - assert s_univ[1] == 1.0 # atomic_number=1 -> idx=0 -> shift=1.0 - assert s_univ[8] == 2.0 - - -################################################################################ -# Tests for ModalWiseRescale # -################################################################################ - - -def test_modalwise_rescale_init(): - """ - Basic sanity check for ModalWiseRescale initialization with - certain shapes. - """ - # Suppose we have 2 modals, 3 species => shift, scale is shape [2,3] - shift = [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] - scale = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] - module = ModalWiseRescale( - shift=shift, - scale=scale, - use_modal_wise_shift=True, - use_modal_wise_scale=True, - ) - # Check shape - assert module.shift.shape == torch.Size([2, 3]) - assert module.scale.shape == torch.Size([2, 3]) - - -def test_modalwise_rescale_forward(): - """ - Test that the forward pass of ModalWiseRescale matches - output[i] = input[i] * scale[modal_i, atom_i] + shift[modal_i, atom_i] - when both use_modal_wise_{shift,scale} are True. - """ - shift = [[0.0, 10.0], [5.0, 15.0]] # shape [2 (modals), 2 (species)] - scale = [[1.0, 2.0], [10.0, 20.0]] - module = ModalWiseRescale( - shift=shift, - scale=scale, - data_key_in='in', - data_key_out='out', - data_key_modal_indices='modal_idx', - data_key_atom_indices='atom_idx', - use_modal_wise_shift=True, - use_modal_wise_scale=True, - ) - - data: AtomGraphDataType = { - 'in': torch.tensor([[1.0], [1.0], [2.0], [2.0]]), - 'modal_idx': torch.tensor([0, 1], dtype=torch.long), - 'atom_idx': torch.tensor([0, 1, 0, 1], dtype=torch.long), - 'batch': torch.tensor([0, 0, 1, 1], dtype=torch.long), - } - - out = module(data) - # i=0 => modal_idx=0, atom_idx=0 => shift=0.0, scale=1.0 => out=1*1+0=1 - # i=1 => modal_idx=0, atom_idx=1 => shift=10.0, scale=2.0 => out=1*2+10=12 - # i=2 => modal_idx=1, atom_idx=0 => shift=5.0, scale=10.0 => out=2*10+5=25 - # i=3 => modal_idx=1, atom_idx=1 => shift=15.0, scale=20.0 => out=2*20+15=55 - expected = torch.tensor([[1.0], [12.0], [25.0], [55.0]]) - assert torch.allclose(out['out'], expected) - - -def test_modalwise_rescale_get_shift_scale(): - """ - Test get_shift() and get_scale() with type_map and modal_map. - """ - # Setup - shift = [[0.0, 10.0], [5.0, 15.0]] - scale = [[1.0, 2.0], [10.0, 20.0]] - mod = ModalWiseRescale( - shift=shift, - scale=scale, - use_modal_wise_shift=True, - use_modal_wise_scale=True, - ) - - # Suppose we have type_map and modal_map - type_map = {1: 0, 8: 1} # Example: H->0, O->1 - modal_map = {'a': 0, 'b': 1} - - # get_shift, get_scale - s = mod.get_shift(type_map=type_map, modal_map=modal_map) - sc = mod.get_scale(type_map=type_map, modal_map=modal_map) - # Expect dict with keys "ambient", "pressure". - # Example: s["ambient"] = [ shift(0,0), shift(0,1) ] mapped to H,O - # s["pressure"] = [ shift(1,0), shift(1,1) ] - assert isinstance(s, dict) and isinstance(sc, dict) - assert set(s.keys()) == {'a', 'b'} - assert set(sc.keys()) == {'a', 'b'} - - -################################################################################ -# Tests for get_resolved_shift_scale function # -################################################################################ - - -def test_get_resolved_shift_scale_rescale(): - """ - Test get_resolved_shift_scale for a Rescale instance. - """ - from_m = Rescale(shift=2.0, scale=5.0) - shift, scale = get_resolved_shift_scale(from_m) - assert shift == 2.0 - assert scale == 5.0 - - -def test_get_resolved_shift_scale_specieswise(): - """ - Test get_resolved_shift_scale for a SpeciesWiseRescale instance. - """ - shift_list = [1.0, 2.0] - scale_list = [3.0, 4.0] - module = SpeciesWiseRescale(shift=shift_list, scale=scale_list) - type_map = {1: 0, 8: 1} - s, sc = get_resolved_shift_scale(module, type_map=type_map) - # The result should be extended to NUM_UNIV_ELEMENT length in real usage, - # but at least the first few should match shift_list, scale_list mapped. - assert isinstance(s, list) - assert isinstance(sc, list) - # Check mapped values - assert s[1] == shift_list[0] - assert s[8] == shift_list[1] - assert sc[1] == scale_list[0] - assert sc[8] == scale_list[1] - - -def test_get_resolved_shift_scale_modalwise(): - """ - Test get_resolved_shift_scale for a ModalWiseRescale instance. - """ - shift = [[0.0, 10.0], [5.0, 15.0]] - scale = [[1.0, 2.0], [10.0, 20.0]] - mmod = ModalWiseRescale( - shift=shift, - scale=scale, - use_modal_wise_shift=True, - use_modal_wise_scale=True, - ) - type_map = {1: 0, 8: 1} - modal_map = {'a': 0, 'b': 1} - s, sc = get_resolved_shift_scale(mmod, type_map=type_map, modal_map=modal_map) - # We expect dictionaries - assert isinstance(s, dict) and isinstance(sc, dict) - # Keys "a", "pressure" - assert 'a' in s - assert 'b' in s - # Check one example - # s["a"] => [0.0, 10.0] - # sc["a"] => [1.0, 2.0] - assert s['a'][1] == 0.0 - assert s['a'][8] == 10.0 - assert sc['a'][1] == 1.0 - assert sc['a'][8] == 2.0 - - -################################################################################ -# Tests for from_mappers function # -################################################################################ - - -@pytest.mark.parametrize( - 'shift, scale, type_map, expected_shift, expected_scale', - [ - # Both shift and scale are floats -> broadcast to each species - ( - 2.0, - 3.0, - {1: 0, 8: 1}, # e.g., H -> index 0, O -> index 1 - [2.0, 2.0], # broadcast - [3.0, 3.0], - ), - # shift, scale are same-length lists => directly used - ( - [0.5, 0.6], - [1.0, 1.1], - {1: 0, 8: 1}, - [0.5, 0.6], - [1.0, 1.1], - ), - # shift, scale are entire "universal" length (NUM_UNIV_ELEMENT=118), - # but we only map out the subset for the actual species in type_map - ( - [0.1] * NUM_UNIV_ELEMENT, - [1.1] * NUM_UNIV_ELEMENT, - {1: 0, 8: 1}, - [0.1, 0.1], - [1.1, 1.1], - ), - # shift is a list, scale is float => shift is used directly, scale broadcast - ( - [1.0, 2.0], - 5.0, - {6: 0, 14: 1}, # C -> 0, Si -> 1 - [1.0, 2.0], - [5.0, 5.0], - ), - ], -) -def test_specieswise_rescale_from_mappers( - shift, scale, type_map, expected_shift, expected_scale -): - """ - Test SpeciesWiseRescale.from_mappers with various combinations of - shift/scale (float, list, universal list) and a given type_map. - """ - module = SpeciesWiseRescale.from_mappers( # type: ignore - shift=shift, - scale=scale, - type_map=type_map, - ) - # Check that the module's internal shift and scale have the correct shape - # The length must match number of species in type_map - assert module.shift.shape[0] == len(type_map) - assert module.scale.shape[0] == len(type_map) - - # Check that the content matches expected - actual_shift = module.shift.detach().cpu().tolist() - actual_scale = module.scale.detach().cpu().tolist() - - assert pytest.approx(actual_shift) == expected_shift - assert pytest.approx(actual_scale) == expected_scale - - -@pytest.mark.parametrize( - 'shift, scale, use_modal_wise_shift, use_modal_wise_scale, ' - 'type_map, modal_map, expected_shift, expected_scale', - [ - # Example 1: single float for shift/scale, - # broadcast over 2 modals and 2 species - ( - 1.0, - 2.0, - True, # shift depends on modal - True, # scale depends on modal - {1: 0, 8: 1}, - {'modA': 0, 'modB': 1}, - # expect 2D => [2 modals x 2 species] - [[1.0, 1.0], [1.0, 1.0]], - [[2.0, 2.0], [2.0, 2.0]], - ), - # Example 2: shift/scale are universal element-lists => use_modal=False => 1D - ( - [0.5] * NUM_UNIV_ELEMENT, - [1.5] * NUM_UNIV_ELEMENT, - False, # shift is not modal-wise - False, # scale is not modal-wise - {6: 0, 14: 1}, # e.g. C->0, Si->1 - {'modA': 0, 'modB': 1}, - # 1D => length = n_atom_types(=2) - [0.5, 0.5], - [1.5, 1.5], - ), - # Example 3: shift is dict of modals -> each is float - # => broadcast for each species - ( - {'modA': 0.0, 'modB': 2.0}, - {'modA': 1.0, 'modB': 3.0}, - True, - True, - {1: 0, 8: 1}, - {'modA': 0, 'modB': 1}, - # shift => shape [2 modals, 2 species] - [[0.0, 0.0], [2.0, 2.0]], - [[1.0, 1.0], [3.0, 3.0]], - ), - # Example 4: already in "modal-wise + species-wise" shape, direct pass - ( - [[0.0, 10.0], [5.0, 15.0]], - [[1.0, 2.0], [10.0, 20.0]], - True, - True, - {1: 0, 8: 1}, - {'modA': 0, 'modB': 1}, - [[0.0, 10.0], [5.0, 15.0]], - [[1.0, 2.0], [10.0, 20.0]], - ), - # Example 5: shift is a list of floats (one per modal), - # but we want modal-wise => broadcast for each species - ( - [0.0, 10.0], # length=2 => same as #modals - [1.0, 2.0], - True, - True, - {1: 0, 8: 1}, - {'modA': 0, 'modB': 1}, - [[0.0, 0.0], [10.0, 10.0]], - [[1.0, 1.0], [2.0, 2.0]], - ), - ], -) -def test_modalwise_rescale_from_mappers( - shift, - scale, - use_modal_wise_shift, - use_modal_wise_scale, - type_map, - modal_map, - expected_shift, - expected_scale, -): - """ - Test ModalWiseRescale.from_mappers for different shapes of shift/scale, - combined with type_map and modal_map. - """ - - module = ModalWiseRescale.from_mappers( # type: ignore - shift=shift, - scale=scale, - use_modal_wise_shift=use_modal_wise_shift, - use_modal_wise_scale=use_modal_wise_scale, - type_map=type_map, - modal_map=modal_map, - ) - # Check shape of the resulting shift, scale - # If modal-wise, we expect a 2D shape: [n_modals, n_species] - # Otherwise, a 1D shape: [n_species] - if use_modal_wise_shift: - assert module.shift.dim() == 2 - assert module.shift.shape[0] == len(modal_map) - assert module.shift.shape[1] == len(type_map) - else: - assert module.shift.dim() == 1 - assert module.shift.shape[0] == len(type_map) - - # Similarly for scale - if use_modal_wise_scale: - assert module.scale.dim() == 2 - assert module.scale.shape[0] == len(modal_map) - assert module.scale.shape[1] == len(type_map) - else: - assert module.scale.dim() == 1 - assert module.scale.shape[0] == len(type_map) - - # Verify the content matches our expectation - actual_shift = module.shift.detach().cpu().tolist() - actual_scale = module.scale.detach().cpu().tolist() - - assert actual_shift == expected_shift - assert actual_scale == expected_scale +import pytest +import torch + +import sevenn._keys as KEY +from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType +from sevenn.nn.scale import ( + ModalWiseRescale, + Rescale, + SpeciesWiseRescale, + get_resolved_shift_scale, +) + +################################################################################ +# Tests for Rescale # +################################################################################ + + +@pytest.mark.parametrize('shift,scale', [(0.0, 1.0), (1.0, 2.0), (-5.0, 10.0)]) +def test_rescale_init(shift, scale): + """ + Test that Rescale can be initialized properly without errors + and that parameters are set correctly. + """ + module = Rescale(shift=shift, scale=scale) + assert module.shift.item() == shift + assert module.scale.item() == scale + assert module.key_input == KEY.SCALED_ATOMIC_ENERGY + assert module.key_output == KEY.ATOMIC_ENERGY + + +def test_rescale_forward(): + """ + Test that Rescale forward pass correctly applies: + output = input * scale + shift + """ + # Setup + shift, scale = 1.0, 2.0 + module = Rescale(shift=shift, scale=scale) + # Make some fake data + input_data = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) + data: AtomGraphDataType = {KEY.SCALED_ATOMIC_ENERGY: input_data.clone()} + + # Forward + out_data = module(data) + + # Check correctness + expected_output = input_data * scale + shift + assert torch.allclose(out_data[KEY.ATOMIC_ENERGY], expected_output) + + +def test_rescale_get_shift_and_scale(): + """ + Test get_shift() and get_scale() methods in Rescale. + """ + module = Rescale(shift=1.5, scale=3.5) + assert module.get_shift() == pytest.approx(1.5) + assert module.get_scale() == pytest.approx(3.5) + + +################################################################################ +# Tests for SpeciesWiseRescale # +################################################################################ + + +def test_specieswise_rescale_init_float(): + """ + Test SpeciesWiseRescale when both shift and scale are floats + (should expand to same length lists). + """ + module = SpeciesWiseRescale(shift=[1.0, -1.0], scale=2.0) + # Expect a parameter of length = 1 in this scenario, but can differ + # if we raise an error for "Both shift and scale is not a list". + # Usually, you'd specify a known number of species or do from_mappers. + # The code as-is throws ValueError if both are float. Let's do from_mappers: + # We'll do direct init if your code allows it. If not, use from_mappers. + assert module.shift.shape == module.scale.shape + # They must be single-parameter (or expanded) if not from mappers. + + +def test_specieswise_rescale_init_list(): + """ + Test initialization with list-based shift/scale of same length. + """ + shift = [1.0, 2.0, 3.0] + scale = [2.0, 3.0, 4.0] + module = SpeciesWiseRescale(shift=shift, scale=scale) + assert len(module.shift) == 3 + assert len(module.scale) == 3 + assert torch.allclose(module.shift, torch.tensor([1.0, 2.0, 3.0])) + assert torch.allclose(module.scale, torch.tensor([2.0, 3.0, 4.0])) + + +def test_specieswise_rescale_forward(): + """ + Test that SpeciesWiseRescale forward pass applies: + output[i] = input[i]*scale[atom_type[i]] + shift[atom_type[i]] + """ + # Suppose we have two species types: + # 0 -> shift=1, scale=2, 1 -> shift=5, scale=10 + # (we'll pass them as lists in the correct order) + shift = [1.0, 5.0] + scale = [2.0, 10.0] + module = SpeciesWiseRescale( + shift=shift, + scale=scale, + data_key_in='in', + data_key_out='out', + data_key_indices='z', + ) + + # Create mock data + # Suppose we have three atoms: species => [0, 1, 0] + # input => [ [1.], [1.], [3.] ] + data: AtomGraphDataType = { + 'z': torch.tensor([0, 1, 0], dtype=torch.long), + 'in': torch.tensor([[1.0], [1.0], [3.0]], dtype=torch.float), + } + + out = module(data) + # Now let's manually compute expected: + # For atom 0: scale=2, shift=1, input=1 => 1*2+1=3 + # For atom 1: scale=10, shift=5, input=1 => 1*10+5=15 + # For atom 2: scale=2, shift=1, input=3 => 3*2+1=7 + expected = torch.tensor([[3.0], [15.0], [7.0]]) + + assert torch.allclose(out['out'], expected) + + +def test_specieswise_rescale_get_shift_scale(): + """ + Test get_shift() and get_scale() with/without type_map. + """ + shift = [1.0, 2.0] + scale = [3.0, 4.0] + module = SpeciesWiseRescale(shift=shift, scale=scale) + + # Without type_map + # Should return the raw parameter values (list form). + s = module.get_shift() + sc = module.get_scale() + assert s == [1.0, 2.0] + assert sc == [3.0, 4.0] + + # With a type_map (example: atomic_number 1 -> 0, 8 -> 1) + type_map = {1: 0, 8: 1} # hydrogen, oxygen + s_univ = module.get_shift(type_map) + sc_univ = module.get_scale(type_map) + # In this small example with NUM_UNIV_ELEMENT = 2, the _as_univ will produce + # a list of length = NUM_UNIV_ELEMENT. If your real NUM_UNIV_ELEMENT is bigger, + # the rest would be padded with default values. + # For demonstration let's assume it returns [1.0, 2.0]. + # Check at least the known mapped portion: + assert len(s_univ) == NUM_UNIV_ELEMENT + assert len(sc_univ) == NUM_UNIV_ELEMENT + assert s_univ[1] == 1.0 # atomic_number=1 -> idx=0 -> shift=1.0 + assert s_univ[8] == 2.0 + + +################################################################################ +# Tests for ModalWiseRescale # +################################################################################ + + +def test_modalwise_rescale_init(): + """ + Basic sanity check for ModalWiseRescale initialization with + certain shapes. + """ + # Suppose we have 2 modals, 3 species => shift, scale is shape [2,3] + shift = [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + scale = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + module = ModalWiseRescale( + shift=shift, + scale=scale, + use_modal_wise_shift=True, + use_modal_wise_scale=True, + ) + # Check shape + assert module.shift.shape == torch.Size([2, 3]) + assert module.scale.shape == torch.Size([2, 3]) + + +def test_modalwise_rescale_forward(): + """ + Test that the forward pass of ModalWiseRescale matches + output[i] = input[i] * scale[modal_i, atom_i] + shift[modal_i, atom_i] + when both use_modal_wise_{shift,scale} are True. + """ + shift = [[0.0, 10.0], [5.0, 15.0]] # shape [2 (modals), 2 (species)] + scale = [[1.0, 2.0], [10.0, 20.0]] + module = ModalWiseRescale( + shift=shift, + scale=scale, + data_key_in='in', + data_key_out='out', + data_key_modal_indices='modal_idx', + data_key_atom_indices='atom_idx', + use_modal_wise_shift=True, + use_modal_wise_scale=True, + ) + + data: AtomGraphDataType = { + 'in': torch.tensor([[1.0], [1.0], [2.0], [2.0]]), + 'modal_idx': torch.tensor([0, 1], dtype=torch.long), + 'atom_idx': torch.tensor([0, 1, 0, 1], dtype=torch.long), + 'batch': torch.tensor([0, 0, 1, 1], dtype=torch.long), + } + + out = module(data) + # i=0 => modal_idx=0, atom_idx=0 => shift=0.0, scale=1.0 => out=1*1+0=1 + # i=1 => modal_idx=0, atom_idx=1 => shift=10.0, scale=2.0 => out=1*2+10=12 + # i=2 => modal_idx=1, atom_idx=0 => shift=5.0, scale=10.0 => out=2*10+5=25 + # i=3 => modal_idx=1, atom_idx=1 => shift=15.0, scale=20.0 => out=2*20+15=55 + expected = torch.tensor([[1.0], [12.0], [25.0], [55.0]]) + assert torch.allclose(out['out'], expected) + + +def test_modalwise_rescale_get_shift_scale(): + """ + Test get_shift() and get_scale() with type_map and modal_map. + """ + # Setup + shift = [[0.0, 10.0], [5.0, 15.0]] + scale = [[1.0, 2.0], [10.0, 20.0]] + mod = ModalWiseRescale( + shift=shift, + scale=scale, + use_modal_wise_shift=True, + use_modal_wise_scale=True, + ) + + # Suppose we have type_map and modal_map + type_map = {1: 0, 8: 1} # Example: H->0, O->1 + modal_map = {'a': 0, 'b': 1} + + # get_shift, get_scale + s = mod.get_shift(type_map=type_map, modal_map=modal_map) + sc = mod.get_scale(type_map=type_map, modal_map=modal_map) + # Expect dict with keys "ambient", "pressure". + # Example: s["ambient"] = [ shift(0,0), shift(0,1) ] mapped to H,O + # s["pressure"] = [ shift(1,0), shift(1,1) ] + assert isinstance(s, dict) and isinstance(sc, dict) + assert set(s.keys()) == {'a', 'b'} + assert set(sc.keys()) == {'a', 'b'} + + +################################################################################ +# Tests for get_resolved_shift_scale function # +################################################################################ + + +def test_get_resolved_shift_scale_rescale(): + """ + Test get_resolved_shift_scale for a Rescale instance. + """ + from_m = Rescale(shift=2.0, scale=5.0) + shift, scale = get_resolved_shift_scale(from_m) + assert shift == 2.0 + assert scale == 5.0 + + +def test_get_resolved_shift_scale_specieswise(): + """ + Test get_resolved_shift_scale for a SpeciesWiseRescale instance. + """ + shift_list = [1.0, 2.0] + scale_list = [3.0, 4.0] + module = SpeciesWiseRescale(shift=shift_list, scale=scale_list) + type_map = {1: 0, 8: 1} + s, sc = get_resolved_shift_scale(module, type_map=type_map) + # The result should be extended to NUM_UNIV_ELEMENT length in real usage, + # but at least the first few should match shift_list, scale_list mapped. + assert isinstance(s, list) + assert isinstance(sc, list) + # Check mapped values + assert s[1] == shift_list[0] + assert s[8] == shift_list[1] + assert sc[1] == scale_list[0] + assert sc[8] == scale_list[1] + + +def test_get_resolved_shift_scale_modalwise(): + """ + Test get_resolved_shift_scale for a ModalWiseRescale instance. + """ + shift = [[0.0, 10.0], [5.0, 15.0]] + scale = [[1.0, 2.0], [10.0, 20.0]] + mmod = ModalWiseRescale( + shift=shift, + scale=scale, + use_modal_wise_shift=True, + use_modal_wise_scale=True, + ) + type_map = {1: 0, 8: 1} + modal_map = {'a': 0, 'b': 1} + s, sc = get_resolved_shift_scale(mmod, type_map=type_map, modal_map=modal_map) + # We expect dictionaries + assert isinstance(s, dict) and isinstance(sc, dict) + # Keys "a", "pressure" + assert 'a' in s + assert 'b' in s + # Check one example + # s["a"] => [0.0, 10.0] + # sc["a"] => [1.0, 2.0] + assert s['a'][1] == 0.0 + assert s['a'][8] == 10.0 + assert sc['a'][1] == 1.0 + assert sc['a'][8] == 2.0 + + +################################################################################ +# Tests for from_mappers function # +################################################################################ + + +@pytest.mark.parametrize( + 'shift, scale, type_map, expected_shift, expected_scale', + [ + # Both shift and scale are floats -> broadcast to each species + ( + 2.0, + 3.0, + {1: 0, 8: 1}, # e.g., H -> index 0, O -> index 1 + [2.0, 2.0], # broadcast + [3.0, 3.0], + ), + # shift, scale are same-length lists => directly used + ( + [0.5, 0.6], + [1.0, 1.1], + {1: 0, 8: 1}, + [0.5, 0.6], + [1.0, 1.1], + ), + # shift, scale are entire "universal" length (NUM_UNIV_ELEMENT=118), + # but we only map out the subset for the actual species in type_map + ( + [0.1] * NUM_UNIV_ELEMENT, + [1.1] * NUM_UNIV_ELEMENT, + {1: 0, 8: 1}, + [0.1, 0.1], + [1.1, 1.1], + ), + # shift is a list, scale is float => shift is used directly, scale broadcast + ( + [1.0, 2.0], + 5.0, + {6: 0, 14: 1}, # C -> 0, Si -> 1 + [1.0, 2.0], + [5.0, 5.0], + ), + ], +) +def test_specieswise_rescale_from_mappers( + shift, scale, type_map, expected_shift, expected_scale +): + """ + Test SpeciesWiseRescale.from_mappers with various combinations of + shift/scale (float, list, universal list) and a given type_map. + """ + module = SpeciesWiseRescale.from_mappers( # type: ignore + shift=shift, + scale=scale, + type_map=type_map, + ) + # Check that the module's internal shift and scale have the correct shape + # The length must match number of species in type_map + assert module.shift.shape[0] == len(type_map) + assert module.scale.shape[0] == len(type_map) + + # Check that the content matches expected + actual_shift = module.shift.detach().cpu().tolist() + actual_scale = module.scale.detach().cpu().tolist() + + assert pytest.approx(actual_shift) == expected_shift + assert pytest.approx(actual_scale) == expected_scale + + +@pytest.mark.parametrize( + 'shift, scale, use_modal_wise_shift, use_modal_wise_scale, ' + 'type_map, modal_map, expected_shift, expected_scale', + [ + # Example 1: single float for shift/scale, + # broadcast over 2 modals and 2 species + ( + 1.0, + 2.0, + True, # shift depends on modal + True, # scale depends on modal + {1: 0, 8: 1}, + {'modA': 0, 'modB': 1}, + # expect 2D => [2 modals x 2 species] + [[1.0, 1.0], [1.0, 1.0]], + [[2.0, 2.0], [2.0, 2.0]], + ), + # Example 2: shift/scale are universal element-lists => use_modal=False => 1D + ( + [0.5] * NUM_UNIV_ELEMENT, + [1.5] * NUM_UNIV_ELEMENT, + False, # shift is not modal-wise + False, # scale is not modal-wise + {6: 0, 14: 1}, # e.g. C->0, Si->1 + {'modA': 0, 'modB': 1}, + # 1D => length = n_atom_types(=2) + [0.5, 0.5], + [1.5, 1.5], + ), + # Example 3: shift is dict of modals -> each is float + # => broadcast for each species + ( + {'modA': 0.0, 'modB': 2.0}, + {'modA': 1.0, 'modB': 3.0}, + True, + True, + {1: 0, 8: 1}, + {'modA': 0, 'modB': 1}, + # shift => shape [2 modals, 2 species] + [[0.0, 0.0], [2.0, 2.0]], + [[1.0, 1.0], [3.0, 3.0]], + ), + # Example 4: already in "modal-wise + species-wise" shape, direct pass + ( + [[0.0, 10.0], [5.0, 15.0]], + [[1.0, 2.0], [10.0, 20.0]], + True, + True, + {1: 0, 8: 1}, + {'modA': 0, 'modB': 1}, + [[0.0, 10.0], [5.0, 15.0]], + [[1.0, 2.0], [10.0, 20.0]], + ), + # Example 5: shift is a list of floats (one per modal), + # but we want modal-wise => broadcast for each species + ( + [0.0, 10.0], # length=2 => same as #modals + [1.0, 2.0], + True, + True, + {1: 0, 8: 1}, + {'modA': 0, 'modB': 1}, + [[0.0, 0.0], [10.0, 10.0]], + [[1.0, 1.0], [2.0, 2.0]], + ), + ], +) +def test_modalwise_rescale_from_mappers( + shift, + scale, + use_modal_wise_shift, + use_modal_wise_scale, + type_map, + modal_map, + expected_shift, + expected_scale, +): + """ + Test ModalWiseRescale.from_mappers for different shapes of shift/scale, + combined with type_map and modal_map. + """ + + module = ModalWiseRescale.from_mappers( # type: ignore + shift=shift, + scale=scale, + use_modal_wise_shift=use_modal_wise_shift, + use_modal_wise_scale=use_modal_wise_scale, + type_map=type_map, + modal_map=modal_map, + ) + # Check shape of the resulting shift, scale + # If modal-wise, we expect a 2D shape: [n_modals, n_species] + # Otherwise, a 1D shape: [n_species] + if use_modal_wise_shift: + assert module.shift.dim() == 2 + assert module.shift.shape[0] == len(modal_map) + assert module.shift.shape[1] == len(type_map) + else: + assert module.shift.dim() == 1 + assert module.shift.shape[0] == len(type_map) + + # Similarly for scale + if use_modal_wise_scale: + assert module.scale.dim() == 2 + assert module.scale.shape[0] == len(modal_map) + assert module.scale.shape[1] == len(type_map) + else: + assert module.scale.dim() == 1 + assert module.scale.shape[0] == len(type_map) + + # Verify the content matches our expectation + actual_shift = module.shift.detach().cpu().tolist() + actual_scale = module.scale.detach().cpu().tolist() + + assert actual_shift == expected_shift + assert actual_scale == expected_scale diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_train.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_train.py index f758a6cb1df8c990aae3f3ee1c6787127f8fbd32..089fe8212b89f583006e04b0591bceafefd576e4 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_train.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_train.py @@ -1,402 +1,402 @@ -import pathlib - -import ase.io -import numpy as np -import pytest -import torch -from torch_geometric.loader import DataLoader - -import sevenn.train.graph_dataset as graph_ds -from sevenn._const import NUM_UNIV_ELEMENT -from sevenn.error_recorder import ErrorRecorder -from sevenn.logger import Logger -from sevenn.scripts.processing_continue import processing_continue_v2 -from sevenn.scripts.processing_epoch import processing_epoch_v2 -from sevenn.train.dataload import graph_build -from sevenn.train.graph_dataset import from_config as dataset_from_config -from sevenn.train.loss import get_loss_functions_from_config -from sevenn.train.trainer import Trainer -from sevenn.util import ( - chemical_species_preprocess, - get_error_recorder, - pretrained_name_to_path, -) - -cutoff = 4.0 - -data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() - -hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') -cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') -sevennet_0_path = pretrained_name_to_path('7net-0_11July2024') - -known_elements = ['Hf', 'O'] -_elemwise_ref_energy_dct = {72: -17.379337, 8: -34.7499924} - -Logger() # init - - -@pytest.fixture() -def HfO2_atoms(): - atoms = ase.io.read(hfo2_path) - return atoms - - -@pytest.fixture(scope='module') -def HfO2_loader(): - atoms = ase.io.read(hfo2_path, index=':') - assert isinstance(atoms, list) - graphs = graph_build(atoms, cutoff, y_from_calc=True) - return DataLoader(graphs, batch_size=2) - - -@pytest.fixture(scope='module') -def graph_dataset_path(tmp_path_factory): - gd_path = tmp_path_factory.mktemp('gd') - ds = graph_ds.SevenNetGraphDataset( - cutoff=cutoff, root=str(gd_path), files=[hfo2_path], processed_name='tmp.pt' - ) - return ds.processed_paths[0] - - -def get_model_config(): - config = { - 'cutoff': cutoff, - 'channel': 4, - 'radial_basis': { - 'radial_basis_name': 'bessel', - }, - 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, - 'interaction_type': 'nequip', - 'lmax': 2, - 'is_parity': True, - 'num_convolution_layer': 3, - 'weight_nn_hidden_neurons': [64, 64], - 'act_radial': 'silu', - 'act_scalar': {'e': 'silu', 'o': 'tanh'}, - 'act_gate': {'e': 'silu', 'o': 'tanh'}, - 'conv_denominator': 'avg_num_neigh', - 'train_denominator': False, - 'self_connection_type': 'nequip', - 'train_shift_scale': False, - 'irreps_manual': False, - 'lmax_edge': -1, - 'lmax_node': -1, - 'readout_as_fcn': False, - 'use_bias_in_linear': False, - '_normalize_sph': True, - } - config.update(**chemical_species_preprocess(known_elements)) - return config - - -def get_train_config(): - config = { - 'random_seed': 1, - 'epoch': 2, - 'loss': 'mse', - 'loss_param': {}, - 'optimizer': 'adam', - 'optim_param': {}, - 'scheduler': 'exponentiallr', - 'scheduler_param': {'gamma': 0.99}, - 'force_loss_weight': 1.0, - 'stress_loss_weight': 0.1, - 'per_epoch': 1, - 'continue': { - 'checkpoint': False, - 'reset_optimizer': False, - 'reset_scheduler': False, - 'reset_epoch': False, - }, - 'is_train_stress': True, - 'train_shuffle': True, - 'best_metric': 'TotalLoss', - 'error_record': [ - ('Energy', 'RMSE'), - ('Force', 'RMSE'), - ('Stress', 'RMSE'), - ('TotalLoss', 'None'), - ], - 'use_modality': False, - 'use_weight': False, - 'device': 'cpu', - 'is_ddp': False, - } - return config - - -def get_data_config(): - config = { - 'batch_size': 2, - 'shift': 'per_atom_energy_mean', - 'scale': 'force_rms', - 'preprocess_num_cores': 1, - 'data_format_args': {}, - 'load_trainset_path': hfo2_path, - } - return config - - -def get_config(overwrite=None): - cf = {} - cf.update(get_model_config()) - cf.update(get_train_config()) - cf.update(get_data_config()) - if overwrite: - cf.update(overwrite) - return cf - - -def test_processing_continue_v2_7net0(tmp_path): - cp = torch.load(sevennet_0_path, weights_only=False, map_location='cpu') - - cfg = get_config( - { - 'continue': { - 'checkpoint': sevennet_0_path, - 'reset_optimizer': False, - 'reset_scheduler': True, - 'reset_epoch': False, - } - } - ) - shift_ref = cp['model_state_dict']['rescale_atomic_energy.shift'].cpu().numpy() - scale_ref = np.array([1.73] * 89) - conv_denominator_ref = np.array([35.989574] * 5) - - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - state_dicts, epoch = processing_continue_v2(cfg) - assert epoch == 601 - assert np.allclose(np.array(cfg['shift']), shift_ref) - assert np.allclose(np.array(cfg['shift'])[0], -5.062768) - assert np.allclose(np.array(cfg['scale']), scale_ref) - assert np.allclose(np.array(cfg['conv_denominator']), conv_denominator_ref) - assert cfg['_number_of_species'] == 89 - assert cfg['_type_map'][89] == 0 # Ac - assert cfg['_type_map'][40] == 88 # Zr - assert state_dicts[2] is None # scheduler reset - - -@pytest.mark.parametrize( - 'cfg_overwrite,ds_names', - [ - ({}, ['trainset']), - ({'load_myset_path': hfo2_path}, ['trainset', 'myset']), - ], -) -def test_dataset_from_config(cfg_overwrite, ds_names, tmp_path): - cfg = get_config(cfg_overwrite) - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - datasets = dataset_from_config(cfg, tmp_path) - - assert set(ds_names) == set(datasets.keys()) - for ds_name in ds_names: - assert (tmp_path / 'sevenn_data' / f'{ds_name}.pt').is_file() - assert (tmp_path / 'sevenn_data' / f'{ds_name}.yaml').is_file() - - -def test_dataset_from_config_as_it_is_load(graph_dataset_path, tmp_path): - cfg = get_config({'load_trainset_path': graph_dataset_path}) - new_wd = tmp_path / 'tmp_wd' - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - _ = dataset_from_config(cfg, str(new_wd)) - print((tmp_path / 'tmp_wd' / 'sevenn_data')) - assert not (tmp_path / 'tmp_wd' / 'sevenn_data').is_dir() - - -@pytest.mark.parametrize( - 'cfg_overwrite,shift,scale,conv', - [ - ( - {}, - -28.978, - 0.113304, - 25.333333, - ), - ( - { - 'shift': -1.2345678, - }, - -1.234567, - 0.113304, - 25.333333, - ), - ( - { - 'conv_denominator': 'sqrt_avg_num_neigh', - }, - -28.978, - 0.113304, - 25.333333**0.5, - ), - ( - { - 'shift': 'force_rms', - }, - 0.113304, - 0.113304, - 25.333333, - ), - ( - { - 'shift': 'elemwise_reference_energies', - }, - [ - 0.0 - if z not in _elemwise_ref_energy_dct - else _elemwise_ref_energy_dct[z] - for z in range(NUM_UNIV_ELEMENT) - ], - 0.113304, - 25.333333, - ), - ], -) -def test_dataset_from_config_statistics_init( - cfg_overwrite, shift, scale, conv, tmp_path -): - cfg = get_config(cfg_overwrite) - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - _ = dataset_from_config(cfg, tmp_path) - - assert np.allclose(cfg['shift'], shift) - assert np.allclose(cfg['scale'], scale) - assert np.allclose(cfg['conv_denominator'], conv) - - -def test_dataset_from_config_chem_auto(tmp_path): - cfg = get_config( - { - 'chemical_species': 'auto', - '_number_of_species': 'auto', - '_type_map': 'auto', - } - ) - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - _ = dataset_from_config(cfg, tmp_path) - assert cfg['chemical_species'] == ['Hf', 'O'] - assert cfg['_number_of_species'] == 2 - assert cfg['_type_map'] == {72: 0, 8: 1} - - -def test_run_one_epoch(HfO2_loader): - trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) - trainer = Trainer(**trainer_args) - erc = get_error_recorder() - - ref1 = { - 'Energy_RMSE': '28.977758', - 'Force_RMSE': '0.214107', - 'Stress_RMSE': '190.014237', - } - - ref2 = { - 'Energy_RMSE': '28.977878', - 'Force_RMSE': '0.213105', - 'Stress_RMSE': '188.772557', - } - - trainer.run_one_epoch(HfO2_loader, is_train=False, error_recorder=erc) - ret1 = erc.get_dct() - erc.epoch_forward() - - for k in ref1: - assert np.allclose(float(ret1[k]), float(ref1[k])) - - trainer.run_one_epoch(HfO2_loader, is_train=True, error_recorder=erc) - erc.epoch_forward() - - trainer.run_one_epoch(HfO2_loader, is_train=False, error_recorder=erc) - ret2 = erc.get_dct() - erc.epoch_forward() - - for k in ref2: - assert np.allclose(float(ret2[k]), float(ref2[k])) - - -def test_processing_epoch_v2(HfO2_loader, tmp_path): - trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) - trainer = Trainer(**trainer_args) - erc = get_error_recorder() - start_epoch = 10 - total_epoch = 12 - per_epoch = 1 - best_metric = 'Energy_RMSE' - best_metric_loader_key = 'myset' - loaders = {'trainset': HfO2_loader, 'myset': HfO2_loader} - - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - processing_epoch_v2( - config={}, - trainer=trainer, - loaders=loaders, - start_epoch=start_epoch, - error_recorder=erc, - total_epoch=total_epoch, - per_epoch=per_epoch, - best_metric_loader_key=best_metric_loader_key, - best_metric=best_metric, - working_dir=tmp_path, - ) - assert (tmp_path / 'checkpoint_10.pth').is_file() - assert (tmp_path / 'checkpoint_11.pth').is_file() - assert (tmp_path / 'checkpoint_12.pth').is_file() - assert (tmp_path / 'checkpoint_best.pth').is_file() - assert (tmp_path / 'lc.csv').is_file() - with open(tmp_path / 'lc.csv', 'r') as f: - lines = f.readlines() - heads = [ll.strip() for ll in lines[0].split(',')] - assert 'epoch' in heads - assert 'lr' in heads - assert 'trainset_Energy_RMSE' in heads - assert 'myset_Stress_MAE' in heads - lasts = [ll.strip() for ll in lines[-1].split(',')] - assert lasts[0] == '12' - assert lasts[1] == '0.000980' # lr - assert lasts[-2] == '0.087873' # myset Force MAE - - -def test_data_weight(graph_dataset_path, tmp_path): - cfg = get_config( - { - 'load_trainset_path': [{ - 'file_list': [{'file': graph_dataset_path}], - 'data_weight': {'energy': 0.1, 'force': 3.0, 'stress': 1.0}, - }], - 'error_record': [ - ('Energy', 'Loss'), - ('Force', 'Loss'), - ('Stress', 'Loss'), - ('TotalLoss', 'None'), - ], - 'use_weight': True - } - ) - trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) - trainer_args['loss_functions'] = get_loss_functions_from_config(cfg) - trainer = Trainer(**trainer_args) - erc = ErrorRecorder.from_config(cfg, trainer.loss_functions) - - db = graph_ds.from_config(cfg, working_dir=tmp_path)['trainset'] - loader_w_weight = DataLoader(db, batch_size=len(db)) - - trainer.run_one_epoch(loader_w_weight, False, erc) - loss = erc.epoch_forward() - assert np.allclose(loss['Energy_Loss'], 839.7104492 * 0.1) - assert np.allclose(loss['Force_Loss'], 0.0152806 * 3.0) - assert np.allclose(loss['Stress_Loss'], 6017.568847 * 1.0) - - -def _write_empty_checkpoint(): - from sevenn.model_build import build_E3_equivariant_model - - # Function I used to make empty checkpoint, to write the test - cfg = get_config({'shift': 0.0, 'scale': 1.0, 'conv_denominator': 5.0}) - model = build_E3_equivariant_model(cfg) - trainer = Trainer.from_config(model, cfg) # type: ignore - trainer.write_checkpoint('./cp_0.pth', config=cfg, epoch=0) - - -if __name__ == '__main__': - _write_empty_checkpoint() +import pathlib + +import ase.io +import numpy as np +import pytest +import torch +from torch_geometric.loader import DataLoader + +import sevenn.train.graph_dataset as graph_ds +from sevenn._const import NUM_UNIV_ELEMENT +from sevenn.error_recorder import ErrorRecorder +from sevenn.logger import Logger +from sevenn.scripts.processing_continue import processing_continue_v2 +from sevenn.scripts.processing_epoch import processing_epoch_v2 +from sevenn.train.dataload import graph_build +from sevenn.train.graph_dataset import from_config as dataset_from_config +from sevenn.train.loss import get_loss_functions_from_config +from sevenn.train.trainer import Trainer +from sevenn.util import ( + chemical_species_preprocess, + get_error_recorder, + pretrained_name_to_path, +) + +cutoff = 4.0 + +data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() + +hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') +cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') +sevennet_0_path = pretrained_name_to_path('7net-0_11July2024') + +known_elements = ['Hf', 'O'] +_elemwise_ref_energy_dct = {72: -17.379337, 8: -34.7499924} + +Logger() # init + + +@pytest.fixture() +def HfO2_atoms(): + atoms = ase.io.read(hfo2_path) + return atoms + + +@pytest.fixture(scope='module') +def HfO2_loader(): + atoms = ase.io.read(hfo2_path, index=':') + assert isinstance(atoms, list) + graphs = graph_build(atoms, cutoff, y_from_calc=True) + return DataLoader(graphs, batch_size=2) + + +@pytest.fixture(scope='module') +def graph_dataset_path(tmp_path_factory): + gd_path = tmp_path_factory.mktemp('gd') + ds = graph_ds.SevenNetGraphDataset( + cutoff=cutoff, root=str(gd_path), files=[hfo2_path], processed_name='tmp.pt' + ) + return ds.processed_paths[0] + + +def get_model_config(): + config = { + 'cutoff': cutoff, + 'channel': 4, + 'radial_basis': { + 'radial_basis_name': 'bessel', + }, + 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, + 'interaction_type': 'nequip', + 'lmax': 2, + 'is_parity': True, + 'num_convolution_layer': 3, + 'weight_nn_hidden_neurons': [64, 64], + 'act_radial': 'silu', + 'act_scalar': {'e': 'silu', 'o': 'tanh'}, + 'act_gate': {'e': 'silu', 'o': 'tanh'}, + 'conv_denominator': 'avg_num_neigh', + 'train_denominator': False, + 'self_connection_type': 'nequip', + 'train_shift_scale': False, + 'irreps_manual': False, + 'lmax_edge': -1, + 'lmax_node': -1, + 'readout_as_fcn': False, + 'use_bias_in_linear': False, + '_normalize_sph': True, + } + config.update(**chemical_species_preprocess(known_elements)) + return config + + +def get_train_config(): + config = { + 'random_seed': 1, + 'epoch': 2, + 'loss': 'mse', + 'loss_param': {}, + 'optimizer': 'adam', + 'optim_param': {}, + 'scheduler': 'exponentiallr', + 'scheduler_param': {'gamma': 0.99}, + 'force_loss_weight': 1.0, + 'stress_loss_weight': 0.1, + 'per_epoch': 1, + 'continue': { + 'checkpoint': False, + 'reset_optimizer': False, + 'reset_scheduler': False, + 'reset_epoch': False, + }, + 'is_train_stress': True, + 'train_shuffle': True, + 'best_metric': 'TotalLoss', + 'error_record': [ + ('Energy', 'RMSE'), + ('Force', 'RMSE'), + ('Stress', 'RMSE'), + ('TotalLoss', 'None'), + ], + 'use_modality': False, + 'use_weight': False, + 'device': 'cpu', + 'is_ddp': False, + } + return config + + +def get_data_config(): + config = { + 'batch_size': 2, + 'shift': 'per_atom_energy_mean', + 'scale': 'force_rms', + 'preprocess_num_cores': 1, + 'data_format_args': {}, + 'load_trainset_path': hfo2_path, + } + return config + + +def get_config(overwrite=None): + cf = {} + cf.update(get_model_config()) + cf.update(get_train_config()) + cf.update(get_data_config()) + if overwrite: + cf.update(overwrite) + return cf + + +def test_processing_continue_v2_7net0(tmp_path): + cp = torch.load(sevennet_0_path, weights_only=False, map_location='cpu') + + cfg = get_config( + { + 'continue': { + 'checkpoint': sevennet_0_path, + 'reset_optimizer': False, + 'reset_scheduler': True, + 'reset_epoch': False, + } + } + ) + shift_ref = cp['model_state_dict']['rescale_atomic_energy.shift'].cpu().numpy() + scale_ref = np.array([1.73] * 89) + conv_denominator_ref = np.array([35.989574] * 5) + + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + state_dicts, epoch = processing_continue_v2(cfg) + assert epoch == 601 + assert np.allclose(np.array(cfg['shift']), shift_ref) + assert np.allclose(np.array(cfg['shift'])[0], -5.062768) + assert np.allclose(np.array(cfg['scale']), scale_ref) + assert np.allclose(np.array(cfg['conv_denominator']), conv_denominator_ref) + assert cfg['_number_of_species'] == 89 + assert cfg['_type_map'][89] == 0 # Ac + assert cfg['_type_map'][40] == 88 # Zr + assert state_dicts[2] is None # scheduler reset + + +@pytest.mark.parametrize( + 'cfg_overwrite,ds_names', + [ + ({}, ['trainset']), + ({'load_myset_path': hfo2_path}, ['trainset', 'myset']), + ], +) +def test_dataset_from_config(cfg_overwrite, ds_names, tmp_path): + cfg = get_config(cfg_overwrite) + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + datasets = dataset_from_config(cfg, tmp_path) + + assert set(ds_names) == set(datasets.keys()) + for ds_name in ds_names: + assert (tmp_path / 'sevenn_data' / f'{ds_name}.pt').is_file() + assert (tmp_path / 'sevenn_data' / f'{ds_name}.yaml').is_file() + + +def test_dataset_from_config_as_it_is_load(graph_dataset_path, tmp_path): + cfg = get_config({'load_trainset_path': graph_dataset_path}) + new_wd = tmp_path / 'tmp_wd' + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + _ = dataset_from_config(cfg, str(new_wd)) + print((tmp_path / 'tmp_wd' / 'sevenn_data')) + assert not (tmp_path / 'tmp_wd' / 'sevenn_data').is_dir() + + +@pytest.mark.parametrize( + 'cfg_overwrite,shift,scale,conv', + [ + ( + {}, + -28.978, + 0.113304, + 25.333333, + ), + ( + { + 'shift': -1.2345678, + }, + -1.234567, + 0.113304, + 25.333333, + ), + ( + { + 'conv_denominator': 'sqrt_avg_num_neigh', + }, + -28.978, + 0.113304, + 25.333333**0.5, + ), + ( + { + 'shift': 'force_rms', + }, + 0.113304, + 0.113304, + 25.333333, + ), + ( + { + 'shift': 'elemwise_reference_energies', + }, + [ + 0.0 + if z not in _elemwise_ref_energy_dct + else _elemwise_ref_energy_dct[z] + for z in range(NUM_UNIV_ELEMENT) + ], + 0.113304, + 25.333333, + ), + ], +) +def test_dataset_from_config_statistics_init( + cfg_overwrite, shift, scale, conv, tmp_path +): + cfg = get_config(cfg_overwrite) + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + _ = dataset_from_config(cfg, tmp_path) + + assert np.allclose(cfg['shift'], shift) + assert np.allclose(cfg['scale'], scale) + assert np.allclose(cfg['conv_denominator'], conv) + + +def test_dataset_from_config_chem_auto(tmp_path): + cfg = get_config( + { + 'chemical_species': 'auto', + '_number_of_species': 'auto', + '_type_map': 'auto', + } + ) + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + _ = dataset_from_config(cfg, tmp_path) + assert cfg['chemical_species'] == ['Hf', 'O'] + assert cfg['_number_of_species'] == 2 + assert cfg['_type_map'] == {72: 0, 8: 1} + + +def test_run_one_epoch(HfO2_loader): + trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) + trainer = Trainer(**trainer_args) + erc = get_error_recorder() + + ref1 = { + 'Energy_RMSE': '28.977758', + 'Force_RMSE': '0.214107', + 'Stress_RMSE': '190.014237', + } + + ref2 = { + 'Energy_RMSE': '28.977878', + 'Force_RMSE': '0.213105', + 'Stress_RMSE': '188.772557', + } + + trainer.run_one_epoch(HfO2_loader, is_train=False, error_recorder=erc) + ret1 = erc.get_dct() + erc.epoch_forward() + + for k in ref1: + assert np.allclose(float(ret1[k]), float(ref1[k])) + + trainer.run_one_epoch(HfO2_loader, is_train=True, error_recorder=erc) + erc.epoch_forward() + + trainer.run_one_epoch(HfO2_loader, is_train=False, error_recorder=erc) + ret2 = erc.get_dct() + erc.epoch_forward() + + for k in ref2: + assert np.allclose(float(ret2[k]), float(ref2[k])) + + +def test_processing_epoch_v2(HfO2_loader, tmp_path): + trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) + trainer = Trainer(**trainer_args) + erc = get_error_recorder() + start_epoch = 10 + total_epoch = 12 + per_epoch = 1 + best_metric = 'Energy_RMSE' + best_metric_loader_key = 'myset' + loaders = {'trainset': HfO2_loader, 'myset': HfO2_loader} + + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + processing_epoch_v2( + config={}, + trainer=trainer, + loaders=loaders, + start_epoch=start_epoch, + error_recorder=erc, + total_epoch=total_epoch, + per_epoch=per_epoch, + best_metric_loader_key=best_metric_loader_key, + best_metric=best_metric, + working_dir=tmp_path, + ) + assert (tmp_path / 'checkpoint_10.pth').is_file() + assert (tmp_path / 'checkpoint_11.pth').is_file() + assert (tmp_path / 'checkpoint_12.pth').is_file() + assert (tmp_path / 'checkpoint_best.pth').is_file() + assert (tmp_path / 'lc.csv').is_file() + with open(tmp_path / 'lc.csv', 'r') as f: + lines = f.readlines() + heads = [ll.strip() for ll in lines[0].split(',')] + assert 'epoch' in heads + assert 'lr' in heads + assert 'trainset_Energy_RMSE' in heads + assert 'myset_Stress_MAE' in heads + lasts = [ll.strip() for ll in lines[-1].split(',')] + assert lasts[0] == '12' + assert lasts[1] == '0.000980' # lr + assert lasts[-2] == '0.087873' # myset Force MAE + + +def test_data_weight(graph_dataset_path, tmp_path): + cfg = get_config( + { + 'load_trainset_path': [{ + 'file_list': [{'file': graph_dataset_path}], + 'data_weight': {'energy': 0.1, 'force': 3.0, 'stress': 1.0}, + }], + 'error_record': [ + ('Energy', 'Loss'), + ('Force', 'Loss'), + ('Stress', 'Loss'), + ('TotalLoss', 'None'), + ], + 'use_weight': True + } + ) + trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) + trainer_args['loss_functions'] = get_loss_functions_from_config(cfg) + trainer = Trainer(**trainer_args) + erc = ErrorRecorder.from_config(cfg, trainer.loss_functions) + + db = graph_ds.from_config(cfg, working_dir=tmp_path)['trainset'] + loader_w_weight = DataLoader(db, batch_size=len(db)) + + trainer.run_one_epoch(loader_w_weight, False, erc) + loss = erc.epoch_forward() + assert np.allclose(loss['Energy_Loss'], 839.7104492 * 0.1) + assert np.allclose(loss['Force_Loss'], 0.0152806 * 3.0) + assert np.allclose(loss['Stress_Loss'], 6017.568847 * 1.0) + + +def _write_empty_checkpoint(): + from sevenn.model_build import build_E3_equivariant_model + + # Function I used to make empty checkpoint, to write the test + cfg = get_config({'shift': 0.0, 'scale': 1.0, 'conv_denominator': 5.0}) + model = build_E3_equivariant_model(cfg) + trainer = Trainer.from_config(model, cfg) # type: ignore + trainer.write_checkpoint('./cp_0.pth', config=cfg, epoch=0) + + +if __name__ == '__main__': + _write_empty_checkpoint() diff --git a/mace-bench/3rdparty/mace/mace/__init__.py b/mace-bench/3rdparty/mace/mace/__init__.py index 711d144a64f47f63db182cdaeae25d1102dfdaea..490c4c01eb79794be90ab73ee32623ac1ca35db2 100644 --- a/mace-bench/3rdparty/mace/mace/__init__.py +++ b/mace-bench/3rdparty/mace/mace/__init__.py @@ -1,5 +1,5 @@ -import os - -from .__version__ import __version__ - -os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" +import os + +from .__version__ import __version__ + +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" diff --git a/mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index ad042d7871652a5f00a02c0a2bd7d6b845651629..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 7895a9a9967da9436428f55384836e3af3f7eb44..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/__pycache__/__version__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/__pycache__/__version__.cpython-310.pyc deleted file mode 100644 index 80500965fe3f406a376e590e2e344288b00a7ef7..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/__pycache__/__version__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/__pycache__/__version__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/__pycache__/__version__.cpython-313.pyc deleted file mode 100644 index b95961fd2994bd9f45690e25ef385f17125df39a..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/__pycache__/__version__.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/__version__.py b/mace-bench/3rdparty/mace/mace/__version__.py index ee8fdbf147b87b64ce3ca38d0393ccabb5d53cfa..343c940b13c83494dd96428a1d62eb93759c7353 100644 --- a/mace-bench/3rdparty/mace/mace/__version__.py +++ b/mace-bench/3rdparty/mace/mace/__version__.py @@ -1,3 +1,3 @@ -__version__ = "0.3.13" - -__all__ = ["__version__"] +__version__ = "0.3.13" + +__all__ = ["__version__"] diff --git a/mace-bench/3rdparty/mace/mace/calculators/__init__.py b/mace-bench/3rdparty/mace/mace/calculators/__init__.py index 7f5a5597ff7023b954e76c84794216c2e4fee159..8511eb9e327962ff8a464d8c0a7bd52d1bea91b7 100644 --- a/mace-bench/3rdparty/mace/mace/calculators/__init__.py +++ b/mace-bench/3rdparty/mace/mace/calculators/__init__.py @@ -1,11 +1,11 @@ -from .foundations_models import mace_anicc, mace_mp, mace_off -from .lammps_mace import LAMMPS_MACE -from .mace import MACECalculator - -__all__ = [ - "MACECalculator", - "LAMMPS_MACE", - "mace_mp", - "mace_off", - "mace_anicc", -] +from .foundations_models import mace_anicc, mace_mp, mace_off +from .lammps_mace import LAMMPS_MACE +from .mace import MACECalculator + +__all__ = [ + "MACECalculator", + "LAMMPS_MACE", + "mace_mp", + "mace_off", + "mace_anicc", +] diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index fe16df03b666be8e459a7b528b576eb91ad3b9a9..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index ae361bd38b101f99adba5713cb075aba54828ae2..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-310.pyc deleted file mode 100644 index bb99d2344cf448d25f3fcc8dad6b1338b701a038..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-313.pyc deleted file mode 100644 index 2fd1d3854b730e59ae0f75ad30a93cd31a017eda..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-310.pyc deleted file mode 100644 index a6b198ba203b9013730ec7d32f36396b1a006933..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-313.pyc deleted file mode 100644 index 41757dccad7c2a2c8b7d0d2dd9d3d4ccb8d2ace4..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/mace.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/mace.cpython-310.pyc deleted file mode 100644 index 074b740bcd1d5da79c2f9fa954196d6959ce60af..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/mace.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/mace.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/mace.cpython-313.pyc deleted file mode 100644 index aebabb3272caef29c67a20f0255b8730c15b8cf1..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/mace.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/calculators/foundations_models.py b/mace-bench/3rdparty/mace/mace/calculators/foundations_models.py index f4666ea0beb180af949d2c4bd683451555f2b392..75e47c1f817e0b480cb8a010c8e387d9b49f66ea 100644 --- a/mace-bench/3rdparty/mace/mace/calculators/foundations_models.py +++ b/mace-bench/3rdparty/mace/mace/calculators/foundations_models.py @@ -1,339 +1,339 @@ -import os -import urllib.request -from pathlib import Path -from typing import Union - -import torch -from ase import units -from ase.calculators.mixing import SumCalculator - -from .mace import MACECalculator - -module_dir = os.path.dirname(__file__) -local_model_path = os.path.join( - module_dir, "foundations_models/mace-mpa-0-medium.model" -) - - -def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: - """ - Downloads or locates the MACE-MP checkpoint file. - - Args: - model (str, optional): Path to the model or size specification. - Defaults to None which uses the medium model. - - Returns: - str: Path to the downloaded (or cached, if previously loaded) checkpoint file. - """ - if model in (None, "medium-mpa-0") and os.path.isfile(local_model_path): - return local_model_path - - urls = { - "small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", - "medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", - "large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", - "small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model", - "medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model", - "small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", - "medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model", - "large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model", - "medium-0b3": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model", - "medium-mpa-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model", - "medium-omat-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model", - "mace-matpes-pbe-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model", - "mace-matpes-r2scan-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model", - } - - checkpoint_url = ( - urls.get(model, urls["medium-mpa-0"]) - if model - in ( - None, - "small", - "medium", - "large", - "small-0b", - "medium-0b", - "small-0b2", - "medium-0b2", - "large-0b2", - "medium-0b3", - "medium-mpa-0", - "medium-omat-0", - ) - else model - ) - - if checkpoint_url == urls["medium-mpa-0"]: - print( - "Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument" - ) - ASL_checkpoint_urls = { - urls["medium-omat-0"], - urls["mace-matpes-pbe-0"], - urls["mace-matpes-r2scan-0"], - } - if checkpoint_url in ASL_checkpoint_urls: - print( - "Using model under Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use this model you accept the terms of the license." - ) - - cache_dir = os.path.expanduser("~/.cache/mace") - checkpoint_url_name = "".join( - c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" - ) - cached_model_path = f"{cache_dir}/{checkpoint_url_name}" - - if not os.path.isfile(cached_model_path): - os.makedirs(cache_dir, exist_ok=True) - print(f"Downloading MACE model from {checkpoint_url!r}") - _, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path) - if "Content-Type: text/html" in http_msg: - raise RuntimeError( - f"Model download failed, please check the URL {checkpoint_url}" - ) - print(f"Cached MACE model to {cached_model_path}") - - return cached_model_path - - -def mace_mp( - model: Union[str, Path] = None, - device: str = "", - default_dtype: str = "float32", - dispersion: bool = False, - damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"] - dispersion_xc: str = "pbe", - dispersion_cutoff: float = 40.0 * units.Bohr, - return_raw_model: bool = False, - **kwargs, -) -> MACECalculator: - """ - Constructs a MACECalculator with a pretrained model based on the Materials Project (89 elements). - The model is released under the MIT license. See https://github.com/ACEsuit/mace-mp for all models. - Note: - If you are using this function, please cite the relevant paper for the Materials Project, - any paper associated with the MACE model, and also the following: - - MACE-MP by Ilyes Batatia, Philipp Benner, Yuan Chiang, Alin M. Elena, - Dávid P. Kovács, Janosh Riebesell, et al., 2023, arXiv:2401.00096 - - MACE-Universal by Yuan Chiang, 2023, Hugging Face, Revision e5ebd9b, - DOI: 10.57967/hf/1202, URL: https://huggingface.co/cyrusyc/mace-universal - - Matbench Discovery by Janosh Riebesell, Rhys EA Goodall, Philipp Benner, Yuan Chiang, - Alpha A Lee, Anubhav Jain, Kristin A Persson, 2023, arXiv:2308.14920 - - Args: - model (str, optional): Path to the model. Defaults to None which first checks for - a local model and then downloads the default model from figshare. Specify "small", - "medium" or "large" to download a smaller or larger model from figshare. - device (str, optional): Device to use for the model. Defaults to "cuda" if available. - default_dtype (str, optional): Default dtype for the model. Defaults to "float32". - dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False. - damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ). - dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections. - dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections. - return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. - **kwargs: Passed to MACECalculator and TorchDFTD3Calculator. - - Returns: - MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). - """ - try: - if model in ( - None, - "small", - "medium", - "large", - "medium-mpa-0", - "small-0b", - "medium-0b", - "small-0b2", - "medium-0b2", - "medium-0b3", - "large-0b2", - "medium-omat-0", - ) or str(model).startswith("https:"): - model_path = download_mace_mp_checkpoint(model) - print(f"Using Materials Project MACE for MACECalculator with {model_path}") - else: - if not Path(model).exists(): - raise FileNotFoundError(f"{model} not found locally") - model_path = model - except Exception as exc: - raise RuntimeError("Model download failed and no local model found") from exc - - device = device or ("cuda" if torch.cuda.is_available() else "cpu") - if default_dtype == "float64": - print( - "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." - ) - if default_dtype == "float32": - print( - "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." - ) - - if return_raw_model: - return torch.load(model_path, map_location=device) - - mace_calc = MACECalculator( - model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs - ) - - if not dispersion: - return mace_calc - - try: - from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator - except ImportError as exc: - raise RuntimeError( - "Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)" - ) from exc - - print("Using TorchDFTD3Calculator for D3 dispersion corrections") - dtype = torch.float32 if default_dtype == "float32" else torch.float64 - d3_calc = TorchDFTD3Calculator( - device=device, - damping=damping, - dtype=dtype, - xc=dispersion_xc, - cutoff=dispersion_cutoff, - **kwargs, - ) - - return SumCalculator([mace_calc, d3_calc]) - - -def mace_off( - model: Union[str, Path] = None, - device: str = "", - default_dtype: str = "float64", - return_raw_model: bool = False, - **kwargs, -) -> MACECalculator: - """ - Constructs a MACECalculator with a pretrained model based on the MACE-OFF23 models. - The model is released under the ASL license. - Note: - If you are using this function, please cite the relevant paper by Kovacs et.al., arXiv:2312.15211 - - Args: - model (str, optional): Path to the model. Defaults to None which first checks for - a local model and then downloads the default medium model from https://github.com/ACEsuit/mace-off. - Specify "small", "medium" or "large" to download a smaller or larger model. - device (str, optional): Device to use for the model. Defaults to "cuda". - default_dtype (str, optional): Default dtype for the model. Defaults to "float64". - return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. - **kwargs: Passed to MACECalculator. - - Returns: - MACECalculator: trained on the MACE-OFF23 dataset - """ - try: - if model in (None, "small", "medium", "large") or str(model).startswith( - "https:" - ): - urls = dict( - small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", - medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", - large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", - ) - checkpoint_url = ( - urls.get(model, urls["medium"]) - if model in (None, "small", "medium", "large") - else model - ) - cache_dir = os.path.expanduser("~/.cache/mace") - checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] - cached_model_path = f"{cache_dir}/{checkpoint_url_name}" - if not os.path.isfile(cached_model_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - print(f"Downloading MACE model from {checkpoint_url!r}") - print( - "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." - ) - print( - "ASL is based on the Gnu Public License, but does not permit commercial use" - ) - urllib.request.urlretrieve(checkpoint_url, cached_model_path) - print(f"Cached MACE model to {cached_model_path}") - model = cached_model_path - msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" - print(msg) - else: - if not Path(model).exists(): - raise FileNotFoundError(f"{model} not found locally") - except Exception as exc: - raise RuntimeError("Model download failed and no local model found") from exc - - device = device or ("cuda" if torch.cuda.is_available() else "cpu") - - if return_raw_model: - return torch.load(model, map_location=device) - - if default_dtype == "float64": - print( - "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." - ) - if default_dtype == "float32": - print( - "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." - ) - mace_calc = MACECalculator( - model_paths=model, device=device, default_dtype=default_dtype, **kwargs - ) - return mace_calc - - -def mace_anicc( - device: str = "cuda", - model_path: str = None, - return_raw_model: bool = False, -) -> MACECalculator: - """ - Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O). - The model is released under the MIT license. - Note: - If you are using this function, please cite the relevant paper associated with the MACE model, ANI dataset, and also the following: - - "Evaluation of the MACE Force Field Architecture by Dávid Péter Kovács, Ilyes Batatia, Eszter Sára Arany, and Gábor Csányi, The Journal of Chemical Physics, 2023, URL: https://doi.org/10.1063/5.0155322 - """ - if model_path is None: - model_path = os.path.join( - module_dir, "foundations_models/ani500k_large_CC.model" - ) - print( - "Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322" - ) - - if not os.path.exists(model_path): - model_dir = os.path.dirname(model_path) - os.makedirs(model_dir, exist_ok=True) - - # Download the model - print(f"Model not found at {model_path}. Downloading...") - model_url = "https://github.com/ACEsuit/mace/raw/main/mace/calculators/foundations_models/ani500k_large_CC.model" - - try: - - def report_progress(block_num, block_size, total_size): - downloaded = block_num * block_size - percent = min(100, downloaded * 100 / total_size) - if total_size > 0: - print( - f"\rDownloading model: {percent:.1f}% ({downloaded / 1024 / 1024:.1f} MB / {total_size / 1024 / 1024:.1f} MB)", - end="", - ) - - urllib.request.urlretrieve( - model_url, model_path, reporthook=report_progress - ) - print("\nDownload complete!") - - except Exception as e: - raise RuntimeError(f"Failed to download model: {e}") from e - - if return_raw_model: - return torch.load(model_path, map_location=device) - return MACECalculator( - model_paths=model_path, device=device, default_dtype="float64" - ) +import os +import urllib.request +from pathlib import Path +from typing import Union + +import torch +from ase import units +from ase.calculators.mixing import SumCalculator + +from .mace import MACECalculator + +module_dir = os.path.dirname(__file__) +local_model_path = os.path.join( + module_dir, "foundations_models/mace-mpa-0-medium.model" +) + + +def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: + """ + Downloads or locates the MACE-MP checkpoint file. + + Args: + model (str, optional): Path to the model or size specification. + Defaults to None which uses the medium model. + + Returns: + str: Path to the downloaded (or cached, if previously loaded) checkpoint file. + """ + if model in (None, "medium-mpa-0") and os.path.isfile(local_model_path): + return local_model_path + + urls = { + "small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", + "medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", + "large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", + "small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model", + "medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model", + "small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", + "medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model", + "large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model", + "medium-0b3": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model", + "medium-mpa-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model", + "medium-omat-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model", + "mace-matpes-pbe-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model", + "mace-matpes-r2scan-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model", + } + + checkpoint_url = ( + urls.get(model, urls["medium-mpa-0"]) + if model + in ( + None, + "small", + "medium", + "large", + "small-0b", + "medium-0b", + "small-0b2", + "medium-0b2", + "large-0b2", + "medium-0b3", + "medium-mpa-0", + "medium-omat-0", + ) + else model + ) + + if checkpoint_url == urls["medium-mpa-0"]: + print( + "Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument" + ) + ASL_checkpoint_urls = { + urls["medium-omat-0"], + urls["mace-matpes-pbe-0"], + urls["mace-matpes-r2scan-0"], + } + if checkpoint_url in ASL_checkpoint_urls: + print( + "Using model under Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use this model you accept the terms of the license." + ) + + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = "".join( + c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" + ) + cached_model_path = f"{cache_dir}/{checkpoint_url_name}" + + if not os.path.isfile(cached_model_path): + os.makedirs(cache_dir, exist_ok=True) + print(f"Downloading MACE model from {checkpoint_url!r}") + _, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Model download failed, please check the URL {checkpoint_url}" + ) + print(f"Cached MACE model to {cached_model_path}") + + return cached_model_path + + +def mace_mp( + model: Union[str, Path] = None, + device: str = "", + default_dtype: str = "float32", + dispersion: bool = False, + damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"] + dispersion_xc: str = "pbe", + dispersion_cutoff: float = 40.0 * units.Bohr, + return_raw_model: bool = False, + **kwargs, +) -> MACECalculator: + """ + Constructs a MACECalculator with a pretrained model based on the Materials Project (89 elements). + The model is released under the MIT license. See https://github.com/ACEsuit/mace-mp for all models. + Note: + If you are using this function, please cite the relevant paper for the Materials Project, + any paper associated with the MACE model, and also the following: + - MACE-MP by Ilyes Batatia, Philipp Benner, Yuan Chiang, Alin M. Elena, + Dávid P. Kovács, Janosh Riebesell, et al., 2023, arXiv:2401.00096 + - MACE-Universal by Yuan Chiang, 2023, Hugging Face, Revision e5ebd9b, + DOI: 10.57967/hf/1202, URL: https://huggingface.co/cyrusyc/mace-universal + - Matbench Discovery by Janosh Riebesell, Rhys EA Goodall, Philipp Benner, Yuan Chiang, + Alpha A Lee, Anubhav Jain, Kristin A Persson, 2023, arXiv:2308.14920 + + Args: + model (str, optional): Path to the model. Defaults to None which first checks for + a local model and then downloads the default model from figshare. Specify "small", + "medium" or "large" to download a smaller or larger model from figshare. + device (str, optional): Device to use for the model. Defaults to "cuda" if available. + default_dtype (str, optional): Default dtype for the model. Defaults to "float32". + dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False. + damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ). + dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections. + dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections. + return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. + **kwargs: Passed to MACECalculator and TorchDFTD3Calculator. + + Returns: + MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). + """ + try: + if model in ( + None, + "small", + "medium", + "large", + "medium-mpa-0", + "small-0b", + "medium-0b", + "small-0b2", + "medium-0b2", + "medium-0b3", + "large-0b2", + "medium-omat-0", + ) or str(model).startswith("https:"): + model_path = download_mace_mp_checkpoint(model) + print(f"Using Materials Project MACE for MACECalculator with {model_path}") + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") + model_path = model + except Exception as exc: + raise RuntimeError("Model download failed and no local model found") from exc + + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + if default_dtype == "float64": + print( + "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." + ) + if default_dtype == "float32": + print( + "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." + ) + + if return_raw_model: + return torch.load(model_path, map_location=device) + + mace_calc = MACECalculator( + model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs + ) + + if not dispersion: + return mace_calc + + try: + from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator + except ImportError as exc: + raise RuntimeError( + "Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)" + ) from exc + + print("Using TorchDFTD3Calculator for D3 dispersion corrections") + dtype = torch.float32 if default_dtype == "float32" else torch.float64 + d3_calc = TorchDFTD3Calculator( + device=device, + damping=damping, + dtype=dtype, + xc=dispersion_xc, + cutoff=dispersion_cutoff, + **kwargs, + ) + + return SumCalculator([mace_calc, d3_calc]) + + +def mace_off( + model: Union[str, Path] = None, + device: str = "", + default_dtype: str = "float64", + return_raw_model: bool = False, + **kwargs, +) -> MACECalculator: + """ + Constructs a MACECalculator with a pretrained model based on the MACE-OFF23 models. + The model is released under the ASL license. + Note: + If you are using this function, please cite the relevant paper by Kovacs et.al., arXiv:2312.15211 + + Args: + model (str, optional): Path to the model. Defaults to None which first checks for + a local model and then downloads the default medium model from https://github.com/ACEsuit/mace-off. + Specify "small", "medium" or "large" to download a smaller or larger model. + device (str, optional): Device to use for the model. Defaults to "cuda". + default_dtype (str, optional): Default dtype for the model. Defaults to "float64". + return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. + **kwargs: Passed to MACECalculator. + + Returns: + MACECalculator: trained on the MACE-OFF23 dataset + """ + try: + if model in (None, "small", "medium", "large") or str(model).startswith( + "https:" + ): + urls = dict( + small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", + medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", + large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", + ) + checkpoint_url = ( + urls.get(model, urls["medium"]) + if model in (None, "small", "medium", "large") + else model + ) + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] + cached_model_path = f"{cache_dir}/{checkpoint_url_name}" + if not os.path.isfile(cached_model_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + print(f"Downloading MACE model from {checkpoint_url!r}") + print( + "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." + ) + print( + "ASL is based on the Gnu Public License, but does not permit commercial use" + ) + urllib.request.urlretrieve(checkpoint_url, cached_model_path) + print(f"Cached MACE model to {cached_model_path}") + model = cached_model_path + msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" + print(msg) + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") + except Exception as exc: + raise RuntimeError("Model download failed and no local model found") from exc + + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + if return_raw_model: + return torch.load(model, map_location=device) + + if default_dtype == "float64": + print( + "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." + ) + if default_dtype == "float32": + print( + "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." + ) + mace_calc = MACECalculator( + model_paths=model, device=device, default_dtype=default_dtype, **kwargs + ) + return mace_calc + + +def mace_anicc( + device: str = "cuda", + model_path: str = None, + return_raw_model: bool = False, +) -> MACECalculator: + """ + Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O). + The model is released under the MIT license. + Note: + If you are using this function, please cite the relevant paper associated with the MACE model, ANI dataset, and also the following: + - "Evaluation of the MACE Force Field Architecture by Dávid Péter Kovács, Ilyes Batatia, Eszter Sára Arany, and Gábor Csányi, The Journal of Chemical Physics, 2023, URL: https://doi.org/10.1063/5.0155322 + """ + if model_path is None: + model_path = os.path.join( + module_dir, "foundations_models/ani500k_large_CC.model" + ) + print( + "Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322" + ) + + if not os.path.exists(model_path): + model_dir = os.path.dirname(model_path) + os.makedirs(model_dir, exist_ok=True) + + # Download the model + print(f"Model not found at {model_path}. Downloading...") + model_url = "https://github.com/ACEsuit/mace/raw/main/mace/calculators/foundations_models/ani500k_large_CC.model" + + try: + + def report_progress(block_num, block_size, total_size): + downloaded = block_num * block_size + percent = min(100, downloaded * 100 / total_size) + if total_size > 0: + print( + f"\rDownloading model: {percent:.1f}% ({downloaded / 1024 / 1024:.1f} MB / {total_size / 1024 / 1024:.1f} MB)", + end="", + ) + + urllib.request.urlretrieve( + model_url, model_path, reporthook=report_progress + ) + print("\nDownload complete!") + + except Exception as e: + raise RuntimeError(f"Failed to download model: {e}") from e + + if return_raw_model: + return torch.load(model_path, map_location=device) + return MACECalculator( + model_paths=model_path, device=device, default_dtype="float64" + ) diff --git a/mace-bench/3rdparty/mace/mace/calculators/lammps_mace.py b/mace-bench/3rdparty/mace/mace/calculators/lammps_mace.py index 4a1edc00e0605efcf25e38b651117ad7d238a488..4211c37f6a5001d55510dbb110fa7d795ca1e57c 100644 --- a/mace-bench/3rdparty/mace/mace/calculators/lammps_mace.py +++ b/mace-bench/3rdparty/mace/mace/calculators/lammps_mace.py @@ -1,105 +1,105 @@ -from typing import Dict, List, Optional - -import torch -from e3nn.util.jit import compile_mode - -from mace.tools.scatter import scatter_sum - - -@compile_mode("script") -class LAMMPS_MACE(torch.nn.Module): - def __init__(self, model, **kwargs): - super().__init__() - self.model = model - self.register_buffer("atomic_numbers", model.atomic_numbers) - self.register_buffer("r_max", model.r_max) - self.register_buffer("num_interactions", model.num_interactions) - if not hasattr(model, "heads"): - model.heads = [None] - self.register_buffer( - "head", - torch.tensor( - self.model.heads.index(kwargs.get("head", self.model.heads[-1])), - dtype=torch.long, - ).unsqueeze(0), - ) - - for param in self.model.parameters(): - param.requires_grad = False - - def forward( - self, - data: Dict[str, torch.Tensor], - local_or_ghost: torch.Tensor, - compute_virials: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - num_graphs = data["ptr"].numel() - 1 - compute_displacement = False - if compute_virials: - compute_displacement = True - data["head"] = self.head - out = self.model( - data, - training=False, - compute_force=False, - compute_virials=False, - compute_stress=False, - compute_displacement=compute_displacement, - ) - node_energy = out["node_energy"] - if node_energy is None: - return { - "total_energy_local": None, - "node_energy": None, - "forces": None, - "virials": None, - } - positions = data["positions"] - displacement = out["displacement"] - forces: Optional[torch.Tensor] = torch.zeros_like(positions) - virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"]) - # accumulate energies of local atoms - node_energy_local = node_energy * local_or_ghost - total_energy_local = scatter_sum( - src=node_energy_local, index=data["batch"], dim=-1, dim_size=num_graphs - ) - # compute partial forces and (possibly) partial virials - grad_outputs: List[Optional[torch.Tensor]] = [ - torch.ones_like(total_energy_local) - ] - if compute_virials and displacement is not None: - forces, virials = torch.autograd.grad( - outputs=[total_energy_local], - inputs=[positions, displacement], - grad_outputs=grad_outputs, - retain_graph=False, - create_graph=False, - allow_unused=True, - ) - if forces is not None: - forces = -1 * forces - else: - forces = torch.zeros_like(positions) - if virials is not None: - virials = -1 * virials - else: - virials = torch.zeros_like(displacement) - else: - forces = torch.autograd.grad( - outputs=[total_energy_local], - inputs=[positions], - grad_outputs=grad_outputs, - retain_graph=False, - create_graph=False, - allow_unused=True, - )[0] - if forces is not None: - forces = -1 * forces - else: - forces = torch.zeros_like(positions) - return { - "total_energy_local": total_energy_local, - "node_energy": node_energy, - "forces": forces, - "virials": virials, - } +from typing import Dict, List, Optional + +import torch +from e3nn.util.jit import compile_mode + +from mace.tools.scatter import scatter_sum + + +@compile_mode("script") +class LAMMPS_MACE(torch.nn.Module): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + self.register_buffer("atomic_numbers", model.atomic_numbers) + self.register_buffer("r_max", model.r_max) + self.register_buffer("num_interactions", model.num_interactions) + if not hasattr(model, "heads"): + model.heads = [None] + self.register_buffer( + "head", + torch.tensor( + self.model.heads.index(kwargs.get("head", self.model.heads[-1])), + dtype=torch.long, + ).unsqueeze(0), + ) + + for param in self.model.parameters(): + param.requires_grad = False + + def forward( + self, + data: Dict[str, torch.Tensor], + local_or_ghost: torch.Tensor, + compute_virials: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + num_graphs = data["ptr"].numel() - 1 + compute_displacement = False + if compute_virials: + compute_displacement = True + data["head"] = self.head + out = self.model( + data, + training=False, + compute_force=False, + compute_virials=False, + compute_stress=False, + compute_displacement=compute_displacement, + ) + node_energy = out["node_energy"] + if node_energy is None: + return { + "total_energy_local": None, + "node_energy": None, + "forces": None, + "virials": None, + } + positions = data["positions"] + displacement = out["displacement"] + forces: Optional[torch.Tensor] = torch.zeros_like(positions) + virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"]) + # accumulate energies of local atoms + node_energy_local = node_energy * local_or_ghost + total_energy_local = scatter_sum( + src=node_energy_local, index=data["batch"], dim=-1, dim_size=num_graphs + ) + # compute partial forces and (possibly) partial virials + grad_outputs: List[Optional[torch.Tensor]] = [ + torch.ones_like(total_energy_local) + ] + if compute_virials and displacement is not None: + forces, virials = torch.autograd.grad( + outputs=[total_energy_local], + inputs=[positions, displacement], + grad_outputs=grad_outputs, + retain_graph=False, + create_graph=False, + allow_unused=True, + ) + if forces is not None: + forces = -1 * forces + else: + forces = torch.zeros_like(positions) + if virials is not None: + virials = -1 * virials + else: + virials = torch.zeros_like(displacement) + else: + forces = torch.autograd.grad( + outputs=[total_energy_local], + inputs=[positions], + grad_outputs=grad_outputs, + retain_graph=False, + create_graph=False, + allow_unused=True, + )[0] + if forces is not None: + forces = -1 * forces + else: + forces = torch.zeros_like(positions) + return { + "total_energy_local": total_energy_local, + "node_energy": node_energy, + "forces": forces, + "virials": virials, + } diff --git a/mace-bench/3rdparty/mace/mace/calculators/lammps_mliap_mace.py b/mace-bench/3rdparty/mace/mace/calculators/lammps_mliap_mace.py index 036931c12e3c99deea45b01c27a6cb00c56a4c1d..f40e9b813bb8a08b70f0dc0a405d1e2942f87c7d 100644 --- a/mace-bench/3rdparty/mace/mace/calculators/lammps_mliap_mace.py +++ b/mace-bench/3rdparty/mace/mace/calculators/lammps_mliap_mace.py @@ -1,214 +1,214 @@ -import logging -import os -import sys -import time -from contextlib import contextmanager -from typing import Dict, Tuple - -import torch -from ase.data import chemical_symbols -from e3nn.util.jit import compile_mode - -try: - from lammps.mliap.mliap_unified_abc import MLIAPUnified -except ImportError: - - class MLIAPUnified: - def __init__(self): - pass - - -class MACELammpsConfig: - """Configuration settings for MACE-LAMMPS integration.""" - - def __init__(self): - self.debug_time = self._get_env_bool("MACE_TIME", False) - self.debug_profile = self._get_env_bool("MACE_PROFILE", False) - self.profile_start_step = int(os.environ.get("MACE_PROFILE_START", "5")) - self.profile_end_step = int(os.environ.get("MACE_PROFILE_END", "10")) - self.allow_cpu = self._get_env_bool("MACE_ALLOW_CPU", False) - self.force_cpu = self._get_env_bool("MACE_FORCE_CPU", False) - - @staticmethod - def _get_env_bool(var_name: str, default: bool) -> bool: - return os.environ.get(var_name, str(default)).lower() in ( - "true", - "1", - "t", - "yes", - ) - - -@contextmanager -def timer(name: str, enabled: bool = True): - """Context manager for timing code blocks.""" - if not enabled: - yield - return - - start = time.perf_counter() - try: - yield - finally: - elapsed = time.perf_counter() - start - logging.info(f"Timer - {name}: {elapsed*1000:.3f} ms") - - -@compile_mode("script") -class MACEEdgeForcesWrapper(torch.nn.Module): - """Wrapper that adds per-pair force computation to a MACE model.""" - - def __init__(self, model: torch.nn.Module, **kwargs): - super().__init__() - self.model = model - self.register_buffer("atomic_numbers", model.atomic_numbers) - self.register_buffer("r_max", model.r_max) - self.register_buffer("num_interactions", model.num_interactions) - - if not hasattr(model, "heads"): - model.heads = ["Default"] - - head_name = kwargs.get("head", model.heads[-1]) - head_idx = model.heads.index(head_name) - self.register_buffer("head", torch.tensor([head_idx], dtype=torch.long)) - - for p in self.model.parameters(): - p.requires_grad = False - - def forward( - self, data: Dict[str, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute energies and per-pair forces.""" - data["head"] = self.head - - out = self.model( - data, - training=False, - compute_force=False, - compute_virials=False, - compute_stress=False, - compute_displacement=False, - compute_edge_forces=True, - lammps_mliap=True, - ) - - node_energy = out["node_energy"] - pair_forces = out["edge_forces"] - total_energy = out["energy"][0] - - if pair_forces is None: - pair_forces = torch.zeros_like(data["vectors"]) - - return total_energy, node_energy, pair_forces - - -class LAMMPS_MLIAP_MACE(MLIAPUnified): - """MACE integration for LAMMPS using the MLIAP interface.""" - - def __init__(self, model, **kwargs): - super().__init__() - self.config = MACELammpsConfig() - self.model = MACEEdgeForcesWrapper(model, **kwargs) - self.element_types = [chemical_symbols[s] for s in model.atomic_numbers] - self.num_species = len(self.element_types) - self.rcutfac = 0.5 * float(model.r_max) - self.ndescriptors = 1 - self.nparams = 1 - self.dtype = model.r_max.dtype - self.device = "cpu" - self.initialized = False - self.step = 0 - - def _initialize_device(self, data): - using_kokkos = "kokkos" in data.__class__.__module__.lower() - - if using_kokkos and not self.config.force_cpu: - device = torch.as_tensor(data.elems).device - if device.type == "cpu" and not self.config.allow_cpu: - raise ValueError( - "GPU requested but tensor is on CPU. Set MACE_ALLOW_CPU=true to allow CPU computation." - ) - else: - device = torch.device("cpu") - - self.device = device - self.model = self.model.to(device) - logging.info(f"MACE model initialized on device: {device}") - self.initialized = True - - def compute_forces(self, data): - natoms = data.nlocal - ntotal = data.ntotal - nghosts = ntotal - natoms - npairs = data.npairs - species = torch.as_tensor(data.elems, dtype=torch.int64) - - if not self.initialized: - self._initialize_device(data) - - self.step += 1 - self._manage_profiling() - - if natoms == 0 or npairs <= 1: - return - - with timer("total_step", enabled=self.config.debug_time): - with timer("prepare_batch", enabled=self.config.debug_time): - batch = self._prepare_batch(data, natoms, nghosts, species) - - with timer("model_forward", enabled=self.config.debug_time): - _, atom_energies, pair_forces = self.model(batch) - - if self.device.type != "cpu": - torch.cuda.synchronize() - - with timer("update_lammps", enabled=self.config.debug_time): - self._update_lammps_data(data, atom_energies, pair_forces, natoms) - - def _prepare_batch(self, data, natoms, nghosts, species): - """Prepare the input batch for the MACE model.""" - return { - "vectors": torch.as_tensor(data.rij).to(self.dtype).to(self.device), - "node_attrs": torch.nn.functional.one_hot( - species.to(self.device), num_classes=self.num_species - ).to(self.dtype), - "edge_index": torch.stack( - [ - torch.as_tensor(data.pair_j, dtype=torch.int64).to(self.device), - torch.as_tensor(data.pair_i, dtype=torch.int64).to(self.device), - ], - dim=0, - ), - "batch": torch.zeros(natoms, dtype=torch.int64, device=self.device), - "lammps_class": data, - "natoms": (natoms, nghosts), - } - - def _update_lammps_data(self, data, atom_energies, pair_forces, natoms): - """Update LAMMPS data structures with computed energies and forces.""" - if self.dtype == torch.float32: - pair_forces = pair_forces.double() - eatoms = torch.as_tensor(data.eatoms) - eatoms.copy_(atom_energies[:natoms]) - data.energy = torch.sum(atom_energies[:natoms]) - data.update_pair_forces_gpu(pair_forces) - - def _manage_profiling(self): - if not self.config.debug_profile: - return - - if self.step == self.config.profile_start_step: - logging.info(f"Starting CUDA profiler at step {self.step}") - torch.cuda.profiler.start() - - if self.step == self.config.profile_end_step: - logging.info(f"Stopping CUDA profiler at step {self.step}") - torch.cuda.profiler.stop() - logging.info("Profiling complete. Exiting.") - sys.exit() - - def compute_descriptors(self, data): - pass - - def compute_gradients(self, data): - pass +import logging +import os +import sys +import time +from contextlib import contextmanager +from typing import Dict, Tuple + +import torch +from ase.data import chemical_symbols +from e3nn.util.jit import compile_mode + +try: + from lammps.mliap.mliap_unified_abc import MLIAPUnified +except ImportError: + + class MLIAPUnified: + def __init__(self): + pass + + +class MACELammpsConfig: + """Configuration settings for MACE-LAMMPS integration.""" + + def __init__(self): + self.debug_time = self._get_env_bool("MACE_TIME", False) + self.debug_profile = self._get_env_bool("MACE_PROFILE", False) + self.profile_start_step = int(os.environ.get("MACE_PROFILE_START", "5")) + self.profile_end_step = int(os.environ.get("MACE_PROFILE_END", "10")) + self.allow_cpu = self._get_env_bool("MACE_ALLOW_CPU", False) + self.force_cpu = self._get_env_bool("MACE_FORCE_CPU", False) + + @staticmethod + def _get_env_bool(var_name: str, default: bool) -> bool: + return os.environ.get(var_name, str(default)).lower() in ( + "true", + "1", + "t", + "yes", + ) + + +@contextmanager +def timer(name: str, enabled: bool = True): + """Context manager for timing code blocks.""" + if not enabled: + yield + return + + start = time.perf_counter() + try: + yield + finally: + elapsed = time.perf_counter() - start + logging.info(f"Timer - {name}: {elapsed*1000:.3f} ms") + + +@compile_mode("script") +class MACEEdgeForcesWrapper(torch.nn.Module): + """Wrapper that adds per-pair force computation to a MACE model.""" + + def __init__(self, model: torch.nn.Module, **kwargs): + super().__init__() + self.model = model + self.register_buffer("atomic_numbers", model.atomic_numbers) + self.register_buffer("r_max", model.r_max) + self.register_buffer("num_interactions", model.num_interactions) + + if not hasattr(model, "heads"): + model.heads = ["Default"] + + head_name = kwargs.get("head", model.heads[-1]) + head_idx = model.heads.index(head_name) + self.register_buffer("head", torch.tensor([head_idx], dtype=torch.long)) + + for p in self.model.parameters(): + p.requires_grad = False + + def forward( + self, data: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute energies and per-pair forces.""" + data["head"] = self.head + + out = self.model( + data, + training=False, + compute_force=False, + compute_virials=False, + compute_stress=False, + compute_displacement=False, + compute_edge_forces=True, + lammps_mliap=True, + ) + + node_energy = out["node_energy"] + pair_forces = out["edge_forces"] + total_energy = out["energy"][0] + + if pair_forces is None: + pair_forces = torch.zeros_like(data["vectors"]) + + return total_energy, node_energy, pair_forces + + +class LAMMPS_MLIAP_MACE(MLIAPUnified): + """MACE integration for LAMMPS using the MLIAP interface.""" + + def __init__(self, model, **kwargs): + super().__init__() + self.config = MACELammpsConfig() + self.model = MACEEdgeForcesWrapper(model, **kwargs) + self.element_types = [chemical_symbols[s] for s in model.atomic_numbers] + self.num_species = len(self.element_types) + self.rcutfac = 0.5 * float(model.r_max) + self.ndescriptors = 1 + self.nparams = 1 + self.dtype = model.r_max.dtype + self.device = "cpu" + self.initialized = False + self.step = 0 + + def _initialize_device(self, data): + using_kokkos = "kokkos" in data.__class__.__module__.lower() + + if using_kokkos and not self.config.force_cpu: + device = torch.as_tensor(data.elems).device + if device.type == "cpu" and not self.config.allow_cpu: + raise ValueError( + "GPU requested but tensor is on CPU. Set MACE_ALLOW_CPU=true to allow CPU computation." + ) + else: + device = torch.device("cpu") + + self.device = device + self.model = self.model.to(device) + logging.info(f"MACE model initialized on device: {device}") + self.initialized = True + + def compute_forces(self, data): + natoms = data.nlocal + ntotal = data.ntotal + nghosts = ntotal - natoms + npairs = data.npairs + species = torch.as_tensor(data.elems, dtype=torch.int64) + + if not self.initialized: + self._initialize_device(data) + + self.step += 1 + self._manage_profiling() + + if natoms == 0 or npairs <= 1: + return + + with timer("total_step", enabled=self.config.debug_time): + with timer("prepare_batch", enabled=self.config.debug_time): + batch = self._prepare_batch(data, natoms, nghosts, species) + + with timer("model_forward", enabled=self.config.debug_time): + _, atom_energies, pair_forces = self.model(batch) + + if self.device.type != "cpu": + torch.cuda.synchronize() + + with timer("update_lammps", enabled=self.config.debug_time): + self._update_lammps_data(data, atom_energies, pair_forces, natoms) + + def _prepare_batch(self, data, natoms, nghosts, species): + """Prepare the input batch for the MACE model.""" + return { + "vectors": torch.as_tensor(data.rij).to(self.dtype).to(self.device), + "node_attrs": torch.nn.functional.one_hot( + species.to(self.device), num_classes=self.num_species + ).to(self.dtype), + "edge_index": torch.stack( + [ + torch.as_tensor(data.pair_j, dtype=torch.int64).to(self.device), + torch.as_tensor(data.pair_i, dtype=torch.int64).to(self.device), + ], + dim=0, + ), + "batch": torch.zeros(natoms, dtype=torch.int64, device=self.device), + "lammps_class": data, + "natoms": (natoms, nghosts), + } + + def _update_lammps_data(self, data, atom_energies, pair_forces, natoms): + """Update LAMMPS data structures with computed energies and forces.""" + if self.dtype == torch.float32: + pair_forces = pair_forces.double() + eatoms = torch.as_tensor(data.eatoms) + eatoms.copy_(atom_energies[:natoms]) + data.energy = torch.sum(atom_energies[:natoms]) + data.update_pair_forces_gpu(pair_forces) + + def _manage_profiling(self): + if not self.config.debug_profile: + return + + if self.step == self.config.profile_start_step: + logging.info(f"Starting CUDA profiler at step {self.step}") + torch.cuda.profiler.start() + + if self.step == self.config.profile_end_step: + logging.info(f"Stopping CUDA profiler at step {self.step}") + torch.cuda.profiler.stop() + logging.info("Profiling complete. Exiting.") + sys.exit() + + def compute_descriptors(self, data): + pass + + def compute_gradients(self, data): + pass diff --git a/mace-bench/3rdparty/mace/mace/calculators/mace.py b/mace-bench/3rdparty/mace/mace/calculators/mace.py index 31dbeb1263cb5223d410d37130567b82e9c0c75e..794065b72adc6725d591821861c3c022247ceb5e 100644 --- a/mace-bench/3rdparty/mace/mace/calculators/mace.py +++ b/mace-bench/3rdparty/mace/mace/calculators/mace.py @@ -1,705 +1,705 @@ -########################################################################################### -# The ASE Calculator for MACE -# Authors: Ilyes Batatia, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging - -# pylint: disable=wrong-import-position -import os -from glob import glob -from pathlib import Path -from typing import List, Union - -os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" - -import numpy as np -import torch -from ase.calculators.calculator import Calculator, all_changes -from ase.stress import full_3x3_to_voigt_6_stress -from e3nn import o3 - -from mace import data -from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq -from mace.modules.utils import extract_invariant -from mace.tools import torch_geometric, torch_tools, utils -from mace.tools.compile import prepare -from mace.tools.scripts_utils import extract_model -import random -from mace.tools.torch_geometric.batch import Batch - -from mace.tools import ( - atomic_numbers_to_indices, - to_one_hot, -) - -import time - - -def get_model_dtype(model: torch.nn.Module) -> torch.dtype: - """Get the dtype of the model""" - mode_dtype = next(model.parameters()).dtype - if mode_dtype == torch.float64: - return "float64" - if mode_dtype == torch.float32: - return "float32" - raise ValueError(f"Unknown dtype {mode_dtype}") - - -class MACECalculator(Calculator): - """MACE ASE Calculator - args: - model_paths: str, path to model or models if a committee is produced - to make a committee use a wild card notation like mace_*.model - device: str, device to run on (cuda or cpu) - energy_units_to_eV: float, conversion factor from model energy units to eV - length_units_to_A: float, conversion factor from model length units to Angstroms - default_dtype: str, default dtype of model - charges_key: str, Array field of atoms object where atomic charges are stored - model_type: str, type of model to load - Options: [MACE, DipoleMACE, EnergyDipoleMACE] - - Dipoles are returned in units of Debye - """ - - def __init__( - self, - model_paths: Union[list, str, None] = None, - models: Union[List[torch.nn.Module], torch.nn.Module, None] = None, - device: str = "cpu", - energy_units_to_eV: float = 1.0, - length_units_to_A: float = 1.0, - default_dtype="", - charges_key="Qs", - model_type="MACE", - compile_mode=None, - fullgraph=True, - enable_cueq=False, - **kwargs, - ): - Calculator.__init__(self, **kwargs) - self.device = device - self.dtype=None - if enable_cueq: - assert model_type == "MACE", "CuEq only supports MACE models" - compile_mode = None - if "model_path" in kwargs: - deprecation_message = ( - "'model_path' argument is deprecated, please use 'model_paths'" - ) - if model_paths is None: - logging.warning(f"{deprecation_message} in the future.") - model_paths = kwargs["model_path"] - else: - raise ValueError( - f"both 'model_path' and 'model_paths' given, {deprecation_message} only." - ) - - if (model_paths is None) == (models is None): - raise ValueError( - "Exactly one of 'model_paths' or 'models' must be provided" - ) - - self.results = {} - - self.model_type = model_type - self.compute_atomic_stresses = False - - if model_type == "MACE": - self.implemented_properties = [ - "energy", - "free_energy", - "node_energy", - "forces", - "stress", - ] - if kwargs.get("compute_atomic_stresses", False): - self.implemented_properties.extend(["stresses", "virials"]) - self.compute_atomic_stresses = True - elif model_type == "DipoleMACE": - self.implemented_properties = ["dipole"] - elif model_type == "EnergyDipoleMACE": - self.implemented_properties = [ - "energy", - "free_energy", - "node_energy", - "forces", - "stress", - "dipole", - ] - else: - raise ValueError( - f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported" - ) - - if model_paths is not None: - if isinstance(model_paths, str): - # Find all models that satisfy the wildcard (e.g. mace_model_*.pt) - model_paths_glob = glob(model_paths) - - if len(model_paths_glob) == 0: - raise ValueError(f"Couldn't find MACE model files: {model_paths}") - - model_paths = model_paths_glob - elif isinstance(model_paths, Path): - model_paths = [model_paths] - - if len(model_paths) == 0: - raise ValueError("No mace file names supplied") - self.num_models = len(model_paths) - - # Load models from files - self.models = [ - torch.load(f=model_path, map_location=device) - for model_path in model_paths - ] - - elif models is not None: - if not isinstance(models, list): - models = [models] - - if len(models) == 0: - raise ValueError("No models supplied") - - self.models = models - self.num_models = len(models) - - if self.num_models > 1: - print(f"Running committee mace with {self.num_models} models") - - if model_type in ["MACE", "EnergyDipoleMACE"]: - self.implemented_properties.extend( - ["energies", "energy_var", "forces_comm", "stress_var"] - ) - elif model_type == "DipoleMACE": - self.implemented_properties.extend(["dipole_var"]) - - if compile_mode is not None: - print(f"Torch compile is enabled with mode: {compile_mode}") - self.models = [ - torch.compile( - prepare(extract_model)(model=model, map_location=device), - mode=compile_mode, - fullgraph=fullgraph, - ) - for model in self.models - ] - self.use_compile = True - else: - self.use_compile = False - - # Ensure all models are on the same device - for model in self.models: - model.to(device) - - r_maxs = [model.r_max.cpu() for model in self.models] - r_maxs = np.array(r_maxs) - if not np.all(r_maxs == r_maxs[0]): - raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}") - self.r_max = float(r_maxs[0]) - - self.device = torch_tools.init_device(device) - self.energy_units_to_eV = energy_units_to_eV - self.length_units_to_A = length_units_to_A - self.z_table = utils.AtomicNumberTable( - [int(z) for z in self.models[0].atomic_numbers] - ) - self.charges_key = charges_key - - try: - self.available_heads: List[str] = self.models[0].heads # type: ignore - except AttributeError: - self.available_heads = ["Default"] - kwarg_head = kwargs.get("head", None) - if kwarg_head is not None: - self.head = kwarg_head - else: - self.head = [head for head in self.available_heads if head.lower() == "default"] - if len(self.head) == 0: - raise ValueError( - "Head keyword was not provided, and no head in the model is 'default'. " - "Please provide a head keyword to specify the head you want to use. " - f"Available heads are: {self.available_heads}" - ) - self.head = self.head[0] - - print("Using head", self.head, "out of", self.available_heads) - - model_dtype = get_model_dtype(self.models[0]) - if default_dtype == "": - print( - f"No dtype selected, switching to {model_dtype} to match model dtype." - ) - default_dtype = model_dtype - if model_dtype != default_dtype: - print( - f"Default dtype {default_dtype} does not match model dtype {model_dtype}, converting models to {default_dtype}." - ) - if default_dtype == "float64": - self.models = [model.double() for model in self.models] - elif default_dtype == "float32": - self.models = [model.float() for model in self.models] - torch_tools.set_default_dtype(default_dtype) - if enable_cueq: - print("Converting models to CuEq for acceleration") - self.models = [ - run_e3nn_to_cueq(model, device=device).to(device) - for model in self.models - ] - for model in self.models: - for param in model.parameters(): - param.requires_grad = False - - self.dtype = torch.float64 if default_dtype == "float64" else torch.float32 - - self.model_time = 0.0 - self.calc_time = 0.0 - - def _create_result_tensors( - self, model_type: str, num_models: int, num_atoms: int - ) -> dict: - """ - Create tensors to store the results of the committee - :param model_type: str, type of model to load - Options: [MACE, DipoleMACE, EnergyDipoleMACE] - :param num_models: int, number of models in the committee - :return: tuple of torch tensors - """ - dict_of_tensors = {} - if model_type in ["MACE", "EnergyDipoleMACE"]: - energies = torch.zeros(num_models, device=self.device) - node_energy = torch.zeros(num_models, num_atoms, device=self.device) - forces = torch.zeros(num_models, num_atoms, 3, device=self.device) - stress = torch.zeros(num_models, 3, 3, device=self.device) - dict_of_tensors.update( - { - "energies": energies, - "node_energy": node_energy, - "forces": forces, - "stress": stress, - } - ) - if model_type in ["EnergyDipoleMACE", "DipoleMACE"]: - dipole = torch.zeros(num_models, 3, device=self.device) - dict_of_tensors.update({"dipole": dipole}) - return dict_of_tensors - - def _atoms_to_batch(self, atoms): - keyspec = data.KeySpecification( - info_keys={}, arrays_keys={"charges": self.charges_key} - ) - config = data.config_from_atoms( - atoms, key_specification=keyspec, head_name=self.head - ) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config( - config, - z_table=self.z_table, - cutoff=self.r_max, - heads=self.available_heads, - ) - ], - batch_size=1, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)).to(self.device) - return batch - - def _clone_batch(self, batch): - batch_clone = batch.clone() - if self.use_compile: - batch_clone["node_attrs"].requires_grad_(True) - batch_clone["positions"].requires_grad_(True) - return batch_clone - - # pylint: disable=dangerous-default-value - def calculate(self, atoms=None, properties=None, system_changes=all_changes): - """ - Calculate properties. - :param atoms: ase.Atoms object - :param properties: [str], properties to be computed, used by ASE internally - :param system_changes: [str], system changes since last calculation, used by ASE internally - :return: - """ - # call to base-class to set atoms attribute - calc_start_t = time.perf_counter() - Calculator.calculate(self, atoms) - - batch_base = self._atoms_to_batch(atoms) - - if self.model_type in ["MACE", "EnergyDipoleMACE"]: - batch = self._clone_batch(batch_base) - node_heads = batch["head"][batch["batch"]] - num_atoms_arange = torch.arange(batch["positions"].shape[0]) - node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ - num_atoms_arange, node_heads - ] - compute_stress = not self.use_compile - else: - compute_stress = False - - ret_tensors = self._create_result_tensors( - self.model_type, self.num_models, len(atoms) - ) - for i, model in enumerate(self.models): - batch = self._clone_batch(batch_base) - # print(f'@@@File: {__file__}, batch.to_dict(): {batch.to_dict()}') - # set_seed(0) - model_start_t = time.perf_counter() - out = model( - batch.to_dict(), - compute_stress=compute_stress, - training=self.use_compile, - compute_edge_forces=self.compute_atomic_stresses, - compute_atomic_stresses=self.compute_atomic_stresses, - ) - model_end_t = time.perf_counter() - self.model_time += (model_end_t - model_start_t) - # print(f'&&& batch.positions: {batch["positions"]}') - # print(f'&&& batch.stress: {batch["stress"]}') - # print(f'compute_stress: {compute_stress}') - # for k,v in batch.to_dict().items(): - # print(f'&&& batch.to_dict(): {k} {v}') - # print("=======") - # print(f'&&& out["forces"]: {out["forces"]}') - # print(f'&&& training: {self.use_compile}') - # print(f'@@@File: {__file__}, out: {out}') - if self.model_type in ["MACE", "EnergyDipoleMACE"]: - ret_tensors["energies"][i] = out["energy"].detach() - ret_tensors["node_energy"][i] = (out["node_energy"] - node_e0).detach() - ret_tensors["forces"][i] = out["forces"].detach() - if out["stress"] is not None: - ret_tensors["stress"][i] = out["stress"].detach() - if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: - ret_tensors["dipole"][i] = out["dipole"].detach() - if self.model_type in ["MACE"]: - if out["atomic_stresses"] is not None: - ret_tensors.setdefault("atomic_stresses", []).append( - out["atomic_stresses"].detach() - ) - if out["atomic_virials"] is not None: - ret_tensors.setdefault("atomic_virials", []).append( - out["atomic_virials"].detach() - ) - - self.results = {} - if self.model_type in ["MACE", "EnergyDipoleMACE"]: - self.results["energy"] = ( - torch.mean(ret_tensors["energies"], dim=0).cpu().item() - * self.energy_units_to_eV - ) - self.results["free_energy"] = self.results["energy"] - self.results["node_energy"] = ( - torch.mean(ret_tensors["node_energy"], dim=0).cpu().numpy() - ) - self.results["forces"] = ( - torch.mean(ret_tensors["forces"], dim=0).cpu().numpy() - * self.energy_units_to_eV - / self.length_units_to_A - ) - if self.num_models > 1: - self.results["energies"] = ( - ret_tensors["energies"].cpu().numpy() * self.energy_units_to_eV - ) - self.results["energy_var"] = ( - torch.var(ret_tensors["energies"], dim=0, unbiased=False) - .cpu() - .item() - * self.energy_units_to_eV - ) - self.results["forces_comm"] = ( - ret_tensors["forces"].cpu().numpy() - * self.energy_units_to_eV - / self.length_units_to_A - ) - if out["stress"] is not None: - self.results["stress"] = full_3x3_to_voigt_6_stress( - torch.mean(ret_tensors["stress"], dim=0).cpu().numpy() - * self.energy_units_to_eV - / self.length_units_to_A**3 - ) - if self.num_models > 1: - self.results["stress_var"] = full_3x3_to_voigt_6_stress( - torch.var(ret_tensors["stress"], dim=0, unbiased=False) - .cpu() - .numpy() - * self.energy_units_to_eV - / self.length_units_to_A**3 - ) - if "atomic_stresses" in ret_tensors: - self.results["stresses"] = ( - torch.mean(torch.stack(ret_tensors["atomic_stresses"]), dim=0) - .cpu() - .numpy() - * self.energy_units_to_eV - / self.length_units_to_A**3 - ) - if "atomic_virials" in ret_tensors: - self.results["virials"] = ( - torch.mean(torch.stack(ret_tensors["atomic_virials"]), dim=0) - .cpu() - .numpy() - * self.energy_units_to_eV - ) - if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: - self.results["dipole"] = ( - torch.mean(ret_tensors["dipole"], dim=0).cpu().numpy() - ) - if self.num_models > 1: - self.results["dipole_var"] = ( - torch.var(ret_tensors["dipole"], dim=0, unbiased=False) - .cpu() - .numpy() - ) - - calc_end_t = time.perf_counter() - self.calc_time += (calc_end_t - calc_start_t) - - def get_hessian(self, atoms=None): - if atoms is None and self.atoms is None: - raise ValueError("atoms not set") - if atoms is None: - atoms = self.atoms - if self.model_type != "MACE": - raise NotImplementedError("Only implemented for MACE models") - batch = self._atoms_to_batch(atoms) - hessians = [ - model( - self._clone_batch(batch).to_dict(), - compute_hessian=True, - compute_stress=False, - training=self.use_compile, - )["hessian"] - for model in self.models - ] - hessians = [hessian.detach().cpu().numpy() for hessian in hessians] - if self.num_models == 1: - return hessians[0] - return hessians - - def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): - """Extracts the descriptors from MACE model. - :param atoms: ase.Atoms object - :param invariants_only: bool, if True only the invariant descriptors are returned - :param num_layers: int, number of layers to extract descriptors from, if -1 all layers are used - :return: np.ndarray (num_atoms, num_interactions, invariant_features) of invariant descriptors if num_models is 1 or list[np.ndarray] otherwise - """ - if atoms is None and self.atoms is None: - raise ValueError("atoms not set") - if atoms is None: - atoms = self.atoms - if self.model_type != "MACE": - raise NotImplementedError("Only implemented for MACE models") - num_interactions = int(self.models[0].num_interactions) - if num_layers == -1: - num_layers = num_interactions - batch = self._atoms_to_batch(atoms) - descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] - - irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out)) - l_max = irreps_out.lmax - num_invariant_features = irreps_out.dim // (l_max + 1) ** 2 - per_layer_features = [irreps_out.dim for _ in range(num_interactions)] - per_layer_features[-1] = ( - num_invariant_features # Equivariant features not created for the last layer - ) - - if invariants_only: - descriptors = [ - extract_invariant( - descriptor, - num_layers=num_layers, - num_features=num_invariant_features, - l_max=l_max, - ) - for descriptor in descriptors - ] - to_keep = np.sum(per_layer_features[:num_layers]) - descriptors = [ - descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors - ] - - if self.num_models == 1: - return descriptors[0] - return descriptors - - - def predict(self, atoms_list, compute_stress=False): - predictions = {'energy': [], 'forces': []} - - configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads - ) - for config in configs - ], - batch_size=len(atoms_list), - shuffle=False, - drop_last=False, - ) - - # get the first batch of data_loader - batch_base = next(iter(data_loader)).to(self.device) - - # calculate node_e0 - # batch = self._clone_batch(batch_base) - # node_heads = batch["head"][batch["batch"]] - # num_atoms_arange = torch.arange(batch["positions"].shape[0]) - # node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ - # num_atoms_arange, node_heads - # ] - - # set_seed(0) - out = self.models[0]( - batch_base.to_dict(), - compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS? - training=self.use_compile, - ) - # print(f'&&& batch.positions: {batch["positions"]}') - # print(f'&&& batch.stress: {batch["stress"]}') - # print(f'&&& batch.to_dict(): {k} {v}') - # print("=======") - # print(f'&&& out["forces"]: {out["forces"]}') - # print(f'&&& training: {self.use_compile}') - predictions["energy"] = out["energy"].unsqueeze(-1).detach() - predictions["forces"] = out["forces"].detach() - if compute_stress: - predictions["stress"] = out["stress"].detach() - - # print(f'&&& predictions["forces"] in predict: {predictions["forces"]}') - - return predictions - - def fast_predict(self, gbatch, compute_stress=False): - gbatch.pos = gbatch.pos.to(self.dtype) - gbatch.cell = gbatch.cell.to(self.dtype) - - predictions = {'energy': [], 'forces': []} - batch_base = self.convert_batch(gbatch) - out = self.models[0]( - batch_base.to_dict(), - compute_stress=compute_stress, - training=self.use_compile, - ) - predictions["energy"] = out["energy"].unsqueeze(-1).detach().to(torch.float64) - predictions["forces"] = out["forces"].detach().to(torch.float64) - if compute_stress: - predictions["stress"] = out["stress"].detach().to(torch.float64) - - return predictions - - - def convert_batch(self, gbatch): - from batchopt import radius_graph_pbc_cuda - # edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_mem_effi( - # from batchopt.pbc_graph_legacy import radius_graph_pbc - # edge_indices, cell_offsets, num_neighbors = radius_graph_pbc( - edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_cuda( - gbatch, - radius=4.5, - max_num_neighbors_threshold=float('inf'), - pbc=[True, True, True], - dtype=self.dtype - ) - - tmp = edge_indices[0].clone() - edge_indices[0] = edge_indices[1] - edge_indices[1] = tmp - - # Create a one-hot matrix with number of columns equal to max atomic number + 1 - indices = atomic_numbers_to_indices(gbatch["atomic_numbers"].to("cpu"), z_table=self.z_table) - one_hot = to_one_hot( - torch.tensor(indices, dtype=torch.long).unsqueeze(-1), - num_classes=len(self.z_table), - ).to(self.device) - - cbatch = Batch( - positions = gbatch["pos"].clone(), - cell = gbatch["cell"].view(-1, 3), - batch = gbatch["batch"], - ptr = gbatch["ptr"], - edge_index = edge_indices, - unit_shifts = cell_offsets, - node_attrs = one_hot, - ) - - return cbatch - - - def predict_debug(self, atoms_list, gbatch, compute_stress=False): - predictions = {'energy': [], 'forces': []} - - configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads - ) - for config in configs - ], - batch_size=len(atoms_list), - shuffle=False, - drop_last=False, - ) - - # get the first batch of data_loader - # batch_base = next(iter(data_loader)).to(self.device) - batch_base_tmp = next(iter(data_loader)).to(self.device) - batch2 = self.convert_batch(gbatch) - batch_base = Batch( - # positions = batch_base_tmp["positions"], - positions = batch2["positions"], - # node_attrs = batch_base_tmp["node_attrs"], - node_attrs = batch2["node_attrs"], - # cell = batch_base_tmp["cell"], - cell = batch2["cell"], - edge_index = batch2["edge_index"], - unit_shifts = batch2["unit_shifts"], - # batch = batch_base_tmp["batch"], - batch = batch2["batch"], - # ptr = batch_base_tmp["ptr"], - ptr = batch2["ptr"], - ) - - torch.set_printoptions(threshold=float('inf')) - - # print(f'batch2["edge_index"]: {batch2["edge_index"]}') - # print(f'batch2["unit_shifts"]: {batch2["unit_shifts"]}') - # print(f'batch_base_tmp["edge_index"]: {batch_base_tmp["edge_index"]}') - # print(f'batch_base_tmp["unit_shifts"]: {batch_base_tmp["unit_shifts"]}') - - # calculate node_e0 - # batch = self._clone_batch(batch_base) - # node_heads = batch["head"][batch["batch"]] - # num_atoms_arange = torch.arange(batch["positions"].shape[0]) - # node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ - # num_atoms_arange, node_heads - # ] - - # set_seed(0) - out = self.models[0]( - batch_base.to_dict(), - compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS? - training=self.use_compile, - ) - # print(f'&&& batch.positions: {batch["positions"]}') - # print(f'&&& batch.cell: {batch["cell"]}') - # print(f'&&& batch.stress: {batch["stress"]}') - # for k,v in batch.to_dict().items(): - # print(f'&&& batch.to_dict(): {k} {v}') - # print("=======") - # print(f'&&& out["forces"]: {out["forces"]}') - # print(f'&&& training: {self.use_compile}') - predictions["energy"] = out["energy"].unsqueeze(-1).detach() - predictions["forces"] = out["forces"].detach() - if compute_stress: - predictions["stress"] = out["stress"].detach() - - # print(f'&&& predictions["forces"] in predict: {predictions["forces"]}') - +########################################################################################### +# The ASE Calculator for MACE +# Authors: Ilyes Batatia, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging + +# pylint: disable=wrong-import-position +import os +from glob import glob +from pathlib import Path +from typing import List, Union + +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" + +import numpy as np +import torch +from ase.calculators.calculator import Calculator, all_changes +from ase.stress import full_3x3_to_voigt_6_stress +from e3nn import o3 + +from mace import data +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq +from mace.modules.utils import extract_invariant +from mace.tools import torch_geometric, torch_tools, utils +from mace.tools.compile import prepare +from mace.tools.scripts_utils import extract_model +import random +from mace.tools.torch_geometric.batch import Batch + +from mace.tools import ( + atomic_numbers_to_indices, + to_one_hot, +) + +import time + + +def get_model_dtype(model: torch.nn.Module) -> torch.dtype: + """Get the dtype of the model""" + mode_dtype = next(model.parameters()).dtype + if mode_dtype == torch.float64: + return "float64" + if mode_dtype == torch.float32: + return "float32" + raise ValueError(f"Unknown dtype {mode_dtype}") + + +class MACECalculator(Calculator): + """MACE ASE Calculator + args: + model_paths: str, path to model or models if a committee is produced + to make a committee use a wild card notation like mace_*.model + device: str, device to run on (cuda or cpu) + energy_units_to_eV: float, conversion factor from model energy units to eV + length_units_to_A: float, conversion factor from model length units to Angstroms + default_dtype: str, default dtype of model + charges_key: str, Array field of atoms object where atomic charges are stored + model_type: str, type of model to load + Options: [MACE, DipoleMACE, EnergyDipoleMACE] + + Dipoles are returned in units of Debye + """ + + def __init__( + self, + model_paths: Union[list, str, None] = None, + models: Union[List[torch.nn.Module], torch.nn.Module, None] = None, + device: str = "cpu", + energy_units_to_eV: float = 1.0, + length_units_to_A: float = 1.0, + default_dtype="", + charges_key="Qs", + model_type="MACE", + compile_mode=None, + fullgraph=True, + enable_cueq=False, + **kwargs, + ): + Calculator.__init__(self, **kwargs) + self.device = device + self.dtype=None + if enable_cueq: + assert model_type == "MACE", "CuEq only supports MACE models" + compile_mode = None + if "model_path" in kwargs: + deprecation_message = ( + "'model_path' argument is deprecated, please use 'model_paths'" + ) + if model_paths is None: + logging.warning(f"{deprecation_message} in the future.") + model_paths = kwargs["model_path"] + else: + raise ValueError( + f"both 'model_path' and 'model_paths' given, {deprecation_message} only." + ) + + if (model_paths is None) == (models is None): + raise ValueError( + "Exactly one of 'model_paths' or 'models' must be provided" + ) + + self.results = {} + + self.model_type = model_type + self.compute_atomic_stresses = False + + if model_type == "MACE": + self.implemented_properties = [ + "energy", + "free_energy", + "node_energy", + "forces", + "stress", + ] + if kwargs.get("compute_atomic_stresses", False): + self.implemented_properties.extend(["stresses", "virials"]) + self.compute_atomic_stresses = True + elif model_type == "DipoleMACE": + self.implemented_properties = ["dipole"] + elif model_type == "EnergyDipoleMACE": + self.implemented_properties = [ + "energy", + "free_energy", + "node_energy", + "forces", + "stress", + "dipole", + ] + else: + raise ValueError( + f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported" + ) + + if model_paths is not None: + if isinstance(model_paths, str): + # Find all models that satisfy the wildcard (e.g. mace_model_*.pt) + model_paths_glob = glob(model_paths) + + if len(model_paths_glob) == 0: + raise ValueError(f"Couldn't find MACE model files: {model_paths}") + + model_paths = model_paths_glob + elif isinstance(model_paths, Path): + model_paths = [model_paths] + + if len(model_paths) == 0: + raise ValueError("No mace file names supplied") + self.num_models = len(model_paths) + + # Load models from files + self.models = [ + torch.load(f=model_path, map_location=device) + for model_path in model_paths + ] + + elif models is not None: + if not isinstance(models, list): + models = [models] + + if len(models) == 0: + raise ValueError("No models supplied") + + self.models = models + self.num_models = len(models) + + if self.num_models > 1: + print(f"Running committee mace with {self.num_models} models") + + if model_type in ["MACE", "EnergyDipoleMACE"]: + self.implemented_properties.extend( + ["energies", "energy_var", "forces_comm", "stress_var"] + ) + elif model_type == "DipoleMACE": + self.implemented_properties.extend(["dipole_var"]) + + if compile_mode is not None: + print(f"Torch compile is enabled with mode: {compile_mode}") + self.models = [ + torch.compile( + prepare(extract_model)(model=model, map_location=device), + mode=compile_mode, + fullgraph=fullgraph, + ) + for model in self.models + ] + self.use_compile = True + else: + self.use_compile = False + + # Ensure all models are on the same device + for model in self.models: + model.to(device) + + r_maxs = [model.r_max.cpu() for model in self.models] + r_maxs = np.array(r_maxs) + if not np.all(r_maxs == r_maxs[0]): + raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}") + self.r_max = float(r_maxs[0]) + + self.device = torch_tools.init_device(device) + self.energy_units_to_eV = energy_units_to_eV + self.length_units_to_A = length_units_to_A + self.z_table = utils.AtomicNumberTable( + [int(z) for z in self.models[0].atomic_numbers] + ) + self.charges_key = charges_key + + try: + self.available_heads: List[str] = self.models[0].heads # type: ignore + except AttributeError: + self.available_heads = ["Default"] + kwarg_head = kwargs.get("head", None) + if kwarg_head is not None: + self.head = kwarg_head + else: + self.head = [head for head in self.available_heads if head.lower() == "default"] + if len(self.head) == 0: + raise ValueError( + "Head keyword was not provided, and no head in the model is 'default'. " + "Please provide a head keyword to specify the head you want to use. " + f"Available heads are: {self.available_heads}" + ) + self.head = self.head[0] + + print("Using head", self.head, "out of", self.available_heads) + + model_dtype = get_model_dtype(self.models[0]) + if default_dtype == "": + print( + f"No dtype selected, switching to {model_dtype} to match model dtype." + ) + default_dtype = model_dtype + if model_dtype != default_dtype: + print( + f"Default dtype {default_dtype} does not match model dtype {model_dtype}, converting models to {default_dtype}." + ) + if default_dtype == "float64": + self.models = [model.double() for model in self.models] + elif default_dtype == "float32": + self.models = [model.float() for model in self.models] + torch_tools.set_default_dtype(default_dtype) + if enable_cueq: + print("Converting models to CuEq for acceleration") + self.models = [ + run_e3nn_to_cueq(model, device=device).to(device) + for model in self.models + ] + for model in self.models: + for param in model.parameters(): + param.requires_grad = False + + self.dtype = torch.float64 if default_dtype == "float64" else torch.float32 + + self.model_time = 0.0 + self.calc_time = 0.0 + + def _create_result_tensors( + self, model_type: str, num_models: int, num_atoms: int + ) -> dict: + """ + Create tensors to store the results of the committee + :param model_type: str, type of model to load + Options: [MACE, DipoleMACE, EnergyDipoleMACE] + :param num_models: int, number of models in the committee + :return: tuple of torch tensors + """ + dict_of_tensors = {} + if model_type in ["MACE", "EnergyDipoleMACE"]: + energies = torch.zeros(num_models, device=self.device) + node_energy = torch.zeros(num_models, num_atoms, device=self.device) + forces = torch.zeros(num_models, num_atoms, 3, device=self.device) + stress = torch.zeros(num_models, 3, 3, device=self.device) + dict_of_tensors.update( + { + "energies": energies, + "node_energy": node_energy, + "forces": forces, + "stress": stress, + } + ) + if model_type in ["EnergyDipoleMACE", "DipoleMACE"]: + dipole = torch.zeros(num_models, 3, device=self.device) + dict_of_tensors.update({"dipole": dipole}) + return dict_of_tensors + + def _atoms_to_batch(self, atoms): + keyspec = data.KeySpecification( + info_keys={}, arrays_keys={"charges": self.charges_key} + ) + config = data.config_from_atoms( + atoms, key_specification=keyspec, head_name=self.head + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config( + config, + z_table=self.z_table, + cutoff=self.r_max, + heads=self.available_heads, + ) + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)).to(self.device) + return batch + + def _clone_batch(self, batch): + batch_clone = batch.clone() + if self.use_compile: + batch_clone["node_attrs"].requires_grad_(True) + batch_clone["positions"].requires_grad_(True) + return batch_clone + + # pylint: disable=dangerous-default-value + def calculate(self, atoms=None, properties=None, system_changes=all_changes): + """ + Calculate properties. + :param atoms: ase.Atoms object + :param properties: [str], properties to be computed, used by ASE internally + :param system_changes: [str], system changes since last calculation, used by ASE internally + :return: + """ + # call to base-class to set atoms attribute + calc_start_t = time.perf_counter() + Calculator.calculate(self, atoms) + + batch_base = self._atoms_to_batch(atoms) + + if self.model_type in ["MACE", "EnergyDipoleMACE"]: + batch = self._clone_batch(batch_base) + node_heads = batch["head"][batch["batch"]] + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ + num_atoms_arange, node_heads + ] + compute_stress = not self.use_compile + else: + compute_stress = False + + ret_tensors = self._create_result_tensors( + self.model_type, self.num_models, len(atoms) + ) + for i, model in enumerate(self.models): + batch = self._clone_batch(batch_base) + # print(f'@@@File: {__file__}, batch.to_dict(): {batch.to_dict()}') + # set_seed(0) + model_start_t = time.perf_counter() + out = model( + batch.to_dict(), + compute_stress=compute_stress, + training=self.use_compile, + compute_edge_forces=self.compute_atomic_stresses, + compute_atomic_stresses=self.compute_atomic_stresses, + ) + model_end_t = time.perf_counter() + self.model_time += (model_end_t - model_start_t) + # print(f'&&& batch.positions: {batch["positions"]}') + # print(f'&&& batch.stress: {batch["stress"]}') + # print(f'compute_stress: {compute_stress}') + # for k,v in batch.to_dict().items(): + # print(f'&&& batch.to_dict(): {k} {v}') + # print("=======") + # print(f'&&& out["forces"]: {out["forces"]}') + # print(f'&&& training: {self.use_compile}') + # print(f'@@@File: {__file__}, out: {out}') + if self.model_type in ["MACE", "EnergyDipoleMACE"]: + ret_tensors["energies"][i] = out["energy"].detach() + ret_tensors["node_energy"][i] = (out["node_energy"] - node_e0).detach() + ret_tensors["forces"][i] = out["forces"].detach() + if out["stress"] is not None: + ret_tensors["stress"][i] = out["stress"].detach() + if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: + ret_tensors["dipole"][i] = out["dipole"].detach() + if self.model_type in ["MACE"]: + if out["atomic_stresses"] is not None: + ret_tensors.setdefault("atomic_stresses", []).append( + out["atomic_stresses"].detach() + ) + if out["atomic_virials"] is not None: + ret_tensors.setdefault("atomic_virials", []).append( + out["atomic_virials"].detach() + ) + + self.results = {} + if self.model_type in ["MACE", "EnergyDipoleMACE"]: + self.results["energy"] = ( + torch.mean(ret_tensors["energies"], dim=0).cpu().item() + * self.energy_units_to_eV + ) + self.results["free_energy"] = self.results["energy"] + self.results["node_energy"] = ( + torch.mean(ret_tensors["node_energy"], dim=0).cpu().numpy() + ) + self.results["forces"] = ( + torch.mean(ret_tensors["forces"], dim=0).cpu().numpy() + * self.energy_units_to_eV + / self.length_units_to_A + ) + if self.num_models > 1: + self.results["energies"] = ( + ret_tensors["energies"].cpu().numpy() * self.energy_units_to_eV + ) + self.results["energy_var"] = ( + torch.var(ret_tensors["energies"], dim=0, unbiased=False) + .cpu() + .item() + * self.energy_units_to_eV + ) + self.results["forces_comm"] = ( + ret_tensors["forces"].cpu().numpy() + * self.energy_units_to_eV + / self.length_units_to_A + ) + if out["stress"] is not None: + self.results["stress"] = full_3x3_to_voigt_6_stress( + torch.mean(ret_tensors["stress"], dim=0).cpu().numpy() + * self.energy_units_to_eV + / self.length_units_to_A**3 + ) + if self.num_models > 1: + self.results["stress_var"] = full_3x3_to_voigt_6_stress( + torch.var(ret_tensors["stress"], dim=0, unbiased=False) + .cpu() + .numpy() + * self.energy_units_to_eV + / self.length_units_to_A**3 + ) + if "atomic_stresses" in ret_tensors: + self.results["stresses"] = ( + torch.mean(torch.stack(ret_tensors["atomic_stresses"]), dim=0) + .cpu() + .numpy() + * self.energy_units_to_eV + / self.length_units_to_A**3 + ) + if "atomic_virials" in ret_tensors: + self.results["virials"] = ( + torch.mean(torch.stack(ret_tensors["atomic_virials"]), dim=0) + .cpu() + .numpy() + * self.energy_units_to_eV + ) + if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: + self.results["dipole"] = ( + torch.mean(ret_tensors["dipole"], dim=0).cpu().numpy() + ) + if self.num_models > 1: + self.results["dipole_var"] = ( + torch.var(ret_tensors["dipole"], dim=0, unbiased=False) + .cpu() + .numpy() + ) + + calc_end_t = time.perf_counter() + self.calc_time += (calc_end_t - calc_start_t) + + def get_hessian(self, atoms=None): + if atoms is None and self.atoms is None: + raise ValueError("atoms not set") + if atoms is None: + atoms = self.atoms + if self.model_type != "MACE": + raise NotImplementedError("Only implemented for MACE models") + batch = self._atoms_to_batch(atoms) + hessians = [ + model( + self._clone_batch(batch).to_dict(), + compute_hessian=True, + compute_stress=False, + training=self.use_compile, + )["hessian"] + for model in self.models + ] + hessians = [hessian.detach().cpu().numpy() for hessian in hessians] + if self.num_models == 1: + return hessians[0] + return hessians + + def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): + """Extracts the descriptors from MACE model. + :param atoms: ase.Atoms object + :param invariants_only: bool, if True only the invariant descriptors are returned + :param num_layers: int, number of layers to extract descriptors from, if -1 all layers are used + :return: np.ndarray (num_atoms, num_interactions, invariant_features) of invariant descriptors if num_models is 1 or list[np.ndarray] otherwise + """ + if atoms is None and self.atoms is None: + raise ValueError("atoms not set") + if atoms is None: + atoms = self.atoms + if self.model_type != "MACE": + raise NotImplementedError("Only implemented for MACE models") + num_interactions = int(self.models[0].num_interactions) + if num_layers == -1: + num_layers = num_interactions + batch = self._atoms_to_batch(atoms) + descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] + + irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out)) + l_max = irreps_out.lmax + num_invariant_features = irreps_out.dim // (l_max + 1) ** 2 + per_layer_features = [irreps_out.dim for _ in range(num_interactions)] + per_layer_features[-1] = ( + num_invariant_features # Equivariant features not created for the last layer + ) + + if invariants_only: + descriptors = [ + extract_invariant( + descriptor, + num_layers=num_layers, + num_features=num_invariant_features, + l_max=l_max, + ) + for descriptor in descriptors + ] + to_keep = np.sum(per_layer_features[:num_layers]) + descriptors = [ + descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors + ] + + if self.num_models == 1: + return descriptors[0] + return descriptors + + + def predict(self, atoms_list, compute_stress=False): + predictions = {'energy': [], 'forces': []} + + configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config( + config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads + ) + for config in configs + ], + batch_size=len(atoms_list), + shuffle=False, + drop_last=False, + ) + + # get the first batch of data_loader + batch_base = next(iter(data_loader)).to(self.device) + + # calculate node_e0 + # batch = self._clone_batch(batch_base) + # node_heads = batch["head"][batch["batch"]] + # num_atoms_arange = torch.arange(batch["positions"].shape[0]) + # node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ + # num_atoms_arange, node_heads + # ] + + # set_seed(0) + out = self.models[0]( + batch_base.to_dict(), + compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS? + training=self.use_compile, + ) + # print(f'&&& batch.positions: {batch["positions"]}') + # print(f'&&& batch.stress: {batch["stress"]}') + # print(f'&&& batch.to_dict(): {k} {v}') + # print("=======") + # print(f'&&& out["forces"]: {out["forces"]}') + # print(f'&&& training: {self.use_compile}') + predictions["energy"] = out["energy"].unsqueeze(-1).detach() + predictions["forces"] = out["forces"].detach() + if compute_stress: + predictions["stress"] = out["stress"].detach() + + # print(f'&&& predictions["forces"] in predict: {predictions["forces"]}') + + return predictions + + def fast_predict(self, gbatch, compute_stress=False): + gbatch.pos = gbatch.pos.to(self.dtype) + gbatch.cell = gbatch.cell.to(self.dtype) + + predictions = {'energy': [], 'forces': []} + batch_base = self.convert_batch(gbatch) + out = self.models[0]( + batch_base.to_dict(), + compute_stress=compute_stress, + training=self.use_compile, + ) + predictions["energy"] = out["energy"].unsqueeze(-1).detach().to(torch.float64) + predictions["forces"] = out["forces"].detach().to(torch.float64) + if compute_stress: + predictions["stress"] = out["stress"].detach().to(torch.float64) + + return predictions + + + def convert_batch(self, gbatch): + from batchopt import radius_graph_pbc_cuda + # edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_mem_effi( + # from batchopt.pbc_graph_legacy import radius_graph_pbc + # edge_indices, cell_offsets, num_neighbors = radius_graph_pbc( + edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_cuda( + gbatch, + radius=4.5, + max_num_neighbors_threshold=float('inf'), + pbc=[True, True, True], + dtype=self.dtype + ) + + tmp = edge_indices[0].clone() + edge_indices[0] = edge_indices[1] + edge_indices[1] = tmp + + # Create a one-hot matrix with number of columns equal to max atomic number + 1 + indices = atomic_numbers_to_indices(gbatch["atomic_numbers"].to("cpu"), z_table=self.z_table) + one_hot = to_one_hot( + torch.tensor(indices, dtype=torch.long).unsqueeze(-1), + num_classes=len(self.z_table), + ).to(self.device) + + cbatch = Batch( + positions = gbatch["pos"].clone(), + cell = gbatch["cell"].view(-1, 3), + batch = gbatch["batch"], + ptr = gbatch["ptr"], + edge_index = edge_indices, + unit_shifts = cell_offsets, + node_attrs = one_hot, + ) + + return cbatch + + + def predict_debug(self, atoms_list, gbatch, compute_stress=False): + predictions = {'energy': [], 'forces': []} + + configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config( + config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads + ) + for config in configs + ], + batch_size=len(atoms_list), + shuffle=False, + drop_last=False, + ) + + # get the first batch of data_loader + # batch_base = next(iter(data_loader)).to(self.device) + batch_base_tmp = next(iter(data_loader)).to(self.device) + batch2 = self.convert_batch(gbatch) + batch_base = Batch( + # positions = batch_base_tmp["positions"], + positions = batch2["positions"], + # node_attrs = batch_base_tmp["node_attrs"], + node_attrs = batch2["node_attrs"], + # cell = batch_base_tmp["cell"], + cell = batch2["cell"], + edge_index = batch2["edge_index"], + unit_shifts = batch2["unit_shifts"], + # batch = batch_base_tmp["batch"], + batch = batch2["batch"], + # ptr = batch_base_tmp["ptr"], + ptr = batch2["ptr"], + ) + + torch.set_printoptions(threshold=float('inf')) + + # print(f'batch2["edge_index"]: {batch2["edge_index"]}') + # print(f'batch2["unit_shifts"]: {batch2["unit_shifts"]}') + # print(f'batch_base_tmp["edge_index"]: {batch_base_tmp["edge_index"]}') + # print(f'batch_base_tmp["unit_shifts"]: {batch_base_tmp["unit_shifts"]}') + + # calculate node_e0 + # batch = self._clone_batch(batch_base) + # node_heads = batch["head"][batch["batch"]] + # num_atoms_arange = torch.arange(batch["positions"].shape[0]) + # node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ + # num_atoms_arange, node_heads + # ] + + # set_seed(0) + out = self.models[0]( + batch_base.to_dict(), + compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS? + training=self.use_compile, + ) + # print(f'&&& batch.positions: {batch["positions"]}') + # print(f'&&& batch.cell: {batch["cell"]}') + # print(f'&&& batch.stress: {batch["stress"]}') + # for k,v in batch.to_dict().items(): + # print(f'&&& batch.to_dict(): {k} {v}') + # print("=======") + # print(f'&&& out["forces"]: {out["forces"]}') + # print(f'&&& training: {self.use_compile}') + predictions["energy"] = out["energy"].unsqueeze(-1).detach() + predictions["forces"] = out["forces"].detach() + if compute_stress: + predictions["stress"] = out["stress"].detach() + + # print(f'&&& predictions["forces"] in predict: {predictions["forces"]}') + return predictions \ No newline at end of file diff --git a/mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index a48194f211a8fc749c31e217d020af909043411b..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 3d6064764790e90fb44232574b7b8ba5c19458ad..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/cli/__pycache__/convert_e3nn_cueq.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/cli/__pycache__/convert_e3nn_cueq.cpython-310.pyc deleted file mode 100644 index d710de37d4ef25fe426f3525f6ce5faf22ec3066..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/cli/__pycache__/convert_e3nn_cueq.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/cli/__pycache__/convert_e3nn_cueq.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/cli/__pycache__/convert_e3nn_cueq.cpython-313.pyc deleted file mode 100644 index 06e888dc612518d3052866b3b69aa22b3bfac5c7..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/cli/__pycache__/convert_e3nn_cueq.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/cli/__pycache__/visualise_train.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/cli/__pycache__/visualise_train.cpython-310.pyc deleted file mode 100644 index 48639ccc5b834099115dce8a9a6079d5895a8e97..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/cli/__pycache__/visualise_train.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/cli/__pycache__/visualise_train.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/cli/__pycache__/visualise_train.cpython-313.pyc deleted file mode 100644 index 4129d51e378cd77cf24b48bc1f3caef16603fa1b..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/cli/__pycache__/visualise_train.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/cli/active_learning_md.py b/mace-bench/3rdparty/mace/mace/cli/active_learning_md.py index 90fc3a8978253e5e4ddc9937808effd80f56f063..9cf4f4a8817bccda23bc411dee3b9fe809681ac5 100644 --- a/mace-bench/3rdparty/mace/mace/cli/active_learning_md.py +++ b/mace-bench/3rdparty/mace/mace/cli/active_learning_md.py @@ -1,193 +1,193 @@ -"""Demonstrates active learning molecular dynamics with constant temperature.""" - -import argparse -import os -import time - -import ase.io -import numpy as np -from ase import units -from ase.md.langevin import Langevin -from ase.md.velocitydistribution import MaxwellBoltzmannDistribution - -from mace.calculators.mace import MACECalculator - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("--config", help="path to XYZ configurations", required=True) - parser.add_argument( - "--config_index", help="index of configuration", type=int, default=-1 - ) - parser.add_argument( - "--error_threshold", help="error threshold", type=float, default=0.1 - ) - parser.add_argument("--temperature_K", help="temperature", type=float, default=300) - parser.add_argument("--friction", help="friction", type=float, default=0.01) - parser.add_argument("--timestep", help="timestep", type=float, default=1) - parser.add_argument("--nsteps", help="number of steps", type=int, default=1000) - parser.add_argument( - "--nprint", help="number of steps between prints", type=int, default=10 - ) - parser.add_argument( - "--nsave", help="number of steps between saves", type=int, default=10 - ) - parser.add_argument( - "--ncheckerror", help="number of steps between saves", type=int, default=10 - ) - - parser.add_argument( - "--model", - help="path to model. Use wildcards to add multiple models as committee eg " - "(`mace_*.model` to load mace_1.model, mace_2.model) ", - required=True, - ) - parser.add_argument("--output", help="output path", required=True) - parser.add_argument( - "--device", - help="select device", - type=str, - choices=["cpu", "cuda"], - default="cuda", - ) - parser.add_argument( - "--default_dtype", - help="set default dtype", - type=str, - choices=["float32", "float64"], - default="float64", - ) - parser.add_argument( - "--compute_stress", - help="compute stress", - action="store_true", - default=False, - ) - parser.add_argument( - "--info_prefix", - help="prefix for energy, forces and stress keys", - type=str, - default="MACE_", - ) - return parser.parse_args() - - -def printenergy(dyn, start_time=None): # store a reference to atoms in the definition. - """Function to print the potential, kinetic and total energy.""" - a = dyn.atoms - epot = a.get_potential_energy() / len(a) - ekin = a.get_kinetic_energy() / len(a) - if start_time is None: - elapsed_time = 0 - else: - elapsed_time = time.time() - start_time - forces_var = np.var(a.calc.results["forces_comm"], axis=0) - print( - "%.1fs: Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) " # pylint: disable=C0209 - "Etot = %.3feV t=%.1ffs Eerr = %.3feV Ferr = %.3feV/A" - % ( - elapsed_time, - epot, - ekin, - ekin / (1.5 * units.kB), - epot + ekin, - dyn.get_time() / units.fs, - a.calc.results["energy_var"], - np.max(np.linalg.norm(forces_var, axis=1)), - ), - flush=True, - ) - - -def save_config(dyn, fname): - atomsi = dyn.atoms - ens = atomsi.get_potential_energy() - frcs = atomsi.get_forces() - - atomsi.info.update( - { - "mlff_energy": ens, - "time": np.round(dyn.get_time() / units.fs, 5), - "mlff_energy_var": atomsi.calc.results["energy_var"], - } - ) - atomsi.arrays.update( - { - "mlff_forces": frcs, - "mlff_forces_var": np.var(atomsi.calc.results["forces_comm"], axis=0), - } - ) - - ase.io.write(fname, atomsi, append=True) - - -def stop_error(dyn, threshold, reg=0.2): - atomsi = dyn.atoms - force_var = np.var(atomsi.calc.results["forces_comm"], axis=0) - force = atomsi.get_forces() - ferr = np.sqrt(np.sum(force_var, axis=1)) - ferr_rel = ferr / (np.linalg.norm(force, axis=1) + reg) - - if np.max(ferr_rel) > threshold: - print( - "Error too large {:.3}. Stopping t={:.2} fs.".format( # pylint: disable=C0209 - np.max(ferr_rel), dyn.get_time() / units.fs - ), - flush=True, - ) - dyn.max_steps = 0 - - -def main() -> None: - args = parse_args() - run(args) - - -def run(args: argparse.Namespace) -> None: - mace_fname = args.model - atoms_fname = args.config - atoms_index = args.config_index - - mace_calc = MACECalculator( - model_paths=mace_fname, - device=args.device, - default_dtype=args.default_dtype, - ) - - NSTEPS = args.nsteps - - if os.path.exists(args.output): - print("Trajectory exists. Continuing from last step.") - atoms = ase.io.read(args.output, index=-1) - len_save = len(ase.io.read(args.output, ":")) - print("Last step: ", atoms.info["time"], "Number of configs: ", len_save) - NSTEPS -= len_save * args.nsave - else: - atoms = ase.io.read(atoms_fname, index=atoms_index) - MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature_K) - - atoms.calc = mace_calc - - # We want to run MD with constant energy using the Langevin algorithm - # with a time step of 5 fs, the temperature T and the friction - # coefficient to 0.02 atomic units. - dyn = Langevin( - atoms=atoms, - timestep=args.timestep * units.fs, - temperature_K=args.temperature_K, - friction=args.friction, - ) - - dyn.attach(printenergy, interval=args.nsave, dyn=dyn, start_time=time.time()) - dyn.attach(save_config, interval=args.nsave, dyn=dyn, fname=args.output) - dyn.attach( - stop_error, interval=args.ncheckerror, dyn=dyn, threshold=args.error_threshold - ) - # Now run the dynamics - dyn.run(NSTEPS) - - -if __name__ == "__main__": - main() +"""Demonstrates active learning molecular dynamics with constant temperature.""" + +import argparse +import os +import time + +import ase.io +import numpy as np +from ase import units +from ase.md.langevin import Langevin +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution + +from mace.calculators.mace import MACECalculator + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--config", help="path to XYZ configurations", required=True) + parser.add_argument( + "--config_index", help="index of configuration", type=int, default=-1 + ) + parser.add_argument( + "--error_threshold", help="error threshold", type=float, default=0.1 + ) + parser.add_argument("--temperature_K", help="temperature", type=float, default=300) + parser.add_argument("--friction", help="friction", type=float, default=0.01) + parser.add_argument("--timestep", help="timestep", type=float, default=1) + parser.add_argument("--nsteps", help="number of steps", type=int, default=1000) + parser.add_argument( + "--nprint", help="number of steps between prints", type=int, default=10 + ) + parser.add_argument( + "--nsave", help="number of steps between saves", type=int, default=10 + ) + parser.add_argument( + "--ncheckerror", help="number of steps between saves", type=int, default=10 + ) + + parser.add_argument( + "--model", + help="path to model. Use wildcards to add multiple models as committee eg " + "(`mace_*.model` to load mace_1.model, mace_2.model) ", + required=True, + ) + parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda"], + default="cuda", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--compute_stress", + help="compute stress", + action="store_true", + default=False, + ) + parser.add_argument( + "--info_prefix", + help="prefix for energy, forces and stress keys", + type=str, + default="MACE_", + ) + return parser.parse_args() + + +def printenergy(dyn, start_time=None): # store a reference to atoms in the definition. + """Function to print the potential, kinetic and total energy.""" + a = dyn.atoms + epot = a.get_potential_energy() / len(a) + ekin = a.get_kinetic_energy() / len(a) + if start_time is None: + elapsed_time = 0 + else: + elapsed_time = time.time() - start_time + forces_var = np.var(a.calc.results["forces_comm"], axis=0) + print( + "%.1fs: Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) " # pylint: disable=C0209 + "Etot = %.3feV t=%.1ffs Eerr = %.3feV Ferr = %.3feV/A" + % ( + elapsed_time, + epot, + ekin, + ekin / (1.5 * units.kB), + epot + ekin, + dyn.get_time() / units.fs, + a.calc.results["energy_var"], + np.max(np.linalg.norm(forces_var, axis=1)), + ), + flush=True, + ) + + +def save_config(dyn, fname): + atomsi = dyn.atoms + ens = atomsi.get_potential_energy() + frcs = atomsi.get_forces() + + atomsi.info.update( + { + "mlff_energy": ens, + "time": np.round(dyn.get_time() / units.fs, 5), + "mlff_energy_var": atomsi.calc.results["energy_var"], + } + ) + atomsi.arrays.update( + { + "mlff_forces": frcs, + "mlff_forces_var": np.var(atomsi.calc.results["forces_comm"], axis=0), + } + ) + + ase.io.write(fname, atomsi, append=True) + + +def stop_error(dyn, threshold, reg=0.2): + atomsi = dyn.atoms + force_var = np.var(atomsi.calc.results["forces_comm"], axis=0) + force = atomsi.get_forces() + ferr = np.sqrt(np.sum(force_var, axis=1)) + ferr_rel = ferr / (np.linalg.norm(force, axis=1) + reg) + + if np.max(ferr_rel) > threshold: + print( + "Error too large {:.3}. Stopping t={:.2} fs.".format( # pylint: disable=C0209 + np.max(ferr_rel), dyn.get_time() / units.fs + ), + flush=True, + ) + dyn.max_steps = 0 + + +def main() -> None: + args = parse_args() + run(args) + + +def run(args: argparse.Namespace) -> None: + mace_fname = args.model + atoms_fname = args.config + atoms_index = args.config_index + + mace_calc = MACECalculator( + model_paths=mace_fname, + device=args.device, + default_dtype=args.default_dtype, + ) + + NSTEPS = args.nsteps + + if os.path.exists(args.output): + print("Trajectory exists. Continuing from last step.") + atoms = ase.io.read(args.output, index=-1) + len_save = len(ase.io.read(args.output, ":")) + print("Last step: ", atoms.info["time"], "Number of configs: ", len_save) + NSTEPS -= len_save * args.nsave + else: + atoms = ase.io.read(atoms_fname, index=atoms_index) + MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature_K) + + atoms.calc = mace_calc + + # We want to run MD with constant energy using the Langevin algorithm + # with a time step of 5 fs, the temperature T and the friction + # coefficient to 0.02 atomic units. + dyn = Langevin( + atoms=atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature_K, + friction=args.friction, + ) + + dyn.attach(printenergy, interval=args.nsave, dyn=dyn, start_time=time.time()) + dyn.attach(save_config, interval=args.nsave, dyn=dyn, fname=args.output) + dyn.attach( + stop_error, interval=args.ncheckerror, dyn=dyn, threshold=args.error_threshold + ) + # Now run the dynamics + dyn.run(NSTEPS) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/convert_cueq_e3nn.py b/mace-bench/3rdparty/mace/mace/cli/convert_cueq_e3nn.py index 5aa2056ad112c345cdd74e291b692f7bfd705711..c2399cae1d7032534baabc225a04f2215c08aee2 100644 --- a/mace-bench/3rdparty/mace/mace/cli/convert_cueq_e3nn.py +++ b/mace-bench/3rdparty/mace/mace/cli/convert_cueq_e3nn.py @@ -1,208 +1,208 @@ -import argparse -import logging -import os -from typing import Dict, List, Tuple - -import torch - -from mace.tools.scripts_utils import extract_config_mace_model - - -def get_transfer_keys(num_layers: int) -> List[str]: - """Get list of keys that need to be transferred""" - return [ - "node_embedding.linear.weight", - "radial_embedding.bessel_fn.bessel_weights", - "atomic_energies_fn.atomic_energies", - "readouts.0.linear.weight", - *[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)], - "scale_shift.scale", - "scale_shift.shift", - *[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)], - ] + [ - s - for j in range(num_layers) - for s in [ - f"interactions.{j}.linear_up.weight", - *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], - f"interactions.{j}.linear.weight", - f"interactions.{j}.skip_tp.weight", - f"products.{j}.linear.weight", - ] - ] - - -def get_kmax_pairs( - max_L: int, correlation: int, num_layers: int -) -> List[Tuple[int, int]]: - """Determine kmax pairs based on max_L and correlation""" - if correlation == 2: - raise NotImplementedError("Correlation 2 not supported yet") - if correlation == 3: - kmax_pairs = [[i, max_L] for i in range(num_layers - 1)] - kmax_pairs = kmax_pairs + [[num_layers - 1, 0]] - return kmax_pairs - raise NotImplementedError(f"Correlation {correlation} not supported") - - -def transfer_symmetric_contractions( - source_dict: Dict[str, torch.Tensor], - target_dict: Dict[str, torch.Tensor], - max_L: int, - correlation: int, - num_layers: int, -): - """Transfer symmetric contraction weights from CuEq to E3nn format""" - kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers) - - for i, kmax in kmax_pairs: - # Get the combined weight tensor from source - wm = source_dict[f"products.{i}.symmetric_contractions.weight"] - - # Get split sizes based on target dimensions - splits = [] - for k in range(kmax + 1): - for suffix in ["_max", ".0", ".1"]: - key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}" - target_shape = target_dict[key].shape - splits.append(target_shape[1]) - - # Split the weights using the calculated sizes - weights_split = torch.split(wm, splits, dim=1) - - # Assign back to target dictionary - idx = 0 - for k in range(kmax + 1): - target_dict[ - f"products.{i}.symmetric_contractions.contractions.{k}.weights_max" - ] = weights_split[idx] - target_dict[ - f"products.{i}.symmetric_contractions.contractions.{k}.weights.0" - ] = weights_split[idx + 1] - target_dict[ - f"products.{i}.symmetric_contractions.contractions.{k}.weights.1" - ] = weights_split[idx + 2] - idx += 3 - - -def transfer_weights( - source_model: torch.nn.Module, - target_model: torch.nn.Module, - max_L: int, - correlation: int, - num_layers: int, -): - """Transfer weights from CuEq to E3nn format""" - # Get state dicts - source_dict = source_model.state_dict() - target_dict = target_model.state_dict() - - # Transfer main weights - transfer_keys = get_transfer_keys(num_layers) - for key in transfer_keys: - if key in source_dict: # Check if key exists - target_dict[key] = source_dict[key] - else: - logging.warning(f"Key {key} not found in source model") - - # Transfer symmetric contractions - transfer_symmetric_contractions( - source_dict, target_dict, max_L, correlation, num_layers - ) - - # Unsqueeze linear and skip_tp layers - for key in source_dict.keys(): - if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key: - target_dict[key] = target_dict[key].squeeze(0) - - # Transfer remaining matching keys - transferred_keys = set(transfer_keys) - remaining_keys = ( - set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys - ) - remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} - - if remaining_keys: - for key in remaining_keys: - if source_dict[key].shape == target_dict[key].shape: - logging.debug(f"Transferring additional key: {key}") - target_dict[key] = source_dict[key] - else: - logging.warning( - f"Shape mismatch for key {key}: " - f"source {source_dict[key].shape} vs target {target_dict[key].shape}" - ) - - # Transfer avg_num_neighbors - for i in range(2): - target_model.interactions[i].avg_num_neighbors = source_model.interactions[ - i - ].avg_num_neighbors - - # Load state dict into target model - target_model.load_state_dict(target_dict) - - -def run(input_model, output_model="_e3nn.model", device="cpu", return_model=True): - - # Load CuEq model - if isinstance(input_model, str): - source_model = torch.load(input_model, map_location=device) - else: - source_model = input_model - default_dtype = next(source_model.parameters()).dtype - torch.set_default_dtype(default_dtype) - # Extract configuration - config = extract_config_mace_model(source_model) - - # Get max_L and correlation from config - max_L = config["hidden_irreps"].lmax - correlation = config["correlation"] - - # Remove CuEq config - config.pop("cueq_config", None) - - # Create new model without CuEq config - logging.info("Creating new model without CuEq settings") - target_model = source_model.__class__(**config) - - # Transfer weights with proper remapping - num_layers = config["num_interactions"] - transfer_weights(source_model, target_model, max_L, correlation, num_layers) - - if return_model: - return target_model - - # Save model - if isinstance(input_model, str): - base = os.path.splitext(input_model)[0] - output_model = f"{base}.{output_model}" - logging.warning(f"Saving E3nn model to {output_model}") - torch.save(target_model, output_model) - return None - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("input_model", help="Path to input CuEq model") - parser.add_argument( - "--output_model", help="Path to output E3nn model", default="e3nn_model.pt" - ) - parser.add_argument("--device", default="cpu", help="Device to use") - parser.add_argument( - "--return_model", - action="store_false", - help="Return model instead of saving to file", - ) - args = parser.parse_args() - - run( - input_model=args.input_model, - output_model=args.output_model, - device=args.device, - return_model=args.return_model, - ) - - -if __name__ == "__main__": - main() +import argparse +import logging +import os +from typing import Dict, List, Tuple + +import torch + +from mace.tools.scripts_utils import extract_config_mace_model + + +def get_transfer_keys(num_layers: int) -> List[str]: + """Get list of keys that need to be transferred""" + return [ + "node_embedding.linear.weight", + "radial_embedding.bessel_fn.bessel_weights", + "atomic_energies_fn.atomic_energies", + "readouts.0.linear.weight", + *[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)], + "scale_shift.scale", + "scale_shift.shift", + *[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)], + ] + [ + s + for j in range(num_layers) + for s in [ + f"interactions.{j}.linear_up.weight", + *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], + f"interactions.{j}.linear.weight", + f"interactions.{j}.skip_tp.weight", + f"products.{j}.linear.weight", + ] + ] + + +def get_kmax_pairs( + max_L: int, correlation: int, num_layers: int +) -> List[Tuple[int, int]]: + """Determine kmax pairs based on max_L and correlation""" + if correlation == 2: + raise NotImplementedError("Correlation 2 not supported yet") + if correlation == 3: + kmax_pairs = [[i, max_L] for i in range(num_layers - 1)] + kmax_pairs = kmax_pairs + [[num_layers - 1, 0]] + return kmax_pairs + raise NotImplementedError(f"Correlation {correlation} not supported") + + +def transfer_symmetric_contractions( + source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int, + num_layers: int, +): + """Transfer symmetric contraction weights from CuEq to E3nn format""" + kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers) + + for i, kmax in kmax_pairs: + # Get the combined weight tensor from source + wm = source_dict[f"products.{i}.symmetric_contractions.weight"] + + # Get split sizes based on target dimensions + splits = [] + for k in range(kmax + 1): + for suffix in ["_max", ".0", ".1"]: + key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}" + target_shape = target_dict[key].shape + splits.append(target_shape[1]) + + # Split the weights using the calculated sizes + weights_split = torch.split(wm, splits, dim=1) + + # Assign back to target dictionary + idx = 0 + for k in range(kmax + 1): + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights_max" + ] = weights_split[idx] + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights.0" + ] = weights_split[idx + 1] + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights.1" + ] = weights_split[idx + 2] + idx += 3 + + +def transfer_weights( + source_model: torch.nn.Module, + target_model: torch.nn.Module, + max_L: int, + correlation: int, + num_layers: int, +): + """Transfer weights from CuEq to E3nn format""" + # Get state dicts + source_dict = source_model.state_dict() + target_dict = target_model.state_dict() + + # Transfer main weights + transfer_keys = get_transfer_keys(num_layers) + for key in transfer_keys: + if key in source_dict: # Check if key exists + target_dict[key] = source_dict[key] + else: + logging.warning(f"Key {key} not found in source model") + + # Transfer symmetric contractions + transfer_symmetric_contractions( + source_dict, target_dict, max_L, correlation, num_layers + ) + + # Unsqueeze linear and skip_tp layers + for key in source_dict.keys(): + if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key: + target_dict[key] = target_dict[key].squeeze(0) + + # Transfer remaining matching keys + transferred_keys = set(transfer_keys) + remaining_keys = ( + set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + ) + remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} + + if remaining_keys: + for key in remaining_keys: + if source_dict[key].shape == target_dict[key].shape: + logging.debug(f"Transferring additional key: {key}") + target_dict[key] = source_dict[key] + else: + logging.warning( + f"Shape mismatch for key {key}: " + f"source {source_dict[key].shape} vs target {target_dict[key].shape}" + ) + + # Transfer avg_num_neighbors + for i in range(2): + target_model.interactions[i].avg_num_neighbors = source_model.interactions[ + i + ].avg_num_neighbors + + # Load state dict into target model + target_model.load_state_dict(target_dict) + + +def run(input_model, output_model="_e3nn.model", device="cpu", return_model=True): + + # Load CuEq model + if isinstance(input_model, str): + source_model = torch.load(input_model, map_location=device) + else: + source_model = input_model + default_dtype = next(source_model.parameters()).dtype + torch.set_default_dtype(default_dtype) + # Extract configuration + config = extract_config_mace_model(source_model) + + # Get max_L and correlation from config + max_L = config["hidden_irreps"].lmax + correlation = config["correlation"] + + # Remove CuEq config + config.pop("cueq_config", None) + + # Create new model without CuEq config + logging.info("Creating new model without CuEq settings") + target_model = source_model.__class__(**config) + + # Transfer weights with proper remapping + num_layers = config["num_interactions"] + transfer_weights(source_model, target_model, max_L, correlation, num_layers) + + if return_model: + return target_model + + # Save model + if isinstance(input_model, str): + base = os.path.splitext(input_model)[0] + output_model = f"{base}.{output_model}" + logging.warning(f"Saving E3nn model to {output_model}") + torch.save(target_model, output_model) + return None + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_model", help="Path to input CuEq model") + parser.add_argument( + "--output_model", help="Path to output E3nn model", default="e3nn_model.pt" + ) + parser.add_argument("--device", default="cpu", help="Device to use") + parser.add_argument( + "--return_model", + action="store_false", + help="Return model instead of saving to file", + ) + args = parser.parse_args() + + run( + input_model=args.input_model, + output_model=args.output_model, + device=args.device, + return_model=args.return_model, + ) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/convert_device.py b/mace-bench/3rdparty/mace/mace/cli/convert_device.py index 3366bfde1b3e6d1a77f919fccd6e3cba02aa24e2..69735b7cf723f45e74b1573d7cf86cb3828dc539 100644 --- a/mace-bench/3rdparty/mace/mace/cli/convert_device.py +++ b/mace-bench/3rdparty/mace/mace/cli/convert_device.py @@ -1,31 +1,31 @@ -from argparse import ArgumentParser - -import torch - - -def main(): - parser = ArgumentParser() - parser.add_argument( - "--target_device", - "-t", - help="device to convert to, usually 'cpu' or 'cuda'", - default="cpu", - ) - parser.add_argument( - "--output_file", - "-o", - help="name for output model, defaults to model_file.target_device", - ) - parser.add_argument("model_file", help="input model file path") - args = parser.parse_args() - - if args.output_file is None: - args.output_file = args.model_file + "." + args.target_device - - model = torch.load(args.model_file, weights_only=False) - model.to(args.target_device) - torch.save(model, args.output_file) - - -if __name__ == "__main__": - main() +from argparse import ArgumentParser + +import torch + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--target_device", + "-t", + help="device to convert to, usually 'cpu' or 'cuda'", + default="cpu", + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file, weights_only=False) + model.to(args.target_device) + torch.save(model, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/convert_e3nn_cueq.py b/mace-bench/3rdparty/mace/mace/cli/convert_e3nn_cueq.py index 5e8223353ae57bd0f9de38589be58263eb9ecb21..299291f4b911500e2f3f664ee3e34398d17453be 100644 --- a/mace-bench/3rdparty/mace/mace/cli/convert_e3nn_cueq.py +++ b/mace-bench/3rdparty/mace/mace/cli/convert_e3nn_cueq.py @@ -1,204 +1,204 @@ -import argparse -import logging -import os -from typing import Dict, List, Tuple - -import torch - -from mace.modules.wrapper_ops import CuEquivarianceConfig -from mace.tools.scripts_utils import extract_config_mace_model - - -def get_transfer_keys(num_layers: int) -> List[str]: - """Get list of keys that need to be transferred""" - return [ - "node_embedding.linear.weight", - "radial_embedding.bessel_fn.bessel_weights", - "atomic_energies_fn.atomic_energies", - "readouts.0.linear.weight", - *[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)], - "scale_shift.scale", - "scale_shift.shift", - *[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)], - ] + [ - s - for j in range(num_layers) - for s in [ - f"interactions.{j}.linear_up.weight", - *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], - f"interactions.{j}.linear.weight", - f"interactions.{j}.skip_tp.weight", - f"products.{j}.linear.weight", - ] - ] - - -def get_kmax_pairs( - max_L: int, correlation: int, num_layers: int -) -> List[Tuple[int, int]]: - """Determine kmax pairs based on max_L and correlation""" - if correlation == 2: - raise NotImplementedError("Correlation 2 not supported yet") - if correlation == 3: - kmax_pairs = [[i, max_L] for i in range(num_layers - 1)] - kmax_pairs = kmax_pairs + [[num_layers - 1, 0]] - return kmax_pairs - raise NotImplementedError(f"Correlation {correlation} not supported") - - -def transfer_symmetric_contractions( - source_dict: Dict[str, torch.Tensor], - target_dict: Dict[str, torch.Tensor], - max_L: int, - correlation: int, - num_layers: int, -): - """Transfer symmetric contraction weights""" - kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers) - - for i, kmax in kmax_pairs: - wm = torch.concatenate( - [ - source_dict[ - f"products.{i}.symmetric_contractions.contractions.{k}.weights{j}" - ] - for k in range(kmax + 1) - for j in ["_max", ".0", ".1"] - ], - dim=1, - ) - target_dict[f"products.{i}.symmetric_contractions.weight"] = wm - - -def transfer_weights( - source_model: torch.nn.Module, - target_model: torch.nn.Module, - max_L: int, - correlation: int, - num_layers: int, -): - """Transfer weights with proper remapping""" - # Get source state dict - source_dict = source_model.state_dict() - target_dict = target_model.state_dict() - - # Transfer main weights - transfer_keys = get_transfer_keys(num_layers) - for key in transfer_keys: - if key in source_dict: # Check if key exists - target_dict[key] = source_dict[key] - else: - logging.warning(f"Key {key} not found in source model") - - # Transfer symmetric contractions - transfer_symmetric_contractions( - source_dict, target_dict, max_L, correlation, num_layers - ) - - # Unsqueeze linear and skip_tp layers - for key in source_dict.keys(): - if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key: - target_dict[key] = target_dict[key].unsqueeze(0) - - transferred_keys = set(transfer_keys) - remaining_keys = ( - set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys - ) - remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} - if remaining_keys: - for key in remaining_keys: - if source_dict[key].shape == target_dict[key].shape: - logging.debug(f"Transferring additional key: {key}") - target_dict[key] = source_dict[key] - else: - logging.warning( - f"Shape mismatch for key {key}: " - f"source {source_dict[key].shape} vs target {target_dict[key].shape}" - ) - # Transfer avg_num_neighbors - for i in range(2): - target_model.interactions[i].avg_num_neighbors = source_model.interactions[ - i - ].avg_num_neighbors - - # Load state dict into target model - target_model.load_state_dict(target_dict) - - -def run( - input_model, - output_model="_cueq.model", - device="cpu", - return_model=True, -): - # Setup logging - - # Load original model - # logging.warning(f"Loading model") - # check if input_model is a path or a model - if isinstance(input_model, str): - source_model = torch.load(input_model, map_location=device) - else: - source_model = input_model - default_dtype = next(source_model.parameters()).dtype - torch.set_default_dtype(default_dtype) - # Extract configuration - config = extract_config_mace_model(source_model) - - # Get max_L and correlation from config - max_L = config["hidden_irreps"].lmax - correlation = config["correlation"] - - # Add cuequivariance config - config["cueq_config"] = CuEquivarianceConfig( - enabled=True, - layout="ir_mul", - group="O3_e3nn", - optimize_all=True, - ) - - # Create new model with cuequivariance config - logging.info("Creating new model with cuequivariance settings") - target_model = source_model.__class__(**config).to(device) - - # Transfer weights with proper remapping - num_layers = config["num_interactions"] - transfer_weights(source_model, target_model, max_L, correlation, num_layers) - - if return_model: - return target_model - - if isinstance(input_model, str): - base = os.path.splitext(input_model)[0] - output_model = f"{base}.{output_model}" - logging.warning(f"Saving CuEq model to {output_model}") - torch.save(target_model, output_model) - return None - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("input_model", help="Path to input MACE model") - parser.add_argument( - "--output_model", - help="Path to output cuequivariance model", - default="cueq_model.pt", - ) - parser.add_argument("--device", default="cpu", help="Device to use") - parser.add_argument( - "--return_model", - action="store_false", - help="Return model instead of saving to file", - ) - args = parser.parse_args() - - run( - input_model=args.input_model, - output_model=args.output_model, - device=args.device, - return_model=args.return_model, - ) - - -if __name__ == "__main__": - main() +import argparse +import logging +import os +from typing import Dict, List, Tuple + +import torch + +from mace.modules.wrapper_ops import CuEquivarianceConfig +from mace.tools.scripts_utils import extract_config_mace_model + + +def get_transfer_keys(num_layers: int) -> List[str]: + """Get list of keys that need to be transferred""" + return [ + "node_embedding.linear.weight", + "radial_embedding.bessel_fn.bessel_weights", + "atomic_energies_fn.atomic_energies", + "readouts.0.linear.weight", + *[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)], + "scale_shift.scale", + "scale_shift.shift", + *[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)], + ] + [ + s + for j in range(num_layers) + for s in [ + f"interactions.{j}.linear_up.weight", + *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], + f"interactions.{j}.linear.weight", + f"interactions.{j}.skip_tp.weight", + f"products.{j}.linear.weight", + ] + ] + + +def get_kmax_pairs( + max_L: int, correlation: int, num_layers: int +) -> List[Tuple[int, int]]: + """Determine kmax pairs based on max_L and correlation""" + if correlation == 2: + raise NotImplementedError("Correlation 2 not supported yet") + if correlation == 3: + kmax_pairs = [[i, max_L] for i in range(num_layers - 1)] + kmax_pairs = kmax_pairs + [[num_layers - 1, 0]] + return kmax_pairs + raise NotImplementedError(f"Correlation {correlation} not supported") + + +def transfer_symmetric_contractions( + source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int, + num_layers: int, +): + """Transfer symmetric contraction weights""" + kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers) + + for i, kmax in kmax_pairs: + wm = torch.concatenate( + [ + source_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights{j}" + ] + for k in range(kmax + 1) + for j in ["_max", ".0", ".1"] + ], + dim=1, + ) + target_dict[f"products.{i}.symmetric_contractions.weight"] = wm + + +def transfer_weights( + source_model: torch.nn.Module, + target_model: torch.nn.Module, + max_L: int, + correlation: int, + num_layers: int, +): + """Transfer weights with proper remapping""" + # Get source state dict + source_dict = source_model.state_dict() + target_dict = target_model.state_dict() + + # Transfer main weights + transfer_keys = get_transfer_keys(num_layers) + for key in transfer_keys: + if key in source_dict: # Check if key exists + target_dict[key] = source_dict[key] + else: + logging.warning(f"Key {key} not found in source model") + + # Transfer symmetric contractions + transfer_symmetric_contractions( + source_dict, target_dict, max_L, correlation, num_layers + ) + + # Unsqueeze linear and skip_tp layers + for key in source_dict.keys(): + if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key: + target_dict[key] = target_dict[key].unsqueeze(0) + + transferred_keys = set(transfer_keys) + remaining_keys = ( + set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + ) + remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} + if remaining_keys: + for key in remaining_keys: + if source_dict[key].shape == target_dict[key].shape: + logging.debug(f"Transferring additional key: {key}") + target_dict[key] = source_dict[key] + else: + logging.warning( + f"Shape mismatch for key {key}: " + f"source {source_dict[key].shape} vs target {target_dict[key].shape}" + ) + # Transfer avg_num_neighbors + for i in range(2): + target_model.interactions[i].avg_num_neighbors = source_model.interactions[ + i + ].avg_num_neighbors + + # Load state dict into target model + target_model.load_state_dict(target_dict) + + +def run( + input_model, + output_model="_cueq.model", + device="cpu", + return_model=True, +): + # Setup logging + + # Load original model + # logging.warning(f"Loading model") + # check if input_model is a path or a model + if isinstance(input_model, str): + source_model = torch.load(input_model, map_location=device) + else: + source_model = input_model + default_dtype = next(source_model.parameters()).dtype + torch.set_default_dtype(default_dtype) + # Extract configuration + config = extract_config_mace_model(source_model) + + # Get max_L and correlation from config + max_L = config["hidden_irreps"].lmax + correlation = config["correlation"] + + # Add cuequivariance config + config["cueq_config"] = CuEquivarianceConfig( + enabled=True, + layout="ir_mul", + group="O3_e3nn", + optimize_all=True, + ) + + # Create new model with cuequivariance config + logging.info("Creating new model with cuequivariance settings") + target_model = source_model.__class__(**config).to(device) + + # Transfer weights with proper remapping + num_layers = config["num_interactions"] + transfer_weights(source_model, target_model, max_L, correlation, num_layers) + + if return_model: + return target_model + + if isinstance(input_model, str): + base = os.path.splitext(input_model)[0] + output_model = f"{base}.{output_model}" + logging.warning(f"Saving CuEq model to {output_model}") + torch.save(target_model, output_model) + return None + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_model", help="Path to input MACE model") + parser.add_argument( + "--output_model", + help="Path to output cuequivariance model", + default="cueq_model.pt", + ) + parser.add_argument("--device", default="cpu", help="Device to use") + parser.add_argument( + "--return_model", + action="store_false", + help="Return model instead of saving to file", + ) + args = parser.parse_args() + + run( + input_model=args.input_model, + output_model=args.output_model, + device=args.device, + return_model=args.return_model, + ) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/create_lammps_model.py b/mace-bench/3rdparty/mace/mace/cli/create_lammps_model.py index f4ac867ff71e95deef0377214cda6b4d70742351..7af81f25a287cc7fe20e78ecb2e81b2c75fbccba 100644 --- a/mace-bench/3rdparty/mace/mace/cli/create_lammps_model.py +++ b/mace-bench/3rdparty/mace/mace/cli/create_lammps_model.py @@ -1,114 +1,114 @@ -# pylint: disable=wrong-import-position -import argparse -import copy -import os - -os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" - -import torch -from e3nn.util import jit - -from mace.calculators import LAMMPS_MACE -from mace.calculators.lammps_mliap_mace import LAMMPS_MLIAP_MACE -from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq - - -def parse_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "model_path", - type=str, - help="Path to the model to be converted to LAMMPS", - ) - parser.add_argument( - "--head", - type=str, - nargs="?", - help="Head of the model to be converted to LAMMPS", - default=None, - ) - parser.add_argument( - "--dtype", - type=str, - nargs="?", - help="Data type of the model to be converted to LAMMPS", - default="float64", - ) - parser.add_argument( - "--format", - type=str, - help="Old libtorch format, or new mliap format", - default="libtorch", - ) - return parser.parse_args() - - -def select_head(model): - if hasattr(model, "heads"): - heads = model.heads - else: - heads = [None] - - if len(heads) == 1: - print(f"Only one head found in the model: {heads[0]}. Skipping selection.") - return heads[0] - - print("Available heads in the model:") - for i, head in enumerate(heads): - print(f"{i + 1}: {head}") - - # Ask the user to select a head - selected = input( - f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): " - ) - - if selected.isdigit() and 1 <= int(selected) <= len(heads): - return heads[int(selected) - 1] - if selected == "": - print("No head selected. Proceeding without specifying a head.") - return None - print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") - return heads[-1] - - -def main(): - args = parse_args() - model_path = args.model_path # takes model name as command-line input - model = torch.load( - model_path, - map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - ) - if args.dtype == "float64": - model = model.double().to("cpu") - elif args.dtype == "float32": - print("Converting model to float32, this may cause loss of precision.") - model = model.float().to("cpu") - - if args.format == "mliap": - # Enabling cuequivariance by default. TODO: switch? - model = run_e3nn_to_cueq(copy.deepcopy(model)) - model.lammps_mliap = True - - if args.head is None: - head = select_head(model) - else: - head = args.head - print( - f"Selected head: {head} from command line in the list available heads: {model.heads}" - ) - - lammps_class = LAMMPS_MLIAP_MACE if args.format == "mliap" else LAMMPS_MACE - lammps_model = ( - lammps_class(model, head=head) if head is not None else lammps_class(model) - ) - if args.format == "mliap": - torch.save(lammps_model, model_path + "-mliap_lammps.pt") - else: - lammps_model_compiled = jit.compile(lammps_model) - lammps_model_compiled.save(model_path + "-lammps.pt") - - -if __name__ == "__main__": - main() +# pylint: disable=wrong-import-position +import argparse +import copy +import os + +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" + +import torch +from e3nn.util import jit + +from mace.calculators import LAMMPS_MACE +from mace.calculators.lammps_mliap_mace import LAMMPS_MLIAP_MACE +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq + + +def parse_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "model_path", + type=str, + help="Path to the model to be converted to LAMMPS", + ) + parser.add_argument( + "--head", + type=str, + nargs="?", + help="Head of the model to be converted to LAMMPS", + default=None, + ) + parser.add_argument( + "--dtype", + type=str, + nargs="?", + help="Data type of the model to be converted to LAMMPS", + default="float64", + ) + parser.add_argument( + "--format", + type=str, + help="Old libtorch format, or new mliap format", + default="libtorch", + ) + return parser.parse_args() + + +def select_head(model): + if hasattr(model, "heads"): + heads = model.heads + else: + heads = [None] + + if len(heads) == 1: + print(f"Only one head found in the model: {heads[0]}. Skipping selection.") + return heads[0] + + print("Available heads in the model:") + for i, head in enumerate(heads): + print(f"{i + 1}: {head}") + + # Ask the user to select a head + selected = input( + f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): " + ) + + if selected.isdigit() and 1 <= int(selected) <= len(heads): + return heads[int(selected) - 1] + if selected == "": + print("No head selected. Proceeding without specifying a head.") + return None + print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") + return heads[-1] + + +def main(): + args = parse_args() + model_path = args.model_path # takes model name as command-line input + model = torch.load( + model_path, + map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) + if args.dtype == "float64": + model = model.double().to("cpu") + elif args.dtype == "float32": + print("Converting model to float32, this may cause loss of precision.") + model = model.float().to("cpu") + + if args.format == "mliap": + # Enabling cuequivariance by default. TODO: switch? + model = run_e3nn_to_cueq(copy.deepcopy(model)) + model.lammps_mliap = True + + if args.head is None: + head = select_head(model) + else: + head = args.head + print( + f"Selected head: {head} from command line in the list available heads: {model.heads}" + ) + + lammps_class = LAMMPS_MLIAP_MACE if args.format == "mliap" else LAMMPS_MACE + lammps_model = ( + lammps_class(model, head=head) if head is not None else lammps_class(model) + ) + if args.format == "mliap": + torch.save(lammps_model, model_path + "-mliap_lammps.pt") + else: + lammps_model_compiled = jit.compile(lammps_model) + lammps_model_compiled.save(model_path + "-lammps.pt") + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/eval_configs.py b/mace-bench/3rdparty/mace/mace/cli/eval_configs.py index d4ec3c702024925a4a1546d73b5cf579ea9e7a62..d00c54c62eb8b9c5cbd601216d13712a273d3d93 100644 --- a/mace-bench/3rdparty/mace/mace/cli/eval_configs.py +++ b/mace-bench/3rdparty/mace/mace/cli/eval_configs.py @@ -1,165 +1,165 @@ -########################################################################################### -# Script for evaluating configurations contained in an xyz file with a trained model -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import argparse - -import ase.data -import ase.io -import numpy as np -import torch - -from mace import data -from mace.tools import torch_geometric, torch_tools, utils - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("--configs", help="path to XYZ configurations", required=True) - parser.add_argument("--model", help="path to model", required=True) - parser.add_argument("--output", help="output path", required=True) - parser.add_argument( - "--device", - help="select device", - type=str, - choices=["cpu", "cuda"], - default="cpu", - ) - parser.add_argument( - "--default_dtype", - help="set default dtype", - type=str, - choices=["float32", "float64"], - default="float64", - ) - parser.add_argument("--batch_size", help="batch size", type=int, default=64) - parser.add_argument( - "--compute_stress", - help="compute stress", - action="store_true", - default=False, - ) - parser.add_argument( - "--return_contributions", - help="model outputs energy contributions for each body order, only supported for MACE, not ScaleShiftMACE", - action="store_true", - default=False, - ) - parser.add_argument( - "--info_prefix", - help="prefix for energy, forces and stress keys", - type=str, - default="MACE_", - ) - parser.add_argument( - "--head", - help="Model head used for evaluation", - type=str, - required=False, - default=None, - ) - return parser.parse_args() - - -def main() -> None: - args = parse_args() - run(args) - - -def run(args: argparse.Namespace) -> None: - torch_tools.set_default_dtype(args.default_dtype) - device = torch_tools.init_device(args.device) - - # Load model - model = torch.load(f=args.model, map_location=args.device) - model = model.to( - args.device - ) # shouldn't be necessary but seems to help with CUDA problems - - for param in model.parameters(): - param.requires_grad = False - - # Load data and prepare input - atoms_list = ase.io.read(args.configs, index=":") - if args.head is not None: - for atoms in atoms_list: - atoms.info["head"] = args.head - configs = [data.config_from_atoms(atoms) for atoms in atoms_list] - - z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) - - try: - heads = model.heads - except AttributeError: - heads = None - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=float(model.r_max), heads=heads - ) - for config in configs - ], - batch_size=args.batch_size, - shuffle=False, - drop_last=False, - ) - - # Collect data - energies_list = [] - contributions_list = [] - stresses_list = [] - forces_collection = [] - - for batch in data_loader: - batch = batch.to(device) - output = model(batch.to_dict(), compute_stress=args.compute_stress) - energies_list.append(torch_tools.to_numpy(output["energy"])) - if args.compute_stress: - stresses_list.append(torch_tools.to_numpy(output["stress"])) - - if args.return_contributions: - contributions_list.append(torch_tools.to_numpy(output["contributions"])) - - forces = np.split( - torch_tools.to_numpy(output["forces"]), - indices_or_sections=batch.ptr[1:], - axis=0, - ) - forces_collection.append(forces[:-1]) # drop last as its empty - - energies = np.concatenate(energies_list, axis=0) - forces_list = [ - forces for forces_list in forces_collection for forces in forces_list - ] - assert len(atoms_list) == len(energies) == len(forces_list) - if args.compute_stress: - stresses = np.concatenate(stresses_list, axis=0) - assert len(atoms_list) == stresses.shape[0] - - if args.return_contributions: - contributions = np.concatenate(contributions_list, axis=0) - assert len(atoms_list) == contributions.shape[0] - - # Store data in atoms objects - for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)): - atoms.calc = None # crucial - atoms.info[args.info_prefix + "energy"] = energy - atoms.arrays[args.info_prefix + "forces"] = forces - - if args.compute_stress: - atoms.info[args.info_prefix + "stress"] = stresses[i] - - if args.return_contributions: - atoms.info[args.info_prefix + "BO_contributions"] = contributions[i] - - # Write atoms to output path - ase.io.write(args.output, images=atoms_list, format="extxyz") - - -if __name__ == "__main__": - main() +########################################################################################### +# Script for evaluating configurations contained in an xyz file with a trained model +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse + +import ase.data +import ase.io +import numpy as np +import torch + +from mace import data +from mace.tools import torch_geometric, torch_tools, utils + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--configs", help="path to XYZ configurations", required=True) + parser.add_argument("--model", help="path to model", required=True) + parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument("--batch_size", help="batch size", type=int, default=64) + parser.add_argument( + "--compute_stress", + help="compute stress", + action="store_true", + default=False, + ) + parser.add_argument( + "--return_contributions", + help="model outputs energy contributions for each body order, only supported for MACE, not ScaleShiftMACE", + action="store_true", + default=False, + ) + parser.add_argument( + "--info_prefix", + help="prefix for energy, forces and stress keys", + type=str, + default="MACE_", + ) + parser.add_argument( + "--head", + help="Model head used for evaluation", + type=str, + required=False, + default=None, + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + run(args) + + +def run(args: argparse.Namespace) -> None: + torch_tools.set_default_dtype(args.default_dtype) + device = torch_tools.init_device(args.device) + + # Load model + model = torch.load(f=args.model, map_location=args.device) + model = model.to( + args.device + ) # shouldn't be necessary but seems to help with CUDA problems + + for param in model.parameters(): + param.requires_grad = False + + # Load data and prepare input + atoms_list = ase.io.read(args.configs, index=":") + if args.head is not None: + for atoms in atoms_list: + atoms.info["head"] = args.head + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + + z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) + + try: + heads = model.heads + except AttributeError: + heads = None + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=float(model.r_max), heads=heads + ) + for config in configs + ], + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + ) + + # Collect data + energies_list = [] + contributions_list = [] + stresses_list = [] + forces_collection = [] + + for batch in data_loader: + batch = batch.to(device) + output = model(batch.to_dict(), compute_stress=args.compute_stress) + energies_list.append(torch_tools.to_numpy(output["energy"])) + if args.compute_stress: + stresses_list.append(torch_tools.to_numpy(output["stress"])) + + if args.return_contributions: + contributions_list.append(torch_tools.to_numpy(output["contributions"])) + + forces = np.split( + torch_tools.to_numpy(output["forces"]), + indices_or_sections=batch.ptr[1:], + axis=0, + ) + forces_collection.append(forces[:-1]) # drop last as its empty + + energies = np.concatenate(energies_list, axis=0) + forces_list = [ + forces for forces_list in forces_collection for forces in forces_list + ] + assert len(atoms_list) == len(energies) == len(forces_list) + if args.compute_stress: + stresses = np.concatenate(stresses_list, axis=0) + assert len(atoms_list) == stresses.shape[0] + + if args.return_contributions: + contributions = np.concatenate(contributions_list, axis=0) + assert len(atoms_list) == contributions.shape[0] + + # Store data in atoms objects + for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)): + atoms.calc = None # crucial + atoms.info[args.info_prefix + "energy"] = energy + atoms.arrays[args.info_prefix + "forces"] = forces + + if args.compute_stress: + atoms.info[args.info_prefix + "stress"] = stresses[i] + + if args.return_contributions: + atoms.info[args.info_prefix + "BO_contributions"] = contributions[i] + + # Write atoms to output path + ase.io.write(args.output, images=atoms_list, format="extxyz") + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/fine_tuning_select.py b/mace-bench/3rdparty/mace/mace/cli/fine_tuning_select.py index 7dcfaba28d0a24f16c989a9b67f66a4092fd2dc0..59a0fb0782b8481967e4ad0b26ce0c3d566548fa 100644 --- a/mace-bench/3rdparty/mace/mace/cli/fine_tuning_select.py +++ b/mace-bench/3rdparty/mace/mace/cli/fine_tuning_select.py @@ -1,494 +1,494 @@ -########################################################################################### -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### -from __future__ import annotations - -import argparse -import logging -from dataclasses import dataclass -from enum import Enum -from typing import List, Tuple, Union - -import ase.data -import ase.io -import numpy as np -import torch - -from mace.calculators import MACECalculator, mace_mp - -try: - import fpsample # type: ignore -except ImportError: - pass - - -class FilteringType(Enum): - NONE = "none" - COMBINATIONS = "combinations" - EXCLUSIVE = "exclusive" - INCLUSIVE = "inclusive" - - -class SubselectType(Enum): - FPS = "fps" - RANDOM = "random" - - -@dataclass -class SelectionSettings: - configs_pt: str - output: str - configs_ft: str | None = None - atomic_numbers: List[int] | None = None - num_samples: int | None = None - subselect: SubselectType = SubselectType.FPS - model: str = "small" - descriptors: str | None = None - device: str = "cpu" - default_dtype: str = "float64" - head_pt: str | None = None - head_ft: str | None = None - filtering_type: FilteringType = FilteringType.COMBINATIONS - weight_ft: float = 1.0 - weight_pt: float = 1.0 - seed: int = 42 - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--configs_pt", - help="path to XYZ configurations for the pretraining", - required=True, - ) - parser.add_argument( - "--configs_ft", - help="path or list of paths to XYZ configurations for the finetuning", - required=False, - default=None, - ) - parser.add_argument( - "--num_samples", - help="number of samples to select for the pretraining", - type=int, - required=False, - default=None, - ) - parser.add_argument( - "--subselect", - help="method to subselect the configurations of the pretraining set", - type=SubselectType, - choices=list(SubselectType), - default=SubselectType.FPS, - ) - parser.add_argument( - "--model", help="path to model", default="small", required=False - ) - parser.add_argument("--output", help="output path", required=True) - parser.add_argument( - "--descriptors", help="path to descriptors", required=False, default=None - ) - parser.add_argument( - "--device", - help="select device", - type=str, - choices=["cpu", "cuda"], - default="cpu", - ) - parser.add_argument( - "--default_dtype", - help="set default dtype", - type=str, - choices=["float32", "float64"], - default="float64", - ) - parser.add_argument( - "--head_pt", - help="level of head for the pretraining set", - type=str, - default=None, - ) - parser.add_argument( - "--head_ft", - help="level of head for the finetuning set", - type=str, - default=None, - ) - parser.add_argument( - "--filtering_type", - help="filtering type", - type=FilteringType, - choices=list(FilteringType), - default=FilteringType.NONE, - ) - parser.add_argument( - "--weight_ft", - help="weight for the finetuning set", - type=float, - default=1.0, - ) - parser.add_argument( - "--weight_pt", - help="weight for the pretraining set", - type=float, - default=1.0, - ) - parser.add_argument("--seed", help="random seed", type=int, default=42) - return parser.parse_args() - - -def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None: - logging.info("Calculating descriptors") - for mol in atoms: - descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) - # average descriptors over atoms for each element - descriptors_dict = { - element: np.mean(descriptors[mol.symbols == element], axis=0) - for element in np.unique(mol.symbols) - } - mol.info["mace_descriptors"] = descriptors_dict - - -def filter_atoms( - atoms: ase.Atoms, - element_subset: List[str], - filtering_type: FilteringType = FilteringType.COMBINATIONS, -) -> bool: - """ - Filters atoms based on the provided filtering type and element subset. - - Parameters: - atoms (ase.Atoms): The atoms object to filter. - element_subset (list): The list of elements to consider during filtering. - filtering_type (FilteringType): The type of filtering to apply. - Can be one of the following `FilteringType` enum members: - - `FilteringType.NONE`: No filtering is applied. - - `FilteringType.COMBINATIONS`: Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present. - - `FilteringType.EXCLUSIVE`: Return true if `atoms` contains *only* elements in the subset, false otherwise. - - `FilteringType.INCLUSIVE`: Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements. - - Returns: - bool: True if the atoms pass the filter, False otherwise. - """ - if filtering_type == FilteringType.NONE: - return True - if filtering_type == FilteringType.COMBINATIONS: - atom_symbols = np.unique(atoms.symbols) - return all( - x in element_subset for x in atom_symbols - ) # atoms must *only* contain elements in the subset - if filtering_type == FilteringType.EXCLUSIVE: - atom_symbols = set(list(atoms.symbols)) - return atom_symbols == set(element_subset) - if filtering_type == FilteringType.INCLUSIVE: - atom_symbols = np.unique(atoms.symbols) - return all( - x in atom_symbols for x in element_subset - ) # atoms must *at least* contain elements in the subset - raise ValueError( - f"Filtering type {filtering_type} not recognised. Must be one of {list(FilteringType)}." - ) - - -class FPS: - def __init__(self, atoms_list: List[ase.Atoms], n_samples: int): - self.n_samples = n_samples - self.atoms_list = atoms_list - self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) # type: ignore - self.species_dict = {x: i for i, x in enumerate(self.species)} - # start from a random configuration - self.list_index = [np.random.randint(0, len(atoms_list))] - self.assemble_descriptors() - - def run( - self, - ) -> List[int]: - """ - Run the farthest point sampling algorithm. - """ - descriptor_dataset_reshaped = ( - self.descriptors_dataset.reshape( # pylint: disable=E1121 - (len(self.atoms_list), -1) - ) - ) - logging.info(f"{descriptor_dataset_reshaped.shape}") - logging.info(f"n_samples: {self.n_samples}") - self.list_index = fpsample.fps_npdu_kdtree_sampling( - descriptor_dataset_reshaped, - self.n_samples, - ) - return self.list_index - - def assemble_descriptors(self) -> None: - """ - Assemble the descriptors for all the configurations. - """ - self.descriptors_dataset: np.ndarray = 10e10 * np.ones( - ( - len(self.atoms_list), - len(self.species), - len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), - ), - dtype=np.float32, - ).astype(np.float32) - - for i, atoms in enumerate(self.atoms_list): - descriptors = atoms.info["mace_descriptors"] - for z in descriptors: - self.descriptors_dataset[i, self.species_dict[z]] = np.array( - descriptors[z] - ).astype(np.float32) - - -def _load_calc( - model: str, device: str, default_dtype: str, subselect: SubselectType -) -> Union[MACECalculator, None]: - if subselect == SubselectType.RANDOM: - return None - if model in ["small", "medium", "large"]: - calc = mace_mp(model, device=device, default_dtype=default_dtype) - else: - calc = MACECalculator( - model_paths=model, - device=device, - default_dtype=default_dtype, - ) - return calc - - -def _get_finetuning_elements( - atoms: List[ase.Atoms], atomic_numbers: List[int] | None -) -> List[str]: - if atoms: - logging.debug( - "Using elements from the finetuning configurations for filtering." - ) - species = np.unique([x.symbol for atoms in atoms for x in atoms]).tolist() # type: ignore - elif atomic_numbers is not None and atomic_numbers: - logging.debug("Using the supplied atomic numbers for filtering.") - species = [ase.data.chemical_symbols[z] for z in atomic_numbers] - else: - species = [] - return species - - -def _read_finetuning_configs( - configs_ft: Union[str, list[str], None], -) -> List[ase.Atoms]: - if isinstance(configs_ft, str): - path = configs_ft - return ase.io.read(path, index=":") # type: ignore - if isinstance(configs_ft, list): - assert all(isinstance(x, str) for x in configs_ft) - atoms_list_ft = [] - for path in configs_ft: - atoms_list_ft += ase.io.read(path, index=":") - return atoms_list_ft - if configs_ft is None: - return [] - raise ValueError(f"Invalid type for configs_ft: {type(configs_ft)}") - - -def _filter_pretraining_data( - atoms: list[ase.Atoms], - filtering_type: FilteringType, - all_species_ft: List[str], -) -> Tuple[List[ase.Atoms], List[ase.Atoms], list[bool]]: - logging.info( - "Filtering configurations based on the finetuning set, " - f"filtering type: {filtering_type}, elements: {all_species_ft}" - ) - passes_filter = [filter_atoms(x, all_species_ft, filtering_type) for x in atoms] - assert len(passes_filter) == len(atoms), "Filtering failed" - filtered_atoms = [x for x, passes in zip(atoms, passes_filter) if passes] - remaining_atoms = [x for x, passes in zip(atoms, passes_filter) if not passes] - return filtered_atoms, remaining_atoms, passes_filter - - -def _get_random_configs( - num_samples: int, - atoms: List[ase.Atoms], -) -> list[ase.Atoms]: - if num_samples > len(atoms): - raise ValueError( - f"Requested more samples ({num_samples}) than available in the remaining set ({len(atoms)})" - ) - indices = np.random.choice(list(range(len(atoms))), num_samples, replace=False) - return [atoms[i] for i in indices] - - -def _load_descriptors( - atoms: List[ase.Atoms], - passes_filter: List[bool], - descriptors_path: str | None, - calc: MACECalculator | None, - full_data_length: int, -) -> None: - if descriptors_path is not None: - logging.info(f"Loading descriptors from {descriptors_path}") - descriptors = np.load(descriptors_path, allow_pickle=True) - assert sum(passes_filter) == len(atoms) - if len(descriptors) != full_data_length: - raise ValueError( - f"Length of the descriptors ({len(descriptors)}) does not match the length of the data ({full_data_length})" - "Please provide descriptors for all configurations" - ) - required_descriptors = [ - descriptors[i] for i, passes in enumerate(passes_filter) if passes - ] - for i, atoms_ in enumerate(atoms): - atoms_.info["mace_descriptors"] = required_descriptors[i] - else: - logging.info("Calculating descriptors") - if calc is None: - raise ValueError("MACECalculator must be provided to calculate descriptors") - calculate_descriptors(atoms, calc) - - -def _maybe_save_descriptors( - atoms: List[ase.Atoms], - output_path: str, -) -> None: - """ - Save the descriptors if they are present in the atoms objects. - Also, delete the descriptors from the atoms objects. - """ - if all("mace_descriptors" in x.info for x in atoms): - descriptor_save_path = output_path.replace(".xyz", "_descriptors.npy") - logging.info(f"Saving descriptors at {descriptor_save_path}") - descriptors_list = [x.info["mace_descriptors"] for x in atoms] - np.save(descriptor_save_path, descriptors_list, allow_pickle=True) - for x in atoms: - del x.info["mace_descriptors"] - - -def _maybe_fps(atoms: List[ase.Atoms], num_samples: int) -> List[ase.Atoms]: - try: - fps_pt = FPS(atoms, num_samples) - idx_pt = fps_pt.run() - logging.info(f"Selected {len(idx_pt)} configurations") - return [atoms[i] for i in idx_pt] - except Exception as e: # pylint: disable=W0703 - logging.error(f"FPS failed, selecting random configurations instead: {e}") - return _get_random_configs(num_samples, atoms) - - -def _subsample_data( - filtered_atoms: List[ase.Atoms], - remaining_atoms: List[ase.Atoms], - passes_filter: List[bool], - num_samples: int | None, - subselect: SubselectType, - descriptors_path: str | None, - calc: MACECalculator | None, -) -> List[ase.Atoms]: - if num_samples is None or num_samples == len(filtered_atoms): - logging.info( - f"No subsampling, keeping all {len(filtered_atoms)} filtered configurations" - ) - return filtered_atoms - if num_samples > len(filtered_atoms): - num_sample_randomly = num_samples - len(filtered_atoms) - logging.info( - f"Number of configurations after filtering {len(filtered_atoms)} " - f"is less than the number of samples {num_samples}, " - f"selecting {num_sample_randomly} random configurations for the rest." - ) - return filtered_atoms + _get_random_configs( - num_sample_randomly, remaining_atoms - ) - if num_samples == 0: - raise ValueError("Number of samples must be greater than 0") - if subselect == SubselectType.FPS: - _load_descriptors( - filtered_atoms, - passes_filter, - descriptors_path, - calc, - full_data_length=len(filtered_atoms) + len(remaining_atoms), - ) - logging.info("Selecting configurations using Farthest Point Sampling") - return _maybe_fps(filtered_atoms, num_samples) - if subselect == SubselectType.RANDOM: - return _get_random_configs(num_samples, filtered_atoms) - raise ValueError(f"Invalid subselect type: {subselect}") - - -def _write_metadata( - atoms: list[ase.Atoms], pretrained: bool, config_weight: float, head: str | None -) -> None: - for a in atoms: - a.info["pretrained"] = pretrained - a.info["config_weight"] = config_weight - if head is not None: - a.info["head"] = head - - -def select_samples( - settings: SelectionSettings, -) -> None: - np.random.seed(settings.seed) - torch.manual_seed(settings.seed) - calc = _load_calc( - settings.model, settings.device, settings.default_dtype, settings.subselect - ) - atoms_list_ft = _read_finetuning_configs(settings.configs_ft) - all_species_ft = _get_finetuning_elements(atoms_list_ft, settings.atomic_numbers) - - if settings.filtering_type is not FilteringType.NONE and not all_species_ft: - raise ValueError( - "Filtering types other than NONE require elements for filtering. They can be specified via the `--atomic_numbers` flag." - ) - - atoms_list_pt: list[ase.Atoms] = ase.io.read(settings.configs_pt, index=":") # type: ignore - filtered_pt_atoms, remaining_atoms, passes_filter = _filter_pretraining_data( - atoms_list_pt, settings.filtering_type, all_species_ft - ) - - subsampled_atoms = _subsample_data( - filtered_pt_atoms, - remaining_atoms, - passes_filter, - settings.num_samples, - settings.subselect, - settings.descriptors, - calc, - ) - _maybe_save_descriptors(subsampled_atoms, settings.output) - - _write_metadata( - subsampled_atoms, - pretrained=True, - config_weight=settings.weight_pt, - head=settings.head_pt, - ) - _write_metadata( - atoms_list_ft, - pretrained=False, - config_weight=settings.weight_ft, - head=settings.head_ft, - ) - - logging.info("Saving the selected configurations") - ase.io.write(settings.output, subsampled_atoms, format="extxyz") - - logging.info("Saving a combined XYZ file") - atoms_fps_pt_ft = subsampled_atoms + atoms_list_ft - - ase.io.write( - settings.output.replace(".xyz", "_combined.xyz"), - atoms_fps_pt_ft, - format="extxyz", - ) - - -def main(): - args = parse_args() - settings = SelectionSettings(**vars(args)) - select_samples(settings) - - -if __name__ == "__main__": - main() +########################################################################################### +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### +from __future__ import annotations + +import argparse +import logging +from dataclasses import dataclass +from enum import Enum +from typing import List, Tuple, Union + +import ase.data +import ase.io +import numpy as np +import torch + +from mace.calculators import MACECalculator, mace_mp + +try: + import fpsample # type: ignore +except ImportError: + pass + + +class FilteringType(Enum): + NONE = "none" + COMBINATIONS = "combinations" + EXCLUSIVE = "exclusive" + INCLUSIVE = "inclusive" + + +class SubselectType(Enum): + FPS = "fps" + RANDOM = "random" + + +@dataclass +class SelectionSettings: + configs_pt: str + output: str + configs_ft: str | None = None + atomic_numbers: List[int] | None = None + num_samples: int | None = None + subselect: SubselectType = SubselectType.FPS + model: str = "small" + descriptors: str | None = None + device: str = "cpu" + default_dtype: str = "float64" + head_pt: str | None = None + head_ft: str | None = None + filtering_type: FilteringType = FilteringType.COMBINATIONS + weight_ft: float = 1.0 + weight_pt: float = 1.0 + seed: int = 42 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--configs_pt", + help="path to XYZ configurations for the pretraining", + required=True, + ) + parser.add_argument( + "--configs_ft", + help="path or list of paths to XYZ configurations for the finetuning", + required=False, + default=None, + ) + parser.add_argument( + "--num_samples", + help="number of samples to select for the pretraining", + type=int, + required=False, + default=None, + ) + parser.add_argument( + "--subselect", + help="method to subselect the configurations of the pretraining set", + type=SubselectType, + choices=list(SubselectType), + default=SubselectType.FPS, + ) + parser.add_argument( + "--model", help="path to model", default="small", required=False + ) + parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--descriptors", help="path to descriptors", required=False, default=None + ) + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--head_pt", + help="level of head for the pretraining set", + type=str, + default=None, + ) + parser.add_argument( + "--head_ft", + help="level of head for the finetuning set", + type=str, + default=None, + ) + parser.add_argument( + "--filtering_type", + help="filtering type", + type=FilteringType, + choices=list(FilteringType), + default=FilteringType.NONE, + ) + parser.add_argument( + "--weight_ft", + help="weight for the finetuning set", + type=float, + default=1.0, + ) + parser.add_argument( + "--weight_pt", + help="weight for the pretraining set", + type=float, + default=1.0, + ) + parser.add_argument("--seed", help="random seed", type=int, default=42) + return parser.parse_args() + + +def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None: + logging.info("Calculating descriptors") + for mol in atoms: + descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) + # average descriptors over atoms for each element + descriptors_dict = { + element: np.mean(descriptors[mol.symbols == element], axis=0) + for element in np.unique(mol.symbols) + } + mol.info["mace_descriptors"] = descriptors_dict + + +def filter_atoms( + atoms: ase.Atoms, + element_subset: List[str], + filtering_type: FilteringType = FilteringType.COMBINATIONS, +) -> bool: + """ + Filters atoms based on the provided filtering type and element subset. + + Parameters: + atoms (ase.Atoms): The atoms object to filter. + element_subset (list): The list of elements to consider during filtering. + filtering_type (FilteringType): The type of filtering to apply. + Can be one of the following `FilteringType` enum members: + - `FilteringType.NONE`: No filtering is applied. + - `FilteringType.COMBINATIONS`: Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present. + - `FilteringType.EXCLUSIVE`: Return true if `atoms` contains *only* elements in the subset, false otherwise. + - `FilteringType.INCLUSIVE`: Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements. + + Returns: + bool: True if the atoms pass the filter, False otherwise. + """ + if filtering_type == FilteringType.NONE: + return True + if filtering_type == FilteringType.COMBINATIONS: + atom_symbols = np.unique(atoms.symbols) + return all( + x in element_subset for x in atom_symbols + ) # atoms must *only* contain elements in the subset + if filtering_type == FilteringType.EXCLUSIVE: + atom_symbols = set(list(atoms.symbols)) + return atom_symbols == set(element_subset) + if filtering_type == FilteringType.INCLUSIVE: + atom_symbols = np.unique(atoms.symbols) + return all( + x in atom_symbols for x in element_subset + ) # atoms must *at least* contain elements in the subset + raise ValueError( + f"Filtering type {filtering_type} not recognised. Must be one of {list(FilteringType)}." + ) + + +class FPS: + def __init__(self, atoms_list: List[ase.Atoms], n_samples: int): + self.n_samples = n_samples + self.atoms_list = atoms_list + self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) # type: ignore + self.species_dict = {x: i for i, x in enumerate(self.species)} + # start from a random configuration + self.list_index = [np.random.randint(0, len(atoms_list))] + self.assemble_descriptors() + + def run( + self, + ) -> List[int]: + """ + Run the farthest point sampling algorithm. + """ + descriptor_dataset_reshaped = ( + self.descriptors_dataset.reshape( # pylint: disable=E1121 + (len(self.atoms_list), -1) + ) + ) + logging.info(f"{descriptor_dataset_reshaped.shape}") + logging.info(f"n_samples: {self.n_samples}") + self.list_index = fpsample.fps_npdu_kdtree_sampling( + descriptor_dataset_reshaped, + self.n_samples, + ) + return self.list_index + + def assemble_descriptors(self) -> None: + """ + Assemble the descriptors for all the configurations. + """ + self.descriptors_dataset: np.ndarray = 10e10 * np.ones( + ( + len(self.atoms_list), + len(self.species), + len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), + ), + dtype=np.float32, + ).astype(np.float32) + + for i, atoms in enumerate(self.atoms_list): + descriptors = atoms.info["mace_descriptors"] + for z in descriptors: + self.descriptors_dataset[i, self.species_dict[z]] = np.array( + descriptors[z] + ).astype(np.float32) + + +def _load_calc( + model: str, device: str, default_dtype: str, subselect: SubselectType +) -> Union[MACECalculator, None]: + if subselect == SubselectType.RANDOM: + return None + if model in ["small", "medium", "large"]: + calc = mace_mp(model, device=device, default_dtype=default_dtype) + else: + calc = MACECalculator( + model_paths=model, + device=device, + default_dtype=default_dtype, + ) + return calc + + +def _get_finetuning_elements( + atoms: List[ase.Atoms], atomic_numbers: List[int] | None +) -> List[str]: + if atoms: + logging.debug( + "Using elements from the finetuning configurations for filtering." + ) + species = np.unique([x.symbol for atoms in atoms for x in atoms]).tolist() # type: ignore + elif atomic_numbers is not None and atomic_numbers: + logging.debug("Using the supplied atomic numbers for filtering.") + species = [ase.data.chemical_symbols[z] for z in atomic_numbers] + else: + species = [] + return species + + +def _read_finetuning_configs( + configs_ft: Union[str, list[str], None], +) -> List[ase.Atoms]: + if isinstance(configs_ft, str): + path = configs_ft + return ase.io.read(path, index=":") # type: ignore + if isinstance(configs_ft, list): + assert all(isinstance(x, str) for x in configs_ft) + atoms_list_ft = [] + for path in configs_ft: + atoms_list_ft += ase.io.read(path, index=":") + return atoms_list_ft + if configs_ft is None: + return [] + raise ValueError(f"Invalid type for configs_ft: {type(configs_ft)}") + + +def _filter_pretraining_data( + atoms: list[ase.Atoms], + filtering_type: FilteringType, + all_species_ft: List[str], +) -> Tuple[List[ase.Atoms], List[ase.Atoms], list[bool]]: + logging.info( + "Filtering configurations based on the finetuning set, " + f"filtering type: {filtering_type}, elements: {all_species_ft}" + ) + passes_filter = [filter_atoms(x, all_species_ft, filtering_type) for x in atoms] + assert len(passes_filter) == len(atoms), "Filtering failed" + filtered_atoms = [x for x, passes in zip(atoms, passes_filter) if passes] + remaining_atoms = [x for x, passes in zip(atoms, passes_filter) if not passes] + return filtered_atoms, remaining_atoms, passes_filter + + +def _get_random_configs( + num_samples: int, + atoms: List[ase.Atoms], +) -> list[ase.Atoms]: + if num_samples > len(atoms): + raise ValueError( + f"Requested more samples ({num_samples}) than available in the remaining set ({len(atoms)})" + ) + indices = np.random.choice(list(range(len(atoms))), num_samples, replace=False) + return [atoms[i] for i in indices] + + +def _load_descriptors( + atoms: List[ase.Atoms], + passes_filter: List[bool], + descriptors_path: str | None, + calc: MACECalculator | None, + full_data_length: int, +) -> None: + if descriptors_path is not None: + logging.info(f"Loading descriptors from {descriptors_path}") + descriptors = np.load(descriptors_path, allow_pickle=True) + assert sum(passes_filter) == len(atoms) + if len(descriptors) != full_data_length: + raise ValueError( + f"Length of the descriptors ({len(descriptors)}) does not match the length of the data ({full_data_length})" + "Please provide descriptors for all configurations" + ) + required_descriptors = [ + descriptors[i] for i, passes in enumerate(passes_filter) if passes + ] + for i, atoms_ in enumerate(atoms): + atoms_.info["mace_descriptors"] = required_descriptors[i] + else: + logging.info("Calculating descriptors") + if calc is None: + raise ValueError("MACECalculator must be provided to calculate descriptors") + calculate_descriptors(atoms, calc) + + +def _maybe_save_descriptors( + atoms: List[ase.Atoms], + output_path: str, +) -> None: + """ + Save the descriptors if they are present in the atoms objects. + Also, delete the descriptors from the atoms objects. + """ + if all("mace_descriptors" in x.info for x in atoms): + descriptor_save_path = output_path.replace(".xyz", "_descriptors.npy") + logging.info(f"Saving descriptors at {descriptor_save_path}") + descriptors_list = [x.info["mace_descriptors"] for x in atoms] + np.save(descriptor_save_path, descriptors_list, allow_pickle=True) + for x in atoms: + del x.info["mace_descriptors"] + + +def _maybe_fps(atoms: List[ase.Atoms], num_samples: int) -> List[ase.Atoms]: + try: + fps_pt = FPS(atoms, num_samples) + idx_pt = fps_pt.run() + logging.info(f"Selected {len(idx_pt)} configurations") + return [atoms[i] for i in idx_pt] + except Exception as e: # pylint: disable=W0703 + logging.error(f"FPS failed, selecting random configurations instead: {e}") + return _get_random_configs(num_samples, atoms) + + +def _subsample_data( + filtered_atoms: List[ase.Atoms], + remaining_atoms: List[ase.Atoms], + passes_filter: List[bool], + num_samples: int | None, + subselect: SubselectType, + descriptors_path: str | None, + calc: MACECalculator | None, +) -> List[ase.Atoms]: + if num_samples is None or num_samples == len(filtered_atoms): + logging.info( + f"No subsampling, keeping all {len(filtered_atoms)} filtered configurations" + ) + return filtered_atoms + if num_samples > len(filtered_atoms): + num_sample_randomly = num_samples - len(filtered_atoms) + logging.info( + f"Number of configurations after filtering {len(filtered_atoms)} " + f"is less than the number of samples {num_samples}, " + f"selecting {num_sample_randomly} random configurations for the rest." + ) + return filtered_atoms + _get_random_configs( + num_sample_randomly, remaining_atoms + ) + if num_samples == 0: + raise ValueError("Number of samples must be greater than 0") + if subselect == SubselectType.FPS: + _load_descriptors( + filtered_atoms, + passes_filter, + descriptors_path, + calc, + full_data_length=len(filtered_atoms) + len(remaining_atoms), + ) + logging.info("Selecting configurations using Farthest Point Sampling") + return _maybe_fps(filtered_atoms, num_samples) + if subselect == SubselectType.RANDOM: + return _get_random_configs(num_samples, filtered_atoms) + raise ValueError(f"Invalid subselect type: {subselect}") + + +def _write_metadata( + atoms: list[ase.Atoms], pretrained: bool, config_weight: float, head: str | None +) -> None: + for a in atoms: + a.info["pretrained"] = pretrained + a.info["config_weight"] = config_weight + if head is not None: + a.info["head"] = head + + +def select_samples( + settings: SelectionSettings, +) -> None: + np.random.seed(settings.seed) + torch.manual_seed(settings.seed) + calc = _load_calc( + settings.model, settings.device, settings.default_dtype, settings.subselect + ) + atoms_list_ft = _read_finetuning_configs(settings.configs_ft) + all_species_ft = _get_finetuning_elements(atoms_list_ft, settings.atomic_numbers) + + if settings.filtering_type is not FilteringType.NONE and not all_species_ft: + raise ValueError( + "Filtering types other than NONE require elements for filtering. They can be specified via the `--atomic_numbers` flag." + ) + + atoms_list_pt: list[ase.Atoms] = ase.io.read(settings.configs_pt, index=":") # type: ignore + filtered_pt_atoms, remaining_atoms, passes_filter = _filter_pretraining_data( + atoms_list_pt, settings.filtering_type, all_species_ft + ) + + subsampled_atoms = _subsample_data( + filtered_pt_atoms, + remaining_atoms, + passes_filter, + settings.num_samples, + settings.subselect, + settings.descriptors, + calc, + ) + _maybe_save_descriptors(subsampled_atoms, settings.output) + + _write_metadata( + subsampled_atoms, + pretrained=True, + config_weight=settings.weight_pt, + head=settings.head_pt, + ) + _write_metadata( + atoms_list_ft, + pretrained=False, + config_weight=settings.weight_ft, + head=settings.head_ft, + ) + + logging.info("Saving the selected configurations") + ase.io.write(settings.output, subsampled_atoms, format="extxyz") + + logging.info("Saving a combined XYZ file") + atoms_fps_pt_ft = subsampled_atoms + atoms_list_ft + + ase.io.write( + settings.output.replace(".xyz", "_combined.xyz"), + atoms_fps_pt_ft, + format="extxyz", + ) + + +def main(): + args = parse_args() + settings = SelectionSettings(**vars(args)) + select_samples(settings) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/plot_train.py b/mace-bench/3rdparty/mace/mace/cli/plot_train.py index 4e27372acf80e10a474175cde263567159e84d3d..238bd095f57ef7f9eaf2c046972051b82506bb60 100644 --- a/mace-bench/3rdparty/mace/mace/cli/plot_train.py +++ b/mace-bench/3rdparty/mace/mace/cli/plot_train.py @@ -1,342 +1,342 @@ -import argparse -import dataclasses -import glob -import json -import os -import re -from typing import List - -import matplotlib.pyplot as plt -import pandas as pd - -plt.rcParams.update({"font.size": 8}) -plt.style.use("seaborn-v0_8-paper") - - -colors = [ - "#1f77b4", # muted blue - "#d62728", # brick red - "#ff7f0e", # safety orange - "#2ca02c", # cooked asparagus green - "#9467bd", # muted purple - "#8c564b", # chestnut brown - "#e377c2", # raspberry yogurt pink - "#7f7f7f", # middle gray - "#bcbd22", # curry yellow-green - "#17becf", # blue-teal -] - - -@dataclasses.dataclass -class RunInfo: - name: str - seed: int - - -name_re = re.compile(r"(?P.+)_run-(?P\d+)_train.txt") - - -def parse_path(path: str) -> RunInfo: - match = name_re.match(os.path.basename(path)) - if not match: - raise RuntimeError(f"Cannot parse {path}") - - return RunInfo(name=match.group("name"), seed=int(match.group("seed"))) - - -def parse_training_results(path: str) -> List[dict]: - run_info = parse_path(path) - results = [] - with open(path, mode="r", encoding="utf-8") as f: - for line in f: - d = json.loads(line) - d["name"] = run_info.name - d["seed"] = run_info.seed - results.append(d) - - return results - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Plot mace training statistics", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--path", help="Path to results file (.txt) or directory.", required=True - ) - parser.add_argument( - "--min_epoch", help="Minimum epoch.", default=0, type=int, required=False - ) - parser.add_argument( - "--start_stage_two", - "--start_swa", - help="Epoch that stage two (swa) loss began. Plots dashed line on plot to indicate. If None then assumed tag not used in training.", - default=None, - type=int, - required=False, - dest="start_swa", - ) - parser.add_argument( - "--linear", - help="Whether to plot linear instead of log scales.", - default=False, - required=False, - action="store_true", - ) - parser.add_argument( - "--error_bars", - help="Whether to plot standard deviations.", - default=False, - required=False, - action="store_true", - ) - parser.add_argument( - "--keys", - help="Comma-separated list of keys to plot.", - default="rmse_e,rmse_f", - type=str, - required=False, - ) - - parser.add_argument( - "--output_format", - help="What file type to save plot as", - default="png", - type=str, - required=False, - ) - - parser.add_argument( - "--heads", - help="Comma-separated name of the heads used for multihead training", - default=None, - type=str, - required=False, - ) - - return parser.parse_args() - - -def plot( - data: pd.DataFrame, - min_epoch: int, - output_path: str, - output_format: str, - linear: bool, - start_swa: int, - error_bars: bool, - keys: str, - heads: str, -) -> None: - """ - Plots train,validation loss and errors as a function of epoch. - min_epoch: minimum epoch to plot. - output_path: path to save the plot. - output_format: format to save the plot. - start_swa: whether to plot a dashed line to show epoch when stage two loss (swa) begins. - error_bars: whether to plot standard deviation of loss. - linear: whether to plot in linear scale or logscale (default). - keys: Values to plot. - heads: Heads used for multihead training. - """ - - labels = { - "mae_e": "MAE E [meV]", - "mae_e_per_atom": "MAE E/atom [meV]", - "rmse_e": "RMSE E [meV]", - "rmse_e_per_atom": "RMSE E/atom [meV]", - "q95_e": "Q95 E [meV]", - "mae_f": "MAE F [meV / A]", - "rel_mae_f": "Relative MAE F [meV / A]", - "rmse_f": "RMSE F [meV / A]", - "rel_rmse_f": "Relative RMSE F [meV / A]", - "q95_f": "Q95 F [meV / A]", - "mae_stress": "MAE Stress", - "rmse_stress": "RMSE Stress [meV / A^3]", - "rmse_virials_per_atom": " RMSE virials/atom [meV]", - "mae_virials": "MAE Virials [meV]", - "rmse_mu_per_atom": "RMSE MU/atom [mDebye]", - } - - data = data[data["epoch"] > min_epoch] - if heads is None: - data = ( - data.groupby(["name", "mode", "epoch"]).agg(["mean", "std"]).reset_index() - ) - - valid_data = data[data["mode"] == "eval"] - valid_data_dict = {"default": valid_data} - train_data = data[data["mode"] == "opt"] - else: - heads = heads.split(",") - # Separate eval and opt data - valid_data = ( - data[data["mode"] == "eval"] - .groupby(["name", "mode", "epoch", "head"]) - .agg(["mean", "std"]) - .reset_index() - ) - train_data = ( - data[data["mode"] == "opt"] - .groupby(["name", "mode", "epoch"]) - .agg(["mean", "std"]) - .reset_index() - ) - valid_data_dict = { - head: valid_data[valid_data["head"] == head] for head in heads - } - - for head, valid_data in valid_data_dict.items(): - fig, axes = plt.subplots( - nrows=1, ncols=2, figsize=(10, 3), constrained_layout=True - ) - - # ---- Plot loss ---- - ax = axes[0] - ax.plot( - train_data["epoch"], - train_data["loss"]["mean"], - color=colors[1], - linewidth=1, - ) - ax.set_ylabel("Training Loss", color=colors[1]) - ax.set_yscale("log") - - ax2 = ax.twinx() - ax2.plot( - valid_data["epoch"], - valid_data["loss"]["mean"], - color=colors[0], - linewidth=1, - ) - ax2.set_ylabel("Validation Loss", color=colors[0]) - - if not linear: - ax.set_yscale("log") - ax2.set_yscale("log") - - if error_bars: - ax.fill_between( - train_data["epoch"], - train_data["loss"]["mean"] - train_data["loss"]["std"], - train_data["loss"]["mean"] + train_data["loss"]["std"], - alpha=0.3, - color=colors[1], - ) - ax.fill_between( - valid_data["epoch"], - valid_data["loss"]["mean"] - valid_data["loss"]["std"], - valid_data["loss"]["mean"] + valid_data["loss"]["std"], - alpha=0.3, - color=colors[0], - ) - - if start_swa is not None: - ax.axvline( - start_swa, - color="black", - linestyle="dashed", - linewidth=1, - alpha=0.6, - label="Stage Two Starts", - ) - - ax.set_xlabel("Epoch") - ax.set_ylabel("Loss") - ax.legend(loc="upper right", fontsize=4) - ax.grid(True, linestyle="--", alpha=0.5) - - # ---- Plot selected keys ---- - ax = axes[1] - twin_axes = [] - for i, key in enumerate(keys.split(",")): - color = colors[(i + 3)] - label = labels.get(key, key) - - if i == 0: - main_ax = ax - else: - main_ax = ax.twinx() - main_ax.spines.right.set_position(("outward", 40 * (i - 1))) - twin_axes.append(main_ax) - - main_ax.plot( - valid_data["epoch"], - valid_data[key]["mean"] * 1e3, - color=color, - label=label, - linewidth=1, - ) - - if error_bars: - main_ax.fill_between( - valid_data["epoch"], - (valid_data[key]["mean"] - valid_data[key]["std"]) * 1e3, - (valid_data[key]["mean"] + valid_data[key]["std"]) * 1e3, - alpha=0.3, - color=color, - ) - - main_ax.set_ylabel(label, color=color) - main_ax.tick_params(axis="y", colors=color) - - if start_swa is not None: - ax.axvline( - start_swa, - color="black", - linestyle="dashed", - linewidth=1, - alpha=0.6, - label="Stage Two Starts", - ) - - ax.set_xlabel("Epoch") - ax.set_xlim(left=min_epoch) - ax.grid(True, linestyle="--", alpha=0.5) - - fig.savefig( - f"{output_path}_{head}.{output_format}", dpi=300, bbox_inches="tight" - ) - plt.close(fig) - - -def get_paths(path: str) -> List[str]: - if os.path.isfile(path): - return [path] - paths = glob.glob(os.path.join(path, "*_train.txt")) - - if len(paths) == 0: - raise RuntimeError(f"Cannot find results in '{path}'") - - return paths - - -def main() -> None: - args = parse_args() - run(args) - - -def run(args: argparse.Namespace) -> None: - data = pd.DataFrame( - results - for path in get_paths(args.path) - for results in parse_training_results(path) - ) - - for name, group in data.groupby("name"): - plot( - group, - min_epoch=args.min_epoch, - output_path=name, - output_format=args.output_format, - linear=args.linear, - start_swa=args.start_swa, - error_bars=args.error_bars, - keys=args.keys, - heads=args.heads, - ) - - -if __name__ == "__main__": - main() +import argparse +import dataclasses +import glob +import json +import os +import re +from typing import List + +import matplotlib.pyplot as plt +import pandas as pd + +plt.rcParams.update({"font.size": 8}) +plt.style.use("seaborn-v0_8-paper") + + +colors = [ + "#1f77b4", # muted blue + "#d62728", # brick red + "#ff7f0e", # safety orange + "#2ca02c", # cooked asparagus green + "#9467bd", # muted purple + "#8c564b", # chestnut brown + "#e377c2", # raspberry yogurt pink + "#7f7f7f", # middle gray + "#bcbd22", # curry yellow-green + "#17becf", # blue-teal +] + + +@dataclasses.dataclass +class RunInfo: + name: str + seed: int + + +name_re = re.compile(r"(?P.+)_run-(?P\d+)_train.txt") + + +def parse_path(path: str) -> RunInfo: + match = name_re.match(os.path.basename(path)) + if not match: + raise RuntimeError(f"Cannot parse {path}") + + return RunInfo(name=match.group("name"), seed=int(match.group("seed"))) + + +def parse_training_results(path: str) -> List[dict]: + run_info = parse_path(path) + results = [] + with open(path, mode="r", encoding="utf-8") as f: + for line in f: + d = json.loads(line) + d["name"] = run_info.name + d["seed"] = run_info.seed + results.append(d) + + return results + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Plot mace training statistics", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--path", help="Path to results file (.txt) or directory.", required=True + ) + parser.add_argument( + "--min_epoch", help="Minimum epoch.", default=0, type=int, required=False + ) + parser.add_argument( + "--start_stage_two", + "--start_swa", + help="Epoch that stage two (swa) loss began. Plots dashed line on plot to indicate. If None then assumed tag not used in training.", + default=None, + type=int, + required=False, + dest="start_swa", + ) + parser.add_argument( + "--linear", + help="Whether to plot linear instead of log scales.", + default=False, + required=False, + action="store_true", + ) + parser.add_argument( + "--error_bars", + help="Whether to plot standard deviations.", + default=False, + required=False, + action="store_true", + ) + parser.add_argument( + "--keys", + help="Comma-separated list of keys to plot.", + default="rmse_e,rmse_f", + type=str, + required=False, + ) + + parser.add_argument( + "--output_format", + help="What file type to save plot as", + default="png", + type=str, + required=False, + ) + + parser.add_argument( + "--heads", + help="Comma-separated name of the heads used for multihead training", + default=None, + type=str, + required=False, + ) + + return parser.parse_args() + + +def plot( + data: pd.DataFrame, + min_epoch: int, + output_path: str, + output_format: str, + linear: bool, + start_swa: int, + error_bars: bool, + keys: str, + heads: str, +) -> None: + """ + Plots train,validation loss and errors as a function of epoch. + min_epoch: minimum epoch to plot. + output_path: path to save the plot. + output_format: format to save the plot. + start_swa: whether to plot a dashed line to show epoch when stage two loss (swa) begins. + error_bars: whether to plot standard deviation of loss. + linear: whether to plot in linear scale or logscale (default). + keys: Values to plot. + heads: Heads used for multihead training. + """ + + labels = { + "mae_e": "MAE E [meV]", + "mae_e_per_atom": "MAE E/atom [meV]", + "rmse_e": "RMSE E [meV]", + "rmse_e_per_atom": "RMSE E/atom [meV]", + "q95_e": "Q95 E [meV]", + "mae_f": "MAE F [meV / A]", + "rel_mae_f": "Relative MAE F [meV / A]", + "rmse_f": "RMSE F [meV / A]", + "rel_rmse_f": "Relative RMSE F [meV / A]", + "q95_f": "Q95 F [meV / A]", + "mae_stress": "MAE Stress", + "rmse_stress": "RMSE Stress [meV / A^3]", + "rmse_virials_per_atom": " RMSE virials/atom [meV]", + "mae_virials": "MAE Virials [meV]", + "rmse_mu_per_atom": "RMSE MU/atom [mDebye]", + } + + data = data[data["epoch"] > min_epoch] + if heads is None: + data = ( + data.groupby(["name", "mode", "epoch"]).agg(["mean", "std"]).reset_index() + ) + + valid_data = data[data["mode"] == "eval"] + valid_data_dict = {"default": valid_data} + train_data = data[data["mode"] == "opt"] + else: + heads = heads.split(",") + # Separate eval and opt data + valid_data = ( + data[data["mode"] == "eval"] + .groupby(["name", "mode", "epoch", "head"]) + .agg(["mean", "std"]) + .reset_index() + ) + train_data = ( + data[data["mode"] == "opt"] + .groupby(["name", "mode", "epoch"]) + .agg(["mean", "std"]) + .reset_index() + ) + valid_data_dict = { + head: valid_data[valid_data["head"] == head] for head in heads + } + + for head, valid_data in valid_data_dict.items(): + fig, axes = plt.subplots( + nrows=1, ncols=2, figsize=(10, 3), constrained_layout=True + ) + + # ---- Plot loss ---- + ax = axes[0] + ax.plot( + train_data["epoch"], + train_data["loss"]["mean"], + color=colors[1], + linewidth=1, + ) + ax.set_ylabel("Training Loss", color=colors[1]) + ax.set_yscale("log") + + ax2 = ax.twinx() + ax2.plot( + valid_data["epoch"], + valid_data["loss"]["mean"], + color=colors[0], + linewidth=1, + ) + ax2.set_ylabel("Validation Loss", color=colors[0]) + + if not linear: + ax.set_yscale("log") + ax2.set_yscale("log") + + if error_bars: + ax.fill_between( + train_data["epoch"], + train_data["loss"]["mean"] - train_data["loss"]["std"], + train_data["loss"]["mean"] + train_data["loss"]["std"], + alpha=0.3, + color=colors[1], + ) + ax.fill_between( + valid_data["epoch"], + valid_data["loss"]["mean"] - valid_data["loss"]["std"], + valid_data["loss"]["mean"] + valid_data["loss"]["std"], + alpha=0.3, + color=colors[0], + ) + + if start_swa is not None: + ax.axvline( + start_swa, + color="black", + linestyle="dashed", + linewidth=1, + alpha=0.6, + label="Stage Two Starts", + ) + + ax.set_xlabel("Epoch") + ax.set_ylabel("Loss") + ax.legend(loc="upper right", fontsize=4) + ax.grid(True, linestyle="--", alpha=0.5) + + # ---- Plot selected keys ---- + ax = axes[1] + twin_axes = [] + for i, key in enumerate(keys.split(",")): + color = colors[(i + 3)] + label = labels.get(key, key) + + if i == 0: + main_ax = ax + else: + main_ax = ax.twinx() + main_ax.spines.right.set_position(("outward", 40 * (i - 1))) + twin_axes.append(main_ax) + + main_ax.plot( + valid_data["epoch"], + valid_data[key]["mean"] * 1e3, + color=color, + label=label, + linewidth=1, + ) + + if error_bars: + main_ax.fill_between( + valid_data["epoch"], + (valid_data[key]["mean"] - valid_data[key]["std"]) * 1e3, + (valid_data[key]["mean"] + valid_data[key]["std"]) * 1e3, + alpha=0.3, + color=color, + ) + + main_ax.set_ylabel(label, color=color) + main_ax.tick_params(axis="y", colors=color) + + if start_swa is not None: + ax.axvline( + start_swa, + color="black", + linestyle="dashed", + linewidth=1, + alpha=0.6, + label="Stage Two Starts", + ) + + ax.set_xlabel("Epoch") + ax.set_xlim(left=min_epoch) + ax.grid(True, linestyle="--", alpha=0.5) + + fig.savefig( + f"{output_path}_{head}.{output_format}", dpi=300, bbox_inches="tight" + ) + plt.close(fig) + + +def get_paths(path: str) -> List[str]: + if os.path.isfile(path): + return [path] + paths = glob.glob(os.path.join(path, "*_train.txt")) + + if len(paths) == 0: + raise RuntimeError(f"Cannot find results in '{path}'") + + return paths + + +def main() -> None: + args = parse_args() + run(args) + + +def run(args: argparse.Namespace) -> None: + data = pd.DataFrame( + results + for path in get_paths(args.path) + for results in parse_training_results(path) + ) + + for name, group in data.groupby("name"): + plot( + group, + min_epoch=args.min_epoch, + output_path=name, + output_format=args.output_format, + linear=args.linear, + start_swa=args.start_swa, + error_bars=args.error_bars, + keys=args.keys, + heads=args.heads, + ) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/preprocess_data.py b/mace-bench/3rdparty/mace/mace/cli/preprocess_data.py index 11ba17b98041d9b587a691d00882d3fd0bb07af2..64f1740221b332ff22c258e6bfadc97089f4bf03 100644 --- a/mace-bench/3rdparty/mace/mace/cli/preprocess_data.py +++ b/mace-bench/3rdparty/mace/mace/cli/preprocess_data.py @@ -1,300 +1,300 @@ -# This file loads an xyz dataset and prepares -# new hdf5 file that is ready for training with on-the-fly dataloading - -import argparse -import ast -import json -import logging -import multiprocessing as mp -import os -import random -from functools import partial -from glob import glob -from typing import List, Tuple - -import h5py -import numpy as np -import tqdm - -from mace import data, tools -from mace.data import KeySpecification, update_keyspec_from_kwargs -from mace.data.utils import save_configurations_as_HDF5 -from mace.modules import compute_statistics -from mace.tools import torch_geometric -from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz -from mace.tools.utils import AtomicNumberTable - - -def compute_stats_target( - file: str, - z_table: AtomicNumberTable, - r_max: float, - atomic_energies: Tuple, - batch_size: int, -): - train_dataset = data.HDF5Dataset(file, z_table=z_table, r_max=r_max) - train_loader = torch_geometric.dataloader.DataLoader( - dataset=train_dataset, - batch_size=batch_size, - shuffle=False, - drop_last=False, - ) - - avg_num_neighbors, mean, std = compute_statistics(train_loader, atomic_energies) - output = [avg_num_neighbors, mean, std] - return output - - -def pool_compute_stats(inputs: List): - path_to_files, z_table, r_max, atomic_energies, batch_size, num_process = inputs - - with mp.Pool(processes=num_process) as pool: - re = [ - pool.apply_async( - compute_stats_target, - args=( - file, - z_table, - r_max, - atomic_energies, - batch_size, - ), - ) - for file in glob(path_to_files + "/*") - ] - - pool.close() - pool.join() - - results = [r.get() for r in tqdm.tqdm(re)] - - if not results: - raise ValueError( - "No results were computed. Check if the input files exist and are readable." - ) - - # Separate avg_num_neighbors, mean, and std - avg_num_neighbors = np.mean([r[0] for r in results]) - means = np.array([r[1] for r in results]) - stds = np.array([r[2] for r in results]) - - # Compute averages - mean = np.mean(means, axis=0).item() - std = np.mean(stds, axis=0).item() - - return avg_num_neighbors, mean, std - - -def split_array(a: np.ndarray, max_size: int): - drop_last = False - if len(a) % 2 == 1: - a = np.append(a, a[-1]) - drop_last = True - factors = get_prime_factors(len(a)) - max_factor = 1 - for i in range(1, len(factors) + 1): - for j in range(0, len(factors) - i + 1): - if np.prod(factors[j : j + i]) <= max_size: - test = np.prod(factors[j : j + i]) - max_factor = max(test, max_factor) - return np.array_split(a, max_factor), drop_last - - -def get_prime_factors(n: int): - factors = [] - for i in range(2, n + 1): - while n % i == 0: - factors.append(i) - n = n / i - return factors - - -# Define Task for Multiprocessiing -def multi_train_hdf5(process, args, split_train, drop_last): - with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f: - f.attrs["drop_last"] = drop_last - save_configurations_as_HDF5(split_train[process], process, f) - - -def multi_valid_hdf5(process, args, split_valid, drop_last): - with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f: - f.attrs["drop_last"] = drop_last - save_configurations_as_HDF5(split_valid[process], process, f) - - -def multi_test_hdf5(process, name, args, split_test, drop_last): - with h5py.File( - args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w" - ) as f: - f.attrs["drop_last"] = drop_last - save_configurations_as_HDF5(split_test[process], process, f) - - -def main() -> None: - """ - This script loads an xyz dataset and prepares - new hdf5 file that is ready for training with on-the-fly dataloading - """ - args = tools.build_preprocess_arg_parser().parse_args() - run(args) - - -def run(args: argparse.Namespace): - """ - This script loads an xyz dataset and prepares - new hdf5 file that is ready for training with on-the-fly dataloading - """ - - # currently support only command line property_key syntax - args.key_specification = KeySpecification() - update_keyspec_from_kwargs(args.key_specification, vars(args)) - - # Setup - tools.set_seeds(args.seed) - random.seed(args.seed) - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)-8s %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[logging.StreamHandler()], - ) - - try: - config_type_weights = ast.literal_eval(args.config_type_weights) - assert isinstance(config_type_weights, dict) - except Exception as e: # pylint: disable=W0703 - logging.warning( - f"Config type weights not specified correctly ({e}), using Default" - ) - config_type_weights = {"Default": 1.0} - - folders = ["train", "val", "test"] - for sub_dir in folders: - if not os.path.exists(args.h5_prefix + sub_dir): - os.makedirs(args.h5_prefix + sub_dir) - - # Data preparation - collections, atomic_energies_dict = get_dataset_from_xyz( - work_dir=args.work_dir, - train_path=args.train_file, - valid_path=args.valid_file, - valid_fraction=args.valid_fraction, - config_type_weights=config_type_weights, - test_path=args.test_file, - seed=args.seed, - key_specification=args.key_specification, - head_name=None, - ) - - # Atomic number table - # yapf: disable - if args.atomic_numbers is None: - z_table = tools.get_atomic_number_table_from_zs( - z - for configs in (collections.train, collections.valid) - for config in configs - for z in config.atomic_numbers - ) - else: - logging.info("Using atomic numbers from command line argument") - zs_list = ast.literal_eval(args.atomic_numbers) - assert isinstance(zs_list, list) - z_table = tools.get_atomic_number_table_from_zs(zs_list) - - logging.info("Preparing training set") - if args.shuffle: - random.shuffle(collections.train) - - # split collections.train into batches and save them to hdf5 - split_train = np.array_split(collections.train,args.num_process) - drop_last = False - if len(collections.train) % 2 == 1: - drop_last = True - - multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last) - processes = [] - for i in range(args.num_process): - p = mp.Process(target=multi_train_hdf5_, args=[i]) - p.start() - processes.append(p) - - for i in processes: - i.join() - - if args.compute_statistics: - logging.info("Computing statistics") - if len(atomic_energies_dict) == 0: - atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) - - # Remove atomic energies if element not in z_table - removed_atomic_energies = {} - for z in list(atomic_energies_dict): - if z not in z_table.zs: - removed_atomic_energies[z] = atomic_energies_dict.pop(z) - if len(removed_atomic_energies) > 0: - logging.warning("Atomic energies for elements not present in the atomic number table have been removed.") - logging.warning(f"Removed atomic energies (eV): {str(removed_atomic_energies)}") - logging.warning("To include these elements in the model, specify all atomic numbers explicitly using the --atomic_numbers argument.") - - atomic_energies: np.ndarray = np.array( - [atomic_energies_dict[z] for z in z_table.zs] - ) - logging.info(f"Atomic Energies: {atomic_energies.tolist()}") - _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] - avg_num_neighbors, mean, std=pool_compute_stats(_inputs) - logging.info(f"Average number of neighbors: {avg_num_neighbors}") - logging.info(f"Mean: {mean}") - logging.info(f"Standard deviation: {std}") - - # save the statistics as a json - statistics = { - "atomic_energies": str(atomic_energies_dict), - "avg_num_neighbors": avg_num_neighbors, - "mean": mean, - "std": std, - "atomic_numbers": str([int(z) for z in z_table.zs]), - "r_max": args.r_max, - } - - with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514 - json.dump(statistics, f) - - logging.info("Preparing validation set") - if args.shuffle: - random.shuffle(collections.valid) - split_valid = np.array_split(collections.valid, args.num_process) - drop_last = False - if len(collections.valid) % 2 == 1: - drop_last = True - - multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last) - processes = [] - for i in range(args.num_process): - p = mp.Process(target=multi_valid_hdf5_, args=[i]) - p.start() - processes.append(p) - - for i in processes: - i.join() - - if args.test_file is not None: - logging.info("Preparing test sets") - for name, subset in collections.tests: - drop_last = False - if len(subset) % 2 == 1: - drop_last = True - split_test = np.array_split(subset, args.num_process) - multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last) - - processes = [] - for i in range(args.num_process): - p = mp.Process(target=multi_test_hdf5_, args=[i, name]) - p.start() - processes.append(p) - - for i in processes: - i.join() - - -if __name__ == "__main__": - main() +# This file loads an xyz dataset and prepares +# new hdf5 file that is ready for training with on-the-fly dataloading + +import argparse +import ast +import json +import logging +import multiprocessing as mp +import os +import random +from functools import partial +from glob import glob +from typing import List, Tuple + +import h5py +import numpy as np +import tqdm + +from mace import data, tools +from mace.data import KeySpecification, update_keyspec_from_kwargs +from mace.data.utils import save_configurations_as_HDF5 +from mace.modules import compute_statistics +from mace.tools import torch_geometric +from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz +from mace.tools.utils import AtomicNumberTable + + +def compute_stats_target( + file: str, + z_table: AtomicNumberTable, + r_max: float, + atomic_energies: Tuple, + batch_size: int, +): + train_dataset = data.HDF5Dataset(file, z_table=z_table, r_max=r_max) + train_loader = torch_geometric.dataloader.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + ) + + avg_num_neighbors, mean, std = compute_statistics(train_loader, atomic_energies) + output = [avg_num_neighbors, mean, std] + return output + + +def pool_compute_stats(inputs: List): + path_to_files, z_table, r_max, atomic_energies, batch_size, num_process = inputs + + with mp.Pool(processes=num_process) as pool: + re = [ + pool.apply_async( + compute_stats_target, + args=( + file, + z_table, + r_max, + atomic_energies, + batch_size, + ), + ) + for file in glob(path_to_files + "/*") + ] + + pool.close() + pool.join() + + results = [r.get() for r in tqdm.tqdm(re)] + + if not results: + raise ValueError( + "No results were computed. Check if the input files exist and are readable." + ) + + # Separate avg_num_neighbors, mean, and std + avg_num_neighbors = np.mean([r[0] for r in results]) + means = np.array([r[1] for r in results]) + stds = np.array([r[2] for r in results]) + + # Compute averages + mean = np.mean(means, axis=0).item() + std = np.mean(stds, axis=0).item() + + return avg_num_neighbors, mean, std + + +def split_array(a: np.ndarray, max_size: int): + drop_last = False + if len(a) % 2 == 1: + a = np.append(a, a[-1]) + drop_last = True + factors = get_prime_factors(len(a)) + max_factor = 1 + for i in range(1, len(factors) + 1): + for j in range(0, len(factors) - i + 1): + if np.prod(factors[j : j + i]) <= max_size: + test = np.prod(factors[j : j + i]) + max_factor = max(test, max_factor) + return np.array_split(a, max_factor), drop_last + + +def get_prime_factors(n: int): + factors = [] + for i in range(2, n + 1): + while n % i == 0: + factors.append(i) + n = n / i + return factors + + +# Define Task for Multiprocessiing +def multi_train_hdf5(process, args, split_train, drop_last): + with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_train[process], process, f) + + +def multi_valid_hdf5(process, args, split_valid, drop_last): + with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_valid[process], process, f) + + +def multi_test_hdf5(process, name, args, split_test, drop_last): + with h5py.File( + args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w" + ) as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_test[process], process, f) + + +def main() -> None: + """ + This script loads an xyz dataset and prepares + new hdf5 file that is ready for training with on-the-fly dataloading + """ + args = tools.build_preprocess_arg_parser().parse_args() + run(args) + + +def run(args: argparse.Namespace): + """ + This script loads an xyz dataset and prepares + new hdf5 file that is ready for training with on-the-fly dataloading + """ + + # currently support only command line property_key syntax + args.key_specification = KeySpecification() + update_keyspec_from_kwargs(args.key_specification, vars(args)) + + # Setup + tools.set_seeds(args.seed) + random.seed(args.seed) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler()], + ) + + try: + config_type_weights = ast.literal_eval(args.config_type_weights) + assert isinstance(config_type_weights, dict) + except Exception as e: # pylint: disable=W0703 + logging.warning( + f"Config type weights not specified correctly ({e}), using Default" + ) + config_type_weights = {"Default": 1.0} + + folders = ["train", "val", "test"] + for sub_dir in folders: + if not os.path.exists(args.h5_prefix + sub_dir): + os.makedirs(args.h5_prefix + sub_dir) + + # Data preparation + collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=args.train_file, + valid_path=args.valid_file, + valid_fraction=args.valid_fraction, + config_type_weights=config_type_weights, + test_path=args.test_file, + seed=args.seed, + key_specification=args.key_specification, + head_name=None, + ) + + # Atomic number table + # yapf: disable + if args.atomic_numbers is None: + z_table = tools.get_atomic_number_table_from_zs( + z + for configs in (collections.train, collections.valid) + for config in configs + for z in config.atomic_numbers + ) + else: + logging.info("Using atomic numbers from command line argument") + zs_list = ast.literal_eval(args.atomic_numbers) + assert isinstance(zs_list, list) + z_table = tools.get_atomic_number_table_from_zs(zs_list) + + logging.info("Preparing training set") + if args.shuffle: + random.shuffle(collections.train) + + # split collections.train into batches and save them to hdf5 + split_train = np.array_split(collections.train,args.num_process) + drop_last = False + if len(collections.train) % 2 == 1: + drop_last = True + + multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last) + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_train_hdf5_, args=[i]) + p.start() + processes.append(p) + + for i in processes: + i.join() + + if args.compute_statistics: + logging.info("Computing statistics") + if len(atomic_energies_dict) == 0: + atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) + + # Remove atomic energies if element not in z_table + removed_atomic_energies = {} + for z in list(atomic_energies_dict): + if z not in z_table.zs: + removed_atomic_energies[z] = atomic_energies_dict.pop(z) + if len(removed_atomic_energies) > 0: + logging.warning("Atomic energies for elements not present in the atomic number table have been removed.") + logging.warning(f"Removed atomic energies (eV): {str(removed_atomic_energies)}") + logging.warning("To include these elements in the model, specify all atomic numbers explicitly using the --atomic_numbers argument.") + + atomic_energies: np.ndarray = np.array( + [atomic_energies_dict[z] for z in z_table.zs] + ) + logging.info(f"Atomic Energies: {atomic_energies.tolist()}") + _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] + avg_num_neighbors, mean, std=pool_compute_stats(_inputs) + logging.info(f"Average number of neighbors: {avg_num_neighbors}") + logging.info(f"Mean: {mean}") + logging.info(f"Standard deviation: {std}") + + # save the statistics as a json + statistics = { + "atomic_energies": str(atomic_energies_dict), + "avg_num_neighbors": avg_num_neighbors, + "mean": mean, + "std": std, + "atomic_numbers": str([int(z) for z in z_table.zs]), + "r_max": args.r_max, + } + + with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514 + json.dump(statistics, f) + + logging.info("Preparing validation set") + if args.shuffle: + random.shuffle(collections.valid) + split_valid = np.array_split(collections.valid, args.num_process) + drop_last = False + if len(collections.valid) % 2 == 1: + drop_last = True + + multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last) + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_valid_hdf5_, args=[i]) + p.start() + processes.append(p) + + for i in processes: + i.join() + + if args.test_file is not None: + logging.info("Preparing test sets") + for name, subset in collections.tests: + drop_last = False + if len(subset) % 2 == 1: + drop_last = True + split_test = np.array_split(subset, args.num_process) + multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last) + + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_test_hdf5_, args=[i, name]) + p.start() + processes.append(p) + + for i in processes: + i.join() + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/run_train.py b/mace-bench/3rdparty/mace/mace/cli/run_train.py index 00a23fbb644c51c003700f3f850bfc7d4ea7308a..977ec45157174c8cd6a465fb5aafed43756540d3 100644 --- a/mace-bench/3rdparty/mace/mace/cli/run_train.py +++ b/mace-bench/3rdparty/mace/mace/cli/run_train.py @@ -1,1007 +1,1007 @@ -########################################################################################### -# Training script for MACE -# Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import ast -import glob -import json -import logging -import os -from copy import deepcopy -from pathlib import Path -from typing import List, Optional - -import torch.distributed -import torch.nn.functional -from e3nn.util import jit -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import LBFGS -from torch.utils.data import ConcatDataset -from torch_ema import ExponentialMovingAverage - -import mace -from mace import data, tools -from mace.calculators.foundations_models import mace_mp, mace_off -from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn -from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq -from mace.cli.visualise_train import TrainingPlotter -from mace.data import KeySpecification, update_keyspec_from_kwargs -from mace.tools import torch_geometric -from mace.tools.model_script_utils import configure_model -from mace.tools.multihead_tools import ( - HeadConfig, - assemble_mp_data, - dict_head_to_dataclass, - prepare_default_head, - prepare_pt_head, -) -from mace.tools.run_train_utils import ( - combine_datasets, - load_dataset_for_path, - normalize_file_paths, -) -from mace.tools.scripts_utils import ( - LRScheduler, - SubsetCollection, - check_path_ase_read, - convert_to_json_format, - dict_to_array, - extract_config_mace_model, - get_atomic_energies, - get_avg_num_neighbors, - get_config_type_weights, - get_dataset_from_xyz, - get_files_with_suffix, - get_loss_fn, - get_optimizer, - get_params_options, - get_swa, - print_git_commit, - remove_pt_head, - setup_wandb, -) -from mace.tools.slurm_distributed import DistributedEnvironment -from mace.tools.tables_utils import create_error_table -from mace.tools.utils import AtomicNumberTable - - -def main() -> None: - """ - This script runs the training/fine tuning for mace - """ - args = tools.build_default_arg_parser().parse_args() - run(args) - - -def run(args) -> None: - """ - This script runs the training/fine tuning for mace - """ - tag = tools.get_tag(name=args.name, seed=args.seed) - args, input_log_messages = tools.check_args(args) - - # default keyspec to update using heads dictionary - args.key_specification = KeySpecification() - update_keyspec_from_kwargs(args.key_specification, vars(args)) - - if args.device == "xpu": - try: - import intel_extension_for_pytorch as ipex - except ImportError as e: - raise ImportError( - "Error: Intel extension for PyTorch not found, but XPU device was specified" - ) from e - if args.distributed: - try: - distr_env = DistributedEnvironment() - except Exception as e: # pylint: disable=W0703 - logging.error(f"Failed to initialize distributed environment: {e}") - return - world_size = distr_env.world_size - local_rank = distr_env.local_rank - rank = distr_env.rank - if rank == 0: - print(distr_env) - torch.distributed.init_process_group(backend="nccl") - else: - rank = int(0) - - # Setup - tools.set_seeds(args.seed) - tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) - logging.info("===========VERIFYING SETTINGS===========") - for message, loglevel in input_log_messages: - logging.log(level=loglevel, msg=message) - - if args.distributed: - torch.cuda.set_device(local_rank) - logging.info(f"Process group initialized: {torch.distributed.is_initialized()}") - logging.info(f"Processes: {world_size}") - - try: - logging.info(f"MACE version: {mace.__version__}") - except AttributeError: - logging.info("Cannot find MACE version, please install MACE via pip") - logging.debug(f"Configuration: {args}") - - tools.set_default_dtype(args.default_dtype) - device = tools.init_device(args.device) - commit = print_git_commit() - model_foundation: Optional[torch.nn.Module] = None - foundation_model_avg_num_neighbors = 0 - if args.foundation_model is not None: - if args.foundation_model in ["small", "medium", "large"]: - logging.info( - f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint." - ) - calc = mace_mp( - model=args.foundation_model, - device=args.device, - default_dtype=args.default_dtype, - ) - model_foundation = calc.models[0] - elif args.foundation_model in ["small_off", "medium_off", "large_off"]: - model_type = args.foundation_model.split("_")[0] - logging.info( - f"Using foundation model mace-off-2023 {model_type} as initial checkpoint. ASL license." - ) - calc = mace_off( - model=model_type, - device=args.device, - default_dtype=args.default_dtype, - ) - model_foundation = calc.models[0] - else: - model_foundation = torch.load( - args.foundation_model, map_location=args.device - ) - logging.info( - f"Using foundation model {args.foundation_model} as initial checkpoint." - ) - args.r_max = model_foundation.r_max.item() - foundation_model_avg_num_neighbors = model_foundation.interactions[ - 0 - ].avg_num_neighbors - if ( - args.foundation_model not in ["small", "medium", "large"] - and args.pt_train_file is None - ): - logging.warning( - "Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file." - ) - args.multiheads_finetuning = False - if args.multiheads_finetuning: - assert ( - args.E0s != "average" - ), "average atomic energies cannot be used for multiheads finetuning" - # check that the foundation model has a single head, if not, use the first head - if not args.force_mh_ft_lr: - logging.info( - "Multihead finetuning mode, setting learning rate to 0.0001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True." - ) - args.lr = 0.0001 - args.ema = True - args.ema_decay = 0.99999 - logging.info( - "Using multiheads finetuning mode, setting learning rate to 0.0001 and EMA to True" - ) - if hasattr(model_foundation, "heads"): - if len(model_foundation.heads) > 1: - logging.warning( - "Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head." - ) - model_foundation = remove_pt_head( - model_foundation, args.foundation_head - ) - else: - args.multiheads_finetuning = False - - if args.heads is not None: - args.heads = ast.literal_eval(args.heads) - for _, head_dict in args.heads.items(): - # priority is global args < head property_key values < head info_keys+arrays_keys - head_keyspec = deepcopy(args.key_specification) - update_keyspec_from_kwargs(head_keyspec, head_dict) - head_keyspec.update( - info_keys=head_dict.get("info_keys", {}), - arrays_keys=head_dict.get("arrays_keys", {}), - ) - head_dict["key_specification"] = head_keyspec - else: - args.heads = prepare_default_head(args) - if args.multiheads_finetuning: - pt_keyspec = ( - args.heads["pt_head"]["key_specification"] - if "pt_head" in args.heads - else deepcopy(args.key_specification) - ) - args.heads["pt_head"] = prepare_pt_head( - args, pt_keyspec, foundation_model_avg_num_neighbors - ) - - logging.info("===========LOADING INPUT DATA===========") - heads = list(args.heads.keys()) - logging.info(f"Using heads: {heads}") - logging.info("Using the key specifications to parse data:") - for name, head_dict in args.heads.items(): - head_keyspec = head_dict["key_specification"] - logging.info(f"{name}: {head_keyspec}") - - head_configs: List[HeadConfig] = [] - for head, head_args in args.heads.items(): - logging.info(f"============= Processing head {head} ===========") - head_config = dict_head_to_dataclass(head_args, head, args) - - # Handle train_file and valid_file - normalize to lists - if hasattr(head_config, "train_file") and head_config.train_file is not None: - head_config.train_file = normalize_file_paths(head_config.train_file) - if hasattr(head_config, "valid_file") and head_config.valid_file is not None: - head_config.valid_file = normalize_file_paths(head_config.valid_file) - if hasattr(head_config, "test_file") and head_config.test_file is not None: - head_config.test_file = normalize_file_paths(head_config.test_file) - - if ( - head_config.statistics_file is not None - and head_config.head_name != "pt_head" - ): - with open(head_config.statistics_file, "r") as f: # pylint: disable=W1514 - statistics = json.load(f) - logging.info("Using statistics json file") - head_config.atomic_numbers = statistics["atomic_numbers"] - head_config.mean = statistics["mean"] - head_config.std = statistics["std"] - head_config.avg_num_neighbors = statistics["avg_num_neighbors"] - head_config.compute_avg_num_neighbors = False - if isinstance(statistics["atomic_energies"], str) and statistics[ - "atomic_energies" - ].endswith(".json"): - with open(statistics["atomic_energies"], "r", encoding="utf-8") as f: - atomic_energies = json.load(f) - head_config.E0s = atomic_energies - head_config.atomic_energies_dict = ast.literal_eval(atomic_energies) - else: - head_config.E0s = statistics["atomic_energies"] - head_config.atomic_energies_dict = ast.literal_eval( - statistics["atomic_energies"] - ) - if head_config.train_file == ["mp"]: - assert ( - head_config.head_name == "pt_head" - ), "Only pt_head should use mp as train_file" - logging.info( - "Using the full Materials Project data for replay. You can construct a different subset using `fine_tuning_select.py` script." - ) - collections = assemble_mp_data(args, head_config, tag) - head_config.collections = collections - elif any(check_path_ase_read(f) for f in head_config.train_file): - train_files_ase_list = [ - f for f in head_config.train_file if check_path_ase_read(f) - ] - valid_files_ase_list = None - test_files_ase_list = None - if head_config.valid_file: - valid_files_ase_list = [ - f for f in head_config.valid_file if check_path_ase_read(f) - ] - if head_config.test_file: - test_files_ase_list = [ - f for f in head_config.test_file if check_path_ase_read(f) - ] - config_type_weights = get_config_type_weights( - head_config.config_type_weights - ) - collections, atomic_energies_dict = get_dataset_from_xyz( - work_dir=args.work_dir, - train_path=train_files_ase_list, - valid_path=valid_files_ase_list, - valid_fraction=head_config.valid_fraction, - config_type_weights=config_type_weights, - test_path=test_files_ase_list, - seed=args.seed, - key_specification=head_config.key_specification, - head_name=head_config.head_name, - keep_isolated_atoms=head_config.keep_isolated_atoms, - ) - head_config.collections = SubsetCollection( - train=collections.train, - valid=collections.valid, - tests=collections.tests, - ) - head_config.atomic_energies_dict = atomic_energies_dict - logging.info( - f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " - f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," - ) - head_configs.append(head_config) - - if all( - check_path_ase_read(head_config.train_file[0]) for head_config in head_configs - ): - size_collections_train = sum( - len(head_config.collections.train) for head_config in head_configs - ) - size_collections_valid = sum( - len(head_config.collections.valid) for head_config in head_configs - ) - if size_collections_train < args.batch_size: - logging.error( - f"Batch size ({args.batch_size}) is larger than the number of training data ({size_collections_train})" - ) - if size_collections_valid < args.valid_batch_size: - logging.warning( - f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({size_collections_valid})" - ) - - if args.multiheads_finetuning: - logging.info( - "==================Using multiheads finetuning mode==================" - ) - args.loss = "universal" - - all_ase_readable = all( - all(check_path_ase_read(f) for f in head_config.train_file) - for head_config in head_configs - ) - head_config_pt = filter(lambda x: x.head_name == "pt_head", head_configs) - head_config_pt = next(head_config_pt, None) - assert head_config_pt is not None, "Pretraining head not found" - if all_ase_readable: - ratio_pt_ft = size_collections_train / len(head_config_pt.collections.train) - if ratio_pt_ft < 0.1: - logging.warning( - f"Ratio of the number of configurations in the training set and the in the pt_train_file is {ratio_pt_ft}, " - f"increasing the number of configurations in the fine-tuning heads by {int(0.1 / ratio_pt_ft)}" - ) - for head_config in head_configs: - if head_config.head_name == "pt_head": - continue - head_config.collections.train += ( - head_config.collections.train * int(0.1 / ratio_pt_ft) - ) - logging.info( - f"Total number of configurations in pretraining: train={len(head_config_pt.collections.train)}, valid={len(head_config_pt.collections.valid)}" - ) - else: - logging.debug( - "Using LMDB/HDF5 datasets for pretraining or fine-tuning - skipping ratio check" - ) - - # Atomic number table - # yapf: disable - for head_config in head_configs: - if head_config.atomic_numbers is None: - assert all(check_path_ase_read(f) for f in head_config.train_file), "Must specify atomic_numbers when using .h5 or .aselmdb train_file input" - z_table_head = tools.get_atomic_number_table_from_zs( - z - for configs in (head_config.collections.train, head_config.collections.valid) - for config in configs - for z in config.atomic_numbers - ) - head_config.atomic_numbers = z_table_head.zs - head_config.z_table = z_table_head - else: - if head_config.statistics_file is None: - logging.info("Using atomic numbers from command line argument") - else: - logging.info("Using atomic numbers from statistics file") - zs_list = ast.literal_eval(head_config.atomic_numbers) - assert isinstance(zs_list, list) - z_table_head = tools.AtomicNumberTable(zs_list) - head_config.atomic_numbers = zs_list - head_config.z_table = z_table_head - # yapf: enable - all_atomic_numbers = set() - for head_config in head_configs: - all_atomic_numbers.update(head_config.atomic_numbers) - z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) - if args.foundation_model_elements and model_foundation: - z_table = AtomicNumberTable(sorted(model_foundation.atomic_numbers.tolist())) - logging.info(f"Atomic Numbers used: {z_table.zs}") - - # Atomic energies - atomic_energies_dict = {} - for head_config in head_configs: - if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: - assert head_config.E0s is not None, "Atomic energies must be provided" - if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() != "foundation": - atomic_energies_dict[head_config.head_name] = get_atomic_energies( - head_config.E0s, head_config.collections.train, head_config.z_table - ) - elif head_config.E0s.lower() == "foundation": - assert args.foundation_model is not None - z_table_foundation = AtomicNumberTable( - [int(z) for z in model_foundation.atomic_numbers] - ) - foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies - if foundation_atomic_energies.ndim > 1: - foundation_atomic_energies = foundation_atomic_energies.squeeze() - if foundation_atomic_energies.ndim == 2: - foundation_atomic_energies = foundation_atomic_energies[0] - logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") - atomic_energies_dict[head_config.head_name] = { - z: foundation_atomic_energies[ - z_table_foundation.z_to_index(z) - ].item() - for z in z_table.zs - } - else: - atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) - else: - atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict - - # Atomic energies for multiheads finetuning - if args.multiheads_finetuning: - assert ( - model_foundation is not None - ), "Model foundation must be provided for multiheads finetuning" - z_table_foundation = AtomicNumberTable( - [int(z) for z in model_foundation.atomic_numbers] - ) - foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies - if foundation_atomic_energies.ndim > 1: - foundation_atomic_energies = foundation_atomic_energies.squeeze() - if foundation_atomic_energies.ndim == 2: - foundation_atomic_energies = foundation_atomic_energies[0] - logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") - atomic_energies_dict["pt_head"] = { - z: foundation_atomic_energies[ - z_table_foundation.z_to_index(z) - ].item() - for z in z_table.zs - } - heads = sorted(heads, key=lambda x: -1000 if x == "pt_head" else 0) - # Padding atomic energies if keeping all elements of the foundation model - if args.foundation_model_elements and model_foundation: - atomic_energies_dict_padded = {} - for head_name, head_energies in atomic_energies_dict.items(): - energy_head_padded = {} - for z in z_table.zs: - energy_head_padded[z] = head_energies.get(z, 0.0) - atomic_energies_dict_padded[head_name] = energy_head_padded - atomic_energies_dict = atomic_energies_dict_padded - - if args.model == "AtomicDipolesMACE": - atomic_energies = None - dipole_only = True - args.compute_dipole = True - args.compute_energy = False - args.compute_forces = False - args.compute_virials = False - args.compute_stress = False - else: - dipole_only = False - if args.model == "EnergyDipolesMACE": - args.compute_dipole = True - args.compute_energy = True - args.compute_forces = True - args.compute_virials = False - args.compute_stress = False - else: - args.compute_energy = True - args.compute_dipole = False - # atomic_energies: np.ndarray = np.array( - # [atomic_energies_dict[z] for z in z_table.zs] - # ) - atomic_energies = dict_to_array(atomic_energies_dict, heads) - for head_config in head_configs: - try: - logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}") - except KeyError as e: - raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e - - # Load datasets for each head, supporting multiple files per head - valid_sets = {head: [] for head in heads} - train_sets = {head: [] for head in heads} - - for head_config in head_configs: - train_datasets = [] - - logging.info(f"Processing datasets for head '{head_config.head_name}'") - ase_files = [f for f in head_config.train_file if check_path_ase_read(f)] - non_ase_files = [f for f in head_config.train_file if not check_path_ase_read(f)] - - if ase_files: - dataset = load_dataset_for_path( - file_path=ase_files, - r_max=args.r_max, - z_table=z_table, - head_config=head_config, - heads=heads, - collection=head_config.collections.train, - ) - train_datasets.append(dataset) - logging.debug(f"Successfully loaded dataset from ASE files: {ase_files}") - - for file in non_ase_files: - dataset = load_dataset_for_path( - file_path=file, - r_max=args.r_max, - z_table=z_table, - head_config=head_config, - heads=heads, - ) - train_datasets.append(dataset) - logging.debug(f"Successfully loaded dataset from non-ASE file: {file}") - - if not train_datasets: - raise ValueError(f"No valid training datasets found for head {head_config.head_name}") - - train_sets[head_config.head_name] = combine_datasets(train_datasets, head_config.head_name) - - if head_config.valid_file: - valid_datasets = [] - - valid_ase_files = [f for f in head_config.valid_file if check_path_ase_read(f)] - valid_non_ase_files = [f for f in head_config.valid_file if not check_path_ase_read(f)] - - if valid_ase_files: - valid_dataset = load_dataset_for_path( - file_path=valid_ase_files, - r_max=args.r_max, - z_table=z_table, - head_config=head_config, - heads=heads, - collection=head_config.collections.valid, - ) - valid_datasets.append(valid_dataset) - logging.debug(f"Successfully loaded validation dataset from ASE files: {valid_ase_files}") - for valid_file in valid_non_ase_files: - valid_dataset = load_dataset_for_path( - file_path=valid_file, - r_max=args.r_max, - z_table=z_table, - head_config=head_config, - heads=heads, - ) - valid_datasets.append(valid_dataset) - logging.debug(f"Successfully loaded validation dataset from {valid_file}") - - # Combine validation datasets - if valid_datasets: - valid_sets[head_config.head_name] = combine_datasets(valid_datasets, f"{head_config.head_name}_valid") - logging.info(f"Combined validation datasets for {head_config.head_name}") - - # If no valid file is provided but collection exist, use the validation set from the collection - if head_config.valid_file is None and head_config.collections.valid: - valid_sets[head_config.head_name] = [ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max, heads=heads - ) - for config in head_config.collections.valid - ] - if not valid_sets[head_config.head_name]: - raise ValueError(f"No valid datasets found for head {head_config.head_name}, please provide a valid_file or a valid_fraction") - - # Create data loader for this head - if isinstance(train_sets[head_config.head_name], list): - dataset_size = len(train_sets[head_config.head_name]) - else: - dataset_size = len(train_sets[head_config.head_name]) - logging.info(f"Head '{head_config.head_name}' training dataset size: {dataset_size}") - - train_loader_head = torch_geometric.dataloader.DataLoader( - dataset=train_sets[head_config.head_name], - batch_size=args.batch_size, - shuffle=True, - drop_last=(not args.lbfgs), - pin_memory=args.pin_memory, - num_workers=args.num_workers, - generator=torch.Generator().manual_seed(args.seed), - ) - head_config.train_loader = train_loader_head - - # concatenate all the trainsets - train_set = ConcatDataset([train_sets[head] for head in heads]) - train_sampler, valid_sampler = None, None - if args.distributed: - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_set, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=(not args.lbfgs), - seed=args.seed, - ) - valid_samplers = {} - for head, valid_set in valid_sets.items(): - valid_sampler = torch.utils.data.distributed.DistributedSampler( - valid_set, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=True, - seed=args.seed, - ) - valid_samplers[head] = valid_sampler - train_loader = torch_geometric.dataloader.DataLoader( - dataset=train_set, - batch_size=args.batch_size, - sampler=train_sampler, - shuffle=(train_sampler is None), - drop_last=(train_sampler is None and not args.lbfgs), - pin_memory=args.pin_memory, - num_workers=args.num_workers, - generator=torch.Generator().manual_seed(args.seed), - ) - valid_loaders = {heads[i]: None for i in range(len(heads))} - if not isinstance(valid_sets, dict): - valid_sets = {"Default": valid_sets} - for head, valid_set in valid_sets.items(): - valid_loaders[head] = torch_geometric.dataloader.DataLoader( - dataset=valid_set, - batch_size=args.valid_batch_size, - sampler=valid_samplers[head] if args.distributed else None, - shuffle=False, - drop_last=False, - pin_memory=args.pin_memory, - num_workers=args.num_workers, - generator=torch.Generator().manual_seed(args.seed), - ) - - loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole) - args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device) - - # Model - model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table, head_configs) - model.to(device) - - logging.debug(model) - logging.info(f"Total number of parameters: {tools.count_parameters(model)}") - logging.info("") - logging.info("===========OPTIMIZER INFORMATION===========") - logging.info(f"Using {args.optimizer.upper()} as parameter optimizer") - logging.info(f"Batch size: {args.batch_size}") - if args.ema: - logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}") - logging.info( - f"Number of gradient updates: {int(args.max_num_epochs*len(train_set)/args.batch_size)}" - ) - logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") - logging.info(loss_fn) - - # Cueq - if args.enable_cueq: - logging.info("Converting model to CUEQ for accelerated training") - assert model.__class__.__name__ in ["MACE", "ScaleShiftMACE"] - model = run_e3nn_to_cueq(deepcopy(model), device=device) - # Optimizer - param_options = get_params_options(args, model) - optimizer: torch.optim.Optimizer - optimizer = get_optimizer(args, param_options) - if args.device == "xpu": - logging.info("Optimzing model and optimzier for XPU") - model, optimizer = ipex.optimize(model, optimizer=optimizer) - logger = tools.MetricsLogger( - directory=args.results_dir, tag=tag + "_train" - ) # pylint: disable=E1123 - - lr_scheduler = LRScheduler(optimizer, args) - - swa: Optional[tools.SWAContainer] = None - swas = [False] - if args.swa: - swa, swas = get_swa(args, model, optimizer, swas, dipole_only) - - checkpoint_handler = tools.CheckpointHandler( - directory=args.checkpoints_dir, - tag=tag, - keep=args.keep_checkpoints, - swa_start=args.start_swa, - ) - - start_epoch = 0 - restart_lbfgs = False - opt_start_epoch = None - if args.restart_latest: - try: - opt_start_epoch = checkpoint_handler.load_latest( - state=tools.CheckpointState(model, optimizer, lr_scheduler), - swa=True, - device=device, - ) - except Exception: # pylint: disable=W0703 - try: - opt_start_epoch = checkpoint_handler.load_latest( - state=tools.CheckpointState(model, optimizer, lr_scheduler), - swa=False, - device=device, - ) - except Exception: # pylint: disable=W0703 - restart_lbfgs = True - if opt_start_epoch is not None: - start_epoch = opt_start_epoch - - ema: Optional[ExponentialMovingAverage] = None - if args.ema: - ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay) - else: - for group in optimizer.param_groups: - group["lr"] = args.lr - - if args.lbfgs: - logging.info("Switching optimizer to LBFGS") - optimizer = LBFGS(model.parameters(), - history_size=200, - max_iter=20, - line_search_fn="strong_wolfe") - if restart_lbfgs: - opt_start_epoch = checkpoint_handler.load_latest( - state=tools.CheckpointState(model, optimizer, lr_scheduler), - swa=False, - device=device, - ) - if opt_start_epoch is not None: - start_epoch = opt_start_epoch - - if args.wandb: - setup_wandb(args) - if args.distributed: - distributed_model = DDP(model, device_ids=[local_rank]) - else: - distributed_model = None - - - train_valid_data_loader = {} - for head_config in head_configs: - data_loader_name = "train_" + head_config.head_name - train_valid_data_loader[data_loader_name] = head_config.train_loader - for head, valid_loader in valid_loaders.items(): - data_load_name = "valid_" + head - train_valid_data_loader[data_load_name] = valid_loader - - if args.plot and args.plot_frequency > 0: - try: - plotter = TrainingPlotter( - results_dir=logger.path, - heads=heads, - table_type=args.error_table, - train_valid_data=train_valid_data_loader, - test_data={}, - output_args=output_args, - device=device, - plot_frequency=args.plot_frequency, - distributed=args.distributed, - swa_start=swa.start if swa else None - ) - except Exception as e: # pylint: disable=W0718 - logging.debug(f"Creating Plotter failed: {e}") - else: - plotter = None - - if args.dry_run: - logging.info("DRY RUN mode enabled. Stopping now.") - return - - - tools.train( - model=model, - loss_fn=loss_fn, - train_loader=train_loader, - valid_loaders=valid_loaders, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - checkpoint_handler=checkpoint_handler, - eval_interval=args.eval_interval, - start_epoch=start_epoch, - max_num_epochs=args.max_num_epochs, - logger=logger, - patience=args.patience, - save_all_checkpoints=args.save_all_checkpoints, - output_args=output_args, - device=device, - swa=swa, - ema=ema, - max_grad_norm=args.clip_grad, - log_errors=args.error_table, - log_wandb=args.wandb, - distributed=args.distributed, - distributed_model=distributed_model, - plotter=plotter, - train_sampler=train_sampler, - rank=rank, - ) - - logging.info("") - logging.info("===========RESULTS===========") - - train_valid_data_loader = {} - for head_config in head_configs: - data_loader_name = "train_" + head_config.head_name - train_valid_data_loader[data_loader_name] = head_config.train_loader - for head, valid_loader in valid_loaders.items(): - data_load_name = "valid_" + head - train_valid_data_loader[data_load_name] = valid_loader - test_sets = {} - stop_first_test = False - test_data_loader = {} - if all( - head_config.test_file == head_configs[0].test_file - for head_config in head_configs - ) and head_configs[0].test_file is not None: - stop_first_test = True - if all( - head_config.test_dir == head_configs[0].test_dir - for head_config in head_configs - ) and head_configs[0].test_dir is not None: - stop_first_test = True - for head_config in head_configs: - if all(check_path_ase_read(f) for f in head_config.train_file): - for name, subset in head_config.collections.tests: - test_sets[name] = [ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max, heads=heads - ) - for config in subset - ] - if head_config.test_dir is not None: - if not args.multi_processed_test: - test_files = get_files_with_suffix(head_config.test_dir, "_test.h5") - for test_file in test_files: - name = os.path.splitext(os.path.basename(test_file))[0] - test_sets[name] = data.HDF5Dataset( - test_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name - ) - else: - test_folders = glob(head_config.test_dir + "/*") - for folder in test_folders: - name = os.path.splitext(os.path.basename(test_file))[0] - test_sets[name] = data.dataset_from_sharded_hdf5( - folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name - ) - for test_name, test_set in test_sets.items(): - test_sampler = None - if args.distributed: - test_sampler = torch.utils.data.distributed.DistributedSampler( - test_set, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=True, - seed=args.seed, - ) - try: - drop_last = test_set.drop_last - except AttributeError as e: # pylint: disable=W0612 - drop_last = False - test_loader = torch_geometric.dataloader.DataLoader( - test_set, - batch_size=args.valid_batch_size, - shuffle=(test_sampler is None), - drop_last=drop_last, - num_workers=args.num_workers, - pin_memory=args.pin_memory, - ) - test_data_loader[test_name] = test_loader - if stop_first_test: - break - - for swa_eval in swas: - epoch = checkpoint_handler.load_latest( - state=tools.CheckpointState(model, optimizer, lr_scheduler), - swa=swa_eval, - device=device, - ) - model.to(device) - if args.distributed: - distributed_model = DDP(model, device_ids=[local_rank]) - model_to_evaluate = model if not args.distributed else distributed_model - if swa_eval: - logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") - else: - logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation") - - if rank == 0: - # Save entire model - if swa_eval: - model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model") - else: - model_path = Path(args.checkpoints_dir) / (tag + ".model") - logging.info(f"Saving model to {model_path}") - model_to_save = deepcopy(model) - if args.enable_cueq: - print("RUNING CUEQ TO E3NN") - print("swa_eval", swa_eval) - model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device) - if args.save_cpu: - model_to_save = model_to_save.to("cpu") - torch.save(model_to_save, model_path) - extra_files = { - "commit.txt": commit.encode("utf-8") if commit is not None else b"", - "config.yaml": json.dumps( - convert_to_json_format(extract_config_mace_model(model)) - ), - } - if swa_eval: - torch.save( - model_to_save, Path(args.model_dir) / (args.name + "_stagetwo.model") - ) - try: - path_complied = Path(args.model_dir) / ( - args.name + "_stagetwo_compiled.model" - ) - logging.info(f"Compiling model, saving metadata {path_complied}") - model_compiled = jit.compile(deepcopy(model_to_save)) - torch.jit.save( - model_compiled, - path_complied, - _extra_files=extra_files, - ) - except Exception as e: # pylint: disable=W0718 - pass - else: - torch.save(model_to_save, Path(args.model_dir) / (args.name + ".model")) - try: - path_complied = Path(args.model_dir) / ( - args.name + "_compiled.model" - ) - logging.info(f"Compiling model, saving metadata to {path_complied}") - model_compiled = jit.compile(deepcopy(model_to_save)) - torch.jit.save( - model_compiled, - path_complied, - _extra_files=extra_files, - ) - except Exception as e: # pylint: disable=W0718 - pass - - logging.info("Computing metrics for training, validation, and test sets") - for param in model.parameters(): - param.requires_grad = False - skip_heads = args.skip_evaluate_heads.split(",") if args.skip_evaluate_heads else [] - if skip_heads: - logging.info(f"Skipping evaluation for heads: {skip_heads}") - table_train_valid = create_error_table( - table_type=args.error_table, - all_data_loaders=train_valid_data_loader, - model=model_to_evaluate, - loss_fn=loss_fn, - output_args=output_args, - log_wandb=args.wandb, - device=device, - distributed=args.distributed, - skip_heads=skip_heads, - ) - logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) - - if test_data_loader: - table_test = create_error_table( - table_type=args.error_table, - all_data_loaders=test_data_loader, - model=model_to_evaluate, - loss_fn=loss_fn, - output_args=output_args, - log_wandb=args.wandb, - device=device, - distributed=args.distributed, - ) - logging.info("Error-table on TEST:\n" + str(table_test)) - if args.plot: - try: - plotter = TrainingPlotter( - results_dir=logger.path, - heads=heads, - table_type=args.error_table, - train_valid_data=train_valid_data_loader, - test_data=test_data_loader, - output_args=output_args, - device=device, - plot_frequency=args.plot_frequency, - distributed=args.distributed, - swa_start=swa.start if swa else None - ) - plotter.plot(epoch, model_to_evaluate, rank) - except Exception as e: # pylint: disable=W0718 - logging.debug(f"Plotting failed: {e}") - - if args.distributed: - torch.distributed.barrier() - - logging.info("Done") - if args.distributed: - torch.distributed.destroy_process_group() - - -if __name__ == "__main__": - main() +########################################################################################### +# Training script for MACE +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import ast +import glob +import json +import logging +import os +from copy import deepcopy +from pathlib import Path +from typing import List, Optional + +import torch.distributed +import torch.nn.functional +from e3nn.util import jit +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import LBFGS +from torch.utils.data import ConcatDataset +from torch_ema import ExponentialMovingAverage + +import mace +from mace import data, tools +from mace.calculators.foundations_models import mace_mp, mace_off +from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq +from mace.cli.visualise_train import TrainingPlotter +from mace.data import KeySpecification, update_keyspec_from_kwargs +from mace.tools import torch_geometric +from mace.tools.model_script_utils import configure_model +from mace.tools.multihead_tools import ( + HeadConfig, + assemble_mp_data, + dict_head_to_dataclass, + prepare_default_head, + prepare_pt_head, +) +from mace.tools.run_train_utils import ( + combine_datasets, + load_dataset_for_path, + normalize_file_paths, +) +from mace.tools.scripts_utils import ( + LRScheduler, + SubsetCollection, + check_path_ase_read, + convert_to_json_format, + dict_to_array, + extract_config_mace_model, + get_atomic_energies, + get_avg_num_neighbors, + get_config_type_weights, + get_dataset_from_xyz, + get_files_with_suffix, + get_loss_fn, + get_optimizer, + get_params_options, + get_swa, + print_git_commit, + remove_pt_head, + setup_wandb, +) +from mace.tools.slurm_distributed import DistributedEnvironment +from mace.tools.tables_utils import create_error_table +from mace.tools.utils import AtomicNumberTable + + +def main() -> None: + """ + This script runs the training/fine tuning for mace + """ + args = tools.build_default_arg_parser().parse_args() + run(args) + + +def run(args) -> None: + """ + This script runs the training/fine tuning for mace + """ + tag = tools.get_tag(name=args.name, seed=args.seed) + args, input_log_messages = tools.check_args(args) + + # default keyspec to update using heads dictionary + args.key_specification = KeySpecification() + update_keyspec_from_kwargs(args.key_specification, vars(args)) + + if args.device == "xpu": + try: + import intel_extension_for_pytorch as ipex + except ImportError as e: + raise ImportError( + "Error: Intel extension for PyTorch not found, but XPU device was specified" + ) from e + if args.distributed: + try: + distr_env = DistributedEnvironment() + except Exception as e: # pylint: disable=W0703 + logging.error(f"Failed to initialize distributed environment: {e}") + return + world_size = distr_env.world_size + local_rank = distr_env.local_rank + rank = distr_env.rank + if rank == 0: + print(distr_env) + torch.distributed.init_process_group(backend="nccl") + else: + rank = int(0) + + # Setup + tools.set_seeds(args.seed) + tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) + logging.info("===========VERIFYING SETTINGS===========") + for message, loglevel in input_log_messages: + logging.log(level=loglevel, msg=message) + + if args.distributed: + torch.cuda.set_device(local_rank) + logging.info(f"Process group initialized: {torch.distributed.is_initialized()}") + logging.info(f"Processes: {world_size}") + + try: + logging.info(f"MACE version: {mace.__version__}") + except AttributeError: + logging.info("Cannot find MACE version, please install MACE via pip") + logging.debug(f"Configuration: {args}") + + tools.set_default_dtype(args.default_dtype) + device = tools.init_device(args.device) + commit = print_git_commit() + model_foundation: Optional[torch.nn.Module] = None + foundation_model_avg_num_neighbors = 0 + if args.foundation_model is not None: + if args.foundation_model in ["small", "medium", "large"]: + logging.info( + f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint." + ) + calc = mace_mp( + model=args.foundation_model, + device=args.device, + default_dtype=args.default_dtype, + ) + model_foundation = calc.models[0] + elif args.foundation_model in ["small_off", "medium_off", "large_off"]: + model_type = args.foundation_model.split("_")[0] + logging.info( + f"Using foundation model mace-off-2023 {model_type} as initial checkpoint. ASL license." + ) + calc = mace_off( + model=model_type, + device=args.device, + default_dtype=args.default_dtype, + ) + model_foundation = calc.models[0] + else: + model_foundation = torch.load( + args.foundation_model, map_location=args.device + ) + logging.info( + f"Using foundation model {args.foundation_model} as initial checkpoint." + ) + args.r_max = model_foundation.r_max.item() + foundation_model_avg_num_neighbors = model_foundation.interactions[ + 0 + ].avg_num_neighbors + if ( + args.foundation_model not in ["small", "medium", "large"] + and args.pt_train_file is None + ): + logging.warning( + "Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file." + ) + args.multiheads_finetuning = False + if args.multiheads_finetuning: + assert ( + args.E0s != "average" + ), "average atomic energies cannot be used for multiheads finetuning" + # check that the foundation model has a single head, if not, use the first head + if not args.force_mh_ft_lr: + logging.info( + "Multihead finetuning mode, setting learning rate to 0.0001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True." + ) + args.lr = 0.0001 + args.ema = True + args.ema_decay = 0.99999 + logging.info( + "Using multiheads finetuning mode, setting learning rate to 0.0001 and EMA to True" + ) + if hasattr(model_foundation, "heads"): + if len(model_foundation.heads) > 1: + logging.warning( + "Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head." + ) + model_foundation = remove_pt_head( + model_foundation, args.foundation_head + ) + else: + args.multiheads_finetuning = False + + if args.heads is not None: + args.heads = ast.literal_eval(args.heads) + for _, head_dict in args.heads.items(): + # priority is global args < head property_key values < head info_keys+arrays_keys + head_keyspec = deepcopy(args.key_specification) + update_keyspec_from_kwargs(head_keyspec, head_dict) + head_keyspec.update( + info_keys=head_dict.get("info_keys", {}), + arrays_keys=head_dict.get("arrays_keys", {}), + ) + head_dict["key_specification"] = head_keyspec + else: + args.heads = prepare_default_head(args) + if args.multiheads_finetuning: + pt_keyspec = ( + args.heads["pt_head"]["key_specification"] + if "pt_head" in args.heads + else deepcopy(args.key_specification) + ) + args.heads["pt_head"] = prepare_pt_head( + args, pt_keyspec, foundation_model_avg_num_neighbors + ) + + logging.info("===========LOADING INPUT DATA===========") + heads = list(args.heads.keys()) + logging.info(f"Using heads: {heads}") + logging.info("Using the key specifications to parse data:") + for name, head_dict in args.heads.items(): + head_keyspec = head_dict["key_specification"] + logging.info(f"{name}: {head_keyspec}") + + head_configs: List[HeadConfig] = [] + for head, head_args in args.heads.items(): + logging.info(f"============= Processing head {head} ===========") + head_config = dict_head_to_dataclass(head_args, head, args) + + # Handle train_file and valid_file - normalize to lists + if hasattr(head_config, "train_file") and head_config.train_file is not None: + head_config.train_file = normalize_file_paths(head_config.train_file) + if hasattr(head_config, "valid_file") and head_config.valid_file is not None: + head_config.valid_file = normalize_file_paths(head_config.valid_file) + if hasattr(head_config, "test_file") and head_config.test_file is not None: + head_config.test_file = normalize_file_paths(head_config.test_file) + + if ( + head_config.statistics_file is not None + and head_config.head_name != "pt_head" + ): + with open(head_config.statistics_file, "r") as f: # pylint: disable=W1514 + statistics = json.load(f) + logging.info("Using statistics json file") + head_config.atomic_numbers = statistics["atomic_numbers"] + head_config.mean = statistics["mean"] + head_config.std = statistics["std"] + head_config.avg_num_neighbors = statistics["avg_num_neighbors"] + head_config.compute_avg_num_neighbors = False + if isinstance(statistics["atomic_energies"], str) and statistics[ + "atomic_energies" + ].endswith(".json"): + with open(statistics["atomic_energies"], "r", encoding="utf-8") as f: + atomic_energies = json.load(f) + head_config.E0s = atomic_energies + head_config.atomic_energies_dict = ast.literal_eval(atomic_energies) + else: + head_config.E0s = statistics["atomic_energies"] + head_config.atomic_energies_dict = ast.literal_eval( + statistics["atomic_energies"] + ) + if head_config.train_file == ["mp"]: + assert ( + head_config.head_name == "pt_head" + ), "Only pt_head should use mp as train_file" + logging.info( + "Using the full Materials Project data for replay. You can construct a different subset using `fine_tuning_select.py` script." + ) + collections = assemble_mp_data(args, head_config, tag) + head_config.collections = collections + elif any(check_path_ase_read(f) for f in head_config.train_file): + train_files_ase_list = [ + f for f in head_config.train_file if check_path_ase_read(f) + ] + valid_files_ase_list = None + test_files_ase_list = None + if head_config.valid_file: + valid_files_ase_list = [ + f for f in head_config.valid_file if check_path_ase_read(f) + ] + if head_config.test_file: + test_files_ase_list = [ + f for f in head_config.test_file if check_path_ase_read(f) + ] + config_type_weights = get_config_type_weights( + head_config.config_type_weights + ) + collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=train_files_ase_list, + valid_path=valid_files_ase_list, + valid_fraction=head_config.valid_fraction, + config_type_weights=config_type_weights, + test_path=test_files_ase_list, + seed=args.seed, + key_specification=head_config.key_specification, + head_name=head_config.head_name, + keep_isolated_atoms=head_config.keep_isolated_atoms, + ) + head_config.collections = SubsetCollection( + train=collections.train, + valid=collections.valid, + tests=collections.tests, + ) + head_config.atomic_energies_dict = atomic_energies_dict + logging.info( + f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " + f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," + ) + head_configs.append(head_config) + + if all( + check_path_ase_read(head_config.train_file[0]) for head_config in head_configs + ): + size_collections_train = sum( + len(head_config.collections.train) for head_config in head_configs + ) + size_collections_valid = sum( + len(head_config.collections.valid) for head_config in head_configs + ) + if size_collections_train < args.batch_size: + logging.error( + f"Batch size ({args.batch_size}) is larger than the number of training data ({size_collections_train})" + ) + if size_collections_valid < args.valid_batch_size: + logging.warning( + f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({size_collections_valid})" + ) + + if args.multiheads_finetuning: + logging.info( + "==================Using multiheads finetuning mode==================" + ) + args.loss = "universal" + + all_ase_readable = all( + all(check_path_ase_read(f) for f in head_config.train_file) + for head_config in head_configs + ) + head_config_pt = filter(lambda x: x.head_name == "pt_head", head_configs) + head_config_pt = next(head_config_pt, None) + assert head_config_pt is not None, "Pretraining head not found" + if all_ase_readable: + ratio_pt_ft = size_collections_train / len(head_config_pt.collections.train) + if ratio_pt_ft < 0.1: + logging.warning( + f"Ratio of the number of configurations in the training set and the in the pt_train_file is {ratio_pt_ft}, " + f"increasing the number of configurations in the fine-tuning heads by {int(0.1 / ratio_pt_ft)}" + ) + for head_config in head_configs: + if head_config.head_name == "pt_head": + continue + head_config.collections.train += ( + head_config.collections.train * int(0.1 / ratio_pt_ft) + ) + logging.info( + f"Total number of configurations in pretraining: train={len(head_config_pt.collections.train)}, valid={len(head_config_pt.collections.valid)}" + ) + else: + logging.debug( + "Using LMDB/HDF5 datasets for pretraining or fine-tuning - skipping ratio check" + ) + + # Atomic number table + # yapf: disable + for head_config in head_configs: + if head_config.atomic_numbers is None: + assert all(check_path_ase_read(f) for f in head_config.train_file), "Must specify atomic_numbers when using .h5 or .aselmdb train_file input" + z_table_head = tools.get_atomic_number_table_from_zs( + z + for configs in (head_config.collections.train, head_config.collections.valid) + for config in configs + for z in config.atomic_numbers + ) + head_config.atomic_numbers = z_table_head.zs + head_config.z_table = z_table_head + else: + if head_config.statistics_file is None: + logging.info("Using atomic numbers from command line argument") + else: + logging.info("Using atomic numbers from statistics file") + zs_list = ast.literal_eval(head_config.atomic_numbers) + assert isinstance(zs_list, list) + z_table_head = tools.AtomicNumberTable(zs_list) + head_config.atomic_numbers = zs_list + head_config.z_table = z_table_head + # yapf: enable + all_atomic_numbers = set() + for head_config in head_configs: + all_atomic_numbers.update(head_config.atomic_numbers) + z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) + if args.foundation_model_elements and model_foundation: + z_table = AtomicNumberTable(sorted(model_foundation.atomic_numbers.tolist())) + logging.info(f"Atomic Numbers used: {z_table.zs}") + + # Atomic energies + atomic_energies_dict = {} + for head_config in head_configs: + if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: + assert head_config.E0s is not None, "Atomic energies must be provided" + if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() != "foundation": + atomic_energies_dict[head_config.head_name] = get_atomic_energies( + head_config.E0s, head_config.collections.train, head_config.z_table + ) + elif head_config.E0s.lower() == "foundation": + assert args.foundation_model is not None + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") + atomic_energies_dict[head_config.head_name] = { + z: foundation_atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table.zs + } + else: + atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) + else: + atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict + + # Atomic energies for multiheads finetuning + if args.multiheads_finetuning: + assert ( + model_foundation is not None + ), "Model foundation must be provided for multiheads finetuning" + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") + atomic_energies_dict["pt_head"] = { + z: foundation_atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table.zs + } + heads = sorted(heads, key=lambda x: -1000 if x == "pt_head" else 0) + # Padding atomic energies if keeping all elements of the foundation model + if args.foundation_model_elements and model_foundation: + atomic_energies_dict_padded = {} + for head_name, head_energies in atomic_energies_dict.items(): + energy_head_padded = {} + for z in z_table.zs: + energy_head_padded[z] = head_energies.get(z, 0.0) + atomic_energies_dict_padded[head_name] = energy_head_padded + atomic_energies_dict = atomic_energies_dict_padded + + if args.model == "AtomicDipolesMACE": + atomic_energies = None + dipole_only = True + args.compute_dipole = True + args.compute_energy = False + args.compute_forces = False + args.compute_virials = False + args.compute_stress = False + else: + dipole_only = False + if args.model == "EnergyDipolesMACE": + args.compute_dipole = True + args.compute_energy = True + args.compute_forces = True + args.compute_virials = False + args.compute_stress = False + else: + args.compute_energy = True + args.compute_dipole = False + # atomic_energies: np.ndarray = np.array( + # [atomic_energies_dict[z] for z in z_table.zs] + # ) + atomic_energies = dict_to_array(atomic_energies_dict, heads) + for head_config in head_configs: + try: + logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}") + except KeyError as e: + raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e + + # Load datasets for each head, supporting multiple files per head + valid_sets = {head: [] for head in heads} + train_sets = {head: [] for head in heads} + + for head_config in head_configs: + train_datasets = [] + + logging.info(f"Processing datasets for head '{head_config.head_name}'") + ase_files = [f for f in head_config.train_file if check_path_ase_read(f)] + non_ase_files = [f for f in head_config.train_file if not check_path_ase_read(f)] + + if ase_files: + dataset = load_dataset_for_path( + file_path=ase_files, + r_max=args.r_max, + z_table=z_table, + head_config=head_config, + heads=heads, + collection=head_config.collections.train, + ) + train_datasets.append(dataset) + logging.debug(f"Successfully loaded dataset from ASE files: {ase_files}") + + for file in non_ase_files: + dataset = load_dataset_for_path( + file_path=file, + r_max=args.r_max, + z_table=z_table, + head_config=head_config, + heads=heads, + ) + train_datasets.append(dataset) + logging.debug(f"Successfully loaded dataset from non-ASE file: {file}") + + if not train_datasets: + raise ValueError(f"No valid training datasets found for head {head_config.head_name}") + + train_sets[head_config.head_name] = combine_datasets(train_datasets, head_config.head_name) + + if head_config.valid_file: + valid_datasets = [] + + valid_ase_files = [f for f in head_config.valid_file if check_path_ase_read(f)] + valid_non_ase_files = [f for f in head_config.valid_file if not check_path_ase_read(f)] + + if valid_ase_files: + valid_dataset = load_dataset_for_path( + file_path=valid_ase_files, + r_max=args.r_max, + z_table=z_table, + head_config=head_config, + heads=heads, + collection=head_config.collections.valid, + ) + valid_datasets.append(valid_dataset) + logging.debug(f"Successfully loaded validation dataset from ASE files: {valid_ase_files}") + for valid_file in valid_non_ase_files: + valid_dataset = load_dataset_for_path( + file_path=valid_file, + r_max=args.r_max, + z_table=z_table, + head_config=head_config, + heads=heads, + ) + valid_datasets.append(valid_dataset) + logging.debug(f"Successfully loaded validation dataset from {valid_file}") + + # Combine validation datasets + if valid_datasets: + valid_sets[head_config.head_name] = combine_datasets(valid_datasets, f"{head_config.head_name}_valid") + logging.info(f"Combined validation datasets for {head_config.head_name}") + + # If no valid file is provided but collection exist, use the validation set from the collection + if head_config.valid_file is None and head_config.collections.valid: + valid_sets[head_config.head_name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in head_config.collections.valid + ] + if not valid_sets[head_config.head_name]: + raise ValueError(f"No valid datasets found for head {head_config.head_name}, please provide a valid_file or a valid_fraction") + + # Create data loader for this head + if isinstance(train_sets[head_config.head_name], list): + dataset_size = len(train_sets[head_config.head_name]) + else: + dataset_size = len(train_sets[head_config.head_name]) + logging.info(f"Head '{head_config.head_name}' training dataset size: {dataset_size}") + + train_loader_head = torch_geometric.dataloader.DataLoader( + dataset=train_sets[head_config.head_name], + batch_size=args.batch_size, + shuffle=True, + drop_last=(not args.lbfgs), + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), + ) + head_config.train_loader = train_loader_head + + # concatenate all the trainsets + train_set = ConcatDataset([train_sets[head] for head in heads]) + train_sampler, valid_sampler = None, None + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=(not args.lbfgs), + seed=args.seed, + ) + valid_samplers = {} + for head, valid_set in valid_sets.items(): + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + valid_samplers[head] = valid_sampler + train_loader = torch_geometric.dataloader.DataLoader( + dataset=train_set, + batch_size=args.batch_size, + sampler=train_sampler, + shuffle=(train_sampler is None), + drop_last=(train_sampler is None and not args.lbfgs), + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), + ) + valid_loaders = {heads[i]: None for i in range(len(heads))} + if not isinstance(valid_sets, dict): + valid_sets = {"Default": valid_sets} + for head, valid_set in valid_sets.items(): + valid_loaders[head] = torch_geometric.dataloader.DataLoader( + dataset=valid_set, + batch_size=args.valid_batch_size, + sampler=valid_samplers[head] if args.distributed else None, + shuffle=False, + drop_last=False, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), + ) + + loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole) + args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device) + + # Model + model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table, head_configs) + model.to(device) + + logging.debug(model) + logging.info(f"Total number of parameters: {tools.count_parameters(model)}") + logging.info("") + logging.info("===========OPTIMIZER INFORMATION===========") + logging.info(f"Using {args.optimizer.upper()} as parameter optimizer") + logging.info(f"Batch size: {args.batch_size}") + if args.ema: + logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}") + logging.info( + f"Number of gradient updates: {int(args.max_num_epochs*len(train_set)/args.batch_size)}" + ) + logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") + logging.info(loss_fn) + + # Cueq + if args.enable_cueq: + logging.info("Converting model to CUEQ for accelerated training") + assert model.__class__.__name__ in ["MACE", "ScaleShiftMACE"] + model = run_e3nn_to_cueq(deepcopy(model), device=device) + # Optimizer + param_options = get_params_options(args, model) + optimizer: torch.optim.Optimizer + optimizer = get_optimizer(args, param_options) + if args.device == "xpu": + logging.info("Optimzing model and optimzier for XPU") + model, optimizer = ipex.optimize(model, optimizer=optimizer) + logger = tools.MetricsLogger( + directory=args.results_dir, tag=tag + "_train" + ) # pylint: disable=E1123 + + lr_scheduler = LRScheduler(optimizer, args) + + swa: Optional[tools.SWAContainer] = None + swas = [False] + if args.swa: + swa, swas = get_swa(args, model, optimizer, swas, dipole_only) + + checkpoint_handler = tools.CheckpointHandler( + directory=args.checkpoints_dir, + tag=tag, + keep=args.keep_checkpoints, + swa_start=args.start_swa, + ) + + start_epoch = 0 + restart_lbfgs = False + opt_start_epoch = None + if args.restart_latest: + try: + opt_start_epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=True, + device=device, + ) + except Exception: # pylint: disable=W0703 + try: + opt_start_epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=False, + device=device, + ) + except Exception: # pylint: disable=W0703 + restart_lbfgs = True + if opt_start_epoch is not None: + start_epoch = opt_start_epoch + + ema: Optional[ExponentialMovingAverage] = None + if args.ema: + ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay) + else: + for group in optimizer.param_groups: + group["lr"] = args.lr + + if args.lbfgs: + logging.info("Switching optimizer to LBFGS") + optimizer = LBFGS(model.parameters(), + history_size=200, + max_iter=20, + line_search_fn="strong_wolfe") + if restart_lbfgs: + opt_start_epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=False, + device=device, + ) + if opt_start_epoch is not None: + start_epoch = opt_start_epoch + + if args.wandb: + setup_wandb(args) + if args.distributed: + distributed_model = DDP(model, device_ids=[local_rank]) + else: + distributed_model = None + + + train_valid_data_loader = {} + for head_config in head_configs: + data_loader_name = "train_" + head_config.head_name + train_valid_data_loader[data_loader_name] = head_config.train_loader + for head, valid_loader in valid_loaders.items(): + data_load_name = "valid_" + head + train_valid_data_loader[data_load_name] = valid_loader + + if args.plot and args.plot_frequency > 0: + try: + plotter = TrainingPlotter( + results_dir=logger.path, + heads=heads, + table_type=args.error_table, + train_valid_data=train_valid_data_loader, + test_data={}, + output_args=output_args, + device=device, + plot_frequency=args.plot_frequency, + distributed=args.distributed, + swa_start=swa.start if swa else None + ) + except Exception as e: # pylint: disable=W0718 + logging.debug(f"Creating Plotter failed: {e}") + else: + plotter = None + + if args.dry_run: + logging.info("DRY RUN mode enabled. Stopping now.") + return + + + tools.train( + model=model, + loss_fn=loss_fn, + train_loader=train_loader, + valid_loaders=valid_loaders, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + checkpoint_handler=checkpoint_handler, + eval_interval=args.eval_interval, + start_epoch=start_epoch, + max_num_epochs=args.max_num_epochs, + logger=logger, + patience=args.patience, + save_all_checkpoints=args.save_all_checkpoints, + output_args=output_args, + device=device, + swa=swa, + ema=ema, + max_grad_norm=args.clip_grad, + log_errors=args.error_table, + log_wandb=args.wandb, + distributed=args.distributed, + distributed_model=distributed_model, + plotter=plotter, + train_sampler=train_sampler, + rank=rank, + ) + + logging.info("") + logging.info("===========RESULTS===========") + + train_valid_data_loader = {} + for head_config in head_configs: + data_loader_name = "train_" + head_config.head_name + train_valid_data_loader[data_loader_name] = head_config.train_loader + for head, valid_loader in valid_loaders.items(): + data_load_name = "valid_" + head + train_valid_data_loader[data_load_name] = valid_loader + test_sets = {} + stop_first_test = False + test_data_loader = {} + if all( + head_config.test_file == head_configs[0].test_file + for head_config in head_configs + ) and head_configs[0].test_file is not None: + stop_first_test = True + if all( + head_config.test_dir == head_configs[0].test_dir + for head_config in head_configs + ) and head_configs[0].test_dir is not None: + stop_first_test = True + for head_config in head_configs: + if all(check_path_ase_read(f) for f in head_config.train_file): + for name, subset in head_config.collections.tests: + test_sets[name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in subset + ] + if head_config.test_dir is not None: + if not args.multi_processed_test: + test_files = get_files_with_suffix(head_config.test_dir, "_test.h5") + for test_file in test_files: + name = os.path.splitext(os.path.basename(test_file))[0] + test_sets[name] = data.HDF5Dataset( + test_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + else: + test_folders = glob(head_config.test_dir + "/*") + for folder in test_folders: + name = os.path.splitext(os.path.basename(test_file))[0] + test_sets[name] = data.dataset_from_sharded_hdf5( + folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + for test_name, test_set in test_sets.items(): + test_sampler = None + if args.distributed: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + try: + drop_last = test_set.drop_last + except AttributeError as e: # pylint: disable=W0612 + drop_last = False + test_loader = torch_geometric.dataloader.DataLoader( + test_set, + batch_size=args.valid_batch_size, + shuffle=(test_sampler is None), + drop_last=drop_last, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + ) + test_data_loader[test_name] = test_loader + if stop_first_test: + break + + for swa_eval in swas: + epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=swa_eval, + device=device, + ) + model.to(device) + if args.distributed: + distributed_model = DDP(model, device_ids=[local_rank]) + model_to_evaluate = model if not args.distributed else distributed_model + if swa_eval: + logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") + else: + logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation") + + if rank == 0: + # Save entire model + if swa_eval: + model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model") + else: + model_path = Path(args.checkpoints_dir) / (tag + ".model") + logging.info(f"Saving model to {model_path}") + model_to_save = deepcopy(model) + if args.enable_cueq: + print("RUNING CUEQ TO E3NN") + print("swa_eval", swa_eval) + model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device) + if args.save_cpu: + model_to_save = model_to_save.to("cpu") + torch.save(model_to_save, model_path) + extra_files = { + "commit.txt": commit.encode("utf-8") if commit is not None else b"", + "config.yaml": json.dumps( + convert_to_json_format(extract_config_mace_model(model)) + ), + } + if swa_eval: + torch.save( + model_to_save, Path(args.model_dir) / (args.name + "_stagetwo.model") + ) + try: + path_complied = Path(args.model_dir) / ( + args.name + "_stagetwo_compiled.model" + ) + logging.info(f"Compiling model, saving metadata {path_complied}") + model_compiled = jit.compile(deepcopy(model_to_save)) + torch.jit.save( + model_compiled, + path_complied, + _extra_files=extra_files, + ) + except Exception as e: # pylint: disable=W0718 + pass + else: + torch.save(model_to_save, Path(args.model_dir) / (args.name + ".model")) + try: + path_complied = Path(args.model_dir) / ( + args.name + "_compiled.model" + ) + logging.info(f"Compiling model, saving metadata to {path_complied}") + model_compiled = jit.compile(deepcopy(model_to_save)) + torch.jit.save( + model_compiled, + path_complied, + _extra_files=extra_files, + ) + except Exception as e: # pylint: disable=W0718 + pass + + logging.info("Computing metrics for training, validation, and test sets") + for param in model.parameters(): + param.requires_grad = False + skip_heads = args.skip_evaluate_heads.split(",") if args.skip_evaluate_heads else [] + if skip_heads: + logging.info(f"Skipping evaluation for heads: {skip_heads}") + table_train_valid = create_error_table( + table_type=args.error_table, + all_data_loaders=train_valid_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + skip_heads=skip_heads, + ) + logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) + + if test_data_loader: + table_test = create_error_table( + table_type=args.error_table, + all_data_loaders=test_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + ) + logging.info("Error-table on TEST:\n" + str(table_test)) + if args.plot: + try: + plotter = TrainingPlotter( + results_dir=logger.path, + heads=heads, + table_type=args.error_table, + train_valid_data=train_valid_data_loader, + test_data=test_data_loader, + output_args=output_args, + device=device, + plot_frequency=args.plot_frequency, + distributed=args.distributed, + swa_start=swa.start if swa else None + ) + plotter.plot(epoch, model_to_evaluate, rank) + except Exception as e: # pylint: disable=W0718 + logging.debug(f"Plotting failed: {e}") + + if args.distributed: + torch.distributed.barrier() + + logging.info("Done") + if args.distributed: + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/select_head.py b/mace-bench/3rdparty/mace/mace/cli/select_head.py index adc568dfe0b780c87490fe92c8b05dd75c0c5fdd..305ab52feb3cce118b2702a95c941ea95ff457d4 100644 --- a/mace-bench/3rdparty/mace/mace/cli/select_head.py +++ b/mace-bench/3rdparty/mace/mace/cli/select_head.py @@ -1,60 +1,60 @@ -from argparse import ArgumentParser - -import torch - -from mace.tools.scripts_utils import remove_pt_head - - -def main(): - parser = ArgumentParser() - grp = parser.add_mutually_exclusive_group() - grp.add_argument( - "--head_name", - "-n", - help="name of the head to extract", - default=None, - ) - grp.add_argument( - "--list_heads", - "-l", - action="store_true", - help="list names of the heads", - ) - parser.add_argument( - "--target_device", - "-d", - help="target device, defaults to model's current device", - ) - parser.add_argument( - "--output_file", - "-o", - help="name for output model, defaults to model.head_name, followed by .target_device if specified", - ) - parser.add_argument("model_file", help="input model file path") - args = parser.parse_args() - - model = torch.load(args.model_file, map_location=args.target_device) - torch.set_default_dtype(next(model.parameters()).dtype) - - if args.list_heads: - print("Available heads:") - print("\n".join([" " + h for h in model.heads])) - else: - - if args.output_file is None: - args.output_file = ( - args.model_file - + "." - + args.head_name - + ("." + args.target_device if (args.target_device is not None) else "") - ) - - model_single = remove_pt_head(model, args.head_name) - if args.target_device is not None: - target_device = str(next(model.parameters()).device) - model_single.to(target_device) - torch.save(model_single, args.output_file) - - -if __name__ == "__main__": - main() +from argparse import ArgumentParser + +import torch + +from mace.tools.scripts_utils import remove_pt_head + + +def main(): + parser = ArgumentParser() + grp = parser.add_mutually_exclusive_group() + grp.add_argument( + "--head_name", + "-n", + help="name of the head to extract", + default=None, + ) + grp.add_argument( + "--list_heads", + "-l", + action="store_true", + help="list names of the heads", + ) + parser.add_argument( + "--target_device", + "-d", + help="target device, defaults to model's current device", + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model.head_name, followed by .target_device if specified", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + model = torch.load(args.model_file, map_location=args.target_device) + torch.set_default_dtype(next(model.parameters()).dtype) + + if args.list_heads: + print("Available heads:") + print("\n".join([" " + h for h in model.heads])) + else: + + if args.output_file is None: + args.output_file = ( + args.model_file + + "." + + args.head_name + + ("." + args.target_device if (args.target_device is not None) else "") + ) + + model_single = remove_pt_head(model, args.head_name) + if args.target_device is not None: + target_device = str(next(model.parameters()).device) + model_single.to(target_device) + torch.save(model_single, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/visualise_train.py b/mace-bench/3rdparty/mace/mace/cli/visualise_train.py index 679d8630652e8858794ad53242edc1cde7397ddf..c380c70dd7599681298936015c986266c690abbe 100644 --- a/mace-bench/3rdparty/mace/mace/cli/visualise_train.py +++ b/mace-bench/3rdparty/mace/mace/cli/visualise_train.py @@ -1,640 +1,640 @@ -import json -import logging -from typing import Dict, List, Optional - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import torch -import torch.distributed -from torchmetrics import Metric - -plt.rcParams.update({"font.size": 8}) -mpl_logger = logging.getLogger("matplotlib") -mpl_logger.setLevel(logging.WARNING) # Only show WARNING and above - -colors = [ - "#1f77b4", # muted blue - "#d62728", # brick red - "#7f7f7f", # middle gray - "#2ca02c", # cooked asparagus green - "#ff7f0e", # safety orange - "#9467bd", # muted purple - "#8c564b", # chestnut brown - "#e377c2", # raspberry yogurt pink - "#bcbd22", # curry yellow-green - "#17becf", # blue-teal -] - -error_type = { - "TotalRMSE": ( - [("rmse_e", "RMSE E [meV]"), ("rmse_f", "RMSE F [meV / A]")], - [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], - ), - "PerAtomRMSE": ( - [("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_f", "RMSE F [meV / A]")], - [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], - ), - "PerAtomRMSEstressvirials": ( - [ - ("rmse_e_per_atom", "RMSE E/atom [meV]"), - ("rmse_f", "RMSE F [meV / A]"), - ("rmse_stress", "RMSE Stress [meV / A^3]"), - ], - [ - ("energy", "Energy per atom [eV]"), - ("force", "Force [eV / A]"), - ("stress", "Stress [eV / A^3]"), - ], - ), - "PerAtomMAEstressvirials": ( - [ - ("mae_e_per_atom", "MAE E/atom [meV]"), - ("mae_f", "MAE F [meV / A]"), - ("mae_stress", "MAE Stress [meV / A^3]"), - ], - [ - ("energy", "Energy per atom [eV]"), - ("force", "Force [eV / A]"), - ("stress", "Stress [eV / A^3]"), - ], - ), - "TotalMAE": ( - [("mae_e", "MAE E [meV]"), ("mae_f", "MAE F [meV / A]")], - [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], - ), - "PerAtomMAE": ( - [("mae_e_per_atom", "MAE E/atom [meV]"), ("mae_f", "MAE F [meV / A]")], - [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], - ), - "DipoleRMSE": ( - [ - ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), - ("rel_rmse_f", "Relative MU RMSE [%]"), - ], - [("dipole", "Dipole per atom [Debye]")], - ), - "DipoleMAE": ( - [("mae_mu", "MAE MU [mDebye]"), ("rel_mae_f", "Relative MU MAE [%]")], - [("dipole", "Dipole per atom [Debye]")], - ), - "EnergyDipoleRMSE": ( - [ - ("rmse_e_per_atom", "RMSE E/atom [meV]"), - ("rmse_f", "RMSE F [meV / A]"), - ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), - ], - [ - ("energy", "Energy per atom [eV]"), - ("force", "Force [eV / A]"), - ("dipole", "Dipole per atom [Debye]"), - ], - ), -} - - -class TrainingPlotter: - def __init__( - self, - results_dir: str, - heads: List[str], - table_type: str, - train_valid_data: Dict, - test_data: Dict, - output_args: str, - device: str, - plot_frequency: int, - distributed: bool = False, - swa_start: Optional[int] = None, - ): - self.results_dir = results_dir - self.heads = heads - self.table_type = table_type - self.train_valid_data = train_valid_data - self.test_data = test_data - self.output_args = output_args - self.device = device - self.plot_frequency = plot_frequency - self.distributed = distributed - self.swa_start = swa_start - - def plot(self, model_epoch: str, model: torch.nn.Module, rank: int) -> None: - - # All ranks process data through model_inference - train_valid_dict = model_inference( - self.train_valid_data, - model, - self.output_args, - self.device, - self.distributed, - ) - test_dict = model_inference( - self.test_data, model, self.output_args, self.device, self.distributed - ) - - # Only rank 0 creates and saves plots - if rank != 0: - return - - data = pd.DataFrame( - results for results in parse_training_results(self.results_dir) - ) - labels, quantities = error_type[self.table_type] - - for head in self.heads: - fig = plt.figure(layout="constrained", figsize=(10, 6)) - fig.suptitle( - f"Model loaded from epoch {model_epoch} ({head} head)", fontsize=16 - ) - - subfigs = fig.subfigures(2, 1, height_ratios=[1, 1], hspace=0.05) - axsTop = subfigs[0].subplots(1, 2, sharey=False) - axsBottom = subfigs[1].subplots(1, len(quantities), sharey=False) - - plot_epoch_dependence(axsTop, data, head, model_epoch, labels) - - # Use the pre-computed results for plotting - plot_inference_from_results( - axsBottom, train_valid_dict, test_dict, head, quantities - ) - - if self.swa_start is not None: - # Add vertical lines to both axes - for ax in axsTop: - ax.axvline( - self.swa_start, - color="black", - linestyle="dashed", - linewidth=1, - alpha=0.6, - label="Stage Two Starts", - ) - stage = "stage_two" if self.swa_start < model_epoch else "stage_one" - else: - stage = "stage_one" - axsTop[0].legend(loc="best") - # Save the figure using the appropriate stage in the filename - filename = f"{self.results_dir[:-4]}_{head}_{stage}.png" - - fig.savefig(filename, dpi=300, bbox_inches="tight") - plt.close(fig) - - -def parse_training_results(path: str) -> List[dict]: - results = [] - with open(path, mode="r", encoding="utf-8") as f: - for line in f: - try: - d = json.loads(line.strip()) # Ensure it's valid JSON - results.append(d) - except json.JSONDecodeError: - print( - f"Skipping invalid line: {line.strip()}" - ) # Handle non-JSON lines gracefully - return results - - -def plot_epoch_dependence( - axes: np.ndarray, data: pd.DataFrame, head: str, model_epoch: str, labels: List[str] -) -> None: - - valid_data = ( - data[data["mode"] == "eval"] - .groupby(["mode", "epoch", "head"]) - .agg(["mean", "std"]) - .reset_index() - ) - valid_data = valid_data[valid_data["head"] == head] - train_data = ( - data[data["mode"] == "opt"] - .groupby(["mode", "epoch"]) - .agg(["mean", "std"]) - .reset_index() - ) - - # ---- Plot loss ---- - ax = axes[0] - ax.plot( - train_data["epoch"], train_data["loss"]["mean"], color=colors[1], linewidth=1 - ) - ax.set_ylabel("Training Loss", color=colors[1]) - ax.set_yscale("log") - - ax2 = ax.twinx() - ax2.plot( - valid_data["epoch"], valid_data["loss"]["mean"], color=colors[0], linewidth=1 - ) - ax2.set_ylabel("Validation Loss", color=colors[0]) - ax2.set_yscale("log") - - ax.axvline( - model_epoch, - color="black", - linestyle="solid", - linewidth=1, - alpha=0.8, - label="Loaded Model", - ) - ax.set_xlabel("Epoch") - ax.grid(True, linestyle="--", alpha=0.5) - - # ---- Plot selected keys ---- - ax = axes[1] - twin_axes = [] - for i, label in enumerate(labels): - color = colors[(i + 3)] - key, axis_label = label - if i == 0: - main_ax = ax - else: - main_ax = ax.twinx() - main_ax.spines.right.set_position(("outward", 60 * (i - 1))) - twin_axes.append(main_ax) - - main_ax.plot( - valid_data["epoch"], - valid_data[key]["mean"] * 1e3, - color=color, - label=label, - linewidth=1, - ) - main_ax.set_yscale("log") - main_ax.set_ylabel(axis_label, color=color) - main_ax.tick_params(axis="y", colors=color) - ax.axvline( - model_epoch, - color="black", - linestyle="solid", - linewidth=1, - alpha=0.8, - label="Loaded Model", - ) - ax.set_xlabel("Epoch") - ax.grid(True, linestyle="--", alpha=0.5) - - -# INFERENCE========= - - -def plot_inference_from_results( - axes: np.ndarray, - train_valid_dict: dict, - test_dict: dict, - head: str, - quantities: List[str], -) -> None: - - for ax, quantity in zip(axes, quantities): - key, label = quantity - - # Store legend handles to avoid duplicates - legend_labels = {} - - # Plot train/valid data (each entry keeps its own name) - for name, result in train_valid_dict.items(): - if "train" in name: - fixed_color_train_valid = colors[1] - marker = "x" - else: - fixed_color_train_valid = colors[0] - marker = "+" - if head not in name: - continue - - # Initialize scatter to None - scatter = None - - if key == "energy" and "energy" in result: - scatter = ax.scatter( - result["energy"]["reference_per_atom"], - result["energy"]["predicted_per_atom"], - marker=marker, - color=fixed_color_train_valid, - label=name, - ) - - elif key == "force" and "forces" in result: - scatter = ax.scatter( - result["forces"]["reference"], - result["forces"]["predicted"], - marker=marker, - color=fixed_color_train_valid, - label=name, - ) - - elif key == "stress" and "stress" in result: - scatter = ax.scatter( - result["stress"]["reference"], - result["stress"]["predicted"], - marker=marker, - color=fixed_color_train_valid, - label=name, - ) - - elif key == "virials" and "virials" in result: - scatter = ax.scatter( - result["virials"]["reference_per_atom"], - result["virials"]["predicted_per_atom"], - marker=marker, - color=fixed_color_train_valid, - label=name, - ) - - elif key == "dipole" and "dipole" in result: - scatter = ax.scatter( - result["dipole"]["reference_per_atom"], - result["dipole"]["predicted_per_atom"], - marker=marker, - color=fixed_color_train_valid, - label=name, - ) - - # Add each train/valid dataset's name to the legend if scatter was assigned - if scatter is not None: - legend_labels[name] = scatter - - fixed_color_test = colors[2] # Color for test dataset - - # Plot test data (single legend entry) - for name, result in test_dict.items(): - # Initialize scatter to None to avoid possibly used before assignment - scatter = None - - if key == "energy" and "energy" in result: - scatter = ax.scatter( - result["energy"]["reference_per_atom"], - result["energy"]["predicted_per_atom"], - marker="o", - color=fixed_color_test, - label="Test", - ) - - elif key == "force" and "forces" in result: - scatter = ax.scatter( - result["forces"]["reference"], - result["forces"]["predicted"], - marker="o", - color=fixed_color_test, - label="Test", - ) - - elif key == "stress" and "stress" in result: - scatter = ax.scatter( - result["stress"]["reference"], - result["stress"]["predicted"], - marker="o", - color=fixed_color_test, - label="Test", - ) - - elif key == "virials" and "virials" in result: - scatter = ax.scatter( - result["virials"]["reference_per_atom"], - result["virials"]["predicted_per_atom"], - marker="o", - color=fixed_color_test, - label="Test", - ) - - elif key == "dipole" and "dipole" in result: - scatter = ax.scatter( - result["dipole"]["reference_per_atom"], - result["dipole"]["predicted_per_atom"], - marker="o", - color=fixed_color_test, - label="Test", - ) - - # Only add to legend_labels if scatter was assigned - if scatter is not None: - legend_labels["Test"] = scatter - - # Add diagonal line for guide - min_val = min(ax.get_xlim()[0], ax.get_ylim()[0]) - max_val = max(ax.get_xlim()[1], ax.get_ylim()[1]) - ax.plot( - [min_val, max_val], - [min_val, max_val], - linestyle="--", - color="black", - alpha=0.7, - ) - - # Set legend with unique entries (Test + individual train/valid names) - if legend_labels: - ax.legend( - handles=legend_labels.values(), labels=legend_labels.keys(), loc="best" - ) - ax.set_xlabel(f"Reference {label}") - ax.set_ylabel(f"MACE {label}") - ax.grid(True, linestyle="--", alpha=0.5) - - -def model_inference( - all_data_loaders: dict, - model: torch.nn.Module, - output_args: Dict[str, bool], - device: str, - distributed: bool = False, -): - - for param in model.parameters(): - param.requires_grad = False - - results_dict = {} - - for name in all_data_loaders: - data_loader = all_data_loaders[name] - logging.debug(f"Running inference on {name} dataset") - scatter_metric = InferenceMetric().to(device) - - for batch in data_loader: - batch = batch.to(device) - batch_dict = batch.to_dict() - output = model( - batch_dict, - training=False, - compute_force=output_args.get("forces", False), - compute_virials=output_args.get("virials", False), - compute_stress=output_args.get("stress", False), - ) - - results = scatter_metric(batch, output) - - if distributed: - torch.distributed.barrier() - - results = scatter_metric.compute() - results_dict[name] = results - scatter_metric.reset() - - del data_loader - - for param in model.parameters(): - param.requires_grad = True - - return results_dict - - -def to_numpy(tensor: torch.Tensor) -> np.ndarray: - return tensor.cpu().detach().numpy() - - -class InferenceMetric(Metric): - """Metric class for collecting reference and predicted values for scatterplot visualization.""" - - def __init__(self): - super().__init__() - # Raw values - self.add_state("ref_energies", default=[], dist_reduce_fx="cat") - self.add_state("pred_energies", default=[], dist_reduce_fx="cat") - self.add_state("ref_forces", default=[], dist_reduce_fx="cat") - self.add_state("pred_forces", default=[], dist_reduce_fx="cat") - self.add_state("ref_stress", default=[], dist_reduce_fx="cat") - self.add_state("pred_stress", default=[], dist_reduce_fx="cat") - self.add_state("ref_virials", default=[], dist_reduce_fx="cat") - self.add_state("pred_virials", default=[], dist_reduce_fx="cat") - self.add_state("ref_dipole", default=[], dist_reduce_fx="cat") - self.add_state("pred_dipole", default=[], dist_reduce_fx="cat") - - # Per-atom normalized values - self.add_state("ref_energies_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("pred_energies_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("ref_virials_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("pred_virials_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("ref_dipole_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("pred_dipole_per_atom", default=[], dist_reduce_fx="cat") - - # Store atom counts for each configuration - self.add_state("atom_counts", default=[], dist_reduce_fx="cat") - - # Counters - self.add_state("n_energy", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("n_forces", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("n_stress", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("n_virials", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("n_dipole", default=torch.tensor(0.0), dist_reduce_fx="sum") - - def update(self, batch, output): # pylint: disable=arguments-differ - """Update metric states with new batch data.""" - # Calculate number of atoms per configuration - atoms_per_config = batch.ptr[1:] - batch.ptr[:-1] - self.atom_counts.append(atoms_per_config) - - # Energy - if output.get("energy") is not None and batch.energy is not None: - self.n_energy += 1.0 - self.ref_energies.append(batch.energy) - self.pred_energies.append(output["energy"]) - # Per-atom normalization - self.ref_energies_per_atom.append(batch.energy / atoms_per_config) - self.pred_energies_per_atom.append(output["energy"] / atoms_per_config) - - # Forces - if output.get("forces") is not None and batch.forces is not None: - self.n_forces += 1.0 - self.ref_forces.append(batch.forces) - self.pred_forces.append(output["forces"]) - - # Stress - if output.get("stress") is not None and batch.stress is not None: - self.n_stress += 1.0 - self.ref_stress.append(batch.stress) - self.pred_stress.append(output["stress"]) - - # Virials - if output.get("virials") is not None and batch.virials is not None: - self.n_virials += 1.0 - self.ref_virials.append(batch.virials) - self.pred_virials.append(output["virials"]) - # Per-atom normalization - atoms_per_config_3d = atoms_per_config.view(-1, 1, 1) - self.ref_virials_per_atom.append(batch.virials / atoms_per_config_3d) - self.pred_virials_per_atom.append(output["virials"] / atoms_per_config_3d) - - # Dipole - if output.get("dipole") is not None and batch.dipole is not None: - self.n_dipole += 1.0 - self.ref_dipole.append(batch.dipole) - self.pred_dipole.append(output["dipole"]) - atoms_per_config_3d = atoms_per_config.view(-1, 1) - self.ref_dipole_per_atom.append(batch.dipole / atoms_per_config_3d) - self.pred_dipole_per_atom.append(output["dipole"] / atoms_per_config_3d) - - def _process_data(self, ref_list, pred_list): - # Handle different possible states of ref_list and pred_list in distributed mode - - # Check if this is a list type object - if isinstance(ref_list, (list, tuple)): - if len(ref_list) == 0: - return None, None - ref = torch.cat(ref_list).reshape(-1) - pred = torch.cat(pred_list).reshape(-1) - # Handle case where ref_list is already a tensor (happens after reset in distributed mode) - elif isinstance(ref_list, torch.Tensor): - ref = ref_list.reshape(-1) - pred = pred_list.reshape(-1) - # Handle other possible types - else: - return None, None - return to_numpy(ref), to_numpy(pred) - - def compute(self): - """Compute final results for scatterplot.""" - results = {} - - # Process energies - if self.n_energy: - ref_e, pred_e = self._process_data(self.ref_energies, self.pred_energies) - ref_e_pa, pred_e_pa = self._process_data( - self.ref_energies_per_atom, self.pred_energies_per_atom - ) - results["energy"] = { - "reference": ref_e, - "predicted": pred_e, - "reference_per_atom": ref_e_pa, - "predicted_per_atom": pred_e_pa, - } - - # Process forces - if self.n_forces: - ref_f, pred_f = self._process_data(self.ref_forces, self.pred_forces) - results["forces"] = { - "reference": ref_f, - "predicted": pred_f, - } - - # Process stress - if self.n_stress: - ref_s, pred_s = self._process_data(self.ref_stress, self.pred_stress) - results["stress"] = { - "reference": ref_s, - "predicted": pred_s, - } - - # Process virials - if self.n_virials: - ref_v, pred_v = self._process_data(self.ref_virials, self.pred_virials) - ref_v_pa, pred_v_pa = self._process_data( - self.ref_virials_per_atom, self.pred_virials_per_atom - ) - results["virials"] = { - "reference": ref_v, - "predicted": pred_v, - "reference_per_atom": ref_v_pa, - "predicted_per_atom": pred_v_pa, - } - - # Process dipoles - if self.n_dipole: - ref_d, pred_d = self._process_data(self.ref_dipole, self.pred_dipole) - ref_d_pa, pred_d_pa = self._process_data( - self.ref_dipole_per_atom, self.pred_dipole_per_atom - ) - results["dipole"] = { - "reference": ref_d, - "predicted": pred_d, - "reference_per_atom": ref_d_pa, - "predicted_per_atom": pred_d_pa, - } - return results +import json +import logging +from typing import Dict, List, Optional + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.distributed +from torchmetrics import Metric + +plt.rcParams.update({"font.size": 8}) +mpl_logger = logging.getLogger("matplotlib") +mpl_logger.setLevel(logging.WARNING) # Only show WARNING and above + +colors = [ + "#1f77b4", # muted blue + "#d62728", # brick red + "#7f7f7f", # middle gray + "#2ca02c", # cooked asparagus green + "#ff7f0e", # safety orange + "#9467bd", # muted purple + "#8c564b", # chestnut brown + "#e377c2", # raspberry yogurt pink + "#bcbd22", # curry yellow-green + "#17becf", # blue-teal +] + +error_type = { + "TotalRMSE": ( + [("rmse_e", "RMSE E [meV]"), ("rmse_f", "RMSE F [meV / A]")], + [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], + ), + "PerAtomRMSE": ( + [("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_f", "RMSE F [meV / A]")], + [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], + ), + "PerAtomRMSEstressvirials": ( + [ + ("rmse_e_per_atom", "RMSE E/atom [meV]"), + ("rmse_f", "RMSE F [meV / A]"), + ("rmse_stress", "RMSE Stress [meV / A^3]"), + ], + [ + ("energy", "Energy per atom [eV]"), + ("force", "Force [eV / A]"), + ("stress", "Stress [eV / A^3]"), + ], + ), + "PerAtomMAEstressvirials": ( + [ + ("mae_e_per_atom", "MAE E/atom [meV]"), + ("mae_f", "MAE F [meV / A]"), + ("mae_stress", "MAE Stress [meV / A^3]"), + ], + [ + ("energy", "Energy per atom [eV]"), + ("force", "Force [eV / A]"), + ("stress", "Stress [eV / A^3]"), + ], + ), + "TotalMAE": ( + [("mae_e", "MAE E [meV]"), ("mae_f", "MAE F [meV / A]")], + [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], + ), + "PerAtomMAE": ( + [("mae_e_per_atom", "MAE E/atom [meV]"), ("mae_f", "MAE F [meV / A]")], + [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], + ), + "DipoleRMSE": ( + [ + ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), + ("rel_rmse_f", "Relative MU RMSE [%]"), + ], + [("dipole", "Dipole per atom [Debye]")], + ), + "DipoleMAE": ( + [("mae_mu", "MAE MU [mDebye]"), ("rel_mae_f", "Relative MU MAE [%]")], + [("dipole", "Dipole per atom [Debye]")], + ), + "EnergyDipoleRMSE": ( + [ + ("rmse_e_per_atom", "RMSE E/atom [meV]"), + ("rmse_f", "RMSE F [meV / A]"), + ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), + ], + [ + ("energy", "Energy per atom [eV]"), + ("force", "Force [eV / A]"), + ("dipole", "Dipole per atom [Debye]"), + ], + ), +} + + +class TrainingPlotter: + def __init__( + self, + results_dir: str, + heads: List[str], + table_type: str, + train_valid_data: Dict, + test_data: Dict, + output_args: str, + device: str, + plot_frequency: int, + distributed: bool = False, + swa_start: Optional[int] = None, + ): + self.results_dir = results_dir + self.heads = heads + self.table_type = table_type + self.train_valid_data = train_valid_data + self.test_data = test_data + self.output_args = output_args + self.device = device + self.plot_frequency = plot_frequency + self.distributed = distributed + self.swa_start = swa_start + + def plot(self, model_epoch: str, model: torch.nn.Module, rank: int) -> None: + + # All ranks process data through model_inference + train_valid_dict = model_inference( + self.train_valid_data, + model, + self.output_args, + self.device, + self.distributed, + ) + test_dict = model_inference( + self.test_data, model, self.output_args, self.device, self.distributed + ) + + # Only rank 0 creates and saves plots + if rank != 0: + return + + data = pd.DataFrame( + results for results in parse_training_results(self.results_dir) + ) + labels, quantities = error_type[self.table_type] + + for head in self.heads: + fig = plt.figure(layout="constrained", figsize=(10, 6)) + fig.suptitle( + f"Model loaded from epoch {model_epoch} ({head} head)", fontsize=16 + ) + + subfigs = fig.subfigures(2, 1, height_ratios=[1, 1], hspace=0.05) + axsTop = subfigs[0].subplots(1, 2, sharey=False) + axsBottom = subfigs[1].subplots(1, len(quantities), sharey=False) + + plot_epoch_dependence(axsTop, data, head, model_epoch, labels) + + # Use the pre-computed results for plotting + plot_inference_from_results( + axsBottom, train_valid_dict, test_dict, head, quantities + ) + + if self.swa_start is not None: + # Add vertical lines to both axes + for ax in axsTop: + ax.axvline( + self.swa_start, + color="black", + linestyle="dashed", + linewidth=1, + alpha=0.6, + label="Stage Two Starts", + ) + stage = "stage_two" if self.swa_start < model_epoch else "stage_one" + else: + stage = "stage_one" + axsTop[0].legend(loc="best") + # Save the figure using the appropriate stage in the filename + filename = f"{self.results_dir[:-4]}_{head}_{stage}.png" + + fig.savefig(filename, dpi=300, bbox_inches="tight") + plt.close(fig) + + +def parse_training_results(path: str) -> List[dict]: + results = [] + with open(path, mode="r", encoding="utf-8") as f: + for line in f: + try: + d = json.loads(line.strip()) # Ensure it's valid JSON + results.append(d) + except json.JSONDecodeError: + print( + f"Skipping invalid line: {line.strip()}" + ) # Handle non-JSON lines gracefully + return results + + +def plot_epoch_dependence( + axes: np.ndarray, data: pd.DataFrame, head: str, model_epoch: str, labels: List[str] +) -> None: + + valid_data = ( + data[data["mode"] == "eval"] + .groupby(["mode", "epoch", "head"]) + .agg(["mean", "std"]) + .reset_index() + ) + valid_data = valid_data[valid_data["head"] == head] + train_data = ( + data[data["mode"] == "opt"] + .groupby(["mode", "epoch"]) + .agg(["mean", "std"]) + .reset_index() + ) + + # ---- Plot loss ---- + ax = axes[0] + ax.plot( + train_data["epoch"], train_data["loss"]["mean"], color=colors[1], linewidth=1 + ) + ax.set_ylabel("Training Loss", color=colors[1]) + ax.set_yscale("log") + + ax2 = ax.twinx() + ax2.plot( + valid_data["epoch"], valid_data["loss"]["mean"], color=colors[0], linewidth=1 + ) + ax2.set_ylabel("Validation Loss", color=colors[0]) + ax2.set_yscale("log") + + ax.axvline( + model_epoch, + color="black", + linestyle="solid", + linewidth=1, + alpha=0.8, + label="Loaded Model", + ) + ax.set_xlabel("Epoch") + ax.grid(True, linestyle="--", alpha=0.5) + + # ---- Plot selected keys ---- + ax = axes[1] + twin_axes = [] + for i, label in enumerate(labels): + color = colors[(i + 3)] + key, axis_label = label + if i == 0: + main_ax = ax + else: + main_ax = ax.twinx() + main_ax.spines.right.set_position(("outward", 60 * (i - 1))) + twin_axes.append(main_ax) + + main_ax.plot( + valid_data["epoch"], + valid_data[key]["mean"] * 1e3, + color=color, + label=label, + linewidth=1, + ) + main_ax.set_yscale("log") + main_ax.set_ylabel(axis_label, color=color) + main_ax.tick_params(axis="y", colors=color) + ax.axvline( + model_epoch, + color="black", + linestyle="solid", + linewidth=1, + alpha=0.8, + label="Loaded Model", + ) + ax.set_xlabel("Epoch") + ax.grid(True, linestyle="--", alpha=0.5) + + +# INFERENCE========= + + +def plot_inference_from_results( + axes: np.ndarray, + train_valid_dict: dict, + test_dict: dict, + head: str, + quantities: List[str], +) -> None: + + for ax, quantity in zip(axes, quantities): + key, label = quantity + + # Store legend handles to avoid duplicates + legend_labels = {} + + # Plot train/valid data (each entry keeps its own name) + for name, result in train_valid_dict.items(): + if "train" in name: + fixed_color_train_valid = colors[1] + marker = "x" + else: + fixed_color_train_valid = colors[0] + marker = "+" + if head not in name: + continue + + # Initialize scatter to None + scatter = None + + if key == "energy" and "energy" in result: + scatter = ax.scatter( + result["energy"]["reference_per_atom"], + result["energy"]["predicted_per_atom"], + marker=marker, + color=fixed_color_train_valid, + label=name, + ) + + elif key == "force" and "forces" in result: + scatter = ax.scatter( + result["forces"]["reference"], + result["forces"]["predicted"], + marker=marker, + color=fixed_color_train_valid, + label=name, + ) + + elif key == "stress" and "stress" in result: + scatter = ax.scatter( + result["stress"]["reference"], + result["stress"]["predicted"], + marker=marker, + color=fixed_color_train_valid, + label=name, + ) + + elif key == "virials" and "virials" in result: + scatter = ax.scatter( + result["virials"]["reference_per_atom"], + result["virials"]["predicted_per_atom"], + marker=marker, + color=fixed_color_train_valid, + label=name, + ) + + elif key == "dipole" and "dipole" in result: + scatter = ax.scatter( + result["dipole"]["reference_per_atom"], + result["dipole"]["predicted_per_atom"], + marker=marker, + color=fixed_color_train_valid, + label=name, + ) + + # Add each train/valid dataset's name to the legend if scatter was assigned + if scatter is not None: + legend_labels[name] = scatter + + fixed_color_test = colors[2] # Color for test dataset + + # Plot test data (single legend entry) + for name, result in test_dict.items(): + # Initialize scatter to None to avoid possibly used before assignment + scatter = None + + if key == "energy" and "energy" in result: + scatter = ax.scatter( + result["energy"]["reference_per_atom"], + result["energy"]["predicted_per_atom"], + marker="o", + color=fixed_color_test, + label="Test", + ) + + elif key == "force" and "forces" in result: + scatter = ax.scatter( + result["forces"]["reference"], + result["forces"]["predicted"], + marker="o", + color=fixed_color_test, + label="Test", + ) + + elif key == "stress" and "stress" in result: + scatter = ax.scatter( + result["stress"]["reference"], + result["stress"]["predicted"], + marker="o", + color=fixed_color_test, + label="Test", + ) + + elif key == "virials" and "virials" in result: + scatter = ax.scatter( + result["virials"]["reference_per_atom"], + result["virials"]["predicted_per_atom"], + marker="o", + color=fixed_color_test, + label="Test", + ) + + elif key == "dipole" and "dipole" in result: + scatter = ax.scatter( + result["dipole"]["reference_per_atom"], + result["dipole"]["predicted_per_atom"], + marker="o", + color=fixed_color_test, + label="Test", + ) + + # Only add to legend_labels if scatter was assigned + if scatter is not None: + legend_labels["Test"] = scatter + + # Add diagonal line for guide + min_val = min(ax.get_xlim()[0], ax.get_ylim()[0]) + max_val = max(ax.get_xlim()[1], ax.get_ylim()[1]) + ax.plot( + [min_val, max_val], + [min_val, max_val], + linestyle="--", + color="black", + alpha=0.7, + ) + + # Set legend with unique entries (Test + individual train/valid names) + if legend_labels: + ax.legend( + handles=legend_labels.values(), labels=legend_labels.keys(), loc="best" + ) + ax.set_xlabel(f"Reference {label}") + ax.set_ylabel(f"MACE {label}") + ax.grid(True, linestyle="--", alpha=0.5) + + +def model_inference( + all_data_loaders: dict, + model: torch.nn.Module, + output_args: Dict[str, bool], + device: str, + distributed: bool = False, +): + + for param in model.parameters(): + param.requires_grad = False + + results_dict = {} + + for name in all_data_loaders: + data_loader = all_data_loaders[name] + logging.debug(f"Running inference on {name} dataset") + scatter_metric = InferenceMetric().to(device) + + for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=False, + compute_force=output_args.get("forces", False), + compute_virials=output_args.get("virials", False), + compute_stress=output_args.get("stress", False), + ) + + results = scatter_metric(batch, output) + + if distributed: + torch.distributed.barrier() + + results = scatter_metric.compute() + results_dict[name] = results + scatter_metric.reset() + + del data_loader + + for param in model.parameters(): + param.requires_grad = True + + return results_dict + + +def to_numpy(tensor: torch.Tensor) -> np.ndarray: + return tensor.cpu().detach().numpy() + + +class InferenceMetric(Metric): + """Metric class for collecting reference and predicted values for scatterplot visualization.""" + + def __init__(self): + super().__init__() + # Raw values + self.add_state("ref_energies", default=[], dist_reduce_fx="cat") + self.add_state("pred_energies", default=[], dist_reduce_fx="cat") + self.add_state("ref_forces", default=[], dist_reduce_fx="cat") + self.add_state("pred_forces", default=[], dist_reduce_fx="cat") + self.add_state("ref_stress", default=[], dist_reduce_fx="cat") + self.add_state("pred_stress", default=[], dist_reduce_fx="cat") + self.add_state("ref_virials", default=[], dist_reduce_fx="cat") + self.add_state("pred_virials", default=[], dist_reduce_fx="cat") + self.add_state("ref_dipole", default=[], dist_reduce_fx="cat") + self.add_state("pred_dipole", default=[], dist_reduce_fx="cat") + + # Per-atom normalized values + self.add_state("ref_energies_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("pred_energies_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("ref_virials_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("pred_virials_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("ref_dipole_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("pred_dipole_per_atom", default=[], dist_reduce_fx="cat") + + # Store atom counts for each configuration + self.add_state("atom_counts", default=[], dist_reduce_fx="cat") + + # Counters + self.add_state("n_energy", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("n_forces", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("n_stress", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("n_virials", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("n_dipole", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, batch, output): # pylint: disable=arguments-differ + """Update metric states with new batch data.""" + # Calculate number of atoms per configuration + atoms_per_config = batch.ptr[1:] - batch.ptr[:-1] + self.atom_counts.append(atoms_per_config) + + # Energy + if output.get("energy") is not None and batch.energy is not None: + self.n_energy += 1.0 + self.ref_energies.append(batch.energy) + self.pred_energies.append(output["energy"]) + # Per-atom normalization + self.ref_energies_per_atom.append(batch.energy / atoms_per_config) + self.pred_energies_per_atom.append(output["energy"] / atoms_per_config) + + # Forces + if output.get("forces") is not None and batch.forces is not None: + self.n_forces += 1.0 + self.ref_forces.append(batch.forces) + self.pred_forces.append(output["forces"]) + + # Stress + if output.get("stress") is not None and batch.stress is not None: + self.n_stress += 1.0 + self.ref_stress.append(batch.stress) + self.pred_stress.append(output["stress"]) + + # Virials + if output.get("virials") is not None and batch.virials is not None: + self.n_virials += 1.0 + self.ref_virials.append(batch.virials) + self.pred_virials.append(output["virials"]) + # Per-atom normalization + atoms_per_config_3d = atoms_per_config.view(-1, 1, 1) + self.ref_virials_per_atom.append(batch.virials / atoms_per_config_3d) + self.pred_virials_per_atom.append(output["virials"] / atoms_per_config_3d) + + # Dipole + if output.get("dipole") is not None and batch.dipole is not None: + self.n_dipole += 1.0 + self.ref_dipole.append(batch.dipole) + self.pred_dipole.append(output["dipole"]) + atoms_per_config_3d = atoms_per_config.view(-1, 1) + self.ref_dipole_per_atom.append(batch.dipole / atoms_per_config_3d) + self.pred_dipole_per_atom.append(output["dipole"] / atoms_per_config_3d) + + def _process_data(self, ref_list, pred_list): + # Handle different possible states of ref_list and pred_list in distributed mode + + # Check if this is a list type object + if isinstance(ref_list, (list, tuple)): + if len(ref_list) == 0: + return None, None + ref = torch.cat(ref_list).reshape(-1) + pred = torch.cat(pred_list).reshape(-1) + # Handle case where ref_list is already a tensor (happens after reset in distributed mode) + elif isinstance(ref_list, torch.Tensor): + ref = ref_list.reshape(-1) + pred = pred_list.reshape(-1) + # Handle other possible types + else: + return None, None + return to_numpy(ref), to_numpy(pred) + + def compute(self): + """Compute final results for scatterplot.""" + results = {} + + # Process energies + if self.n_energy: + ref_e, pred_e = self._process_data(self.ref_energies, self.pred_energies) + ref_e_pa, pred_e_pa = self._process_data( + self.ref_energies_per_atom, self.pred_energies_per_atom + ) + results["energy"] = { + "reference": ref_e, + "predicted": pred_e, + "reference_per_atom": ref_e_pa, + "predicted_per_atom": pred_e_pa, + } + + # Process forces + if self.n_forces: + ref_f, pred_f = self._process_data(self.ref_forces, self.pred_forces) + results["forces"] = { + "reference": ref_f, + "predicted": pred_f, + } + + # Process stress + if self.n_stress: + ref_s, pred_s = self._process_data(self.ref_stress, self.pred_stress) + results["stress"] = { + "reference": ref_s, + "predicted": pred_s, + } + + # Process virials + if self.n_virials: + ref_v, pred_v = self._process_data(self.ref_virials, self.pred_virials) + ref_v_pa, pred_v_pa = self._process_data( + self.ref_virials_per_atom, self.pred_virials_per_atom + ) + results["virials"] = { + "reference": ref_v, + "predicted": pred_v, + "reference_per_atom": ref_v_pa, + "predicted_per_atom": pred_v_pa, + } + + # Process dipoles + if self.n_dipole: + ref_d, pred_d = self._process_data(self.ref_dipole, self.pred_dipole) + ref_d_pa, pred_d_pa = self._process_data( + self.ref_dipole_per_atom, self.pred_dipole_per_atom + ) + results["dipole"] = { + "reference": ref_d, + "predicted": pred_d, + "reference_per_atom": ref_d_pa, + "predicted_per_atom": pred_d_pa, + } + return results diff --git a/mace-bench/3rdparty/mace/mace/data/__init__.py b/mace-bench/3rdparty/mace/mace/data/__init__.py index ad58cca679185605c8f31faf97af8541c420362e..8629cf521b395a4cfd2899b5f89a17d45dfa2749 100644 --- a/mace-bench/3rdparty/mace/mace/data/__init__.py +++ b/mace-bench/3rdparty/mace/mace/data/__init__.py @@ -1,40 +1,40 @@ -from .atomic_data import AtomicData -from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5 -from .lmdb_dataset import LMDBDataset -from .neighborhood import get_neighborhood -from .utils import ( - Configuration, - Configurations, - KeySpecification, - compute_average_E0s, - config_from_atoms, - config_from_atoms_list, - load_from_xyz, - random_train_valid_split, - save_AtomicData_to_HDF5, - save_configurations_as_HDF5, - save_dataset_as_HDF5, - test_config_types, - update_keyspec_from_kwargs, -) - -__all__ = [ - "get_neighborhood", - "Configuration", - "Configurations", - "random_train_valid_split", - "load_from_xyz", - "test_config_types", - "config_from_atoms", - "config_from_atoms_list", - "AtomicData", - "compute_average_E0s", - "save_dataset_as_HDF5", - "HDF5Dataset", - "dataset_from_sharded_hdf5", - "save_AtomicData_to_HDF5", - "save_configurations_as_HDF5", - "KeySpecification", - "update_keyspec_from_kwargs", - "LMDBDataset", -] +from .atomic_data import AtomicData +from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5 +from .lmdb_dataset import LMDBDataset +from .neighborhood import get_neighborhood +from .utils import ( + Configuration, + Configurations, + KeySpecification, + compute_average_E0s, + config_from_atoms, + config_from_atoms_list, + load_from_xyz, + random_train_valid_split, + save_AtomicData_to_HDF5, + save_configurations_as_HDF5, + save_dataset_as_HDF5, + test_config_types, + update_keyspec_from_kwargs, +) + +__all__ = [ + "get_neighborhood", + "Configuration", + "Configurations", + "random_train_valid_split", + "load_from_xyz", + "test_config_types", + "config_from_atoms", + "config_from_atoms_list", + "AtomicData", + "compute_average_E0s", + "save_dataset_as_HDF5", + "HDF5Dataset", + "dataset_from_sharded_hdf5", + "save_AtomicData_to_HDF5", + "save_configurations_as_HDF5", + "KeySpecification", + "update_keyspec_from_kwargs", + "LMDBDataset", +] diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 13866301413816fc28ea594c20e7a3a84a5072f6..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index fd358bc00d63ef15eabe4c0ce1cf835fee7c4b0f..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-310.pyc deleted file mode 100644 index 79291dc9ce890a0d9d182b87be237bc8800505da..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-313.pyc deleted file mode 100644 index 2a3c8c64cdc1257c9f81740f92ce4f77c0eea25e..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/hdf5_dataset.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/hdf5_dataset.cpython-310.pyc deleted file mode 100644 index fbb873cbac8e95d14cab4e4202055fcc5c435029..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/hdf5_dataset.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/hdf5_dataset.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/hdf5_dataset.cpython-313.pyc deleted file mode 100644 index 8131883de97390b6d54e5bd3cefc2b59641557bd..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/hdf5_dataset.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/lmdb_dataset.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/lmdb_dataset.cpython-310.pyc deleted file mode 100644 index e1153eceb99c0a46d6ec50d128520c0a62837b3b..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/lmdb_dataset.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/lmdb_dataset.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/lmdb_dataset.cpython-313.pyc deleted file mode 100644 index fa4f65494f5bdb62975eb9523fbdf426dbf59a3d..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/lmdb_dataset.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/neighborhood.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/neighborhood.cpython-310.pyc deleted file mode 100644 index 3019675b7040ba04f1d83a6a61e4ee81b7824bde..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/neighborhood.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/neighborhood.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/neighborhood.cpython-313.pyc deleted file mode 100644 index 0ab36d8e633468980075f1024d2c1e193ad97204..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/neighborhood.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/utils.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 05e4fcdd98125d5aaf6339dd00cb35603a7f8573..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/utils.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/utils.cpython-313.pyc deleted file mode 100644 index ebfab693d457d33ebbae7721e564bdef5ae0f23f..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/data/__pycache__/utils.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/data/atomic_data.py b/mace-bench/3rdparty/mace/mace/data/atomic_data.py index 26ebe1b7ca55e52eb535e487b047f2a21d7293e3..14dd39df3263633bbbe279427dfcbde64dd16097 100644 --- a/mace-bench/3rdparty/mace/mace/data/atomic_data.py +++ b/mace-bench/3rdparty/mace/mace/data/atomic_data.py @@ -1,300 +1,300 @@ -########################################################################################### -# Atomic Data Class for handling molecules as graphs -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from copy import deepcopy -from typing import Optional, Sequence - -import torch.utils.data - -from mace.tools import ( - AtomicNumberTable, - atomic_numbers_to_indices, - to_one_hot, - torch_geometric, - voigt_to_matrix, -) - -from .neighborhood import get_neighborhood -from .utils import Configuration - - -class AtomicData(torch_geometric.data.Data): - num_graphs: torch.Tensor - batch: torch.Tensor - edge_index: torch.Tensor - node_attrs: torch.Tensor - edge_vectors: torch.Tensor - edge_lengths: torch.Tensor - positions: torch.Tensor - shifts: torch.Tensor - unit_shifts: torch.Tensor - cell: torch.Tensor - forces: torch.Tensor - energy: torch.Tensor - stress: torch.Tensor - virials: torch.Tensor - dipole: torch.Tensor - charges: torch.Tensor - weight: torch.Tensor - energy_weight: torch.Tensor - forces_weight: torch.Tensor - stress_weight: torch.Tensor - virials_weight: torch.Tensor - dipole_weight: torch.Tensor - charges_weight: torch.Tensor - - def __init__( - self, - edge_index: torch.Tensor, # [2, n_edges] - node_attrs: torch.Tensor, # [n_nodes, n_node_feats] - positions: torch.Tensor, # [n_nodes, 3] - shifts: torch.Tensor, # [n_edges, 3], - unit_shifts: torch.Tensor, # [n_edges, 3] - cell: Optional[torch.Tensor], # [3,3] - weight: Optional[torch.Tensor], # [,] - head: Optional[torch.Tensor], # [,] - energy_weight: Optional[torch.Tensor], # [,] - forces_weight: Optional[torch.Tensor], # [,] - stress_weight: Optional[torch.Tensor], # [,] - virials_weight: Optional[torch.Tensor], # [,] - dipole_weight: Optional[torch.Tensor], # [,] - charges_weight: Optional[torch.Tensor], # [,] - forces: Optional[torch.Tensor], # [n_nodes, 3] - energy: Optional[torch.Tensor], # [, ] - stress: Optional[torch.Tensor], # [1,3,3] - virials: Optional[torch.Tensor], # [1,3,3] - dipole: Optional[torch.Tensor], # [, 3] - charges: Optional[torch.Tensor], # [n_nodes, ] - ): - # Check shapes - num_nodes = node_attrs.shape[0] - - assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 - assert positions.shape == (num_nodes, 3) - assert shifts.shape[1] == 3 - assert unit_shifts.shape[1] == 3 - assert len(node_attrs.shape) == 2 - assert weight is None or len(weight.shape) == 0 - assert head is None or len(head.shape) == 0 - assert energy_weight is None or len(energy_weight.shape) == 0 - assert forces_weight is None or len(forces_weight.shape) == 0 - assert stress_weight is None or len(stress_weight.shape) == 0 - assert virials_weight is None or len(virials_weight.shape) == 0 - assert dipole_weight is None or dipole_weight.shape == (1, 3), dipole_weight - assert charges_weight is None or len(charges_weight.shape) == 0 - assert cell is None or cell.shape == (3, 3) - assert forces is None or forces.shape == (num_nodes, 3) - assert energy is None or len(energy.shape) == 0 - assert stress is None or stress.shape == (1, 3, 3) - assert virials is None or virials.shape == (1, 3, 3) - assert dipole is None or dipole.shape[-1] == 3 - assert charges is None or charges.shape == (num_nodes,) - # Aggregate data - data = { - "num_nodes": num_nodes, - "edge_index": edge_index, - "positions": positions, - "shifts": shifts, - "unit_shifts": unit_shifts, - "cell": cell, - "node_attrs": node_attrs, - "weight": weight, - "head": head, - "energy_weight": energy_weight, - "forces_weight": forces_weight, - "stress_weight": stress_weight, - "virials_weight": virials_weight, - "dipole_weight": dipole_weight, - "charges_weight": charges_weight, - "forces": forces, - "energy": energy, - "stress": stress, - "virials": virials, - "dipole": dipole, - "charges": charges, - } - super().__init__(**data) - - @classmethod - def from_config( - cls, - config: Configuration, - z_table: AtomicNumberTable, - cutoff: float, - heads: Optional[list] = None, - **kwargs, # pylint: disable=unused-argument - ) -> "AtomicData": - if heads is None: - heads = ["Default"] - edge_index, shifts, unit_shifts, cell = get_neighborhood( - positions=config.positions, - cutoff=cutoff, - pbc=deepcopy(config.pbc), - cell=deepcopy(config.cell), - ) - indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) - one_hot = to_one_hot( - torch.tensor(indices, dtype=torch.long).unsqueeze(-1), - num_classes=len(z_table), - ) - try: - head = torch.tensor(heads.index(config.head), dtype=torch.long) - except ValueError: - head = torch.tensor(len(heads) - 1, dtype=torch.long) - - cell = ( - torch.tensor(cell, dtype=torch.get_default_dtype()) - if cell is not None - else torch.tensor( - 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() - ).view(3, 3) - ) - - num_atoms = len(config.atomic_numbers) - - weight = ( - torch.tensor(config.weight, dtype=torch.get_default_dtype()) - if config.weight is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - energy_weight = ( - torch.tensor( - config.property_weights.get("energy"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("energy") is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - forces_weight = ( - torch.tensor( - config.property_weights.get("forces"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("forces") is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - stress_weight = ( - torch.tensor( - config.property_weights.get("stress"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("stress") is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - virials_weight = ( - torch.tensor( - config.property_weights.get("virials"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("virials") is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - dipole_weight = ( - torch.tensor( - config.property_weights.get("dipole"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("dipole") is not None - else torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()) - ) - if len(dipole_weight.shape) == 0: - dipole_weight = dipole_weight * torch.tensor( - [[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype() - ) - elif len(dipole_weight.shape) == 1: - dipole_weight = dipole_weight.unsqueeze(0) - - charges_weight = ( - torch.tensor( - config.property_weights.get("charges"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("charges") is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - forces = ( - torch.tensor( - config.properties.get("forces"), dtype=torch.get_default_dtype() - ) - if config.properties.get("forces") is not None - else torch.zeros(num_atoms, 3, dtype=torch.get_default_dtype()) - ) - energy = ( - torch.tensor( - config.properties.get("energy"), dtype=torch.get_default_dtype() - ) - if config.properties.get("energy") is not None - else torch.tensor(0.0, dtype=torch.get_default_dtype()) - ) - stress = ( - voigt_to_matrix( - torch.tensor( - config.properties.get("stress"), dtype=torch.get_default_dtype() - ) - ).unsqueeze(0) - if config.properties.get("stress") is not None - else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) - ) - virials = ( - voigt_to_matrix( - torch.tensor( - config.properties.get("virials"), dtype=torch.get_default_dtype() - ) - ).unsqueeze(0) - if config.properties.get("virials") is not None - else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) - ) - dipole = ( - torch.tensor( - config.properties.get("dipole"), dtype=torch.get_default_dtype() - ).unsqueeze(0) - if config.properties.get("dipole") is not None - else torch.zeros(1, 3, dtype=torch.get_default_dtype()) - ) - charges = ( - torch.tensor( - config.properties.get("charges"), dtype=torch.get_default_dtype() - ) - if config.properties.get("charges") is not None - else torch.zeros(num_atoms, dtype=torch.get_default_dtype()) - ) - - return cls( - edge_index=torch.tensor(edge_index, dtype=torch.long), - positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), - shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), - unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), - cell=cell, - node_attrs=one_hot, - weight=weight, - head=head, - energy_weight=energy_weight, - forces_weight=forces_weight, - stress_weight=stress_weight, - virials_weight=virials_weight, - dipole_weight=dipole_weight, - charges_weight=charges_weight, - forces=forces, - energy=energy, - stress=stress, - virials=virials, - dipole=dipole, - charges=charges, - ) - - -def get_data_loader( - dataset: Sequence[AtomicData], - batch_size: int, - shuffle=True, - drop_last=False, -) -> torch.utils.data.DataLoader: - return torch_geometric.dataloader.DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - ) +########################################################################################### +# Atomic Data Class for handling molecules as graphs +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from copy import deepcopy +from typing import Optional, Sequence + +import torch.utils.data + +from mace.tools import ( + AtomicNumberTable, + atomic_numbers_to_indices, + to_one_hot, + torch_geometric, + voigt_to_matrix, +) + +from .neighborhood import get_neighborhood +from .utils import Configuration + + +class AtomicData(torch_geometric.data.Data): + num_graphs: torch.Tensor + batch: torch.Tensor + edge_index: torch.Tensor + node_attrs: torch.Tensor + edge_vectors: torch.Tensor + edge_lengths: torch.Tensor + positions: torch.Tensor + shifts: torch.Tensor + unit_shifts: torch.Tensor + cell: torch.Tensor + forces: torch.Tensor + energy: torch.Tensor + stress: torch.Tensor + virials: torch.Tensor + dipole: torch.Tensor + charges: torch.Tensor + weight: torch.Tensor + energy_weight: torch.Tensor + forces_weight: torch.Tensor + stress_weight: torch.Tensor + virials_weight: torch.Tensor + dipole_weight: torch.Tensor + charges_weight: torch.Tensor + + def __init__( + self, + edge_index: torch.Tensor, # [2, n_edges] + node_attrs: torch.Tensor, # [n_nodes, n_node_feats] + positions: torch.Tensor, # [n_nodes, 3] + shifts: torch.Tensor, # [n_edges, 3], + unit_shifts: torch.Tensor, # [n_edges, 3] + cell: Optional[torch.Tensor], # [3,3] + weight: Optional[torch.Tensor], # [,] + head: Optional[torch.Tensor], # [,] + energy_weight: Optional[torch.Tensor], # [,] + forces_weight: Optional[torch.Tensor], # [,] + stress_weight: Optional[torch.Tensor], # [,] + virials_weight: Optional[torch.Tensor], # [,] + dipole_weight: Optional[torch.Tensor], # [,] + charges_weight: Optional[torch.Tensor], # [,] + forces: Optional[torch.Tensor], # [n_nodes, 3] + energy: Optional[torch.Tensor], # [, ] + stress: Optional[torch.Tensor], # [1,3,3] + virials: Optional[torch.Tensor], # [1,3,3] + dipole: Optional[torch.Tensor], # [, 3] + charges: Optional[torch.Tensor], # [n_nodes, ] + ): + # Check shapes + num_nodes = node_attrs.shape[0] + + assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 + assert positions.shape == (num_nodes, 3) + assert shifts.shape[1] == 3 + assert unit_shifts.shape[1] == 3 + assert len(node_attrs.shape) == 2 + assert weight is None or len(weight.shape) == 0 + assert head is None or len(head.shape) == 0 + assert energy_weight is None or len(energy_weight.shape) == 0 + assert forces_weight is None or len(forces_weight.shape) == 0 + assert stress_weight is None or len(stress_weight.shape) == 0 + assert virials_weight is None or len(virials_weight.shape) == 0 + assert dipole_weight is None or dipole_weight.shape == (1, 3), dipole_weight + assert charges_weight is None or len(charges_weight.shape) == 0 + assert cell is None or cell.shape == (3, 3) + assert forces is None or forces.shape == (num_nodes, 3) + assert energy is None or len(energy.shape) == 0 + assert stress is None or stress.shape == (1, 3, 3) + assert virials is None or virials.shape == (1, 3, 3) + assert dipole is None or dipole.shape[-1] == 3 + assert charges is None or charges.shape == (num_nodes,) + # Aggregate data + data = { + "num_nodes": num_nodes, + "edge_index": edge_index, + "positions": positions, + "shifts": shifts, + "unit_shifts": unit_shifts, + "cell": cell, + "node_attrs": node_attrs, + "weight": weight, + "head": head, + "energy_weight": energy_weight, + "forces_weight": forces_weight, + "stress_weight": stress_weight, + "virials_weight": virials_weight, + "dipole_weight": dipole_weight, + "charges_weight": charges_weight, + "forces": forces, + "energy": energy, + "stress": stress, + "virials": virials, + "dipole": dipole, + "charges": charges, + } + super().__init__(**data) + + @classmethod + def from_config( + cls, + config: Configuration, + z_table: AtomicNumberTable, + cutoff: float, + heads: Optional[list] = None, + **kwargs, # pylint: disable=unused-argument + ) -> "AtomicData": + if heads is None: + heads = ["Default"] + edge_index, shifts, unit_shifts, cell = get_neighborhood( + positions=config.positions, + cutoff=cutoff, + pbc=deepcopy(config.pbc), + cell=deepcopy(config.cell), + ) + indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) + one_hot = to_one_hot( + torch.tensor(indices, dtype=torch.long).unsqueeze(-1), + num_classes=len(z_table), + ) + try: + head = torch.tensor(heads.index(config.head), dtype=torch.long) + except ValueError: + head = torch.tensor(len(heads) - 1, dtype=torch.long) + + cell = ( + torch.tensor(cell, dtype=torch.get_default_dtype()) + if cell is not None + else torch.tensor( + 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() + ).view(3, 3) + ) + + num_atoms = len(config.atomic_numbers) + + weight = ( + torch.tensor(config.weight, dtype=torch.get_default_dtype()) + if config.weight is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + energy_weight = ( + torch.tensor( + config.property_weights.get("energy"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("energy") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + forces_weight = ( + torch.tensor( + config.property_weights.get("forces"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("forces") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + stress_weight = ( + torch.tensor( + config.property_weights.get("stress"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("stress") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + virials_weight = ( + torch.tensor( + config.property_weights.get("virials"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("virials") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + dipole_weight = ( + torch.tensor( + config.property_weights.get("dipole"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("dipole") is not None + else torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()) + ) + if len(dipole_weight.shape) == 0: + dipole_weight = dipole_weight * torch.tensor( + [[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype() + ) + elif len(dipole_weight.shape) == 1: + dipole_weight = dipole_weight.unsqueeze(0) + + charges_weight = ( + torch.tensor( + config.property_weights.get("charges"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("charges") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + forces = ( + torch.tensor( + config.properties.get("forces"), dtype=torch.get_default_dtype() + ) + if config.properties.get("forces") is not None + else torch.zeros(num_atoms, 3, dtype=torch.get_default_dtype()) + ) + energy = ( + torch.tensor( + config.properties.get("energy"), dtype=torch.get_default_dtype() + ) + if config.properties.get("energy") is not None + else torch.tensor(0.0, dtype=torch.get_default_dtype()) + ) + stress = ( + voigt_to_matrix( + torch.tensor( + config.properties.get("stress"), dtype=torch.get_default_dtype() + ) + ).unsqueeze(0) + if config.properties.get("stress") is not None + else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) + ) + virials = ( + voigt_to_matrix( + torch.tensor( + config.properties.get("virials"), dtype=torch.get_default_dtype() + ) + ).unsqueeze(0) + if config.properties.get("virials") is not None + else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) + ) + dipole = ( + torch.tensor( + config.properties.get("dipole"), dtype=torch.get_default_dtype() + ).unsqueeze(0) + if config.properties.get("dipole") is not None + else torch.zeros(1, 3, dtype=torch.get_default_dtype()) + ) + charges = ( + torch.tensor( + config.properties.get("charges"), dtype=torch.get_default_dtype() + ) + if config.properties.get("charges") is not None + else torch.zeros(num_atoms, dtype=torch.get_default_dtype()) + ) + + return cls( + edge_index=torch.tensor(edge_index, dtype=torch.long), + positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), + shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), + unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), + cell=cell, + node_attrs=one_hot, + weight=weight, + head=head, + energy_weight=energy_weight, + forces_weight=forces_weight, + stress_weight=stress_weight, + virials_weight=virials_weight, + dipole_weight=dipole_weight, + charges_weight=charges_weight, + forces=forces, + energy=energy, + stress=stress, + virials=virials, + dipole=dipole, + charges=charges, + ) + + +def get_data_loader( + dataset: Sequence[AtomicData], + batch_size: int, + shuffle=True, + drop_last=False, +) -> torch.utils.data.DataLoader: + return torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + ) diff --git a/mace-bench/3rdparty/mace/mace/data/hdf5_dataset.py b/mace-bench/3rdparty/mace/mace/data/hdf5_dataset.py index b374885ec013882d50c1c9a4a4ec1d5192d51b74..ab6aa7c7095b756ff0e020c0821bd6efafda9aed 100644 --- a/mace-bench/3rdparty/mace/mace/data/hdf5_dataset.py +++ b/mace-bench/3rdparty/mace/mace/data/hdf5_dataset.py @@ -1,97 +1,97 @@ -from glob import glob -from typing import List - -import h5py -from torch.utils.data import ConcatDataset, Dataset - -from mace.data.atomic_data import AtomicData -from mace.data.utils import Configuration -from mace.tools.utils import AtomicNumberTable - - -class HDF5Dataset(Dataset): - def __init__( - self, file_path, r_max, z_table, atomic_dataclass=AtomicData, **kwargs - ): - super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments - self.file_path = file_path - self._file = None - batch_key = list(self.file.keys())[0] - self.batch_size = len(self.file[batch_key].keys()) - self.length = len(self.file.keys()) * self.batch_size - self.r_max = r_max - self.z_table = z_table - self.atomic_dataclass = atomic_dataclass - try: - self.drop_last = bool(self.file.attrs["drop_last"]) - except KeyError: - self.drop_last = False - self.kwargs = kwargs - - @property - def file(self): - if self._file is None: - # If a file has not already been opened, open one here - self._file = h5py.File(self.file_path, "r") - return self._file - - def __getstate__(self): - _d = dict(self.__dict__) - - # An opened h5py.File cannot be pickled, so we must exclude it from the state - _d["_file"] = None - return _d - - def __len__(self): - return self.length - - def __getitem__(self, index): - # compute the index of the batch - batch_index = index // self.batch_size - config_index = index % self.batch_size - grp = self.file["config_batch_" + str(batch_index)] - subgrp = grp["config_" + str(config_index)] - - properties = {} - property_weights = {} - for key in subgrp["properties"]: - properties[key] = unpack_value(subgrp["properties"][key][()]) - for key in subgrp["property_weights"]: - property_weights[key] = unpack_value(subgrp["property_weights"][key][()]) - - config = Configuration( - atomic_numbers=subgrp["atomic_numbers"][()], - positions=subgrp["positions"][()], - properties=properties, - weight=unpack_value(subgrp["weight"][()]), - property_weights=property_weights, - config_type=unpack_value(subgrp["config_type"][()]), - pbc=unpack_value(subgrp["pbc"][()]), - cell=unpack_value(subgrp["cell"][()]), - ) - if config.head is None: - config.head = self.kwargs.get("head") - atomic_data = self.atomic_dataclass.from_config( - config, - z_table=self.z_table, - cutoff=self.r_max, - heads=self.kwargs.get("heads", ["Default"]), - **{k: v for k, v in self.kwargs.items() if k != "heads"}, - ) - return atomic_data - - -def dataset_from_sharded_hdf5( - files: List, z_table: AtomicNumberTable, r_max: float, **kwargs -): - files = glob(files + "/*") - datasets = [] - for file in files: - datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max, **kwargs)) - full_dataset = ConcatDataset(datasets) - return full_dataset - - -def unpack_value(value): - value = value.decode("utf-8") if isinstance(value, bytes) else value - return None if str(value) == "None" else value +from glob import glob +from typing import List + +import h5py +from torch.utils.data import ConcatDataset, Dataset + +from mace.data.atomic_data import AtomicData +from mace.data.utils import Configuration +from mace.tools.utils import AtomicNumberTable + + +class HDF5Dataset(Dataset): + def __init__( + self, file_path, r_max, z_table, atomic_dataclass=AtomicData, **kwargs + ): + super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments + self.file_path = file_path + self._file = None + batch_key = list(self.file.keys())[0] + self.batch_size = len(self.file[batch_key].keys()) + self.length = len(self.file.keys()) * self.batch_size + self.r_max = r_max + self.z_table = z_table + self.atomic_dataclass = atomic_dataclass + try: + self.drop_last = bool(self.file.attrs["drop_last"]) + except KeyError: + self.drop_last = False + self.kwargs = kwargs + + @property + def file(self): + if self._file is None: + # If a file has not already been opened, open one here + self._file = h5py.File(self.file_path, "r") + return self._file + + def __getstate__(self): + _d = dict(self.__dict__) + + # An opened h5py.File cannot be pickled, so we must exclude it from the state + _d["_file"] = None + return _d + + def __len__(self): + return self.length + + def __getitem__(self, index): + # compute the index of the batch + batch_index = index // self.batch_size + config_index = index % self.batch_size + grp = self.file["config_batch_" + str(batch_index)] + subgrp = grp["config_" + str(config_index)] + + properties = {} + property_weights = {} + for key in subgrp["properties"]: + properties[key] = unpack_value(subgrp["properties"][key][()]) + for key in subgrp["property_weights"]: + property_weights[key] = unpack_value(subgrp["property_weights"][key][()]) + + config = Configuration( + atomic_numbers=subgrp["atomic_numbers"][()], + positions=subgrp["positions"][()], + properties=properties, + weight=unpack_value(subgrp["weight"][()]), + property_weights=property_weights, + config_type=unpack_value(subgrp["config_type"][()]), + pbc=unpack_value(subgrp["pbc"][()]), + cell=unpack_value(subgrp["cell"][()]), + ) + if config.head is None: + config.head = self.kwargs.get("head") + atomic_data = self.atomic_dataclass.from_config( + config, + z_table=self.z_table, + cutoff=self.r_max, + heads=self.kwargs.get("heads", ["Default"]), + **{k: v for k, v in self.kwargs.items() if k != "heads"}, + ) + return atomic_data + + +def dataset_from_sharded_hdf5( + files: List, z_table: AtomicNumberTable, r_max: float, **kwargs +): + files = glob(files + "/*") + datasets = [] + for file in files: + datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max, **kwargs)) + full_dataset = ConcatDataset(datasets) + return full_dataset + + +def unpack_value(value): + value = value.decode("utf-8") if isinstance(value, bytes) else value + return None if str(value) == "None" else value diff --git a/mace-bench/3rdparty/mace/mace/data/lmdb_dataset.py b/mace-bench/3rdparty/mace/mace/data/lmdb_dataset.py index 342b17979fe7e6303d8b896e94c059a3d0de85d6..e1ebbdaf54f3faee0fd5453cd1c509510d14c9d5 100644 --- a/mace-bench/3rdparty/mace/mace/data/lmdb_dataset.py +++ b/mace-bench/3rdparty/mace/mace/data/lmdb_dataset.py @@ -1,69 +1,69 @@ -import os - -import numpy as np -from torch.utils.data import Dataset - -from mace.data.atomic_data import AtomicData -from mace.data.utils import KeySpecification, config_from_atoms -from mace.tools.default_keys import DefaultKeys -from mace.tools.fairchem_dataset import AseDBDataset - - -class LMDBDataset(Dataset): - def __init__(self, file_path, r_max, z_table, **kwargs): - dataset_paths = file_path.split(":") # using : split multiple paths - # make sure each of the path exist - for path in dataset_paths: - assert os.path.exists(path) - config_kwargs = {} - super(LMDBDataset, self).__init__() # pylint: disable=super-with-arguments - self.AseDB = AseDBDataset(config=dict(src=dataset_paths, **config_kwargs)) - self.r_max = r_max - self.z_table = z_table - - self.kwargs = kwargs - self.transform = kwargs["transform"] if "transform" in kwargs else None - - def __len__(self): - return len(self.AseDB) - - def __getitem__(self, index): - try: - atoms = self.AseDB.get_atoms(self.AseDB.ids[index]) - except Exception as e: # pylint: disable=broad-except - print(f"Error in index {index}") - print(e) - return None - assert np.sum(atoms.get_cell() == atoms.cell) == 9 - - if hasattr(atoms, "calc") and hasattr(atoms.calc, "results"): - if "energy" in atoms.calc.results: - atoms.info[DefaultKeys.ENERGY.value] = atoms.calc.results["energy"] - if "forces" in atoms.calc.results: - atoms.arrays[DefaultKeys.FORCES.value] = atoms.calc.results["forces"] - if "stress" in atoms.calc.results: - atoms.info[DefaultKeys.STRESS.value] = atoms.calc.results["stress"] - - config = config_from_atoms( - atoms, - key_specification=KeySpecification.from_defaults(), - ) - - # Set head if not already set - if config.head == "Default": - config.head = self.kwargs.get("head", "Default") - - try: - atomic_data = AtomicData.from_config( - config, - z_table=self.z_table, - cutoff=self.r_max, - heads=self.kwargs.get("heads", ["Default"]), - ) - except Exception as e: # pylint: disable=broad-except - print(f"Error in index {index}") - print(e) - - if self.transform: - atomic_data = self.transform(atomic_data) - return atomic_data +import os + +import numpy as np +from torch.utils.data import Dataset + +from mace.data.atomic_data import AtomicData +from mace.data.utils import KeySpecification, config_from_atoms +from mace.tools.default_keys import DefaultKeys +from mace.tools.fairchem_dataset import AseDBDataset + + +class LMDBDataset(Dataset): + def __init__(self, file_path, r_max, z_table, **kwargs): + dataset_paths = file_path.split(":") # using : split multiple paths + # make sure each of the path exist + for path in dataset_paths: + assert os.path.exists(path) + config_kwargs = {} + super(LMDBDataset, self).__init__() # pylint: disable=super-with-arguments + self.AseDB = AseDBDataset(config=dict(src=dataset_paths, **config_kwargs)) + self.r_max = r_max + self.z_table = z_table + + self.kwargs = kwargs + self.transform = kwargs["transform"] if "transform" in kwargs else None + + def __len__(self): + return len(self.AseDB) + + def __getitem__(self, index): + try: + atoms = self.AseDB.get_atoms(self.AseDB.ids[index]) + except Exception as e: # pylint: disable=broad-except + print(f"Error in index {index}") + print(e) + return None + assert np.sum(atoms.get_cell() == atoms.cell) == 9 + + if hasattr(atoms, "calc") and hasattr(atoms.calc, "results"): + if "energy" in atoms.calc.results: + atoms.info[DefaultKeys.ENERGY.value] = atoms.calc.results["energy"] + if "forces" in atoms.calc.results: + atoms.arrays[DefaultKeys.FORCES.value] = atoms.calc.results["forces"] + if "stress" in atoms.calc.results: + atoms.info[DefaultKeys.STRESS.value] = atoms.calc.results["stress"] + + config = config_from_atoms( + atoms, + key_specification=KeySpecification.from_defaults(), + ) + + # Set head if not already set + if config.head == "Default": + config.head = self.kwargs.get("head", "Default") + + try: + atomic_data = AtomicData.from_config( + config, + z_table=self.z_table, + cutoff=self.r_max, + heads=self.kwargs.get("heads", ["Default"]), + ) + except Exception as e: # pylint: disable=broad-except + print(f"Error in index {index}") + print(e) + + if self.transform: + atomic_data = self.transform(atomic_data) + return atomic_data diff --git a/mace-bench/3rdparty/mace/mace/data/neighborhood.py b/mace-bench/3rdparty/mace/mace/data/neighborhood.py index cd463524d34842f20a8c73d4c634433465278ec7..03728969df6af8111b9bab8aa1602c8fa73e8082 100644 --- a/mace-bench/3rdparty/mace/mace/data/neighborhood.py +++ b/mace-bench/3rdparty/mace/mace/data/neighborhood.py @@ -1,66 +1,66 @@ -from typing import Optional, Tuple - -import numpy as np -from matscipy.neighbours import neighbour_list - - -def get_neighborhood( - positions: np.ndarray, # [num_positions, 3] - cutoff: float, - pbc: Optional[Tuple[bool, bool, bool]] = None, - cell: Optional[np.ndarray] = None, # [3, 3] - true_self_interaction=False, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - if pbc is None: - pbc = (False, False, False) - - if cell is None or cell.any() == np.zeros((3, 3)).any(): - cell = np.identity(3, dtype=float) - - assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) - assert cell.shape == (3, 3) - - pbc_x = pbc[0] - pbc_y = pbc[1] - pbc_z = pbc[2] - identity = np.identity(3, dtype=float) - max_positions = np.max(np.absolute(positions)) + 1 - # Extend cell in non-periodic directions - # For models with more than 5 layers, the multiplicative constant needs to be increased. - # temp_cell = np.copy(cell) - if not pbc_x: - cell[0, :] = max_positions * 5 * cutoff * identity[0, :] - if not pbc_y: - cell[1, :] = max_positions * 5 * cutoff * identity[1, :] - if not pbc_z: - cell[2, :] = max_positions * 5 * cutoff * identity[2, :] - - sender, receiver, unit_shifts = neighbour_list( - quantities="ijS", - pbc=pbc, - cell=cell, - positions=positions, - cutoff=cutoff, - # self_interaction=True, # we want edges from atom to itself in different periodic images - # use_scaled_positions=False, # positions are not scaled positions - ) - - if not true_self_interaction: - # Eliminate self-edges that don't cross periodic boundaries - true_self_edge = sender == receiver - true_self_edge &= np.all(unit_shifts == 0, axis=1) - keep_edge = ~true_self_edge - - # Note: after eliminating self-edges, it can be that no edges remain in this system - sender = sender[keep_edge] - receiver = receiver[keep_edge] - unit_shifts = unit_shifts[keep_edge] - - # Build output - edge_index = np.stack((sender, receiver)) # [2, n_edges] - - # From the docs: With the shift vector S, the distances D between atoms can be computed from - # D = positions[j]-positions[i]+S.dot(cell) - shifts = np.dot(unit_shifts, cell) # [n_edges, 3] - - return edge_index, shifts, unit_shifts, cell +from typing import Optional, Tuple + +import numpy as np +from matscipy.neighbours import neighbour_list + + +def get_neighborhood( + positions: np.ndarray, # [num_positions, 3] + cutoff: float, + pbc: Optional[Tuple[bool, bool, bool]] = None, + cell: Optional[np.ndarray] = None, # [3, 3] + true_self_interaction=False, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + if pbc is None: + pbc = (False, False, False) + + if cell is None or cell.any() == np.zeros((3, 3)).any(): + cell = np.identity(3, dtype=float) + + assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) + assert cell.shape == (3, 3) + + pbc_x = pbc[0] + pbc_y = pbc[1] + pbc_z = pbc[2] + identity = np.identity(3, dtype=float) + max_positions = np.max(np.absolute(positions)) + 1 + # Extend cell in non-periodic directions + # For models with more than 5 layers, the multiplicative constant needs to be increased. + # temp_cell = np.copy(cell) + if not pbc_x: + cell[0, :] = max_positions * 5 * cutoff * identity[0, :] + if not pbc_y: + cell[1, :] = max_positions * 5 * cutoff * identity[1, :] + if not pbc_z: + cell[2, :] = max_positions * 5 * cutoff * identity[2, :] + + sender, receiver, unit_shifts = neighbour_list( + quantities="ijS", + pbc=pbc, + cell=cell, + positions=positions, + cutoff=cutoff, + # self_interaction=True, # we want edges from atom to itself in different periodic images + # use_scaled_positions=False, # positions are not scaled positions + ) + + if not true_self_interaction: + # Eliminate self-edges that don't cross periodic boundaries + true_self_edge = sender == receiver + true_self_edge &= np.all(unit_shifts == 0, axis=1) + keep_edge = ~true_self_edge + + # Note: after eliminating self-edges, it can be that no edges remain in this system + sender = sender[keep_edge] + receiver = receiver[keep_edge] + unit_shifts = unit_shifts[keep_edge] + + # Build output + edge_index = np.stack((sender, receiver)) # [2, n_edges] + + # From the docs: With the shift vector S, the distances D between atoms can be computed from + # D = positions[j]-positions[i]+S.dot(cell) + shifts = np.dot(unit_shifts, cell) # [n_edges, 3] + + return edge_index, shifts, unit_shifts, cell diff --git a/mace-bench/3rdparty/mace/mace/data/utils.py b/mace-bench/3rdparty/mace/mace/data/utils.py index 6afa3bb9618dff643c41c77a2a12cb2fcea3b057..947ee60f4f1941f390a0f8ba52b821a7ed45e90b 100644 --- a/mace-bench/3rdparty/mace/mace/data/utils.py +++ b/mace-bench/3rdparty/mace/mace/data/utils.py @@ -1,368 +1,368 @@ -########################################################################################### -# Data parsing utilities -# Authors: Ilyes Batatia, Gregor Simm and David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Sequence, Tuple - -import ase.data -import ase.io -import h5py -import numpy as np - -from mace.tools import AtomicNumberTable, DefaultKeys - -Positions = np.ndarray # [..., 3] -Cell = np.ndarray # [3,3] -Pbc = tuple # (3,) - -DEFAULT_CONFIG_TYPE = "Default" -DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} - - -@dataclass -class KeySpecification: - info_keys: Dict[str, str] = field(default_factory=dict) - arrays_keys: Dict[str, str] = field(default_factory=dict) - - def update( - self, - info_keys: Optional[Dict[str, str]] = None, - arrays_keys: Optional[Dict[str, str]] = None, - ): - if info_keys is not None: - self.info_keys.update(info_keys) - if arrays_keys is not None: - self.arrays_keys.update(arrays_keys) - return self - - @classmethod - def from_defaults(cls): - instance = cls() - return update_keyspec_from_kwargs(instance, DefaultKeys.keydict()) - - -def update_keyspec_from_kwargs( - keyspec: KeySpecification, keydict: Dict[str, str] -) -> KeySpecification: - # convert command line style property_key arguments into a keyspec - infos = ["energy_key", "stress_key", "virials_key", "dipole_key", "head_key"] - arrays = ["forces_key", "charges_key"] - info_keys = {} - arrays_keys = {} - for key in infos: - if key in keydict: - info_keys[key[:-4]] = keydict[key] - for key in arrays: - if key in keydict: - arrays_keys[key[:-4]] = keydict[key] - keyspec.update(info_keys=info_keys, arrays_keys=arrays_keys) - return keyspec - - -@dataclass -class Configuration: - atomic_numbers: np.ndarray - positions: Positions # Angstrom - properties: Dict[str, Any] - property_weights: Dict[str, float] - cell: Optional[Cell] = None - pbc: Optional[Pbc] = None - - weight: float = 1.0 # weight of config in loss - config_type: str = DEFAULT_CONFIG_TYPE # config_type of config - head: str = "Default" # head used to compute the config - - -Configurations = List[Configuration] - - -def random_train_valid_split( - items: Sequence, valid_fraction: float, seed: int, work_dir: str -) -> Tuple[List, List]: - assert 0.0 < valid_fraction < 1.0 - - size = len(items) - train_size = size - int(valid_fraction * size) - - indices = list(range(size)) - rng = np.random.default_rng(seed) - rng.shuffle(indices) - if len(indices[train_size:]) < 10: - logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" - ) - else: - # Save indices to file - with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: - for index in indices[train_size:]: - f.write(f"{index}\n") - - logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" - ) - - return ( - [items[i] for i in indices[:train_size]], - [items[i] for i in indices[train_size:]], - ) - - -def config_from_atoms_list( - atoms_list: List[ase.Atoms], - key_specification: KeySpecification, - config_type_weights: Optional[Dict[str, float]] = None, - head_name: str = "Default", -) -> Configurations: - """Convert list of ase.Atoms into Configurations""" - if config_type_weights is None: - config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS - - all_configs = [] - for atoms in atoms_list: - all_configs.append( - config_from_atoms( - atoms, - key_specification=key_specification, - config_type_weights=config_type_weights, - head_name=head_name, - ) - ) - return all_configs - - -def config_from_atoms( - atoms: ase.Atoms, - key_specification: KeySpecification = KeySpecification(), - config_type_weights: Optional[Dict[str, float]] = None, - head_name: str = "Default", -) -> Configuration: - """Convert ase.Atoms to Configuration""" - if config_type_weights is None: - config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS - - atomic_numbers = np.array( - [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] - ) - pbc = tuple(atoms.get_pbc()) - cell = np.array(atoms.get_cell()) - config_type = atoms.info.get("config_type", "Default") - weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( - config_type, 1.0 - ) - - properties = {} - property_weights = {} - for name in list(key_specification.arrays_keys) + list(key_specification.info_keys): - property_weights[name] = atoms.info.get(f"config_{name}_weight", 1.0) - - for name, atoms_key in key_specification.info_keys.items(): - properties[name] = atoms.info.get(atoms_key, None) - if not atoms_key in atoms.info: - property_weights[name] = 0.0 - - for name, atoms_key in key_specification.arrays_keys.items(): - properties[name] = atoms.arrays.get(atoms_key, None) - if not atoms_key in atoms.arrays: - property_weights[name] = 0.0 - - return Configuration( - atomic_numbers=atomic_numbers, - positions=atoms.get_positions(), - properties=properties, - weight=weight, - property_weights=property_weights, - head=head_name, - config_type=config_type, - pbc=pbc, - cell=cell, - ) - - -def test_config_types( - test_configs: Configurations, -) -> List[Tuple[str, List[Configuration]]]: - """Split test set based on config_type-s""" - test_by_ct = [] - all_cts = [] - for conf in test_configs: - config_type_name = conf.config_type + "_" + conf.head - if config_type_name not in all_cts: - all_cts.append(config_type_name) - test_by_ct.append((config_type_name, [conf])) - else: - ind = all_cts.index(config_type_name) - test_by_ct[ind][1].append(conf) - return test_by_ct - - -def load_from_xyz( - file_path: str, - key_specification: KeySpecification, - head_name: str = "Default", - config_type_weights: Optional[Dict] = None, - extract_atomic_energies: bool = False, - keep_isolated_atoms: bool = False, -) -> Tuple[Dict[int, float], Configurations]: - atoms_list = ase.io.read(file_path, index=":") - energy_key = key_specification.info_keys["energy"] - forces_key = key_specification.arrays_keys["forces"] - stress_key = key_specification.info_keys["stress"] - head_key = key_specification.info_keys["head"] - if energy_key == "energy": - logging.warning( - "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." - ) - key_specification.info_keys["energy"] = "REF_energy" - for atoms in atoms_list: - try: - atoms.info["REF_energy"] = atoms.get_potential_energy() - except Exception as e: # pylint: disable=W0703 - logging.error(f"Failed to extract energy: {e}") - atoms.info["REF_energy"] = None - if forces_key == "forces": - logging.warning( - "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." - ) - key_specification.arrays_keys["forces"] = "REF_forces" - for atoms in atoms_list: - try: - atoms.arrays["REF_forces"] = atoms.get_forces() - except Exception as e: # pylint: disable=W0703 - logging.error(f"Failed to extract forces: {e}") - atoms.arrays["REF_forces"] = None - if stress_key == "stress": - logging.warning( - "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." - ) - key_specification.info_keys["stress"] = "REF_stress" - for atoms in atoms_list: - try: - atoms.info["REF_stress"] = atoms.get_stress() - except Exception as e: # pylint: disable=W0703 - atoms.info["REF_stress"] = None - if not isinstance(atoms_list, list): - atoms_list = [atoms_list] - - atomic_energies_dict = {} - if extract_atomic_energies: - atoms_without_iso_atoms = [] - - for idx, atoms in enumerate(atoms_list): - atoms.info[head_key] = head_name - isolated_atom_config = ( - len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" - ) - if isolated_atom_config: - atomic_number = int(atoms.get_atomic_numbers()[0]) - if energy_key in atoms.info.keys(): - atomic_energies_dict[atomic_number] = float(atoms.info[energy_key]) - else: - logging.warning( - f"Configuration '{idx}' is marked as 'IsolatedAtom' " - "but does not contain an energy. Zero energy will be used." - ) - atomic_energies_dict[atomic_number] = 0.0 - else: - atoms_without_iso_atoms.append(atoms) - - if len(atomic_energies_dict) > 0: - logging.info("Using isolated atom energies from training file") - if not keep_isolated_atoms: - atoms_list = atoms_without_iso_atoms - - for atoms in atoms_list: - atoms.info[head_key] = head_name - - configs = config_from_atoms_list( - atoms_list, - config_type_weights=config_type_weights, - key_specification=key_specification, - head_name=head_name, - ) - return atomic_energies_dict, configs - - -def compute_average_E0s( - collections_train: Configurations, z_table: AtomicNumberTable -) -> Dict[int, float]: - """ - Function to compute the average interaction energy of each chemical element - returns dictionary of E0s - """ - len_train = len(collections_train) - len_zs = len(z_table) - A = np.zeros((len_train, len_zs)) - B = np.zeros(len_train) - for i in range(len_train): - B[i] = collections_train[i].properties["energy"] - for j, z in enumerate(z_table.zs): - A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) - try: - E0s = np.linalg.lstsq(A, B, rcond=None)[0] - atomic_energies_dict = {} - for i, z in enumerate(z_table.zs): - atomic_energies_dict[z] = E0s[i] - except np.linalg.LinAlgError: - logging.error( - "Failed to compute E0s using least squares regression, using the same for all atoms" - ) - atomic_energies_dict = {} - for i, z in enumerate(z_table.zs): - atomic_energies_dict[z] = 0.0 - return atomic_energies_dict - - -def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: - with h5py.File(out_name, "w") as f: - for i, data in enumerate(dataset): - save_AtomicData_to_HDF5(data, i, f) - - -def save_AtomicData_to_HDF5(data, i, h5_file) -> None: - grp = h5_file.create_group(f"config_{i}") - grp["num_nodes"] = data.num_nodes - grp["edge_index"] = data.edge_index - grp["positions"] = data.positions - grp["shifts"] = data.shifts - grp["unit_shifts"] = data.unit_shifts - grp["cell"] = data.cell - grp["node_attrs"] = data.node_attrs - grp["weight"] = data.weight - grp["energy_weight"] = data.energy_weight - grp["forces_weight"] = data.forces_weight - grp["stress_weight"] = data.stress_weight - grp["virials_weight"] = data.virials_weight - grp["forces"] = data.forces - grp["energy"] = data.energy - grp["stress"] = data.stress - grp["virials"] = data.virials - grp["dipole"] = data.dipole - grp["charges"] = data.charges - grp["head"] = data.head - - -def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: - grp = h5_file.create_group("config_batch_0") - for j, config in enumerate(configurations): - subgroup_name = f"config_{j}" - subgroup = grp.create_group(subgroup_name) - subgroup["atomic_numbers"] = write_value(config.atomic_numbers) - subgroup["positions"] = write_value(config.positions) - properties_subgrp = subgroup.create_group("properties") - for key, value in config.properties.items(): - properties_subgrp[key] = write_value(value) - subgroup["cell"] = write_value(config.cell) - subgroup["pbc"] = write_value(config.pbc) - subgroup["weight"] = write_value(config.weight) - weights_subgrp = subgroup.create_group("property_weights") - for key, value in config.property_weights.items(): - weights_subgrp[key] = write_value(value) - subgroup["config_type"] = write_value(config.config_type) - - -def write_value(value): - return value if value is not None else "None" +########################################################################################### +# Data parsing utilities +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import ase.data +import ase.io +import h5py +import numpy as np + +from mace.tools import AtomicNumberTable, DefaultKeys + +Positions = np.ndarray # [..., 3] +Cell = np.ndarray # [3,3] +Pbc = tuple # (3,) + +DEFAULT_CONFIG_TYPE = "Default" +DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} + + +@dataclass +class KeySpecification: + info_keys: Dict[str, str] = field(default_factory=dict) + arrays_keys: Dict[str, str] = field(default_factory=dict) + + def update( + self, + info_keys: Optional[Dict[str, str]] = None, + arrays_keys: Optional[Dict[str, str]] = None, + ): + if info_keys is not None: + self.info_keys.update(info_keys) + if arrays_keys is not None: + self.arrays_keys.update(arrays_keys) + return self + + @classmethod + def from_defaults(cls): + instance = cls() + return update_keyspec_from_kwargs(instance, DefaultKeys.keydict()) + + +def update_keyspec_from_kwargs( + keyspec: KeySpecification, keydict: Dict[str, str] +) -> KeySpecification: + # convert command line style property_key arguments into a keyspec + infos = ["energy_key", "stress_key", "virials_key", "dipole_key", "head_key"] + arrays = ["forces_key", "charges_key"] + info_keys = {} + arrays_keys = {} + for key in infos: + if key in keydict: + info_keys[key[:-4]] = keydict[key] + for key in arrays: + if key in keydict: + arrays_keys[key[:-4]] = keydict[key] + keyspec.update(info_keys=info_keys, arrays_keys=arrays_keys) + return keyspec + + +@dataclass +class Configuration: + atomic_numbers: np.ndarray + positions: Positions # Angstrom + properties: Dict[str, Any] + property_weights: Dict[str, float] + cell: Optional[Cell] = None + pbc: Optional[Pbc] = None + + weight: float = 1.0 # weight of config in loss + config_type: str = DEFAULT_CONFIG_TYPE # config_type of config + head: str = "Default" # head used to compute the config + + +Configurations = List[Configuration] + + +def random_train_valid_split( + items: Sequence, valid_fraction: float, seed: int, work_dir: str +) -> Tuple[List, List]: + assert 0.0 < valid_fraction < 1.0 + + size = len(items) + train_size = size - int(valid_fraction * size) + + indices = list(range(size)) + rng = np.random.default_rng(seed) + rng.shuffle(indices) + if len(indices[train_size:]) < 10: + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" + ) + else: + # Save indices to file + with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: + for index in indices[train_size:]: + f.write(f"{index}\n") + + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" + ) + + return ( + [items[i] for i in indices[:train_size]], + [items[i] for i in indices[train_size:]], + ) + + +def config_from_atoms_list( + atoms_list: List[ase.Atoms], + key_specification: KeySpecification, + config_type_weights: Optional[Dict[str, float]] = None, + head_name: str = "Default", +) -> Configurations: + """Convert list of ase.Atoms into Configurations""" + if config_type_weights is None: + config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS + + all_configs = [] + for atoms in atoms_list: + all_configs.append( + config_from_atoms( + atoms, + key_specification=key_specification, + config_type_weights=config_type_weights, + head_name=head_name, + ) + ) + return all_configs + + +def config_from_atoms( + atoms: ase.Atoms, + key_specification: KeySpecification = KeySpecification(), + config_type_weights: Optional[Dict[str, float]] = None, + head_name: str = "Default", +) -> Configuration: + """Convert ase.Atoms to Configuration""" + if config_type_weights is None: + config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS + + atomic_numbers = np.array( + [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] + ) + pbc = tuple(atoms.get_pbc()) + cell = np.array(atoms.get_cell()) + config_type = atoms.info.get("config_type", "Default") + weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( + config_type, 1.0 + ) + + properties = {} + property_weights = {} + for name in list(key_specification.arrays_keys) + list(key_specification.info_keys): + property_weights[name] = atoms.info.get(f"config_{name}_weight", 1.0) + + for name, atoms_key in key_specification.info_keys.items(): + properties[name] = atoms.info.get(atoms_key, None) + if not atoms_key in atoms.info: + property_weights[name] = 0.0 + + for name, atoms_key in key_specification.arrays_keys.items(): + properties[name] = atoms.arrays.get(atoms_key, None) + if not atoms_key in atoms.arrays: + property_weights[name] = 0.0 + + return Configuration( + atomic_numbers=atomic_numbers, + positions=atoms.get_positions(), + properties=properties, + weight=weight, + property_weights=property_weights, + head=head_name, + config_type=config_type, + pbc=pbc, + cell=cell, + ) + + +def test_config_types( + test_configs: Configurations, +) -> List[Tuple[str, List[Configuration]]]: + """Split test set based on config_type-s""" + test_by_ct = [] + all_cts = [] + for conf in test_configs: + config_type_name = conf.config_type + "_" + conf.head + if config_type_name not in all_cts: + all_cts.append(config_type_name) + test_by_ct.append((config_type_name, [conf])) + else: + ind = all_cts.index(config_type_name) + test_by_ct[ind][1].append(conf) + return test_by_ct + + +def load_from_xyz( + file_path: str, + key_specification: KeySpecification, + head_name: str = "Default", + config_type_weights: Optional[Dict] = None, + extract_atomic_energies: bool = False, + keep_isolated_atoms: bool = False, +) -> Tuple[Dict[int, float], Configurations]: + atoms_list = ase.io.read(file_path, index=":") + energy_key = key_specification.info_keys["energy"] + forces_key = key_specification.arrays_keys["forces"] + stress_key = key_specification.info_keys["stress"] + head_key = key_specification.info_keys["head"] + if energy_key == "energy": + logging.warning( + "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." + ) + key_specification.info_keys["energy"] = "REF_energy" + for atoms in atoms_list: + try: + atoms.info["REF_energy"] = atoms.get_potential_energy() + except Exception as e: # pylint: disable=W0703 + logging.error(f"Failed to extract energy: {e}") + atoms.info["REF_energy"] = None + if forces_key == "forces": + logging.warning( + "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." + ) + key_specification.arrays_keys["forces"] = "REF_forces" + for atoms in atoms_list: + try: + atoms.arrays["REF_forces"] = atoms.get_forces() + except Exception as e: # pylint: disable=W0703 + logging.error(f"Failed to extract forces: {e}") + atoms.arrays["REF_forces"] = None + if stress_key == "stress": + logging.warning( + "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." + ) + key_specification.info_keys["stress"] = "REF_stress" + for atoms in atoms_list: + try: + atoms.info["REF_stress"] = atoms.get_stress() + except Exception as e: # pylint: disable=W0703 + atoms.info["REF_stress"] = None + if not isinstance(atoms_list, list): + atoms_list = [atoms_list] + + atomic_energies_dict = {} + if extract_atomic_energies: + atoms_without_iso_atoms = [] + + for idx, atoms in enumerate(atoms_list): + atoms.info[head_key] = head_name + isolated_atom_config = ( + len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" + ) + if isolated_atom_config: + atomic_number = int(atoms.get_atomic_numbers()[0]) + if energy_key in atoms.info.keys(): + atomic_energies_dict[atomic_number] = float(atoms.info[energy_key]) + else: + logging.warning( + f"Configuration '{idx}' is marked as 'IsolatedAtom' " + "but does not contain an energy. Zero energy will be used." + ) + atomic_energies_dict[atomic_number] = 0.0 + else: + atoms_without_iso_atoms.append(atoms) + + if len(atomic_energies_dict) > 0: + logging.info("Using isolated atom energies from training file") + if not keep_isolated_atoms: + atoms_list = atoms_without_iso_atoms + + for atoms in atoms_list: + atoms.info[head_key] = head_name + + configs = config_from_atoms_list( + atoms_list, + config_type_weights=config_type_weights, + key_specification=key_specification, + head_name=head_name, + ) + return atomic_energies_dict, configs + + +def compute_average_E0s( + collections_train: Configurations, z_table: AtomicNumberTable +) -> Dict[int, float]: + """ + Function to compute the average interaction energy of each chemical element + returns dictionary of E0s + """ + len_train = len(collections_train) + len_zs = len(z_table) + A = np.zeros((len_train, len_zs)) + B = np.zeros(len_train) + for i in range(len_train): + B[i] = collections_train[i].properties["energy"] + for j, z in enumerate(z_table.zs): + A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) + try: + E0s = np.linalg.lstsq(A, B, rcond=None)[0] + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = E0s[i] + except np.linalg.LinAlgError: + logging.error( + "Failed to compute E0s using least squares regression, using the same for all atoms" + ) + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = 0.0 + return atomic_energies_dict + + +def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: + with h5py.File(out_name, "w") as f: + for i, data in enumerate(dataset): + save_AtomicData_to_HDF5(data, i, f) + + +def save_AtomicData_to_HDF5(data, i, h5_file) -> None: + grp = h5_file.create_group(f"config_{i}") + grp["num_nodes"] = data.num_nodes + grp["edge_index"] = data.edge_index + grp["positions"] = data.positions + grp["shifts"] = data.shifts + grp["unit_shifts"] = data.unit_shifts + grp["cell"] = data.cell + grp["node_attrs"] = data.node_attrs + grp["weight"] = data.weight + grp["energy_weight"] = data.energy_weight + grp["forces_weight"] = data.forces_weight + grp["stress_weight"] = data.stress_weight + grp["virials_weight"] = data.virials_weight + grp["forces"] = data.forces + grp["energy"] = data.energy + grp["stress"] = data.stress + grp["virials"] = data.virials + grp["dipole"] = data.dipole + grp["charges"] = data.charges + grp["head"] = data.head + + +def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: + grp = h5_file.create_group("config_batch_0") + for j, config in enumerate(configurations): + subgroup_name = f"config_{j}" + subgroup = grp.create_group(subgroup_name) + subgroup["atomic_numbers"] = write_value(config.atomic_numbers) + subgroup["positions"] = write_value(config.positions) + properties_subgrp = subgroup.create_group("properties") + for key, value in config.properties.items(): + properties_subgrp[key] = write_value(value) + subgroup["cell"] = write_value(config.cell) + subgroup["pbc"] = write_value(config.pbc) + subgroup["weight"] = write_value(config.weight) + weights_subgrp = subgroup.create_group("property_weights") + for key, value in config.property_weights.items(): + weights_subgrp[key] = write_value(value) + subgroup["config_type"] = write_value(config.config_type) + + +def write_value(value): + return value if value is not None else "None" diff --git a/mace-bench/3rdparty/mace/mace/modules/__init__.py b/mace-bench/3rdparty/mace/mace/modules/__init__.py index d816220c9961a7ef47fb9d416ef08976ace6c3a9..40a29d334a0a112ec6c4209711bfdac6f550bc6d 100644 --- a/mace-bench/3rdparty/mace/mace/modules/__init__.py +++ b/mace-bench/3rdparty/mace/mace/modules/__init__.py @@ -1,100 +1,100 @@ -from typing import Callable, Dict, Optional, Type - -import torch - -from .blocks import ( - AtomicEnergiesBlock, - EquivariantProductBasisBlock, - InteractionBlock, - LinearDipoleReadoutBlock, - LinearNodeEmbeddingBlock, - LinearReadoutBlock, - NonLinearDipoleReadoutBlock, - NonLinearReadoutBlock, - RadialEmbeddingBlock, - RealAgnosticAttResidualInteractionBlock, - RealAgnosticDensityInteractionBlock, - RealAgnosticDensityResidualInteractionBlock, - RealAgnosticInteractionBlock, - RealAgnosticResidualInteractionBlock, - ScaleShiftBlock, -) -from .loss import ( - DipoleSingleLoss, - UniversalLoss, - WeightedEnergyForcesDipoleLoss, - WeightedEnergyForcesL1L2Loss, - WeightedEnergyForcesLoss, - WeightedEnergyForcesStressLoss, - WeightedEnergyForcesVirialsLoss, - WeightedForcesLoss, - WeightedHuberEnergyForcesStressLoss, -) -from .models import MACE, AtomicDipolesMACE, EnergyDipolesMACE, ScaleShiftMACE -from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis -from .symmetric_contraction import SymmetricContraction -from .utils import ( - compute_avg_num_neighbors, - compute_fixed_charge_dipole, - compute_mean_rms_energy_forces, - compute_mean_std_atomic_inter_energy, - compute_rms_dipoles, - compute_statistics, -) - -interaction_classes: Dict[str, Type[InteractionBlock]] = { - "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, - "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, - "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, - "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, - "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, -} - -scaling_classes: Dict[str, Callable] = { - "std_scaling": compute_mean_std_atomic_inter_energy, - "rms_forces_scaling": compute_mean_rms_energy_forces, - "rms_dipoles_scaling": compute_rms_dipoles, -} - -gate_dict: Dict[str, Optional[Callable]] = { - "abs": torch.abs, - "tanh": torch.tanh, - "silu": torch.nn.functional.silu, - "None": None, -} - -__all__ = [ - "AtomicEnergiesBlock", - "RadialEmbeddingBlock", - "ZBLBasis", - "LinearNodeEmbeddingBlock", - "LinearReadoutBlock", - "EquivariantProductBasisBlock", - "ScaleShiftBlock", - "LinearDipoleReadoutBlock", - "NonLinearDipoleReadoutBlock", - "InteractionBlock", - "NonLinearReadoutBlock", - "PolynomialCutoff", - "BesselBasis", - "GaussianBasis", - "MACE", - "ScaleShiftMACE", - "AtomicDipolesMACE", - "EnergyDipolesMACE", - "WeightedEnergyForcesLoss", - "WeightedForcesLoss", - "WeightedEnergyForcesVirialsLoss", - "WeightedEnergyForcesStressLoss", - "DipoleSingleLoss", - "WeightedEnergyForcesDipoleLoss", - "WeightedHuberEnergyForcesStressLoss", - "UniversalLoss", - "WeightedEnergyForcesL1L2Loss", - "SymmetricContraction", - "interaction_classes", - "compute_mean_std_atomic_inter_energy", - "compute_avg_num_neighbors", - "compute_statistics", - "compute_fixed_charge_dipole", -] +from typing import Callable, Dict, Optional, Type + +import torch + +from .blocks import ( + AtomicEnergiesBlock, + EquivariantProductBasisBlock, + InteractionBlock, + LinearDipoleReadoutBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearDipoleReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, + RealAgnosticAttResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, + RealAgnosticInteractionBlock, + RealAgnosticResidualInteractionBlock, + ScaleShiftBlock, +) +from .loss import ( + DipoleSingleLoss, + UniversalLoss, + WeightedEnergyForcesDipoleLoss, + WeightedEnergyForcesL1L2Loss, + WeightedEnergyForcesLoss, + WeightedEnergyForcesStressLoss, + WeightedEnergyForcesVirialsLoss, + WeightedForcesLoss, + WeightedHuberEnergyForcesStressLoss, +) +from .models import MACE, AtomicDipolesMACE, EnergyDipolesMACE, ScaleShiftMACE +from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis +from .symmetric_contraction import SymmetricContraction +from .utils import ( + compute_avg_num_neighbors, + compute_fixed_charge_dipole, + compute_mean_rms_energy_forces, + compute_mean_std_atomic_inter_energy, + compute_rms_dipoles, + compute_statistics, +) + +interaction_classes: Dict[str, Type[InteractionBlock]] = { + "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, + "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, + "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, + "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, + "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, +} + +scaling_classes: Dict[str, Callable] = { + "std_scaling": compute_mean_std_atomic_inter_energy, + "rms_forces_scaling": compute_mean_rms_energy_forces, + "rms_dipoles_scaling": compute_rms_dipoles, +} + +gate_dict: Dict[str, Optional[Callable]] = { + "abs": torch.abs, + "tanh": torch.tanh, + "silu": torch.nn.functional.silu, + "None": None, +} + +__all__ = [ + "AtomicEnergiesBlock", + "RadialEmbeddingBlock", + "ZBLBasis", + "LinearNodeEmbeddingBlock", + "LinearReadoutBlock", + "EquivariantProductBasisBlock", + "ScaleShiftBlock", + "LinearDipoleReadoutBlock", + "NonLinearDipoleReadoutBlock", + "InteractionBlock", + "NonLinearReadoutBlock", + "PolynomialCutoff", + "BesselBasis", + "GaussianBasis", + "MACE", + "ScaleShiftMACE", + "AtomicDipolesMACE", + "EnergyDipolesMACE", + "WeightedEnergyForcesLoss", + "WeightedForcesLoss", + "WeightedEnergyForcesVirialsLoss", + "WeightedEnergyForcesStressLoss", + "DipoleSingleLoss", + "WeightedEnergyForcesDipoleLoss", + "WeightedHuberEnergyForcesStressLoss", + "UniversalLoss", + "WeightedEnergyForcesL1L2Loss", + "SymmetricContraction", + "interaction_classes", + "compute_mean_std_atomic_inter_energy", + "compute_avg_num_neighbors", + "compute_statistics", + "compute_fixed_charge_dipole", +] diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index fb761d632b28e4aa34bd81e2821cb4a84baca5d4..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 275df7cd255ab020edf677878b781720b033afac..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-310.pyc deleted file mode 100644 index 7da39a02fd504793d0c0ed9557483b8cf3022a7b..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-313.pyc deleted file mode 100644 index 953da4c8f0fe8e515db82c7cb2040303eb9ca22a..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-310.pyc deleted file mode 100644 index f7d4c4f4b908409ea738ec04bb4e3fa22abae401..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-313.pyc deleted file mode 100644 index d9952a6f58a04d837ee593b5827f6072247dcabd..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/loss.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/loss.cpython-310.pyc deleted file mode 100644 index 1bd846dfc236402135178704736eabf7f7c1c20d..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/loss.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/loss.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/loss.cpython-313.pyc deleted file mode 100644 index fb236e8b186f41a01cc2e12397ee3a022b9e6538..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/loss.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/models.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/models.cpython-310.pyc deleted file mode 100644 index d411204f33eaa55f26939aa388421530caee96e3..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/models.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/models.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/models.cpython-313.pyc deleted file mode 100644 index 27a8347805ba50ac1bfdcb25e948963e48b9bc45..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/models.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/radial.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/radial.cpython-310.pyc deleted file mode 100644 index efa36c8671fee63eb7ed4b48dad94a62dd663259..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/radial.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/radial.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/radial.cpython-313.pyc deleted file mode 100644 index 005cbe922ae2bbcf045a19e103c18f963be45260..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/radial.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/symmetric_contraction.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/symmetric_contraction.cpython-310.pyc deleted file mode 100644 index d3ecd794a2fb7b99386cd843f8f8829c7063840c..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/symmetric_contraction.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/symmetric_contraction.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/symmetric_contraction.cpython-313.pyc deleted file mode 100644 index 16ef822000ec5469fc71ac8ec2675ecc160a82cb..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/symmetric_contraction.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/utils.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 1c234334b2d5a3f1b0f67cfef85e44ae66aa0256..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/utils.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/utils.cpython-313.pyc deleted file mode 100644 index ed25f42b8ce17f63a58c526fa4c0e018d604a3e2..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/utils.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/wrapper_ops.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/wrapper_ops.cpython-310.pyc deleted file mode 100644 index e1a6ca179dee658675ccea09259a395a7382b397..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/wrapper_ops.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/wrapper_ops.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/wrapper_ops.cpython-313.pyc deleted file mode 100644 index 246f6bf3a7d1e6a8d3bb5a57898959c04d5d5947..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/modules/__pycache__/wrapper_ops.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/modules/blocks.py b/mace-bench/3rdparty/mace/mace/modules/blocks.py index 4c64e6a08e1ea50fabec9cb83f129ef094138080..bf1e2dad84939ce2efb9606697a1c7bc8c8564e5 100644 --- a/mace-bench/3rdparty/mace/mace/modules/blocks.py +++ b/mace-bench/3rdparty/mace/mace/modules/blocks.py @@ -1,922 +1,922 @@ -########################################################################################### -# Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from abc import abstractmethod -from typing import Any, Callable, List, Optional, Tuple, Union - -import numpy as np -import torch.nn.functional -from e3nn import nn, o3 -from e3nn.util.jit import compile_mode - -from mace.modules.wrapper_ops import ( - CuEquivarianceConfig, - FullyConnectedTensorProduct, - Linear, - SymmetricContractionWrapper, - TensorProduct, -) -from mace.tools.compile import simplify_if_compile -from mace.tools.scatter import scatter_sum -from mace.tools.utils import LAMMPS_MP - -from .irreps_tools import mask_head, reshape_irreps, tp_out_irreps_with_instructions -from .radial import ( - AgnesiTransform, - BesselBasis, - ChebychevBasis, - GaussianBasis, - PolynomialCutoff, - SoftTransform, -) - - -@compile_mode("script") -class LinearNodeEmbeddingBlock(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - irreps_out: o3.Irreps, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - super().__init__() - self.linear = Linear( - irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config - ) - - def forward( - self, - node_attrs: torch.Tensor, - ) -> torch.Tensor: # [n_nodes, irreps] - return self.linear(node_attrs) - - -@compile_mode("script") -class LinearReadoutBlock(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - irrep_out: o3.Irreps = o3.Irreps("0e"), - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - super().__init__() - self.linear = Linear( - irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config - ) - - def forward( - self, - x: torch.Tensor, - heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument - ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - return self.linear(x) # [n_nodes, 1] - - -@simplify_if_compile -@compile_mode("script") -class NonLinearReadoutBlock(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - MLP_irreps: o3.Irreps, - gate: Optional[Callable], - irrep_out: o3.Irreps = o3.Irreps("0e"), - num_heads: int = 1, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - super().__init__() - self.hidden_irreps = MLP_irreps - self.num_heads = num_heads - self.linear_1 = Linear( - irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config - ) - self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) - self.linear_2 = Linear( - irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config - ) - - def forward( - self, x: torch.Tensor, heads: Optional[torch.Tensor] = None - ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - x = self.non_linearity(self.linear_1(x)) - if hasattr(self, "num_heads"): - if self.num_heads > 1 and heads is not None: - x = mask_head(x, heads, self.num_heads) - return self.linear_2(x) # [n_nodes, len(heads)] - - -@compile_mode("script") -class LinearDipoleReadoutBlock(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - dipole_only: bool = False, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - super().__init__() - if dipole_only: - self.irreps_out = o3.Irreps("1x1o") - else: - self.irreps_out = o3.Irreps("1x0e + 1x1o") - self.linear = Linear( - irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - return self.linear(x) # [n_nodes, 1] - - -@compile_mode("script") -class NonLinearDipoleReadoutBlock(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - MLP_irreps: o3.Irreps, - gate: Callable, - dipole_only: bool = False, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - super().__init__() - self.hidden_irreps = MLP_irreps - if dipole_only: - self.irreps_out = o3.Irreps("1x1o") - else: - self.irreps_out = o3.Irreps("1x0e + 1x1o") - irreps_scalars = o3.Irreps( - [(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out] - ) - irreps_gated = o3.Irreps( - [(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out] - ) - irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated) - self.equivariant_nonlin = nn.Gate( - irreps_scalars=irreps_scalars, - act_scalars=[gate for _, ir in irreps_scalars], - irreps_gates=irreps_gates, - act_gates=[gate] * len(irreps_gates), - irreps_gated=irreps_gated, - ) - self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() - self.linear_1 = Linear( - irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config - ) - self.linear_2 = Linear( - irreps_in=self.hidden_irreps, - irreps_out=self.irreps_out, - cueq_config=cueq_config, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - x = self.equivariant_nonlin(self.linear_1(x)) - return self.linear_2(x) # [n_nodes, 1] - - -@compile_mode("script") -class AtomicEnergiesBlock(torch.nn.Module): - atomic_energies: torch.Tensor - - def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): - super().__init__() - # assert len(atomic_energies.shape) == 1 - - self.register_buffer( - "atomic_energies", - torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), - ) # [n_elements, n_heads] - - def forward( - self, x: torch.Tensor # one-hot of elements [..., n_elements] - ) -> torch.Tensor: # [..., ] - return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T) - - def __repr__(self): - formatted_energies = ", ".join( - [ - "[" + ", ".join([f"{x:.4f}" for x in group]) + "]" - for group in torch.atleast_2d(self.atomic_energies) - ] - ) - return f"{self.__class__.__name__}(energies=[{formatted_energies}])" - - -@compile_mode("script") -class RadialEmbeddingBlock(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - radial_type: str = "bessel", - distance_transform: str = "None", - ): - super().__init__() - if radial_type == "bessel": - self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel) - elif radial_type == "gaussian": - self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel) - elif radial_type == "chebyshev": - self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel) - if distance_transform == "Agnesi": - self.distance_transform = AgnesiTransform() - elif distance_transform == "Soft": - self.distance_transform = SoftTransform() - self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff) - self.out_dim = num_bessel - - def forward( - self, - edge_lengths: torch.Tensor, # [n_edges, 1] - node_attrs: torch.Tensor, - edge_index: torch.Tensor, - atomic_numbers: torch.Tensor, - ): - cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1] - if hasattr(self, "distance_transform"): - edge_lengths = self.distance_transform( - edge_lengths, node_attrs, edge_index, atomic_numbers - ) - radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis] - return radial * cutoff # [n_edges, n_basis] - - -@compile_mode("script") -class EquivariantProductBasisBlock(torch.nn.Module): - def __init__( - self, - node_feats_irreps: o3.Irreps, - target_irreps: o3.Irreps, - correlation: int, - use_sc: bool = True, - num_elements: Optional[int] = None, - cueq_config: Optional[CuEquivarianceConfig] = None, - ) -> None: - super().__init__() - - self.use_sc = use_sc - self.symmetric_contractions = SymmetricContractionWrapper( - irreps_in=node_feats_irreps, - irreps_out=target_irreps, - correlation=correlation, - num_elements=num_elements, - cueq_config=cueq_config, - ) - # Update linear - self.linear = Linear( - target_irreps, - target_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=cueq_config, - ) - self.cueq_config = cueq_config - - def forward( - self, - node_feats: torch.Tensor, - sc: Optional[torch.Tensor], - node_attrs: torch.Tensor, - ) -> torch.Tensor: - use_cueq = False - use_cueq_mul_ir = False - if hasattr(self, "cueq_config"): - if self.cueq_config is not None: - if self.cueq_config.enabled and ( - self.cueq_config.optimize_all or self.cueq_config.optimize_symmetric - ): - use_cueq = True - if self.cueq_config.layout_str == "mul_ir": - use_cueq_mul_ir = True - if use_cueq: - if use_cueq_mul_ir: - node_feats = torch.transpose(node_feats, 1, 2) - index_attrs = torch.nonzero(node_attrs)[:, 1].int() - node_feats = self.symmetric_contractions( - node_feats.flatten(1), - index_attrs, - ) - else: - node_feats = self.symmetric_contractions(node_feats, node_attrs) - if self.use_sc and sc is not None: - return self.linear(node_feats) + sc - return self.linear(node_feats) - - -@compile_mode("script") -class InteractionBlock(torch.nn.Module): - def __init__( - self, - node_attrs_irreps: o3.Irreps, - node_feats_irreps: o3.Irreps, - edge_attrs_irreps: o3.Irreps, - edge_feats_irreps: o3.Irreps, - target_irreps: o3.Irreps, - hidden_irreps: o3.Irreps, - avg_num_neighbors: float, - radial_MLP: Optional[List[int]] = None, - cueq_config: Optional[CuEquivarianceConfig] = None, - ) -> None: - super().__init__() - self.node_attrs_irreps = node_attrs_irreps - self.node_feats_irreps = node_feats_irreps - self.edge_attrs_irreps = edge_attrs_irreps - self.edge_feats_irreps = edge_feats_irreps - self.target_irreps = target_irreps - self.hidden_irreps = hidden_irreps - self.avg_num_neighbors = avg_num_neighbors - if radial_MLP is None: - radial_MLP = [64, 64, 64] - self.radial_MLP = radial_MLP - self.cueq_config = cueq_config - self._setup() - - @abstractmethod - def _setup(self) -> None: - raise NotImplementedError - - def handle_lammps( - self, - node_feats: torch.Tensor, - lammps_class: Optional[Any], - lammps_natoms: Tuple[int, int], - first_layer: bool, - ) -> torch.Tensor: # noqa: D401 – internal helper - if lammps_class is None or first_layer or torch.jit.is_scripting(): - return node_feats - _, n_total = lammps_natoms - pad = torch.zeros( - (n_total, node_feats.shape[1]), - dtype=node_feats.dtype, - device=node_feats.device, - ) - node_feats = torch.cat((node_feats, pad), dim=0) - node_feats = LAMMPS_MP.apply(node_feats, lammps_class) - return node_feats - - def truncate_ghosts( - self, tensor: torch.Tensor, n_real: Optional[int] = None - ) -> torch.Tensor: - """Truncate the tensor to only keep the real atoms in case of presence of ghost atoms during multi-GPU MD simulations.""" - return tensor[:n_real] if n_real is not None else tensor - - @abstractmethod - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - ) -> torch.Tensor: - raise NotImplementedError - - -nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh} - - -@compile_mode("script") -class RealAgnosticInteractionBlock(InteractionBlock): - def _setup(self) -> None: - if not hasattr(self, "cueq_config"): - self.cueq_config = None - # First linear - self.linear_up = Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - cueq_config=self.cueq_config, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, - ) - - # Linear - self.irreps_out = self.target_irreps - self.linear = Linear( - irreps_mid, - self.irreps_out, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - - # Selector TensorProduct - self.skip_tp = FullyConnectedTensorProduct( - self.irreps_out, - self.node_attrs_irreps, - self.irreps_out, - cueq_config=self.cueq_config, - ) - self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - lammps_natoms: Tuple[int, int] = (0, 0), - lammps_class: Optional[Any] = None, - first_layer: bool = False, - ) -> Tuple[torch.Tensor, None]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - n_real = lammps_natoms[0] if lammps_class is not None else None - node_feats = self.linear_up(node_feats) - node_feats = self.handle_lammps( - node_feats, - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - first_layer=first_layer, - ) - tp_weights = self.conv_tp_weights(edge_feats) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.truncate_ghosts(message, n_real) - node_attrs = self.truncate_ghosts(node_attrs, n_real) - message = self.linear(message) / self.avg_num_neighbors - message = self.skip_tp(message, node_attrs) - return ( - self.reshape(message), - None, - ) # [n_nodes, channels, (lmax + 1)**2] - - -@compile_mode("script") -class RealAgnosticResidualInteractionBlock(InteractionBlock): - def _setup(self) -> None: - if not hasattr(self, "cueq_config"): - self.cueq_config = None - # First linear - self.linear_up = Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - cueq_config=self.cueq_config, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, # gate - ) - - # Linear - self.irreps_out = self.target_irreps - self.linear = Linear( - irreps_mid, - self.irreps_out, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - - # Selector TensorProduct - self.skip_tp = FullyConnectedTensorProduct( - self.node_feats_irreps, - self.node_attrs_irreps, - self.hidden_irreps, - cueq_config=self.cueq_config, - ) - self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - lammps_class: Optional[Any] = None, - lammps_natoms: Tuple[int, int] = (0, 0), - first_layer: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - n_real = lammps_natoms[0] if lammps_class is not None else None - sc = self.skip_tp(node_feats, node_attrs) - node_feats = self.linear_up(node_feats) - node_feats = self.handle_lammps( - node_feats, - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - first_layer=first_layer, - ) - tp_weights = self.conv_tp_weights(edge_feats) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.truncate_ghosts(message, n_real) - node_attrs = self.truncate_ghosts(node_attrs, n_real) - sc = self.truncate_ghosts(sc, n_real) - message = self.linear(message) / self.avg_num_neighbors - return ( - self.reshape(message), - sc, - ) # [n_nodes, channels, (lmax + 1)**2] - - -@compile_mode("script") -class RealAgnosticDensityInteractionBlock(InteractionBlock): - def _setup(self) -> None: - if not hasattr(self, "cueq_config"): - self.cueq_config = None - # First linear - self.linear_up = Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - cueq_config=self.cueq_config, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, - ) - - # Linear - self.irreps_out = self.target_irreps - self.linear = Linear( - irreps_mid, - self.irreps_out, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - - # Selector TensorProduct - self.skip_tp = FullyConnectedTensorProduct( - self.irreps_out, - self.node_attrs_irreps, - self.irreps_out, - cueq_config=self.cueq_config, - ) - - # Density normalization - self.density_fn = nn.FullyConnectedNet( - [input_dim] - + [ - 1, - ], - torch.nn.functional.silu, - ) - # Reshape - self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - lammps_class: Optional[Any] = None, - lammps_natoms: Tuple[int, int] = (0, 0), - first_layer: bool = False, - ) -> Tuple[torch.Tensor, None]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - n_real = lammps_natoms[0] if lammps_class is not None else None - node_feats = self.linear_up(node_feats) - node_feats = self.handle_lammps( - node_feats, - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - first_layer=first_layer, - ) - tp_weights = self.conv_tp_weights(edge_feats) - edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - density = scatter_sum( - src=edge_density, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, 1] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.truncate_ghosts(message, n_real) - node_attrs = self.truncate_ghosts(node_attrs, n_real) - density = self.truncate_ghosts(density, n_real) - message = self.linear(message) / (density + 1) - message = self.skip_tp(message, node_attrs) - return ( - self.reshape(message), - None, - ) # [n_nodes, channels, (lmax + 1)**2] - - -@compile_mode("script") -class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): - def _setup(self) -> None: - if not hasattr(self, "cueq_config"): - self.cueq_config = None - - # First linear - self.linear_up = Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - cueq_config=self.cueq_config, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, # gate - ) - - # Linear - self.irreps_out = self.target_irreps - self.linear = Linear( - irreps_mid, - self.irreps_out, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - - # Selector TensorProduct - self.skip_tp = FullyConnectedTensorProduct( - self.node_feats_irreps, - self.node_attrs_irreps, - self.hidden_irreps, - cueq_config=self.cueq_config, - ) - - # Density normalization - self.density_fn = nn.FullyConnectedNet( - [input_dim] - + [ - 1, - ], - torch.nn.functional.silu, - ) - - # Reshape - self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - lammps_class: Optional[Any] = None, - lammps_natoms: Tuple[int, int] = (0, 0), - first_layer: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - n_real = lammps_natoms[0] if lammps_class is not None else None - sc = self.skip_tp(node_feats, node_attrs) - node_feats = self.linear_up(node_feats) - node_feats = self.handle_lammps( - node_feats, - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - first_layer=first_layer, - ) - tp_weights = self.conv_tp_weights(edge_feats) - edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - density = scatter_sum( - src=edge_density, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, 1] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.truncate_ghosts(message, n_real) - node_attrs = self.truncate_ghosts(node_attrs, n_real) - density = self.truncate_ghosts(density, n_real) - sc = self.truncate_ghosts(sc, n_real) - message = self.linear(message) / (density + 1) - return ( - self.reshape(message), - sc, - ) # [n_nodes, channels, (lmax + 1)**2] - - -@compile_mode("script") -class RealAgnosticAttResidualInteractionBlock(InteractionBlock): - def _setup(self) -> None: - if not hasattr(self, "cueq_config"): - self.cueq_config = None - self.node_feats_down_irreps = o3.Irreps("64x0e") - # First linear - self.linear_up = Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - cueq_config=self.cueq_config, - ) - - # Convolution weights - self.linear_down = Linear( - self.node_feats_irreps, - self.node_feats_down_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - input_dim = ( - self.edge_feats_irreps.num_irreps - + 2 * self.node_feats_down_irreps.num_irreps - ) - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], - torch.nn.functional.silu, - ) - - # Linear - self.irreps_out = self.target_irreps - self.linear = Linear( - irreps_mid, - self.irreps_out, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - - self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) - - # Skip connection. - self.skip_linear = Linear( - self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config - ) - - # pylint: disable=unused-argument - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - lammps_class: Optional[Any] = None, - lammps_natoms: Tuple[int, int] = (0, 0), - first_layer: bool = False, - ) -> Tuple[torch.Tensor, None]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - sc = self.skip_linear(node_feats) - node_feats_up = self.linear_up(node_feats) - node_feats_down = self.linear_down(node_feats) - augmented_edge_feats = torch.cat( - [ - edge_feats, - node_feats_down[sender], - node_feats_down[receiver], - ], - dim=-1, - ) - tp_weights = self.conv_tp_weights(augmented_edge_feats) - mji = self.conv_tp( - node_feats_up[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.linear(message) / self.avg_num_neighbors - return ( - self.reshape(message), - sc, - ) # [n_nodes, channels, (lmax + 1)**2] - - -@compile_mode("script") -class ScaleShiftBlock(torch.nn.Module): - def __init__(self, scale: float, shift: float): - super().__init__() - self.register_buffer( - "scale", - torch.tensor(scale, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "shift", - torch.tensor(shift, dtype=torch.get_default_dtype()), - ) - - def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor: - return ( - torch.atleast_1d(self.scale)[head] * x + torch.atleast_1d(self.shift)[head] - ) - - def __repr__(self): - formatted_scale = ( - ", ".join([f"{x:.4f}" for x in self.scale]) - if self.scale.numel() > 1 - else f"{self.scale.item():.4f}" - ) - formatted_shift = ( - ", ".join([f"{x:.4f}" for x in self.shift]) - if self.shift.numel() > 1 - else f"{self.shift.item():.4f}" - ) - return f"{self.__class__.__name__}(scale={formatted_scale}, shift={formatted_shift})" +########################################################################################### +# Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from abc import abstractmethod +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import torch.nn.functional +from e3nn import nn, o3 +from e3nn.util.jit import compile_mode + +from mace.modules.wrapper_ops import ( + CuEquivarianceConfig, + FullyConnectedTensorProduct, + Linear, + SymmetricContractionWrapper, + TensorProduct, +) +from mace.tools.compile import simplify_if_compile +from mace.tools.scatter import scatter_sum +from mace.tools.utils import LAMMPS_MP + +from .irreps_tools import mask_head, reshape_irreps, tp_out_irreps_with_instructions +from .radial import ( + AgnesiTransform, + BesselBasis, + ChebychevBasis, + GaussianBasis, + PolynomialCutoff, + SoftTransform, +) + + +@compile_mode("script") +class LinearNodeEmbeddingBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + self.linear = Linear( + irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config + ) + + def forward( + self, + node_attrs: torch.Tensor, + ) -> torch.Tensor: # [n_nodes, irreps] + return self.linear(node_attrs) + + +@compile_mode("script") +class LinearReadoutBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps = o3.Irreps("0e"), + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + self.linear = Linear( + irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config + ) + + def forward( + self, + x: torch.Tensor, + heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, 1] + + +@simplify_if_compile +@compile_mode("script") +class NonLinearReadoutBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + MLP_irreps: o3.Irreps, + gate: Optional[Callable], + irrep_out: o3.Irreps = o3.Irreps("0e"), + num_heads: int = 1, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + self.hidden_irreps = MLP_irreps + self.num_heads = num_heads + self.linear_1 = Linear( + irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config + ) + self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) + self.linear_2 = Linear( + irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config + ) + + def forward( + self, x: torch.Tensor, heads: Optional[torch.Tensor] = None + ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + x = self.non_linearity(self.linear_1(x)) + if hasattr(self, "num_heads"): + if self.num_heads > 1 and heads is not None: + x = mask_head(x, heads, self.num_heads) + return self.linear_2(x) # [n_nodes, len(heads)] + + +@compile_mode("script") +class LinearDipoleReadoutBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + dipole_only: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + if dipole_only: + self.irreps_out = o3.Irreps("1x1o") + else: + self.irreps_out = o3.Irreps("1x0e + 1x1o") + self.linear = Linear( + irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, 1] + + +@compile_mode("script") +class NonLinearDipoleReadoutBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + MLP_irreps: o3.Irreps, + gate: Callable, + dipole_only: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + self.hidden_irreps = MLP_irreps + if dipole_only: + self.irreps_out = o3.Irreps("1x1o") + else: + self.irreps_out = o3.Irreps("1x0e + 1x1o") + irreps_scalars = o3.Irreps( + [(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out] + ) + irreps_gated = o3.Irreps( + [(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out] + ) + irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated) + self.equivariant_nonlin = nn.Gate( + irreps_scalars=irreps_scalars, + act_scalars=[gate for _, ir in irreps_scalars], + irreps_gates=irreps_gates, + act_gates=[gate] * len(irreps_gates), + irreps_gated=irreps_gated, + ) + self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() + self.linear_1 = Linear( + irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config + ) + self.linear_2 = Linear( + irreps_in=self.hidden_irreps, + irreps_out=self.irreps_out, + cueq_config=cueq_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + x = self.equivariant_nonlin(self.linear_1(x)) + return self.linear_2(x) # [n_nodes, 1] + + +@compile_mode("script") +class AtomicEnergiesBlock(torch.nn.Module): + atomic_energies: torch.Tensor + + def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): + super().__init__() + # assert len(atomic_energies.shape) == 1 + + self.register_buffer( + "atomic_energies", + torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), + ) # [n_elements, n_heads] + + def forward( + self, x: torch.Tensor # one-hot of elements [..., n_elements] + ) -> torch.Tensor: # [..., ] + return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T) + + def __repr__(self): + formatted_energies = ", ".join( + [ + "[" + ", ".join([f"{x:.4f}" for x in group]) + "]" + for group in torch.atleast_2d(self.atomic_energies) + ] + ) + return f"{self.__class__.__name__}(energies=[{formatted_energies}])" + + +@compile_mode("script") +class RadialEmbeddingBlock(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + radial_type: str = "bessel", + distance_transform: str = "None", + ): + super().__init__() + if radial_type == "bessel": + self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel) + elif radial_type == "gaussian": + self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel) + elif radial_type == "chebyshev": + self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel) + if distance_transform == "Agnesi": + self.distance_transform = AgnesiTransform() + elif distance_transform == "Soft": + self.distance_transform = SoftTransform() + self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff) + self.out_dim = num_bessel + + def forward( + self, + edge_lengths: torch.Tensor, # [n_edges, 1] + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ): + cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1] + if hasattr(self, "distance_transform"): + edge_lengths = self.distance_transform( + edge_lengths, node_attrs, edge_index, atomic_numbers + ) + radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis] + return radial * cutoff # [n_edges, n_basis] + + +@compile_mode("script") +class EquivariantProductBasisBlock(torch.nn.Module): + def __init__( + self, + node_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + correlation: int, + use_sc: bool = True, + num_elements: Optional[int] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, + ) -> None: + super().__init__() + + self.use_sc = use_sc + self.symmetric_contractions = SymmetricContractionWrapper( + irreps_in=node_feats_irreps, + irreps_out=target_irreps, + correlation=correlation, + num_elements=num_elements, + cueq_config=cueq_config, + ) + # Update linear + self.linear = Linear( + target_irreps, + target_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=cueq_config, + ) + self.cueq_config = cueq_config + + def forward( + self, + node_feats: torch.Tensor, + sc: Optional[torch.Tensor], + node_attrs: torch.Tensor, + ) -> torch.Tensor: + use_cueq = False + use_cueq_mul_ir = False + if hasattr(self, "cueq_config"): + if self.cueq_config is not None: + if self.cueq_config.enabled and ( + self.cueq_config.optimize_all or self.cueq_config.optimize_symmetric + ): + use_cueq = True + if self.cueq_config.layout_str == "mul_ir": + use_cueq_mul_ir = True + if use_cueq: + if use_cueq_mul_ir: + node_feats = torch.transpose(node_feats, 1, 2) + index_attrs = torch.nonzero(node_attrs)[:, 1].int() + node_feats = self.symmetric_contractions( + node_feats.flatten(1), + index_attrs, + ) + else: + node_feats = self.symmetric_contractions(node_feats, node_attrs) + if self.use_sc and sc is not None: + return self.linear(node_feats) + sc + return self.linear(node_feats) + + +@compile_mode("script") +class InteractionBlock(torch.nn.Module): + def __init__( + self, + node_attrs_irreps: o3.Irreps, + node_feats_irreps: o3.Irreps, + edge_attrs_irreps: o3.Irreps, + edge_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + hidden_irreps: o3.Irreps, + avg_num_neighbors: float, + radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, + ) -> None: + super().__init__() + self.node_attrs_irreps = node_attrs_irreps + self.node_feats_irreps = node_feats_irreps + self.edge_attrs_irreps = edge_attrs_irreps + self.edge_feats_irreps = edge_feats_irreps + self.target_irreps = target_irreps + self.hidden_irreps = hidden_irreps + self.avg_num_neighbors = avg_num_neighbors + if radial_MLP is None: + radial_MLP = [64, 64, 64] + self.radial_MLP = radial_MLP + self.cueq_config = cueq_config + self._setup() + + @abstractmethod + def _setup(self) -> None: + raise NotImplementedError + + def handle_lammps( + self, + node_feats: torch.Tensor, + lammps_class: Optional[Any], + lammps_natoms: Tuple[int, int], + first_layer: bool, + ) -> torch.Tensor: # noqa: D401 – internal helper + if lammps_class is None or first_layer or torch.jit.is_scripting(): + return node_feats + _, n_total = lammps_natoms + pad = torch.zeros( + (n_total, node_feats.shape[1]), + dtype=node_feats.dtype, + device=node_feats.device, + ) + node_feats = torch.cat((node_feats, pad), dim=0) + node_feats = LAMMPS_MP.apply(node_feats, lammps_class) + return node_feats + + def truncate_ghosts( + self, tensor: torch.Tensor, n_real: Optional[int] = None + ) -> torch.Tensor: + """Truncate the tensor to only keep the real atoms in case of presence of ghost atoms during multi-GPU MD simulations.""" + return tensor[:n_real] if n_real is not None else tensor + + @abstractmethod + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh} + + +@compile_mode("script") +class RealAgnosticInteractionBlock(InteractionBlock): + def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + # First linear + self.linear_up = Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + cueq_config=self.cueq_config, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + self.irreps_out = self.target_irreps + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + + # Selector TensorProduct + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, + ) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + lammps_natoms: Tuple[int, int] = (0, 0), + lammps_class: Optional[Any] = None, + first_layer: bool = False, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + n_real = lammps_natoms[0] if lammps_class is not None else None + node_feats = self.linear_up(node_feats) + node_feats = self.handle_lammps( + node_feats, + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + first_layer=first_layer, + ) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.truncate_ghosts(message, n_real) + node_attrs = self.truncate_ghosts(node_attrs, n_real) + message = self.linear(message) / self.avg_num_neighbors + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + # First linear + self.linear_up = Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + cueq_config=self.cueq_config, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + self.irreps_out = self.target_irreps + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + + # Selector TensorProduct + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.hidden_irreps, + cueq_config=self.cueq_config, + ) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + lammps_class: Optional[Any] = None, + lammps_natoms: Tuple[int, int] = (0, 0), + first_layer: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + n_real = lammps_natoms[0] if lammps_class is not None else None + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + node_feats = self.handle_lammps( + node_feats, + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + first_layer=first_layer, + ) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.truncate_ghosts(message, n_real) + node_attrs = self.truncate_ghosts(node_attrs, n_real) + sc = self.truncate_ghosts(sc, n_real) + message = self.linear(message) / self.avg_num_neighbors + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityInteractionBlock(InteractionBlock): + def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + # First linear + self.linear_up = Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + cueq_config=self.cueq_config, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + self.irreps_out = self.target_irreps + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + + # Selector TensorProduct + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, + ) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + lammps_class: Optional[Any] = None, + lammps_natoms: Tuple[int, int] = (0, 0), + first_layer: bool = False, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + n_real = lammps_natoms[0] if lammps_class is not None else None + node_feats = self.linear_up(node_feats) + node_feats = self.handle_lammps( + node_feats, + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + first_layer=first_layer, + ) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.truncate_ghosts(message, n_real) + node_attrs = self.truncate_ghosts(node_attrs, n_real) + density = self.truncate_ghosts(density, n_real) + message = self.linear(message) / (density + 1) + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + + # First linear + self.linear_up = Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + cueq_config=self.cueq_config, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + self.irreps_out = self.target_irreps + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + + # Selector TensorProduct + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.hidden_irreps, + cueq_config=self.cueq_config, + ) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + # Reshape + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + lammps_class: Optional[Any] = None, + lammps_natoms: Tuple[int, int] = (0, 0), + first_layer: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + n_real = lammps_natoms[0] if lammps_class is not None else None + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + node_feats = self.handle_lammps( + node_feats, + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + first_layer=first_layer, + ) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.truncate_ghosts(message, n_real) + node_attrs = self.truncate_ghosts(node_attrs, n_real) + density = self.truncate_ghosts(density, n_real) + sc = self.truncate_ghosts(sc, n_real) + message = self.linear(message) / (density + 1) + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticAttResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + self.node_feats_down_irreps = o3.Irreps("64x0e") + # First linear + self.linear_up = Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + cueq_config=self.cueq_config, + ) + + # Convolution weights + self.linear_down = Linear( + self.node_feats_irreps, + self.node_feats_down_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + input_dim = ( + self.edge_feats_irreps.num_irreps + + 2 * self.node_feats_down_irreps.num_irreps + ) + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + self.irreps_out = self.target_irreps + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) + + # Skip connection. + self.skip_linear = Linear( + self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config + ) + + # pylint: disable=unused-argument + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + lammps_class: Optional[Any] = None, + lammps_natoms: Tuple[int, int] = (0, 0), + first_layer: bool = False, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_linear(node_feats) + node_feats_up = self.linear_up(node_feats) + node_feats_down = self.linear_down(node_feats) + augmented_edge_feats = torch.cat( + [ + edge_feats, + node_feats_down[sender], + node_feats_down[receiver], + ], + dim=-1, + ) + tp_weights = self.conv_tp_weights(augmented_edge_feats) + mji = self.conv_tp( + node_feats_up[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class ScaleShiftBlock(torch.nn.Module): + def __init__(self, scale: float, shift: float): + super().__init__() + self.register_buffer( + "scale", + torch.tensor(scale, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "shift", + torch.tensor(shift, dtype=torch.get_default_dtype()), + ) + + def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor: + return ( + torch.atleast_1d(self.scale)[head] * x + torch.atleast_1d(self.shift)[head] + ) + + def __repr__(self): + formatted_scale = ( + ", ".join([f"{x:.4f}" for x in self.scale]) + if self.scale.numel() > 1 + else f"{self.scale.item():.4f}" + ) + formatted_shift = ( + ", ".join([f"{x:.4f}" for x in self.shift]) + if self.shift.numel() > 1 + else f"{self.shift.item():.4f}" + ) + return f"{self.__class__.__name__}(scale={formatted_scale}, shift={formatted_shift})" diff --git a/mace-bench/3rdparty/mace/mace/modules/irreps_tools.py b/mace-bench/3rdparty/mace/mace/modules/irreps_tools.py index c3388018ceed362e2381f7d596b883ad184bc908..6677b1bef2be949300a26e5e852ef80fac8c0724 100644 --- a/mace-bench/3rdparty/mace/mace/modules/irreps_tools.py +++ b/mace-bench/3rdparty/mace/mace/modules/irreps_tools.py @@ -1,116 +1,116 @@ -########################################################################################### -# Elementary tools for handling irreducible representations -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from typing import List, Optional, Tuple - -import torch -from e3nn import o3 -from e3nn.util.jit import compile_mode - -from mace.modules.wrapper_ops import CuEquivarianceConfig - - -# Based on mir-group/nequip -def tp_out_irreps_with_instructions( - irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps -) -> Tuple[o3.Irreps, List]: - trainable = True - - # Collect possible irreps and their instructions - irreps_out_list: List[Tuple[int, o3.Irreps]] = [] - instructions = [] - for i, (mul, ir_in) in enumerate(irreps1): - for j, (_, ir_edge) in enumerate(irreps2): - for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 - if ir_out in target_irreps: - k = len(irreps_out_list) # instruction index - irreps_out_list.append((mul, ir_out)) - instructions.append((i, j, k, "uvu", trainable)) - - # We sort the output irreps of the tensor product so that we can simplify them - # when they are provided to the second o3.Linear - irreps_out = o3.Irreps(irreps_out_list) - irreps_out, permut, _ = irreps_out.sort() - - # Permute the output indexes of the instructions to match the sorted irreps: - instructions = [ - (i_in1, i_in2, permut[i_out], mode, train) - for i_in1, i_in2, i_out, mode, train in instructions - ] - - instructions = sorted(instructions, key=lambda x: x[2]) - - return irreps_out, instructions - - -def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: - # Assuming simplified irreps - irreps_mid = [] - for _, ir_in in irreps: - found = False - - for mul, ir_out in target_irreps: - if ir_in == ir_out: - irreps_mid.append((mul, ir_out)) - found = True - break - - if not found: - raise RuntimeError(f"{ir_in} not in {target_irreps}") - - return o3.Irreps(irreps_mid) - - -@compile_mode("script") -class reshape_irreps(torch.nn.Module): - def __init__( - self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None - ) -> None: - super().__init__() - self.irreps = o3.Irreps(irreps) - self.cueq_config = cueq_config - self.dims = [] - self.muls = [] - for mul, ir in self.irreps: - d = ir.dim - self.dims.append(d) - self.muls.append(mul) - - def forward(self, tensor: torch.Tensor) -> torch.Tensor: - ix = 0 - out = [] - batch, _ = tensor.shape - for mul, d in zip(self.muls, self.dims): - field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] - ix += mul * d - if hasattr(self, "cueq_config"): - if self.cueq_config is not None: - if self.cueq_config.layout_str == "mul_ir": - field = field.reshape(batch, mul, d) - else: - field = field.reshape(batch, d, mul) - else: - field = field.reshape(batch, mul, d) - else: - field = field.reshape(batch, mul, d) - out.append(field) - - if hasattr(self, "cueq_config"): - if self.cueq_config is not None: # pylint: disable=no-else-return - if self.cueq_config.layout_str == "mul_ir": - return torch.cat(out, dim=-1) - return torch.cat(out, dim=-2) - else: - return torch.cat(out, dim=-1) - return torch.cat(out, dim=-1) - - -def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor: - mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device) - idx = torch.arange(mask.shape[0], device=x.device) - mask[idx, :, head] = 1 - mask = mask.permute(0, 2, 1).reshape(x.shape) - return x * mask +########################################################################################### +# Elementary tools for handling irreducible representations +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import List, Optional, Tuple + +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + +from mace.modules.wrapper_ops import CuEquivarianceConfig + + +# Based on mir-group/nequip +def tp_out_irreps_with_instructions( + irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps +) -> Tuple[o3.Irreps, List]: + trainable = True + + # Collect possible irreps and their instructions + irreps_out_list: List[Tuple[int, o3.Irreps]] = [] + instructions = [] + for i, (mul, ir_in) in enumerate(irreps1): + for j, (_, ir_edge) in enumerate(irreps2): + for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 + if ir_out in target_irreps: + k = len(irreps_out_list) # instruction index + irreps_out_list.append((mul, ir_out)) + instructions.append((i, j, k, "uvu", trainable)) + + # We sort the output irreps of the tensor product so that we can simplify them + # when they are provided to the second o3.Linear + irreps_out = o3.Irreps(irreps_out_list) + irreps_out, permut, _ = irreps_out.sort() + + # Permute the output indexes of the instructions to match the sorted irreps: + instructions = [ + (i_in1, i_in2, permut[i_out], mode, train) + for i_in1, i_in2, i_out, mode, train in instructions + ] + + instructions = sorted(instructions, key=lambda x: x[2]) + + return irreps_out, instructions + + +def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: + # Assuming simplified irreps + irreps_mid = [] + for _, ir_in in irreps: + found = False + + for mul, ir_out in target_irreps: + if ir_in == ir_out: + irreps_mid.append((mul, ir_out)) + found = True + break + + if not found: + raise RuntimeError(f"{ir_in} not in {target_irreps}") + + return o3.Irreps(irreps_mid) + + +@compile_mode("script") +class reshape_irreps(torch.nn.Module): + def __init__( + self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None + ) -> None: + super().__init__() + self.irreps = o3.Irreps(irreps) + self.cueq_config = cueq_config + self.dims = [] + self.muls = [] + for mul, ir in self.irreps: + d = ir.dim + self.dims.append(d) + self.muls.append(mul) + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + ix = 0 + out = [] + batch, _ = tensor.shape + for mul, d in zip(self.muls, self.dims): + field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] + ix += mul * d + if hasattr(self, "cueq_config"): + if self.cueq_config is not None: + if self.cueq_config.layout_str == "mul_ir": + field = field.reshape(batch, mul, d) + else: + field = field.reshape(batch, d, mul) + else: + field = field.reshape(batch, mul, d) + else: + field = field.reshape(batch, mul, d) + out.append(field) + + if hasattr(self, "cueq_config"): + if self.cueq_config is not None: # pylint: disable=no-else-return + if self.cueq_config.layout_str == "mul_ir": + return torch.cat(out, dim=-1) + return torch.cat(out, dim=-2) + else: + return torch.cat(out, dim=-1) + return torch.cat(out, dim=-1) + + +def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor: + mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device) + idx = torch.arange(mask.shape[0], device=x.device) + mask[idx, :, head] = 1 + mask = mask.permute(0, 2, 1).reshape(x.shape) + return x * mask diff --git a/mace-bench/3rdparty/mace/mace/modules/loss.py b/mace-bench/3rdparty/mace/mace/modules/loss.py index 19ad76ae0ba13b8f9110986b73d3a152bd7ba59c..ff567e39444ecac83d10661054c02fd323bd5d2f 100644 --- a/mace-bench/3rdparty/mace/mace/modules/loss.py +++ b/mace-bench/3rdparty/mace/mace/modules/loss.py @@ -1,566 +1,566 @@ -########################################################################################### -# Implementation of different loss functions -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from typing import Optional - -import torch -import torch.distributed as dist - -from mace.tools import TensorDict -from mace.tools.torch_geometric import Batch - - -# ------------------------------------------------------------------------------ -# Helper function for loss reduction that handles DDP correction -# ------------------------------------------------------------------------------ -def is_ddp_enabled(): - return dist.is_initialized() and dist.get_world_size() > 1 - - -def reduce_loss(raw_loss: torch.Tensor, ddp: Optional[bool] = None) -> torch.Tensor: - """ - Reduces an element-wise loss tensor. - - If ddp is True and distributed is initialized, the function computes: - - loss = (local_sum * world_size) / global_num_elements - - Otherwise, it returns the regular mean. - """ - ddp = is_ddp_enabled() if ddp is None else ddp - if ddp and dist.is_initialized(): - world_size = dist.get_world_size() - n_local = raw_loss.numel() - loss_sum = raw_loss.sum() - total_samples = torch.tensor( - n_local, device=raw_loss.device, dtype=raw_loss.dtype - ) - dist.all_reduce(total_samples, op=dist.ReduceOp.SUM) - return loss_sum * world_size / total_samples - return raw_loss.mean() - - -# ------------------------------------------------------------------------------ -# Energy Loss Functions -# ------------------------------------------------------------------------------ - - -def mean_squared_error_energy( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - raw_loss = torch.square(ref["energy"] - pred["energy"]) - return reduce_loss(raw_loss, ddp) - - -def weighted_mean_squared_error_energy( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - # Calculate per-graph number of atoms. - num_atoms = ref.ptr[1:] - ref.ptr[:-1] # shape: [n_graphs] - raw_loss = ( - ref.weight - * ref.energy_weight - * torch.square((ref["energy"] - pred["energy"]) / num_atoms) - ) - return reduce_loss(raw_loss, ddp) - - -def weighted_mean_absolute_error_energy( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - num_atoms = ref.ptr[1:] - ref.ptr[:-1] - raw_loss = ( - ref.weight - * ref.energy_weight - * torch.abs((ref["energy"] - pred["energy"]) / num_atoms) - ) - return reduce_loss(raw_loss, ddp) - - -# ------------------------------------------------------------------------------ -# Stress and Virials Loss Functions -# ------------------------------------------------------------------------------ - - -def weighted_mean_squared_stress( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - configs_weight = ref.weight.view(-1, 1, 1) - configs_stress_weight = ref.stress_weight.view(-1, 1, 1) - raw_loss = ( - configs_weight - * configs_stress_weight - * torch.square(ref["stress"] - pred["stress"]) - ) - return reduce_loss(raw_loss, ddp) - - -def weighted_mean_squared_virials( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - configs_weight = ref.weight.view(-1, 1, 1) - configs_virials_weight = ref.virials_weight.view(-1, 1, 1) - num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) - raw_loss = ( - configs_weight - * configs_virials_weight - * torch.square((ref["virials"] - pred["virials"]) / num_atoms) - ) - return reduce_loss(raw_loss, ddp) - - -# ------------------------------------------------------------------------------ -# Forces Loss Functions -# ------------------------------------------------------------------------------ - - -def mean_squared_error_forces( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - # Repeat per-graph weights to per-atom level. - configs_weight = torch.repeat_interleave( - ref.weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze(-1) - configs_forces_weight = torch.repeat_interleave( - ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze(-1) - raw_loss = ( - configs_weight - * configs_forces_weight - * torch.square(ref["forces"] - pred["forces"]) - ) - return reduce_loss(raw_loss, ddp) - - -def mean_normed_error_forces( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - raw_loss = torch.linalg.vector_norm(ref["forces"] - pred["forces"], ord=2, dim=-1) - return reduce_loss(raw_loss, ddp) - - -# ------------------------------------------------------------------------------ -# Dipole Loss Function -# ------------------------------------------------------------------------------ - - -def weighted_mean_squared_error_dipole( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).unsqueeze(-1) - raw_loss = torch.square((ref["dipole"] - pred["dipole"]) / num_atoms) - return reduce_loss(raw_loss, ddp) - - -# ------------------------------------------------------------------------------ -# Conditional Losses for Forces -# ------------------------------------------------------------------------------ - - -def conditional_mse_forces( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - configs_weight = torch.repeat_interleave( - ref.weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze(-1) - configs_forces_weight = torch.repeat_interleave( - ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze(-1) - # Define multiplication factors for different regimes. - factors = torch.tensor( - [1.0, 0.7, 0.4, 0.1], device=ref["forces"].device, dtype=ref["forces"].dtype - ) - err = ref["forces"] - pred["forces"] - se = torch.zeros_like(err) - norm_forces = torch.norm(ref["forces"], dim=-1) - c1 = norm_forces < 100 - c2 = (norm_forces >= 100) & (norm_forces < 200) - c3 = (norm_forces >= 200) & (norm_forces < 300) - se[c1] = torch.square(err[c1]) * factors[0] - se[c2] = torch.square(err[c2]) * factors[1] - se[c3] = torch.square(err[c3]) * factors[2] - se[~(c1 | c2 | c3)] = torch.square(err[~(c1 | c2 | c3)]) * factors[3] - raw_loss = configs_weight * configs_forces_weight * se - return reduce_loss(raw_loss, ddp) - - -def conditional_huber_forces( - ref_forces: torch.Tensor, - pred_forces: torch.Tensor, - huber_delta: float, - ddp: Optional[bool] = None, -) -> torch.Tensor: - factors = huber_delta * torch.tensor( - [1.0, 0.7, 0.4, 0.1], device=ref_forces.device, dtype=ref_forces.dtype - ) - norm_forces = torch.norm(ref_forces, dim=-1) - c1 = norm_forces < 100 - c2 = (norm_forces >= 100) & (norm_forces < 200) - c3 = (norm_forces >= 200) & (norm_forces < 300) - c4 = ~(c1 | c2 | c3) - se = torch.zeros_like(pred_forces) - se[c1] = torch.nn.functional.huber_loss( - ref_forces[c1], pred_forces[c1], reduction="none", delta=factors[0] - ) - se[c2] = torch.nn.functional.huber_loss( - ref_forces[c2], pred_forces[c2], reduction="none", delta=factors[1] - ) - se[c3] = torch.nn.functional.huber_loss( - ref_forces[c3], pred_forces[c3], reduction="none", delta=factors[2] - ) - se[c4] = torch.nn.functional.huber_loss( - ref_forces[c4], pred_forces[c4], reduction="none", delta=factors[3] - ) - return reduce_loss(se, ddp) - - -# ------------------------------------------------------------------------------ -# Loss Modules Combining Multiple Quantities -# ------------------------------------------------------------------------------ - - -class WeightedEnergyForcesLoss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) - loss_forces = mean_squared_error_forces(ref, pred, ddp) - return self.energy_weight * loss_energy + self.forces_weight * loss_forces - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f})" - ) - - -class WeightedForcesLoss(torch.nn.Module): - def __init__(self, forces_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_forces = mean_squared_error_forces(ref, pred, ddp) - return self.forces_weight * loss_forces - - def __repr__(self): - return f"{self.__class__.__name__}(forces_weight={self.forces_weight:.3f})" - - -class WeightedEnergyForcesStressLoss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "stress_weight", - torch.tensor(stress_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) - loss_forces = mean_squared_error_forces(ref, pred, ddp) - loss_stress = weighted_mean_squared_stress(ref, pred, ddp) - return ( - self.energy_weight * loss_energy - + self.forces_weight * loss_forces - + self.stress_weight * loss_stress - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" - ) - - -class WeightedHuberEnergyForcesStressLoss(torch.nn.Module): - def __init__( - self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 - ) -> None: - super().__init__() - # We store the huber_delta rather than a loss with fixed reduction. - self.huber_delta = huber_delta - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "stress_weight", - torch.tensor(stress_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - num_atoms = ref.ptr[1:] - ref.ptr[:-1] - if ddp: - loss_energy = torch.nn.functional.huber_loss( - ref["energy"] / num_atoms, - pred["energy"] / num_atoms, - reduction="none", - delta=self.huber_delta, - ) - loss_energy = reduce_loss(loss_energy, ddp) - loss_forces = torch.nn.functional.huber_loss( - ref["forces"], pred["forces"], reduction="none", delta=self.huber_delta - ) - loss_forces = reduce_loss(loss_forces, ddp) - loss_stress = torch.nn.functional.huber_loss( - ref["stress"], pred["stress"], reduction="none", delta=self.huber_delta - ) - loss_stress = reduce_loss(loss_stress, ddp) - else: - loss_energy = torch.nn.functional.huber_loss( - ref["energy"] / num_atoms, - pred["energy"] / num_atoms, - reduction="mean", - delta=self.huber_delta, - ) - loss_forces = torch.nn.functional.huber_loss( - ref["forces"], pred["forces"], reduction="mean", delta=self.huber_delta - ) - loss_stress = torch.nn.functional.huber_loss( - ref["stress"], pred["stress"], reduction="mean", delta=self.huber_delta - ) - return ( - self.energy_weight * loss_energy - + self.forces_weight * loss_forces - + self.stress_weight * loss_stress - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" - ) - - -class UniversalLoss(torch.nn.Module): - def __init__( - self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 - ) -> None: - super().__init__() - self.huber_delta = huber_delta - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "stress_weight", - torch.tensor(stress_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - num_atoms = ref.ptr[1:] - ref.ptr[:-1] - configs_stress_weight = ref.stress_weight.view(-1, 1, 1) - configs_energy_weight = ref.energy_weight - configs_forces_weight = torch.repeat_interleave( - ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze(-1) - if ddp: - loss_energy = torch.nn.functional.huber_loss( - configs_energy_weight * ref["energy"] / num_atoms, - configs_energy_weight * pred["energy"] / num_atoms, - reduction="none", - delta=self.huber_delta, - ) - loss_energy = reduce_loss(loss_energy, ddp) - loss_forces = conditional_huber_forces( - configs_forces_weight * ref["forces"], - configs_forces_weight * pred["forces"], - huber_delta=self.huber_delta, - ddp=ddp, - ) - loss_stress = torch.nn.functional.huber_loss( - configs_stress_weight * ref["stress"], - configs_stress_weight * pred["stress"], - reduction="none", - delta=self.huber_delta, - ) - loss_stress = reduce_loss(loss_stress, ddp) - else: - loss_energy = torch.nn.functional.huber_loss( - configs_energy_weight * ref["energy"] / num_atoms, - configs_energy_weight * pred["energy"] / num_atoms, - reduction="mean", - delta=self.huber_delta, - ) - loss_forces = conditional_huber_forces( - configs_forces_weight * ref["forces"], - configs_forces_weight * pred["forces"], - huber_delta=self.huber_delta, - ddp=ddp, - ) - loss_stress = torch.nn.functional.huber_loss( - configs_stress_weight * ref["stress"], - configs_stress_weight * pred["stress"], - reduction="mean", - delta=self.huber_delta, - ) - return ( - self.energy_weight * loss_energy - + self.forces_weight * loss_forces - + self.stress_weight * loss_stress - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" - ) - - -class WeightedEnergyForcesVirialsLoss(torch.nn.Module): - def __init__( - self, energy_weight=1.0, forces_weight=1.0, virials_weight=1.0 - ) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "virials_weight", - torch.tensor(virials_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) - loss_forces = mean_squared_error_forces(ref, pred, ddp) - loss_virials = weighted_mean_squared_virials(ref, pred, ddp) - return ( - self.energy_weight * loss_energy - + self.forces_weight * loss_forces - + self.virials_weight * loss_virials - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, virials_weight={self.virials_weight:.3f})" - ) - - -class DipoleSingleLoss(torch.nn.Module): - def __init__(self, dipole_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "dipole_weight", - torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss = ( - weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0 - ) # scale adjustment - return self.dipole_weight * loss - - def __repr__(self): - return f"{self.__class__.__name__}(dipole_weight={self.dipole_weight:.3f})" - - -class WeightedEnergyForcesDipoleLoss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0, dipole_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "dipole_weight", - torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) - loss_forces = mean_squared_error_forces(ref, pred, ddp) - loss_dipole = weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0 - return ( - self.energy_weight * loss_energy - + self.forces_weight * loss_forces - + self.dipole_weight * loss_dipole - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})" - ) - - -class WeightedEnergyForcesL1L2Loss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_energy = weighted_mean_absolute_error_energy(ref, pred, ddp) - loss_forces = mean_normed_error_forces(ref, pred, ddp) - return self.energy_weight * loss_energy + self.forces_weight * loss_forces - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f})" - ) +########################################################################################### +# Implementation of different loss functions +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Optional + +import torch +import torch.distributed as dist + +from mace.tools import TensorDict +from mace.tools.torch_geometric import Batch + + +# ------------------------------------------------------------------------------ +# Helper function for loss reduction that handles DDP correction +# ------------------------------------------------------------------------------ +def is_ddp_enabled(): + return dist.is_initialized() and dist.get_world_size() > 1 + + +def reduce_loss(raw_loss: torch.Tensor, ddp: Optional[bool] = None) -> torch.Tensor: + """ + Reduces an element-wise loss tensor. + + If ddp is True and distributed is initialized, the function computes: + + loss = (local_sum * world_size) / global_num_elements + + Otherwise, it returns the regular mean. + """ + ddp = is_ddp_enabled() if ddp is None else ddp + if ddp and dist.is_initialized(): + world_size = dist.get_world_size() + n_local = raw_loss.numel() + loss_sum = raw_loss.sum() + total_samples = torch.tensor( + n_local, device=raw_loss.device, dtype=raw_loss.dtype + ) + dist.all_reduce(total_samples, op=dist.ReduceOp.SUM) + return loss_sum * world_size / total_samples + return raw_loss.mean() + + +# ------------------------------------------------------------------------------ +# Energy Loss Functions +# ------------------------------------------------------------------------------ + + +def mean_squared_error_energy( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + raw_loss = torch.square(ref["energy"] - pred["energy"]) + return reduce_loss(raw_loss, ddp) + + +def weighted_mean_squared_error_energy( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + # Calculate per-graph number of atoms. + num_atoms = ref.ptr[1:] - ref.ptr[:-1] # shape: [n_graphs] + raw_loss = ( + ref.weight + * ref.energy_weight + * torch.square((ref["energy"] - pred["energy"]) / num_atoms) + ) + return reduce_loss(raw_loss, ddp) + + +def weighted_mean_absolute_error_energy( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + num_atoms = ref.ptr[1:] - ref.ptr[:-1] + raw_loss = ( + ref.weight + * ref.energy_weight + * torch.abs((ref["energy"] - pred["energy"]) / num_atoms) + ) + return reduce_loss(raw_loss, ddp) + + +# ------------------------------------------------------------------------------ +# Stress and Virials Loss Functions +# ------------------------------------------------------------------------------ + + +def weighted_mean_squared_stress( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + configs_weight = ref.weight.view(-1, 1, 1) + configs_stress_weight = ref.stress_weight.view(-1, 1, 1) + raw_loss = ( + configs_weight + * configs_stress_weight + * torch.square(ref["stress"] - pred["stress"]) + ) + return reduce_loss(raw_loss, ddp) + + +def weighted_mean_squared_virials( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + configs_weight = ref.weight.view(-1, 1, 1) + configs_virials_weight = ref.virials_weight.view(-1, 1, 1) + num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) + raw_loss = ( + configs_weight + * configs_virials_weight + * torch.square((ref["virials"] - pred["virials"]) / num_atoms) + ) + return reduce_loss(raw_loss, ddp) + + +# ------------------------------------------------------------------------------ +# Forces Loss Functions +# ------------------------------------------------------------------------------ + + +def mean_squared_error_forces( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + # Repeat per-graph weights to per-atom level. + configs_weight = torch.repeat_interleave( + ref.weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) + raw_loss = ( + configs_weight + * configs_forces_weight + * torch.square(ref["forces"] - pred["forces"]) + ) + return reduce_loss(raw_loss, ddp) + + +def mean_normed_error_forces( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + raw_loss = torch.linalg.vector_norm(ref["forces"] - pred["forces"], ord=2, dim=-1) + return reduce_loss(raw_loss, ddp) + + +# ------------------------------------------------------------------------------ +# Dipole Loss Function +# ------------------------------------------------------------------------------ + + +def weighted_mean_squared_error_dipole( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).unsqueeze(-1) + raw_loss = torch.square((ref["dipole"] - pred["dipole"]) / num_atoms) + return reduce_loss(raw_loss, ddp) + + +# ------------------------------------------------------------------------------ +# Conditional Losses for Forces +# ------------------------------------------------------------------------------ + + +def conditional_mse_forces( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + configs_weight = torch.repeat_interleave( + ref.weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) + # Define multiplication factors for different regimes. + factors = torch.tensor( + [1.0, 0.7, 0.4, 0.1], device=ref["forces"].device, dtype=ref["forces"].dtype + ) + err = ref["forces"] - pred["forces"] + se = torch.zeros_like(err) + norm_forces = torch.norm(ref["forces"], dim=-1) + c1 = norm_forces < 100 + c2 = (norm_forces >= 100) & (norm_forces < 200) + c3 = (norm_forces >= 200) & (norm_forces < 300) + se[c1] = torch.square(err[c1]) * factors[0] + se[c2] = torch.square(err[c2]) * factors[1] + se[c3] = torch.square(err[c3]) * factors[2] + se[~(c1 | c2 | c3)] = torch.square(err[~(c1 | c2 | c3)]) * factors[3] + raw_loss = configs_weight * configs_forces_weight * se + return reduce_loss(raw_loss, ddp) + + +def conditional_huber_forces( + ref_forces: torch.Tensor, + pred_forces: torch.Tensor, + huber_delta: float, + ddp: Optional[bool] = None, +) -> torch.Tensor: + factors = huber_delta * torch.tensor( + [1.0, 0.7, 0.4, 0.1], device=ref_forces.device, dtype=ref_forces.dtype + ) + norm_forces = torch.norm(ref_forces, dim=-1) + c1 = norm_forces < 100 + c2 = (norm_forces >= 100) & (norm_forces < 200) + c3 = (norm_forces >= 200) & (norm_forces < 300) + c4 = ~(c1 | c2 | c3) + se = torch.zeros_like(pred_forces) + se[c1] = torch.nn.functional.huber_loss( + ref_forces[c1], pred_forces[c1], reduction="none", delta=factors[0] + ) + se[c2] = torch.nn.functional.huber_loss( + ref_forces[c2], pred_forces[c2], reduction="none", delta=factors[1] + ) + se[c3] = torch.nn.functional.huber_loss( + ref_forces[c3], pred_forces[c3], reduction="none", delta=factors[2] + ) + se[c4] = torch.nn.functional.huber_loss( + ref_forces[c4], pred_forces[c4], reduction="none", delta=factors[3] + ) + return reduce_loss(se, ddp) + + +# ------------------------------------------------------------------------------ +# Loss Modules Combining Multiple Quantities +# ------------------------------------------------------------------------------ + + +class WeightedEnergyForcesLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) + loss_forces = mean_squared_error_forces(ref, pred, ddp) + return self.energy_weight * loss_energy + self.forces_weight * loss_forces + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f})" + ) + + +class WeightedForcesLoss(torch.nn.Module): + def __init__(self, forces_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_forces = mean_squared_error_forces(ref, pred, ddp) + return self.forces_weight * loss_forces + + def __repr__(self): + return f"{self.__class__.__name__}(forces_weight={self.forces_weight:.3f})" + + +class WeightedEnergyForcesStressLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) + loss_forces = mean_squared_error_forces(ref, pred, ddp) + loss_stress = weighted_mean_squared_stress(ref, pred, ddp) + return ( + self.energy_weight * loss_energy + + self.forces_weight * loss_forces + + self.stress_weight * loss_stress + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class WeightedHuberEnergyForcesStressLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 + ) -> None: + super().__init__() + # We store the huber_delta rather than a loss with fixed reduction. + self.huber_delta = huber_delta + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + num_atoms = ref.ptr[1:] - ref.ptr[:-1] + if ddp: + loss_energy = torch.nn.functional.huber_loss( + ref["energy"] / num_atoms, + pred["energy"] / num_atoms, + reduction="none", + delta=self.huber_delta, + ) + loss_energy = reduce_loss(loss_energy, ddp) + loss_forces = torch.nn.functional.huber_loss( + ref["forces"], pred["forces"], reduction="none", delta=self.huber_delta + ) + loss_forces = reduce_loss(loss_forces, ddp) + loss_stress = torch.nn.functional.huber_loss( + ref["stress"], pred["stress"], reduction="none", delta=self.huber_delta + ) + loss_stress = reduce_loss(loss_stress, ddp) + else: + loss_energy = torch.nn.functional.huber_loss( + ref["energy"] / num_atoms, + pred["energy"] / num_atoms, + reduction="mean", + delta=self.huber_delta, + ) + loss_forces = torch.nn.functional.huber_loss( + ref["forces"], pred["forces"], reduction="mean", delta=self.huber_delta + ) + loss_stress = torch.nn.functional.huber_loss( + ref["stress"], pred["stress"], reduction="mean", delta=self.huber_delta + ) + return ( + self.energy_weight * loss_energy + + self.forces_weight * loss_forces + + self.stress_weight * loss_stress + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class UniversalLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 + ) -> None: + super().__init__() + self.huber_delta = huber_delta + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + num_atoms = ref.ptr[1:] - ref.ptr[:-1] + configs_stress_weight = ref.stress_weight.view(-1, 1, 1) + configs_energy_weight = ref.energy_weight + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) + if ddp: + loss_energy = torch.nn.functional.huber_loss( + configs_energy_weight * ref["energy"] / num_atoms, + configs_energy_weight * pred["energy"] / num_atoms, + reduction="none", + delta=self.huber_delta, + ) + loss_energy = reduce_loss(loss_energy, ddp) + loss_forces = conditional_huber_forces( + configs_forces_weight * ref["forces"], + configs_forces_weight * pred["forces"], + huber_delta=self.huber_delta, + ddp=ddp, + ) + loss_stress = torch.nn.functional.huber_loss( + configs_stress_weight * ref["stress"], + configs_stress_weight * pred["stress"], + reduction="none", + delta=self.huber_delta, + ) + loss_stress = reduce_loss(loss_stress, ddp) + else: + loss_energy = torch.nn.functional.huber_loss( + configs_energy_weight * ref["energy"] / num_atoms, + configs_energy_weight * pred["energy"] / num_atoms, + reduction="mean", + delta=self.huber_delta, + ) + loss_forces = conditional_huber_forces( + configs_forces_weight * ref["forces"], + configs_forces_weight * pred["forces"], + huber_delta=self.huber_delta, + ddp=ddp, + ) + loss_stress = torch.nn.functional.huber_loss( + configs_stress_weight * ref["stress"], + configs_stress_weight * pred["stress"], + reduction="mean", + delta=self.huber_delta, + ) + return ( + self.energy_weight * loss_energy + + self.forces_weight * loss_forces + + self.stress_weight * loss_stress + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class WeightedEnergyForcesVirialsLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, virials_weight=1.0 + ) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "virials_weight", + torch.tensor(virials_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) + loss_forces = mean_squared_error_forces(ref, pred, ddp) + loss_virials = weighted_mean_squared_virials(ref, pred, ddp) + return ( + self.energy_weight * loss_energy + + self.forces_weight * loss_forces + + self.virials_weight * loss_virials + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, virials_weight={self.virials_weight:.3f})" + ) + + +class DipoleSingleLoss(torch.nn.Module): + def __init__(self, dipole_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "dipole_weight", + torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss = ( + weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0 + ) # scale adjustment + return self.dipole_weight * loss + + def __repr__(self): + return f"{self.__class__.__name__}(dipole_weight={self.dipole_weight:.3f})" + + +class WeightedEnergyForcesDipoleLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0, dipole_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "dipole_weight", + torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) + loss_forces = mean_squared_error_forces(ref, pred, ddp) + loss_dipole = weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0 + return ( + self.energy_weight * loss_energy + + self.forces_weight * loss_forces + + self.dipole_weight * loss_dipole + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})" + ) + + +class WeightedEnergyForcesL1L2Loss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_energy = weighted_mean_absolute_error_energy(ref, pred, ddp) + loss_forces = mean_normed_error_forces(ref, pred, ddp) + return self.energy_weight * loss_energy + self.forces_weight * loss_forces + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f})" + ) diff --git a/mace-bench/3rdparty/mace/mace/modules/models.py b/mace-bench/3rdparty/mace/mace/modules/models.py index c6ba5bc96ff66869b07d532989ea26daed809e54..b551f8bf33c11844589c58896a6aea61f7a2df04 100644 --- a/mace-bench/3rdparty/mace/mace/modules/models.py +++ b/mace-bench/3rdparty/mace/mace/modules/models.py @@ -1,947 +1,947 @@ -########################################################################################### -# Implementation of MACE models and other models based E(3)-Equivariant MPNNs -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from typing import Any, Callable, Dict, List, Optional, Type, Union - -import numpy as np -import torch -from e3nn import o3 -from e3nn.util.jit import compile_mode - -from mace.modules.radial import ZBLBasis -from mace.tools.scatter import scatter_sum - -from .blocks import ( - AtomicEnergiesBlock, - EquivariantProductBasisBlock, - InteractionBlock, - LinearDipoleReadoutBlock, - LinearNodeEmbeddingBlock, - LinearReadoutBlock, - NonLinearDipoleReadoutBlock, - NonLinearReadoutBlock, - RadialEmbeddingBlock, - ScaleShiftBlock, -) -from .utils import ( - compute_fixed_charge_dipole, - get_atomic_virials_stresses, - get_edge_vectors_and_lengths, - get_outputs, - get_symmetric_displacement, - prepare_graph, -) - -# pylint: disable=C0302 - - -@compile_mode("script") -class MACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - atomic_energies: np.ndarray, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: Union[int, List[int]], - gate: Optional[Callable], - pair_repulsion: bool = False, - distance_transform: str = "None", - radial_MLP: Optional[List[int]] = None, - radial_type: Optional[str] = "bessel", - heads: Optional[List[str]] = None, - cueq_config: Optional[Dict[str, Any]] = None, - lammps_mliap: Optional[bool] = False, - ): - super().__init__() - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) - self.register_buffer( - "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) - ) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - if heads is None: - heads = ["Default"] - self.heads = heads - if isinstance(correlation, int): - correlation = [correlation] * num_interactions - self.lammps_mliap = lammps_mliap - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, - irreps_out=node_feats_irreps, - cueq_config=cueq_config, - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - radial_type=radial_type, - distance_transform=distance_transform, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - if pair_repulsion: - self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff) - self.pair_repulsion = True - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - # Interactions and readout - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - cueq_config=cueq_config, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer for proper E0 - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation[0], - num_elements=num_elements, - use_sc=use_sc_first, - cueq_config=cueq_config, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append( - LinearReadoutBlock( - hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config - ) - ) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - hidden_irreps_out = str( - hidden_irreps[0] - ) # Select only scalars for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - cueq_config=cueq_config, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation[i + 1], - num_elements=num_elements, - use_sc=True, - cueq_config=cueq_config, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearReadoutBlock( - hidden_irreps_out, - (len(heads) * MLP_irreps).simplify(), - gate, - o3.Irreps(f"{len(heads)}x0e"), - len(heads), - cueq_config, - ) - ) - else: - self.readouts.append( - LinearReadoutBlock( - hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config - ) - ) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_hessian: bool = False, - compute_edge_forces: bool = False, - compute_atomic_stresses: bool = False, - lammps_mliap: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - ctx = prepare_graph( - data, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_displacement=compute_displacement, - lammps_mliap=lammps_mliap, - ) - is_lammps = ctx.is_lammps - num_atoms_arange = ctx.num_atoms_arange - num_graphs = ctx.num_graphs - displacement = ctx.displacement - positions = ctx.positions - vectors = ctx.vectors - lengths = ctx.lengths - cell = ctx.cell - node_heads = ctx.node_heads - interaction_kwargs = ctx.interaction_kwargs - lammps_natoms = interaction_kwargs.lammps_natoms - lammps_class = interaction_kwargs.lammps_class - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, node_heads - ] - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs - ) # [n_graphs, n_heads] - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - if hasattr(self, "pair_repulsion"): - pair_node_energy = self.pair_repulsion_fn( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - if is_lammps: - pair_node_energy = pair_node_energy[: lammps_natoms[0]] - pair_energy = scatter_sum( - src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - else: - pair_node_energy = torch.zeros_like(node_e0) - pair_energy = torch.zeros_like(e0) - - # Interactions - energies = [e0, pair_energy] - node_energies_list = [node_e0, pair_node_energy] - node_feats_concat: List[torch.Tensor] = [] - - for i, (interaction, product, readout) in enumerate( - zip(self.interactions, self.products, self.readouts) - ): - node_attrs_slice = data["node_attrs"] - if is_lammps and i > 0: - node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] - node_feats, sc = interaction( - node_attrs=node_attrs_slice, - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - first_layer=(i == 0), - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - ) - if is_lammps and i == 0: - node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] - node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice - ) - node_feats_concat.append(node_feats) - node_es = readout(node_feats, node_heads)[num_atoms_arange, node_heads] - energy = scatter_sum(node_es, data["batch"], dim=0, dim_size=num_graphs) - energies.append(energy) - node_energies_list.append(node_es) - - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) - node_energy = torch.sum(torch.stack(node_energies_list, dim=-1), dim=-1) - node_feats_out = torch.cat(node_feats_concat, dim=-1) - node_energy = node_e0.double() + pair_node_energy.double() - - forces, virials, stress, hessian, edge_forces = get_outputs( - energy=total_energy, - positions=positions, - displacement=displacement, - vectors=vectors, - cell=cell, - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_hessian=compute_hessian, - compute_edge_forces=compute_edge_forces, - ) - - atomic_virials: Optional[torch.Tensor] = None - atomic_stresses: Optional[torch.Tensor] = None - if compute_atomic_stresses and edge_forces is not None: - atomic_virials, atomic_stresses = get_atomic_virials_stresses( - edge_forces=edge_forces, - edge_index=data["edge_index"], - vectors=vectors, - num_atoms=positions.shape[0], - batch=data["batch"], - cell=cell, - ) - return { - "energy": total_energy, - "node_energy": node_energy, - "contributions": contributions, - "forces": forces, - "edge_forces": edge_forces, - "virials": virials, - "stress": stress, - "atomic_virials": atomic_virials, - "atomic_stresses": atomic_stresses, - "displacement": displacement, - "hessian": hessian, - "node_feats": node_feats_out, - } - - -@compile_mode("script") -class ScaleShiftMACE(MACE): - def __init__( - self, - atomic_inter_scale: float, - atomic_inter_shift: float, - **kwargs, - ): - super().__init__(**kwargs) - self.scale_shift = ScaleShiftBlock( - scale=atomic_inter_scale, shift=atomic_inter_shift - ) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_hessian: bool = False, - compute_edge_forces: bool = False, - compute_atomic_stresses: bool = False, - lammps_mliap: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - ctx = prepare_graph( - data, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_displacement=compute_displacement, - lammps_mliap=lammps_mliap, - ) - - is_lammps = ctx.is_lammps - num_atoms_arange = ctx.num_atoms_arange - num_graphs = ctx.num_graphs - displacement = ctx.displacement - positions = ctx.positions - vectors = ctx.vectors - lengths = ctx.lengths - cell = ctx.cell - node_heads = ctx.node_heads - interaction_kwargs = ctx.interaction_kwargs - lammps_natoms = interaction_kwargs.lammps_natoms - lammps_class = interaction_kwargs.lammps_class - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, node_heads - ] - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs - ) # [n_graphs, num_heads] - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - - if hasattr(self, "pair_repulsion"): - pair_node_energy = self.pair_repulsion_fn( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - if is_lammps: - pair_node_energy = pair_node_energy[: lammps_natoms[0]] - else: - pair_node_energy = torch.zeros_like(node_e0) - - # Interactions - node_es_list = [pair_node_energy] - node_feats_list: List[torch.Tensor] = [] - - for i, (interaction, product, readout) in enumerate( - zip(self.interactions, self.products, self.readouts) - ): - node_attrs_slice = data["node_attrs"] - if is_lammps and i > 0: - node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] - node_feats, sc = interaction( - node_attrs=node_attrs_slice, - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - first_layer=(i == 0), - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - ) - if is_lammps and i == 0: - node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] - node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice - ) - node_feats_list.append(node_feats) - node_es_list.append( - readout(node_feats, node_heads)[num_atoms_arange, node_heads] - ) - - node_feats_out = torch.cat(node_feats_list, dim=-1) - node_inter_es = torch.sum(torch.stack(node_es_list, dim=0), dim=0) - node_inter_es = self.scale_shift(node_inter_es, node_heads) - inter_e = scatter_sum(node_inter_es, data["batch"], dim=-1, dim_size=num_graphs) - - total_energy = e0 + inter_e - node_energy = node_e0.clone().double() + node_inter_es.clone().double() - - forces, virials, stress, hessian, edge_forces = get_outputs( - energy=inter_e, - positions=positions, - displacement=displacement, - vectors=vectors, - cell=cell, - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_hessian=compute_hessian, - compute_edge_forces=compute_edge_forces or compute_atomic_stresses, - ) - - atomic_virials: Optional[torch.Tensor] = None - atomic_stresses: Optional[torch.Tensor] = None - if compute_atomic_stresses and edge_forces is not None: - atomic_virials, atomic_stresses = get_atomic_virials_stresses( - edge_forces=edge_forces, - edge_index=data["edge_index"], - vectors=vectors, - num_atoms=positions.shape[0], - batch=data["batch"], - cell=cell, - ) - return { - "energy": total_energy, - "node_energy": node_energy, - "interaction_energy": inter_e, - "forces": forces, - "edge_forces": edge_forces, - "virials": virials, - "stress": stress, - "atomic_virials": atomic_virials, - "atomic_stresses": atomic_stresses, - "hessian": hessian, - "displacement": displacement, - "node_feats": node_feats_out, - } - - -@compile_mode("script") -class AtomicDipolesMACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - atomic_energies: Optional[ - None - ], # Just here to make it compatible with energy models, MUST be None - radial_type: Optional[str] = "bessel", - radial_MLP: Optional[List[int]] = None, - cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument - ): - super().__init__() - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) - self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - assert atomic_energies is None - - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - radial_type=radial_type, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - - # Interactions and readouts - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - assert ( - len(hidden_irreps) > 1 - ), "To predict dipoles use at least l=1 hidden_irreps" - hidden_irreps_out = str( - hidden_irreps[1] - ) # Select only l=1 vectors for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearDipoleReadoutBlock( - hidden_irreps_out, MLP_irreps, gate, dipole_only=True - ) - ) - else: - self.readouts.append( - LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) - ) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, # pylint: disable=W0613 - compute_force: bool = False, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_edge_forces: bool = False, # pylint: disable=W0613 - compute_atomic_stresses: bool = False, # pylint: disable=W0613 - ) -> Dict[str, Optional[torch.Tensor]]: - assert compute_force is False - assert compute_virials is False - assert compute_stress is False - assert compute_displacement is False - # Setup - data["node_attrs"].requires_grad_(True) - data["positions"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - - # Interactions - dipoles = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, - sc=sc, - node_attrs=data["node_attrs"], - ) - node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] - dipoles.append(node_dipoles) - - # Compute the dipoles - contributions_dipoles = torch.stack( - dipoles, dim=-1 - ) # [n_nodes,3,n_contributions] - atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] - total_dipole = scatter_sum( - src=atomic_dipoles, - index=data["batch"], - dim=0, - dim_size=num_graphs, - ) # [n_graphs,3] - baseline = compute_fixed_charge_dipole( - charges=data["charges"], - positions=data["positions"], - batch=data["batch"], - num_graphs=num_graphs, - ) # [n_graphs,3] - total_dipole = total_dipole + baseline - - output = { - "dipole": total_dipole, - "atomic_dipoles": atomic_dipoles, - } - return output - - -@compile_mode("script") -class EnergyDipolesMACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - atomic_energies: Optional[np.ndarray], - radial_MLP: Optional[List[int]] = None, - cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument - ): - super().__init__() - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) - self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - # Interactions and readouts - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - assert ( - len(hidden_irreps) > 1 - ), "To predict dipoles use at least l=1 hidden_irreps" - hidden_irreps_out = str( - hidden_irreps[:2] - ) # Select scalars and l=1 vectors for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearDipoleReadoutBlock( - hidden_irreps_out, MLP_irreps, gate, dipole_only=False - ) - ) - else: - self.readouts.append( - LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) - ) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_edge_forces: bool = False, # pylint: disable=W0613 - compute_atomic_stresses: bool = False, # pylint: disable=W0613 - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - data["node_attrs"].requires_grad_(True) - data["positions"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - num_atoms_arange = torch.arange(data["positions"].shape[0]) - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=data["positions"].dtype, - device=data["positions"].device, - ) - if compute_virials or compute_stress or compute_displacement: - ( - data["positions"], - data["shifts"], - displacement, - ) = get_symmetric_displacement( - positions=data["positions"], - unit_shifts=data["unit_shifts"], - cell=data["cell"], - edge_index=data["edge_index"], - num_graphs=num_graphs, - batch=data["batch"], - ) - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, data["head"][data["batch"]] - ] - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - - # Interactions - energies = [e0] - node_energies_list = [node_e0] - dipoles = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, - sc=sc, - node_attrs=data["node_attrs"], - ) - node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] - # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - node_energies = node_out[:, 0] - energy = scatter_sum( - src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - energies.append(energy) - node_dipoles = node_out[:, 1:] - dipoles.append(node_dipoles) - - # Compute the energies and dipoles - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - node_energy_contributions = torch.stack(node_energies_list, dim=-1) - node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] - contributions_dipoles = torch.stack( - dipoles, dim=-1 - ) # [n_nodes,3,n_contributions] - atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] - total_dipole = scatter_sum( - src=atomic_dipoles, - index=data["batch"].unsqueeze(-1), - dim=0, - dim_size=num_graphs, - ) # [n_graphs,3] - baseline = compute_fixed_charge_dipole( - charges=data["charges"], - positions=data["positions"], - batch=data["batch"], - num_graphs=num_graphs, - ) # [n_graphs,3] - total_dipole = total_dipole + baseline - - forces, virials, stress, _, _ = get_outputs( - energy=total_energy, - positions=data["positions"], - displacement=displacement, - cell=data["cell"], - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - ) - - output = { - "energy": total_energy, - "node_energy": node_energy, - "contributions": contributions, - "forces": forces, - "virials": virials, - "stress": stress, - "displacement": displacement, - "dipole": total_dipole, - "atomic_dipoles": atomic_dipoles, - } - return output +########################################################################################### +# Implementation of MACE models and other models based E(3)-Equivariant MPNNs +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import numpy as np +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + +from mace.modules.radial import ZBLBasis +from mace.tools.scatter import scatter_sum + +from .blocks import ( + AtomicEnergiesBlock, + EquivariantProductBasisBlock, + InteractionBlock, + LinearDipoleReadoutBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearDipoleReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, + ScaleShiftBlock, +) +from .utils import ( + compute_fixed_charge_dipole, + get_atomic_virials_stresses, + get_edge_vectors_and_lengths, + get_outputs, + get_symmetric_displacement, + prepare_graph, +) + +# pylint: disable=C0302 + + +@compile_mode("script") +class MACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + atomic_energies: np.ndarray, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: Union[int, List[int]], + gate: Optional[Callable], + pair_repulsion: bool = False, + distance_transform: str = "None", + radial_MLP: Optional[List[int]] = None, + radial_type: Optional[str] = "bessel", + heads: Optional[List[str]] = None, + cueq_config: Optional[Dict[str, Any]] = None, + lammps_mliap: Optional[bool] = False, + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + if heads is None: + heads = ["Default"] + self.heads = heads + if isinstance(correlation, int): + correlation = [correlation] * num_interactions + self.lammps_mliap = lammps_mliap + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, + irreps_out=node_feats_irreps, + cueq_config=cueq_config, + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + radial_type=radial_type, + distance_transform=distance_transform, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + if pair_repulsion: + self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff) + self.pair_repulsion = True + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + # Interactions and readout + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + cueq_config=cueq_config, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer for proper E0 + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation[0], + num_elements=num_elements, + use_sc=use_sc_first, + cueq_config=cueq_config, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append( + LinearReadoutBlock( + hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config + ) + ) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + hidden_irreps_out = str( + hidden_irreps[0] + ) # Select only scalars for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + cueq_config=cueq_config, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation[i + 1], + num_elements=num_elements, + use_sc=True, + cueq_config=cueq_config, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearReadoutBlock( + hidden_irreps_out, + (len(heads) * MLP_irreps).simplify(), + gate, + o3.Irreps(f"{len(heads)}x0e"), + len(heads), + cueq_config, + ) + ) + else: + self.readouts.append( + LinearReadoutBlock( + hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config + ) + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_hessian: bool = False, + compute_edge_forces: bool = False, + compute_atomic_stresses: bool = False, + lammps_mliap: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + ctx = prepare_graph( + data, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_displacement=compute_displacement, + lammps_mliap=lammps_mliap, + ) + is_lammps = ctx.is_lammps + num_atoms_arange = ctx.num_atoms_arange + num_graphs = ctx.num_graphs + displacement = ctx.displacement + positions = ctx.positions + vectors = ctx.vectors + lengths = ctx.lengths + cell = ctx.cell + node_heads = ctx.node_heads + interaction_kwargs = ctx.interaction_kwargs + lammps_natoms = interaction_kwargs.lammps_natoms + lammps_class = interaction_kwargs.lammps_class + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, node_heads + ] + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs + ) # [n_graphs, n_heads] + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if hasattr(self, "pair_repulsion"): + pair_node_energy = self.pair_repulsion_fn( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if is_lammps: + pair_node_energy = pair_node_energy[: lammps_natoms[0]] + pair_energy = scatter_sum( + src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + else: + pair_node_energy = torch.zeros_like(node_e0) + pair_energy = torch.zeros_like(e0) + + # Interactions + energies = [e0, pair_energy] + node_energies_list = [node_e0, pair_node_energy] + node_feats_concat: List[torch.Tensor] = [] + + for i, (interaction, product, readout) in enumerate( + zip(self.interactions, self.products, self.readouts) + ): + node_attrs_slice = data["node_attrs"] + if is_lammps and i > 0: + node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] + node_feats, sc = interaction( + node_attrs=node_attrs_slice, + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + first_layer=(i == 0), + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + ) + if is_lammps and i == 0: + node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] + node_feats = product( + node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice + ) + node_feats_concat.append(node_feats) + node_es = readout(node_feats, node_heads)[num_atoms_arange, node_heads] + energy = scatter_sum(node_es, data["batch"], dim=0, dim_size=num_graphs) + energies.append(energy) + node_energies_list.append(node_es) + + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) + node_energy = torch.sum(torch.stack(node_energies_list, dim=-1), dim=-1) + node_feats_out = torch.cat(node_feats_concat, dim=-1) + node_energy = node_e0.double() + pair_node_energy.double() + + forces, virials, stress, hessian, edge_forces = get_outputs( + energy=total_energy, + positions=positions, + displacement=displacement, + vectors=vectors, + cell=cell, + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_hessian=compute_hessian, + compute_edge_forces=compute_edge_forces, + ) + + atomic_virials: Optional[torch.Tensor] = None + atomic_stresses: Optional[torch.Tensor] = None + if compute_atomic_stresses and edge_forces is not None: + atomic_virials, atomic_stresses = get_atomic_virials_stresses( + edge_forces=edge_forces, + edge_index=data["edge_index"], + vectors=vectors, + num_atoms=positions.shape[0], + batch=data["batch"], + cell=cell, + ) + return { + "energy": total_energy, + "node_energy": node_energy, + "contributions": contributions, + "forces": forces, + "edge_forces": edge_forces, + "virials": virials, + "stress": stress, + "atomic_virials": atomic_virials, + "atomic_stresses": atomic_stresses, + "displacement": displacement, + "hessian": hessian, + "node_feats": node_feats_out, + } + + +@compile_mode("script") +class ScaleShiftMACE(MACE): + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_hessian: bool = False, + compute_edge_forces: bool = False, + compute_atomic_stresses: bool = False, + lammps_mliap: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + ctx = prepare_graph( + data, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_displacement=compute_displacement, + lammps_mliap=lammps_mliap, + ) + + is_lammps = ctx.is_lammps + num_atoms_arange = ctx.num_atoms_arange + num_graphs = ctx.num_graphs + displacement = ctx.displacement + positions = ctx.positions + vectors = ctx.vectors + lengths = ctx.lengths + cell = ctx.cell + node_heads = ctx.node_heads + interaction_kwargs = ctx.interaction_kwargs + lammps_natoms = interaction_kwargs.lammps_natoms + lammps_class = interaction_kwargs.lammps_class + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, node_heads + ] + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs + ) # [n_graphs, num_heads] + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + if hasattr(self, "pair_repulsion"): + pair_node_energy = self.pair_repulsion_fn( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if is_lammps: + pair_node_energy = pair_node_energy[: lammps_natoms[0]] + else: + pair_node_energy = torch.zeros_like(node_e0) + + # Interactions + node_es_list = [pair_node_energy] + node_feats_list: List[torch.Tensor] = [] + + for i, (interaction, product, readout) in enumerate( + zip(self.interactions, self.products, self.readouts) + ): + node_attrs_slice = data["node_attrs"] + if is_lammps and i > 0: + node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] + node_feats, sc = interaction( + node_attrs=node_attrs_slice, + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + first_layer=(i == 0), + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + ) + if is_lammps and i == 0: + node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] + node_feats = product( + node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice + ) + node_feats_list.append(node_feats) + node_es_list.append( + readout(node_feats, node_heads)[num_atoms_arange, node_heads] + ) + + node_feats_out = torch.cat(node_feats_list, dim=-1) + node_inter_es = torch.sum(torch.stack(node_es_list, dim=0), dim=0) + node_inter_es = self.scale_shift(node_inter_es, node_heads) + inter_e = scatter_sum(node_inter_es, data["batch"], dim=-1, dim_size=num_graphs) + + total_energy = e0 + inter_e + node_energy = node_e0.clone().double() + node_inter_es.clone().double() + + forces, virials, stress, hessian, edge_forces = get_outputs( + energy=inter_e, + positions=positions, + displacement=displacement, + vectors=vectors, + cell=cell, + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_hessian=compute_hessian, + compute_edge_forces=compute_edge_forces or compute_atomic_stresses, + ) + + atomic_virials: Optional[torch.Tensor] = None + atomic_stresses: Optional[torch.Tensor] = None + if compute_atomic_stresses and edge_forces is not None: + atomic_virials, atomic_stresses = get_atomic_virials_stresses( + edge_forces=edge_forces, + edge_index=data["edge_index"], + vectors=vectors, + num_atoms=positions.shape[0], + batch=data["batch"], + cell=cell, + ) + return { + "energy": total_energy, + "node_energy": node_energy, + "interaction_energy": inter_e, + "forces": forces, + "edge_forces": edge_forces, + "virials": virials, + "stress": stress, + "atomic_virials": atomic_virials, + "atomic_stresses": atomic_stresses, + "hessian": hessian, + "displacement": displacement, + "node_feats": node_feats_out, + } + + +@compile_mode("script") +class AtomicDipolesMACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: int, + gate: Optional[Callable], + atomic_energies: Optional[ + None + ], # Just here to make it compatible with energy models, MUST be None + radial_type: Optional[str] = "bessel", + radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + assert atomic_energies is None + + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + radial_type=radial_type, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + + # Interactions and readouts + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation, + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + assert ( + len(hidden_irreps) > 1 + ), "To predict dipoles use at least l=1 hidden_irreps" + hidden_irreps_out = str( + hidden_irreps[1] + ) # Select only l=1 vectors for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation, + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearDipoleReadoutBlock( + hidden_irreps_out, MLP_irreps, gate, dipole_only=True + ) + ) + else: + self.readouts.append( + LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, # pylint: disable=W0613 + compute_force: bool = False, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_edge_forces: bool = False, # pylint: disable=W0613 + compute_atomic_stresses: bool = False, # pylint: disable=W0613 + ) -> Dict[str, Optional[torch.Tensor]]: + assert compute_force is False + assert compute_virials is False + assert compute_stress is False + assert compute_displacement is False + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + dipoles = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] + dipoles.append(node_dipoles) + + # Compute the dipoles + contributions_dipoles = torch.stack( + dipoles, dim=-1 + ) # [n_nodes,3,n_contributions] + atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] + total_dipole = scatter_sum( + src=atomic_dipoles, + index=data["batch"], + dim=0, + dim_size=num_graphs, + ) # [n_graphs,3] + baseline = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] + total_dipole = total_dipole + baseline + + output = { + "dipole": total_dipole, + "atomic_dipoles": atomic_dipoles, + } + return output + + +@compile_mode("script") +class EnergyDipolesMACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: int, + gate: Optional[Callable], + atomic_energies: Optional[np.ndarray], + radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + # Interactions and readouts + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation, + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + assert ( + len(hidden_irreps) > 1 + ), "To predict dipoles use at least l=1 hidden_irreps" + hidden_irreps_out = str( + hidden_irreps[:2] + ) # Select scalars and l=1 vectors for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation, + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearDipoleReadoutBlock( + hidden_irreps_out, MLP_irreps, gate, dipole_only=False + ) + ) + else: + self.readouts.append( + LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_edge_forces: bool = False, # pylint: disable=W0613 + compute_atomic_stresses: bool = False, # pylint: disable=W0613 + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + num_atoms_arange = torch.arange(data["positions"].shape[0]) + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=data["positions"].dtype, + device=data["positions"].device, + ) + if compute_virials or compute_stress or compute_displacement: + ( + data["positions"], + data["shifts"], + displacement, + ) = get_symmetric_displacement( + positions=data["positions"], + unit_shifts=data["unit_shifts"], + cell=data["cell"], + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["head"][data["batch"]] + ] + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + energies = [e0] + node_energies_list = [node_e0] + dipoles = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] + # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + node_energies = node_out[:, 0] + energy = scatter_sum( + src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + energies.append(energy) + node_dipoles = node_out[:, 1:] + dipoles.append(node_dipoles) + + # Compute the energies and dipoles + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + node_energy_contributions = torch.stack(node_energies_list, dim=-1) + node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] + contributions_dipoles = torch.stack( + dipoles, dim=-1 + ) # [n_nodes,3,n_contributions] + atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] + total_dipole = scatter_sum( + src=atomic_dipoles, + index=data["batch"].unsqueeze(-1), + dim=0, + dim_size=num_graphs, + ) # [n_graphs,3] + baseline = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] + total_dipole = total_dipole + baseline + + forces, virials, stress, _, _ = get_outputs( + energy=total_energy, + positions=data["positions"], + displacement=displacement, + cell=data["cell"], + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + ) + + output = { + "energy": total_energy, + "node_energy": node_energy, + "contributions": contributions, + "forces": forces, + "virials": virials, + "stress": stress, + "displacement": displacement, + "dipole": total_dipole, + "atomic_dipoles": atomic_dipoles, + } + return output diff --git a/mace-bench/3rdparty/mace/mace/modules/radial.py b/mace-bench/3rdparty/mace/mace/modules/radial.py index b78dd4eb69207b9f61645f8c3b6244ebcd9d87cb..ff69b43e5884137ab6a37861eab763714a058155 100644 --- a/mace-bench/3rdparty/mace/mace/modules/radial.py +++ b/mace-bench/3rdparty/mace/mace/modules/radial.py @@ -1,358 +1,358 @@ -########################################################################################### -# Radial basis and cutoff -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging - -import ase -import numpy as np -import torch -from e3nn.util.jit import compile_mode - -from mace.tools.scatter import scatter_sum - - -@compile_mode("script") -class BesselBasis(torch.nn.Module): - """ - Equation (7) - """ - - def __init__(self, r_max: float, num_basis=8, trainable=False): - super().__init__() - - bessel_weights = ( - np.pi - / r_max - * torch.linspace( - start=1.0, - end=num_basis, - steps=num_basis, - dtype=torch.get_default_dtype(), - ) - ) - if trainable: - self.bessel_weights = torch.nn.Parameter(bessel_weights) - else: - self.register_buffer("bessel_weights", bessel_weights) - - self.register_buffer( - "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) - ) - self.register_buffer( - "prefactor", - torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] - numerator = torch.sin(self.bessel_weights * x) # [..., num_basis] - return self.prefactor * (numerator / x) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, " - f"trainable={self.bessel_weights.requires_grad})" - ) - - -@compile_mode("script") -class ChebychevBasis(torch.nn.Module): - """ - Equation (7) - """ - - def __init__(self, r_max: float, num_basis=8): - super().__init__() - self.register_buffer( - "n", - torch.arange(1, num_basis + 1, dtype=torch.get_default_dtype()).unsqueeze( - 0 - ), - ) - self.num_basis = num_basis - self.r_max = r_max - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] - x = x.repeat(1, self.num_basis) - n = self.n.repeat(len(x), 1) - return torch.special.chebyshev_polynomial_t(x, n) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={self.num_basis}," - ) - - -@compile_mode("script") -class GaussianBasis(torch.nn.Module): - """ - Gaussian basis functions - """ - - def __init__(self, r_max: float, num_basis=128, trainable=False): - super().__init__() - gaussian_weights = torch.linspace( - start=0.0, end=r_max, steps=num_basis, dtype=torch.get_default_dtype() - ) - if trainable: - self.gaussian_weights = torch.nn.Parameter( - gaussian_weights, requires_grad=True - ) - else: - self.register_buffer("gaussian_weights", gaussian_weights) - self.coeff = -0.5 / (r_max / (num_basis - 1)) ** 2 - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] - x = x - self.gaussian_weights - return torch.exp(self.coeff * torch.pow(x, 2)) - - -@compile_mode("script") -class PolynomialCutoff(torch.nn.Module): - """Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max. - Equation (8) -- TODO: from where? - """ - - p: torch.Tensor - r_max: torch.Tensor - - def __init__(self, r_max: float, p=6): - super().__init__() - self.register_buffer("p", torch.tensor(p, dtype=torch.int)) - self.register_buffer( - "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.calculate_envelope(x, self.r_max, self.p.to(torch.int)) - - @staticmethod - def calculate_envelope( - x: torch.Tensor, r_max: torch.Tensor, p: torch.Tensor - ) -> torch.Tensor: - r_over_r_max = x / r_max - envelope = ( - 1.0 - - ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p) - + p * (p + 2.0) * torch.pow(r_over_r_max, p + 1) - - (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2) - ) - return envelope * (x < r_max) - - def __repr__(self): - return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" - - -@compile_mode("script") -class ZBLBasis(torch.nn.Module): - """Implementation of the Ziegler-Biersack-Littmark (ZBL) potential - with a polynomial cutoff envelope. - """ - - p: torch.Tensor - - def __init__(self, p=6, trainable=False, **kwargs): - super().__init__() - if "r_max" in kwargs: - logging.warning( - "r_max is deprecated. r_max is determined from the covalent radii." - ) - - # Pre-calculate the p coefficients for the ZBL potential - self.register_buffer( - "c", - torch.tensor( - [0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype() - ), - ) - self.register_buffer("p", torch.tensor(p, dtype=torch.int)) - self.register_buffer( - "covalent_radii", - torch.tensor( - ase.data.covalent_radii, - dtype=torch.get_default_dtype(), - ), - ) - if trainable: - self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True)) - self.a_prefactor = torch.nn.Parameter( - torch.tensor(0.4543, requires_grad=True) - ) - else: - self.register_buffer("a_exp", torch.tensor(0.300)) - self.register_buffer("a_prefactor", torch.tensor(0.4543)) - - def forward( - self, - x: torch.Tensor, - node_attrs: torch.Tensor, - edge_index: torch.Tensor, - atomic_numbers: torch.Tensor, - ) -> torch.Tensor: - sender = edge_index[0] - receiver = edge_index[1] - node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( - -1 - ) - Z_u = node_atomic_numbers[sender] - Z_v = node_atomic_numbers[receiver] - a = ( - self.a_prefactor - * 0.529 - / (torch.pow(Z_u, self.a_exp) + torch.pow(Z_v, self.a_exp)) - ) - r_over_a = x / a - phi = ( - self.c[0] * torch.exp(-3.2 * r_over_a) - + self.c[1] * torch.exp(-0.9423 * r_over_a) - + self.c[2] * torch.exp(-0.4028 * r_over_a) - + self.c[3] * torch.exp(-0.2016 * r_over_a) - ) - v_edges = (14.3996 * Z_u * Z_v) / x * phi - r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] - envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p) - v_edges = 0.5 * v_edges * envelope - V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0)) - return V_ZBL.squeeze(-1) - - def __repr__(self): - return f"{self.__class__.__name__}(c={self.c})" - - -@compile_mode("script") -class AgnesiTransform(torch.nn.Module): - """Agnesi transform - see section on Radial transformations in - ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783). - """ - - def __init__( - self, - q: float = 0.9183, - p: float = 4.5791, - a: float = 1.0805, - trainable=False, - ): - super().__init__() - self.register_buffer("q", torch.tensor(q, dtype=torch.get_default_dtype())) - self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) - self.register_buffer("a", torch.tensor(a, dtype=torch.get_default_dtype())) - self.register_buffer( - "covalent_radii", - torch.tensor( - ase.data.covalent_radii, - dtype=torch.get_default_dtype(), - ), - ) - if trainable: - self.a = torch.nn.Parameter(torch.tensor(1.0805, requires_grad=True)) - self.q = torch.nn.Parameter(torch.tensor(0.9183, requires_grad=True)) - self.p = torch.nn.Parameter(torch.tensor(4.5791, requires_grad=True)) - - def forward( - self, - x: torch.Tensor, - node_attrs: torch.Tensor, - edge_index: torch.Tensor, - atomic_numbers: torch.Tensor, - ) -> torch.Tensor: - sender = edge_index[0] - receiver = edge_index[1] - node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( - -1 - ) - Z_u = node_atomic_numbers[sender] - Z_v = node_atomic_numbers[receiver] - r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) - r_over_r_0 = x / r_0 - return ( - 1 - + ( - self.a - * torch.pow(r_over_r_0, self.q) - / (1 + torch.pow(r_over_r_0, self.q - self.p)) - ) - ).reciprocal_() - - def __repr__(self): - return ( - f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})" - ) - - -@compile_mode("script") -class SoftTransform(torch.nn.Module): - """ - Tanh-based smooth transformation: - T(x) = p1 + (x - p1)*0.5*[1 + tanh(alpha*(x - m))], - which smoothly transitions from ~p1 for x << p1 to ~x for x >> r0. - """ - - def __init__(self, alpha: float = 4.0, trainable=False): - """ - Args: - p1 (float): Lower "clamp" point. - alpha (float): Steepness; if None, defaults to ~6/(r0-p1). - trainable (bool): Whether to make parameters trainable. - """ - super().__init__() - # Initialize parameters - self.register_buffer( - "alpha", torch.tensor(alpha, dtype=torch.get_default_dtype()) - ) - if trainable: - self.alpha = torch.nn.Parameter(self.alpha.clone()) - self.register_buffer( - "covalent_radii", - torch.tensor( - ase.data.covalent_radii, - dtype=torch.get_default_dtype(), - ), - ) - - def compute_r_0( - self, - node_attrs: torch.Tensor, - edge_index: torch.Tensor, - atomic_numbers: torch.Tensor, - ) -> torch.Tensor: - """ - Compute r_0 based on atomic information. - - Args: - node_attrs (torch.Tensor): Node attributes (one-hot encoding of atomic numbers). - edge_index (torch.Tensor): Edge index indicating connections. - atomic_numbers (torch.Tensor): Atomic numbers. - - Returns: - torch.Tensor: r_0 values for each edge. - """ - sender = edge_index[0] - receiver = edge_index[1] - node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( - -1 - ) - Z_u = node_atomic_numbers[sender] - Z_v = node_atomic_numbers[receiver] - r_0: torch.Tensor = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] - return r_0 - - def forward( - self, - x: torch.Tensor, - node_attrs: torch.Tensor, - edge_index: torch.Tensor, - atomic_numbers: torch.Tensor, - ) -> torch.Tensor: - - r_0 = self.compute_r_0(node_attrs, edge_index, atomic_numbers) - p_0 = (3 / 4) * r_0 - p_1 = (4 / 3) * r_0 - m = 0.5 * (p_0 + p_1) - alpha = self.alpha / (p_1 - p_0) - s_x = 0.5 * (1.0 + torch.tanh(alpha * (x - m))) - return p_0 + (x - p_0) * s_x - - def __repr__(self): - return f"{self.__class__.__name__}(alpha={self.alpha.item():.4f})" +########################################################################################### +# Radial basis and cutoff +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging + +import ase +import numpy as np +import torch +from e3nn.util.jit import compile_mode + +from mace.tools.scatter import scatter_sum + + +@compile_mode("script") +class BesselBasis(torch.nn.Module): + """ + Equation (7) + """ + + def __init__(self, r_max: float, num_basis=8, trainable=False): + super().__init__() + + bessel_weights = ( + np.pi + / r_max + * torch.linspace( + start=1.0, + end=num_basis, + steps=num_basis, + dtype=torch.get_default_dtype(), + ) + ) + if trainable: + self.bessel_weights = torch.nn.Parameter(bessel_weights) + else: + self.register_buffer("bessel_weights", bessel_weights) + + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "prefactor", + torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + numerator = torch.sin(self.bessel_weights * x) # [..., num_basis] + return self.prefactor * (numerator / x) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, " + f"trainable={self.bessel_weights.requires_grad})" + ) + + +@compile_mode("script") +class ChebychevBasis(torch.nn.Module): + """ + Equation (7) + """ + + def __init__(self, r_max: float, num_basis=8): + super().__init__() + self.register_buffer( + "n", + torch.arange(1, num_basis + 1, dtype=torch.get_default_dtype()).unsqueeze( + 0 + ), + ) + self.num_basis = num_basis + self.r_max = r_max + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + x = x.repeat(1, self.num_basis) + n = self.n.repeat(len(x), 1) + return torch.special.chebyshev_polynomial_t(x, n) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={self.num_basis}," + ) + + +@compile_mode("script") +class GaussianBasis(torch.nn.Module): + """ + Gaussian basis functions + """ + + def __init__(self, r_max: float, num_basis=128, trainable=False): + super().__init__() + gaussian_weights = torch.linspace( + start=0.0, end=r_max, steps=num_basis, dtype=torch.get_default_dtype() + ) + if trainable: + self.gaussian_weights = torch.nn.Parameter( + gaussian_weights, requires_grad=True + ) + else: + self.register_buffer("gaussian_weights", gaussian_weights) + self.coeff = -0.5 / (r_max / (num_basis - 1)) ** 2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + x = x - self.gaussian_weights + return torch.exp(self.coeff * torch.pow(x, 2)) + + +@compile_mode("script") +class PolynomialCutoff(torch.nn.Module): + """Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max. + Equation (8) -- TODO: from where? + """ + + p: torch.Tensor + r_max: torch.Tensor + + def __init__(self, r_max: float, p=6): + super().__init__() + self.register_buffer("p", torch.tensor(p, dtype=torch.int)) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.calculate_envelope(x, self.r_max, self.p.to(torch.int)) + + @staticmethod + def calculate_envelope( + x: torch.Tensor, r_max: torch.Tensor, p: torch.Tensor + ) -> torch.Tensor: + r_over_r_max = x / r_max + envelope = ( + 1.0 + - ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p) + + p * (p + 2.0) * torch.pow(r_over_r_max, p + 1) + - (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2) + ) + return envelope * (x < r_max) + + def __repr__(self): + return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" + + +@compile_mode("script") +class ZBLBasis(torch.nn.Module): + """Implementation of the Ziegler-Biersack-Littmark (ZBL) potential + with a polynomial cutoff envelope. + """ + + p: torch.Tensor + + def __init__(self, p=6, trainable=False, **kwargs): + super().__init__() + if "r_max" in kwargs: + logging.warning( + "r_max is deprecated. r_max is determined from the covalent radii." + ) + + # Pre-calculate the p coefficients for the ZBL potential + self.register_buffer( + "c", + torch.tensor( + [0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype() + ), + ) + self.register_buffer("p", torch.tensor(p, dtype=torch.int)) + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + if trainable: + self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True)) + self.a_prefactor = torch.nn.Parameter( + torch.tensor(0.4543, requires_grad=True) + ) + else: + self.register_buffer("a_exp", torch.tensor(0.300)) + self.register_buffer("a_prefactor", torch.tensor(0.4543)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + a = ( + self.a_prefactor + * 0.529 + / (torch.pow(Z_u, self.a_exp) + torch.pow(Z_v, self.a_exp)) + ) + r_over_a = x / a + phi = ( + self.c[0] * torch.exp(-3.2 * r_over_a) + + self.c[1] * torch.exp(-0.9423 * r_over_a) + + self.c[2] * torch.exp(-0.4028 * r_over_a) + + self.c[3] * torch.exp(-0.2016 * r_over_a) + ) + v_edges = (14.3996 * Z_u * Z_v) / x * phi + r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] + envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p) + v_edges = 0.5 * v_edges * envelope + V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0)) + return V_ZBL.squeeze(-1) + + def __repr__(self): + return f"{self.__class__.__name__}(c={self.c})" + + +@compile_mode("script") +class AgnesiTransform(torch.nn.Module): + """Agnesi transform - see section on Radial transformations in + ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783). + """ + + def __init__( + self, + q: float = 0.9183, + p: float = 4.5791, + a: float = 1.0805, + trainable=False, + ): + super().__init__() + self.register_buffer("q", torch.tensor(q, dtype=torch.get_default_dtype())) + self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer("a", torch.tensor(a, dtype=torch.get_default_dtype())) + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + if trainable: + self.a = torch.nn.Parameter(torch.tensor(1.0805, requires_grad=True)) + self.q = torch.nn.Parameter(torch.tensor(0.9183, requires_grad=True)) + self.p = torch.nn.Parameter(torch.tensor(4.5791, requires_grad=True)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) + r_over_r_0 = x / r_0 + return ( + 1 + + ( + self.a + * torch.pow(r_over_r_0, self.q) + / (1 + torch.pow(r_over_r_0, self.q - self.p)) + ) + ).reciprocal_() + + def __repr__(self): + return ( + f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})" + ) + + +@compile_mode("script") +class SoftTransform(torch.nn.Module): + """ + Tanh-based smooth transformation: + T(x) = p1 + (x - p1)*0.5*[1 + tanh(alpha*(x - m))], + which smoothly transitions from ~p1 for x << p1 to ~x for x >> r0. + """ + + def __init__(self, alpha: float = 4.0, trainable=False): + """ + Args: + p1 (float): Lower "clamp" point. + alpha (float): Steepness; if None, defaults to ~6/(r0-p1). + trainable (bool): Whether to make parameters trainable. + """ + super().__init__() + # Initialize parameters + self.register_buffer( + "alpha", torch.tensor(alpha, dtype=torch.get_default_dtype()) + ) + if trainable: + self.alpha = torch.nn.Parameter(self.alpha.clone()) + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + + def compute_r_0( + self, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + """ + Compute r_0 based on atomic information. + + Args: + node_attrs (torch.Tensor): Node attributes (one-hot encoding of atomic numbers). + edge_index (torch.Tensor): Edge index indicating connections. + atomic_numbers (torch.Tensor): Atomic numbers. + + Returns: + torch.Tensor: r_0 values for each edge. + """ + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + r_0: torch.Tensor = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] + return r_0 + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + + r_0 = self.compute_r_0(node_attrs, edge_index, atomic_numbers) + p_0 = (3 / 4) * r_0 + p_1 = (4 / 3) * r_0 + m = 0.5 * (p_0 + p_1) + alpha = self.alpha / (p_1 - p_0) + s_x = 0.5 * (1.0 + torch.tanh(alpha * (x - m))) + return p_0 + (x - p_0) * s_x + + def __repr__(self): + return f"{self.__class__.__name__}(alpha={self.alpha.item():.4f})" diff --git a/mace-bench/3rdparty/mace/mace/modules/symmetric_contraction.py b/mace-bench/3rdparty/mace/mace/modules/symmetric_contraction.py index 577713cab70a501169c03babcde792c3b043abd4..9db75da0255d4d44d0f4b05bd713139875147356 100644 --- a/mace-bench/3rdparty/mace/mace/modules/symmetric_contraction.py +++ b/mace-bench/3rdparty/mace/mace/modules/symmetric_contraction.py @@ -1,233 +1,233 @@ -########################################################################################### -# Implementation of the symmetric contraction algorithm presented in the MACE paper -# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11) -# Authors: Ilyes Batatia -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from typing import Dict, Optional, Union - -import opt_einsum_fx -import torch -import torch.fx -from e3nn import o3 -from e3nn.util.codegen import CodeGenMixin -from e3nn.util.jit import compile_mode - -from mace.tools.cg import U_matrix_real - -BATCH_EXAMPLE = 10 -ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"] - - -@compile_mode("script") -class SymmetricContraction(CodeGenMixin, torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - irreps_out: o3.Irreps, - correlation: Union[int, Dict[str, int]], - irrep_normalization: str = "component", - path_normalization: str = "element", - internal_weights: Optional[bool] = None, - shared_weights: Optional[bool] = None, - num_elements: Optional[int] = None, - ) -> None: - super().__init__() - - if irrep_normalization is None: - irrep_normalization = "component" - - if path_normalization is None: - path_normalization = "element" - - assert irrep_normalization in ["component", "norm", "none"] - assert path_normalization in ["element", "path", "none"] - - self.irreps_in = o3.Irreps(irreps_in) - self.irreps_out = o3.Irreps(irreps_out) - - del irreps_in, irreps_out - - if not isinstance(correlation, tuple): - corr = correlation - correlation = {} - for irrep_out in self.irreps_out: - correlation[irrep_out] = corr - - assert shared_weights or not internal_weights - - if internal_weights is None: - internal_weights = True - - self.internal_weights = internal_weights - self.shared_weights = shared_weights - - del internal_weights, shared_weights - - self.contractions = torch.nn.ModuleList() - for irrep_out in self.irreps_out: - self.contractions.append( - Contraction( - irreps_in=self.irreps_in, - irrep_out=o3.Irreps(str(irrep_out.ir)), - correlation=correlation[irrep_out], - internal_weights=self.internal_weights, - num_elements=num_elements, - weights=self.shared_weights, - ) - ) - - def forward(self, x: torch.Tensor, y: torch.Tensor): - outs = [contraction(x, y) for contraction in self.contractions] - return torch.cat(outs, dim=-1) - - -@compile_mode("script") -class Contraction(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - irrep_out: o3.Irreps, - correlation: int, - internal_weights: bool = True, - num_elements: Optional[int] = None, - weights: Optional[torch.Tensor] = None, - ) -> None: - super().__init__() - - self.num_features = irreps_in.count((0, 1)) - self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in]) - self.correlation = correlation - dtype = torch.get_default_dtype() - for nu in range(1, correlation + 1): - U_matrix = U_matrix_real( - irreps_in=self.coupling_irreps, - irreps_out=irrep_out, - correlation=nu, - dtype=dtype, - )[-1] - self.register_buffer(f"U_matrix_{nu}", U_matrix) - - # Tensor contraction equations - self.contractions_weighting = torch.nn.ModuleList() - self.contractions_features = torch.nn.ModuleList() - - # Create weight for product basis - self.weights = torch.nn.ParameterList([]) - - for i in range(correlation, 0, -1): - # Shapes definying - num_params = self.U_tensors(i).size()[-1] - num_equivariance = 2 * irrep_out.lmax + 1 - num_ell = self.U_tensors(i).size()[-2] - - if i == correlation: - parse_subscript_main = ( - [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] - + ["ik,ekc,bci,be -> bc"] - + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] - ) - graph_module_main = torch.fx.symbolic_trace( - lambda x, y, w, z: torch.einsum( - "".join(parse_subscript_main), x, y, w, z - ) - ) - - # Optimizing the contractions - self.graph_opt_main = opt_einsum_fx.optimize_einsums_full( - model=graph_module_main, - example_inputs=( - torch.randn( - [num_equivariance] + [num_ell] * i + [num_params] - ).squeeze(0), - torch.randn((num_elements, num_params, self.num_features)), - torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), - torch.randn((BATCH_EXAMPLE, num_elements)), - ), - ) - # Parameters for the product basis - w = torch.nn.Parameter( - torch.randn((num_elements, num_params, self.num_features)) - / num_params - ) - self.weights_max = w - else: - # Generate optimized contractions equations - parse_subscript_weighting = ( - [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] - + ["k,ekc,be->bc"] - + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] - ) - parse_subscript_features = ( - ["bc"] - + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] - + ["i,bci->bc"] - + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] - ) - - # Symbolic tracing of contractions - graph_module_weighting = torch.fx.symbolic_trace( - lambda x, y, z: torch.einsum( - "".join(parse_subscript_weighting), x, y, z - ) - ) - graph_module_features = torch.fx.symbolic_trace( - lambda x, y: torch.einsum("".join(parse_subscript_features), x, y) - ) - - # Optimizing the contractions - graph_opt_weighting = opt_einsum_fx.optimize_einsums_full( - model=graph_module_weighting, - example_inputs=( - torch.randn( - [num_equivariance] + [num_ell] * i + [num_params] - ).squeeze(0), - torch.randn((num_elements, num_params, self.num_features)), - torch.randn((BATCH_EXAMPLE, num_elements)), - ), - ) - graph_opt_features = opt_einsum_fx.optimize_einsums_full( - model=graph_module_features, - example_inputs=( - torch.randn( - [BATCH_EXAMPLE, self.num_features, num_equivariance] - + [num_ell] * i - ).squeeze(2), - torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), - ), - ) - self.contractions_weighting.append(graph_opt_weighting) - self.contractions_features.append(graph_opt_features) - # Parameters for the product basis - w = torch.nn.Parameter( - torch.randn((num_elements, num_params, self.num_features)) - / num_params - ) - self.weights.append(w) - if not internal_weights: - self.weights = weights[:-1] - self.weights_max = weights[-1] - - def forward(self, x: torch.Tensor, y: torch.Tensor): - out = self.graph_opt_main( - self.U_tensors(self.correlation), - self.weights_max, - x, - y, - ) - for i, (weight, contract_weights, contract_features) in enumerate( - zip(self.weights, self.contractions_weighting, self.contractions_features) - ): - c_tensor = contract_weights( - self.U_tensors(self.correlation - i - 1), - weight, - y, - ) - c_tensor = c_tensor + out - out = contract_features(c_tensor, x) - - return out.view(out.shape[0], -1) - - def U_tensors(self, nu: int): - return dict(self.named_buffers())[f"U_matrix_{nu}"] +########################################################################################### +# Implementation of the symmetric contraction algorithm presented in the MACE paper +# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11) +# Authors: Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Dict, Optional, Union + +import opt_einsum_fx +import torch +import torch.fx +from e3nn import o3 +from e3nn.util.codegen import CodeGenMixin +from e3nn.util.jit import compile_mode + +from mace.tools.cg import U_matrix_real + +BATCH_EXAMPLE = 10 +ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"] + + +@compile_mode("script") +class SymmetricContraction(CodeGenMixin, torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: Union[int, Dict[str, int]], + irrep_normalization: str = "component", + path_normalization: str = "element", + internal_weights: Optional[bool] = None, + shared_weights: Optional[bool] = None, + num_elements: Optional[int] = None, + ) -> None: + super().__init__() + + if irrep_normalization is None: + irrep_normalization = "component" + + if path_normalization is None: + path_normalization = "element" + + assert irrep_normalization in ["component", "norm", "none"] + assert path_normalization in ["element", "path", "none"] + + self.irreps_in = o3.Irreps(irreps_in) + self.irreps_out = o3.Irreps(irreps_out) + + del irreps_in, irreps_out + + if not isinstance(correlation, tuple): + corr = correlation + correlation = {} + for irrep_out in self.irreps_out: + correlation[irrep_out] = corr + + assert shared_weights or not internal_weights + + if internal_weights is None: + internal_weights = True + + self.internal_weights = internal_weights + self.shared_weights = shared_weights + + del internal_weights, shared_weights + + self.contractions = torch.nn.ModuleList() + for irrep_out in self.irreps_out: + self.contractions.append( + Contraction( + irreps_in=self.irreps_in, + irrep_out=o3.Irreps(str(irrep_out.ir)), + correlation=correlation[irrep_out], + internal_weights=self.internal_weights, + num_elements=num_elements, + weights=self.shared_weights, + ) + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + outs = [contraction(x, y) for contraction in self.contractions] + return torch.cat(outs, dim=-1) + + +@compile_mode("script") +class Contraction(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps, + correlation: int, + internal_weights: bool = True, + num_elements: Optional[int] = None, + weights: Optional[torch.Tensor] = None, + ) -> None: + super().__init__() + + self.num_features = irreps_in.count((0, 1)) + self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in]) + self.correlation = correlation + dtype = torch.get_default_dtype() + for nu in range(1, correlation + 1): + U_matrix = U_matrix_real( + irreps_in=self.coupling_irreps, + irreps_out=irrep_out, + correlation=nu, + dtype=dtype, + )[-1] + self.register_buffer(f"U_matrix_{nu}", U_matrix) + + # Tensor contraction equations + self.contractions_weighting = torch.nn.ModuleList() + self.contractions_features = torch.nn.ModuleList() + + # Create weight for product basis + self.weights = torch.nn.ParameterList([]) + + for i in range(correlation, 0, -1): + # Shapes definying + num_params = self.U_tensors(i).size()[-1] + num_equivariance = 2 * irrep_out.lmax + 1 + num_ell = self.U_tensors(i).size()[-2] + + if i == correlation: + parse_subscript_main = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + + ["ik,ekc,bci,be -> bc"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + ) + graph_module_main = torch.fx.symbolic_trace( + lambda x, y, w, z: torch.einsum( + "".join(parse_subscript_main), x, y, w, z + ) + ) + + # Optimizing the contractions + self.graph_opt_main = opt_einsum_fx.optimize_einsums_full( + model=graph_module_main, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + torch.randn((num_elements, num_params, self.num_features)), + torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), + torch.randn((BATCH_EXAMPLE, num_elements)), + ), + ) + # Parameters for the product basis + w = torch.nn.Parameter( + torch.randn((num_elements, num_params, self.num_features)) + / num_params + ) + self.weights_max = w + else: + # Generate optimized contractions equations + parse_subscript_weighting = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + + ["k,ekc,be->bc"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + ) + parse_subscript_features = ( + ["bc"] + + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + + ["i,bci->bc"] + + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + ) + + # Symbolic tracing of contractions + graph_module_weighting = torch.fx.symbolic_trace( + lambda x, y, z: torch.einsum( + "".join(parse_subscript_weighting), x, y, z + ) + ) + graph_module_features = torch.fx.symbolic_trace( + lambda x, y: torch.einsum("".join(parse_subscript_features), x, y) + ) + + # Optimizing the contractions + graph_opt_weighting = opt_einsum_fx.optimize_einsums_full( + model=graph_module_weighting, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + torch.randn((num_elements, num_params, self.num_features)), + torch.randn((BATCH_EXAMPLE, num_elements)), + ), + ) + graph_opt_features = opt_einsum_fx.optimize_einsums_full( + model=graph_module_features, + example_inputs=( + torch.randn( + [BATCH_EXAMPLE, self.num_features, num_equivariance] + + [num_ell] * i + ).squeeze(2), + torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), + ), + ) + self.contractions_weighting.append(graph_opt_weighting) + self.contractions_features.append(graph_opt_features) + # Parameters for the product basis + w = torch.nn.Parameter( + torch.randn((num_elements, num_params, self.num_features)) + / num_params + ) + self.weights.append(w) + if not internal_weights: + self.weights = weights[:-1] + self.weights_max = weights[-1] + + def forward(self, x: torch.Tensor, y: torch.Tensor): + out = self.graph_opt_main( + self.U_tensors(self.correlation), + self.weights_max, + x, + y, + ) + for i, (weight, contract_weights, contract_features) in enumerate( + zip(self.weights, self.contractions_weighting, self.contractions_features) + ): + c_tensor = contract_weights( + self.U_tensors(self.correlation - i - 1), + weight, + y, + ) + c_tensor = c_tensor + out + out = contract_features(c_tensor, x) + + return out.view(out.shape[0], -1) + + def U_tensors(self, nu: int): + return dict(self.named_buffers())[f"U_matrix_{nu}"] diff --git a/mace-bench/3rdparty/mace/mace/modules/utils.py b/mace-bench/3rdparty/mace/mace/modules/utils.py index 59da11882aabc7fb51927cb8e2f47afd9626f568..6a5a8e04fbde9180f442d39221f8cf63a8c398b6 100644 --- a/mace-bench/3rdparty/mace/mace/modules/utils.py +++ b/mace-bench/3rdparty/mace/mace/modules/utils.py @@ -1,582 +1,582 @@ -########################################################################################### -# Utilities -# Authors: Ilyes Batatia, Gregor Simm and David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging -from typing import Dict, List, NamedTuple, Optional, Tuple - -import numpy as np -import torch -import torch.utils.data -from scipy.constants import c, e - -from mace.tools import to_numpy -from mace.tools.scatter import scatter_mean, scatter_std, scatter_sum -from mace.tools.torch_geometric.batch import Batch - -from .blocks import AtomicEnergiesBlock - - -def compute_forces( - energy: torch.Tensor, positions: torch.Tensor, training: bool = True -) -> torch.Tensor: - grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] - gradient = torch.autograd.grad( - outputs=[energy], # [n_graphs, ] - inputs=[positions], # [n_nodes, 3] - grad_outputs=grad_outputs, - retain_graph=training, # Make sure the graph is not destroyed during training - create_graph=training, # Create graph for second derivative - allow_unused=True, # For complete dissociation turn to true - )[ - 0 - ] # [n_nodes, 3] - if gradient is None: - return torch.zeros_like(positions) - return -1 * gradient - - -def compute_forces_virials( - energy: torch.Tensor, - positions: torch.Tensor, - displacement: torch.Tensor, - cell: torch.Tensor, - training: bool = True, - compute_stress: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] - forces, virials = torch.autograd.grad( - outputs=[energy], # [n_graphs, ] - inputs=[positions, displacement], # [n_nodes, 3] - grad_outputs=grad_outputs, - retain_graph=training, # Make sure the graph is not destroyed during training - create_graph=training, # Create graph for second derivative - allow_unused=True, - ) - stress = torch.zeros_like(displacement) - if compute_stress and virials is not None: - cell = cell.view(-1, 3, 3) - volume = torch.linalg.det(cell).abs().unsqueeze(-1) - stress = virials / volume.view(-1, 1, 1) - stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) - if forces is None: - forces = torch.zeros_like(positions) - if virials is None: - virials = torch.zeros((1, 3, 3)) - - return -1 * forces, -1 * virials, stress - - -def get_symmetric_displacement( - positions: torch.Tensor, - unit_shifts: torch.Tensor, - cell: Optional[torch.Tensor], - edge_index: torch.Tensor, - num_graphs: int, - batch: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if cell is None: - cell = torch.zeros( - num_graphs * 3, - 3, - dtype=positions.dtype, - device=positions.device, - ) - sender = edge_index[0] - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=positions.dtype, - device=positions.device, - ) - displacement.requires_grad_(True) - symmetric_displacement = 0.5 * ( - displacement + displacement.transpose(-1, -2) - ) # From https://github.com/mir-group/nequip - positions = positions + torch.einsum( - "be,bec->bc", positions, symmetric_displacement[batch] - ) - cell = cell.view(-1, 3, 3) - cell = cell + torch.matmul(cell, symmetric_displacement) - shifts = torch.einsum( - "be,bec->bc", - unit_shifts, - cell[batch[sender]], - ) - return positions, shifts, displacement - - -@torch.jit.unused -def compute_hessians_vmap( - forces: torch.Tensor, - positions: torch.Tensor, -) -> torch.Tensor: - forces_flatten = forces.view(-1) - num_elements = forces_flatten.shape[0] - - def get_vjp(v): - return torch.autograd.grad( - -1 * forces_flatten, - positions, - v, - retain_graph=True, - create_graph=False, - allow_unused=False, - ) - - I_N = torch.eye(num_elements).to(forces.device) - try: - chunk_size = 1 if num_elements < 64 else 16 - gradient = torch.vmap(get_vjp, in_dims=0, out_dims=0, chunk_size=chunk_size)( - I_N - )[0] - except RuntimeError: - gradient = compute_hessians_loop(forces, positions) - if gradient is None: - return torch.zeros((positions.shape[0], forces.shape[0], 3, 3)) - return gradient - - -@torch.jit.unused -def compute_hessians_loop( - forces: torch.Tensor, - positions: torch.Tensor, -) -> torch.Tensor: - hessian = [] - for grad_elem in forces.view(-1): - hess_row = torch.autograd.grad( - outputs=[-1 * grad_elem], - inputs=[positions], - grad_outputs=torch.ones_like(grad_elem), - retain_graph=True, - create_graph=False, - allow_unused=False, - )[0] - hess_row = hess_row.detach() # this makes it very slow? but needs less memory - if hess_row is None: - hessian.append(torch.zeros_like(positions)) - else: - hessian.append(hess_row) - hessian = torch.stack(hessian) - return hessian - - -def get_outputs( - energy: torch.Tensor, - positions: torch.Tensor, - cell: torch.Tensor, - displacement: Optional[torch.Tensor], - vectors: Optional[torch.Tensor] = None, - training: bool = False, - compute_force: bool = True, - compute_virials: bool = True, - compute_stress: bool = True, - compute_hessian: bool = False, - compute_edge_forces: bool = False, -) -> Tuple[ - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], -]: - if (compute_virials or compute_stress) and displacement is not None: - forces, virials, stress = compute_forces_virials( - energy=energy, - positions=positions, - displacement=displacement, - cell=cell, - compute_stress=compute_stress, - training=(training or compute_hessian or compute_edge_forces), - ) - elif compute_force: - forces, virials, stress = ( - compute_forces( - energy=energy, - positions=positions, - training=(training or compute_hessian or compute_edge_forces), - ), - None, - None, - ) - else: - forces, virials, stress = (None, None, None) - if compute_hessian: - assert forces is not None, "Forces must be computed to get the hessian" - hessian = compute_hessians_vmap(forces, positions) - else: - hessian = None - if compute_edge_forces and vectors is not None: - edge_forces = compute_forces( - energy=energy, - positions=vectors, - training=(training or compute_hessian), - ) - if edge_forces is not None: - edge_forces = -1 * edge_forces # Match LAMMPS sign convention - else: - edge_forces = None - return forces, virials, stress, hessian, edge_forces - - -def get_atomic_virials_stresses( - edge_forces: torch.Tensor, # [n_edges, 3] - edge_index: torch.Tensor, # [2, n_edges] - vectors: torch.Tensor, # [n_edges, 3] - num_atoms: int, - batch: torch.Tensor, - cell: torch.Tensor, # [n_graphs, 3, 3] -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Compute atomic virials and optionally atomic stresses from edge forces and vectors. - From pobo95 PR #528. - Returns: - Tuple of: - - Atomic virials [num_atoms, 3, 3] - - Atomic stresses [num_atoms, 3, 3] (None if not computed) - """ - edge_virial = torch.einsum("zi,zj->zij", edge_forces, vectors) - atom_virial_sender = scatter_sum( - src=edge_virial, index=edge_index[0], dim=0, dim_size=num_atoms - ) - atom_virial_receiver = scatter_sum( - src=edge_virial, index=edge_index[1], dim=0, dim_size=num_atoms - ) - atom_virial = (atom_virial_sender + atom_virial_receiver) / 2 - atom_virial = (atom_virial + atom_virial.transpose(-1, -2)) / 2 - atom_stress = None - cell = cell.view(-1, 3, 3) - volume = torch.linalg.det(cell).abs().unsqueeze(-1) - atom_volume = volume[batch].view(-1, 1, 1) - atom_stress = atom_virial / atom_volume - atom_stress = torch.where( - torch.abs(atom_stress) < 1e10, atom_stress, torch.zeros_like(atom_stress) - ) - return -1 * atom_virial, atom_stress - - -def get_edge_vectors_and_lengths( - positions: torch.Tensor, # [n_nodes, 3] - edge_index: torch.Tensor, # [2, n_edges] - shifts: torch.Tensor, # [n_edges, 3] - normalize: bool = False, - eps: float = 1e-9, -) -> Tuple[torch.Tensor, torch.Tensor]: - sender = edge_index[0] - receiver = edge_index[1] - vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3] - lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] - if normalize: - vectors_normed = vectors / (lengths + eps) - return vectors_normed, lengths - - return vectors, lengths - - -def _check_non_zero(std): - if np.any(std == 0): - logging.warning( - "Standard deviation of the scaling is zero, Changing to no scaling" - ) - std[std == 0] = 1 - return std - - -def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): - out = [] - out.append(x[:, :num_features]) - for i in range(1, num_layers): - out.append( - x[ - :, - i - * (l_max + 1) ** 2 - * num_features : (i * (l_max + 1) ** 2 + 1) - * num_features, - ] - ) - return torch.cat(out, dim=-1) - - -def compute_mean_std_atomic_inter_energy( - data_loader: torch.utils.data.DataLoader, - atomic_energies: np.ndarray, -) -> Tuple[float, float]: - atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - - avg_atom_inter_es_list = [] - head_list = [] - - for batch in data_loader: - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), batch.head] - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - avg_atom_inter_es_list.append( - (batch.energy - graph_e0s) / graph_sizes - ) # {[n_graphs], } - head_list.append(batch.head) - - avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] - head = torch.cat(head_list, dim=0) # [total_n_graphs] - # mean = to_numpy(torch.mean(avg_atom_inter_es)).item() - # std = to_numpy(torch.std(avg_atom_inter_es)).item() - mean = to_numpy(scatter_mean(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) - std = to_numpy(scatter_std(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) - std = _check_non_zero(std) - - return mean, std - - -def _compute_mean_std_atomic_inter_energy( - batch: Batch, - atomic_energies_fn: AtomicEnergiesBlock, -) -> Tuple[torch.Tensor, torch.Tensor]: - head = batch.head - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), head] - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energies = (batch.energy - graph_e0s) / graph_sizes - return atom_energies - - -def compute_mean_rms_energy_forces( - data_loader: torch.utils.data.DataLoader, - atomic_energies: np.ndarray, -) -> Tuple[float, float]: - atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - - atom_energy_list = [] - forces_list = [] - head_list = [] - head_batch = [] - - for batch in data_loader: - head = batch.head - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), head] - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energy_list.append( - (batch.energy - graph_e0s) / graph_sizes - ) # {[n_graphs], } - forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } - head_list.append(head) - head_batch.append(head[batch.batch]) - - atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] - forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - head = torch.cat(head_list, dim=0) # [total_n_graphs] - head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] - - # mean = to_numpy(torch.mean(atom_energies)).item() - # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() - mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) - rms = to_numpy( - torch.sqrt( - scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) - ) - ) - rms = _check_non_zero(rms) - - return mean, rms - - -def _compute_mean_rms_energy_forces( - batch: Batch, - atomic_energies_fn: AtomicEnergiesBlock, -) -> Tuple[torch.Tensor, torch.Tensor]: - head = batch.head - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), head] - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } - forces = batch.forces # {[n_graphs*n_atoms,3], } - - return atom_energies, forces - - -def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: - num_neighbors = [] - for batch in data_loader: - _, receivers = batch.edge_index - _, counts = torch.unique(receivers, return_counts=True) - num_neighbors.append(counts) - - avg_num_neighbors = torch.mean( - torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) - ) - return to_numpy(avg_num_neighbors).item() - - -def compute_statistics( - data_loader: torch.utils.data.DataLoader, - atomic_energies: np.ndarray, -) -> Tuple[float, float, float, float]: - atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - - atom_energy_list = [] - forces_list = [] - num_neighbors = [] - head_list = [] - head_batch = [] - - for batch in data_loader: - head = batch.head - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), head] - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energy_list.append( - (batch.energy - graph_e0s) / graph_sizes - ) # {[n_graphs], } - forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } - head_list.append(head) # {[n_graphs], } - head_batch.append(head[batch.batch]) - _, receivers = batch.edge_index - _, counts = torch.unique(receivers, return_counts=True) - num_neighbors.append(counts) - - atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] - forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - head = torch.cat(head_list, dim=0) # [total_n_graphs] - head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] - - # mean = to_numpy(torch.mean(atom_energies)).item() - mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) - rms = to_numpy( - torch.sqrt( - scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) - ) - ) - - avg_num_neighbors = torch.mean( - torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) - ) - - return to_numpy(avg_num_neighbors).item(), mean, rms - - -def compute_rms_dipoles( - data_loader: torch.utils.data.DataLoader, -) -> Tuple[float, float]: - dipoles_list = [] - for batch in data_loader: - dipoles_list.append(batch.dipole) # {[n_graphs,3], } - - dipoles = torch.cat(dipoles_list, dim=0) # {[total_n_graphs,3], } - rms = to_numpy(torch.sqrt(torch.mean(torch.square(dipoles)))).item() - rms = _check_non_zero(rms) - return rms - - -def compute_fixed_charge_dipole( - charges: torch.Tensor, - positions: torch.Tensor, - batch: torch.Tensor, - num_graphs: int, -) -> torch.Tensor: - mu = positions * charges.unsqueeze(-1) / (1e-11 / c / e) # [N_atoms,3] - return scatter_sum( - src=mu, index=batch.unsqueeze(-1), dim=0, dim_size=num_graphs - ) # [N_graphs,3] - - -class InteractionKwargs(NamedTuple): - lammps_class: Optional[torch.Tensor] - lammps_natoms: Tuple[int, int] = (0, 0) - - -class GraphContext(NamedTuple): - is_lammps: bool - num_graphs: int - num_atoms_arange: torch.Tensor - displacement: Optional[torch.Tensor] - positions: torch.Tensor - vectors: torch.Tensor - lengths: torch.Tensor - cell: torch.Tensor - node_heads: torch.Tensor - interaction_kwargs: InteractionKwargs - - -def prepare_graph( - data: Dict[str, torch.Tensor], - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - lammps_mliap: bool = False, -) -> GraphContext: - if torch.jit.is_scripting(): - lammps_mliap = False - - node_heads = ( - data["head"][data["batch"]] - if "head" in data - else torch.zeros_like(data["batch"]) - ) - - if lammps_mliap: - n_real, n_total = data["natoms"][0], data["natoms"][1] - num_graphs = 2 - num_atoms_arange = torch.arange(n_real, device=data["node_attrs"].device) - displacement = None - positions = torch.zeros( - (int(n_real), 3), - dtype=data["vectors"].dtype, - device=data["vectors"].device, - ) - cell = torch.zeros( - (num_graphs, 3, 3), - dtype=data["vectors"].dtype, - device=data["vectors"].device, - ) - vectors = data["vectors"].requires_grad_(True) - lengths = torch.linalg.vector_norm(vectors, dim=1, keepdim=True) - ikw = InteractionKwargs(data["lammps_class"], (n_real, n_total)) - else: - data["positions"].requires_grad_(True) - positions = data["positions"] - cell = data["cell"] - num_atoms_arange = torch.arange(positions.shape[0], device=positions.device) - num_graphs = int(data["ptr"].numel() - 1) - displacement = torch.zeros( - (num_graphs, 3, 3), dtype=positions.dtype, device=positions.device - ) - if compute_virials or compute_stress or compute_displacement: - p, s, displacement = get_symmetric_displacement( - positions=positions, - unit_shifts=data["unit_shifts"], - cell=cell, - edge_index=data["edge_index"], - num_graphs=num_graphs, - batch=data["batch"], - ) - data["positions"], data["shifts"] = p, s - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - ikw = InteractionKwargs(None, (0, 0)) - - return GraphContext( - is_lammps=lammps_mliap, - num_graphs=num_graphs, - num_atoms_arange=num_atoms_arange, - displacement=displacement, - positions=positions, - vectors=vectors, - lengths=lengths, - cell=cell, - node_heads=node_heads, - interaction_kwargs=ikw, - ) +########################################################################################### +# Utilities +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from typing import Dict, List, NamedTuple, Optional, Tuple + +import numpy as np +import torch +import torch.utils.data +from scipy.constants import c, e + +from mace.tools import to_numpy +from mace.tools.scatter import scatter_mean, scatter_std, scatter_sum +from mace.tools.torch_geometric.batch import Batch + +from .blocks import AtomicEnergiesBlock + + +def compute_forces( + energy: torch.Tensor, positions: torch.Tensor, training: bool = True +) -> torch.Tensor: + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] + gradient = torch.autograd.grad( + outputs=[energy], # [n_graphs, ] + inputs=[positions], # [n_nodes, 3] + grad_outputs=grad_outputs, + retain_graph=training, # Make sure the graph is not destroyed during training + create_graph=training, # Create graph for second derivative + allow_unused=True, # For complete dissociation turn to true + )[ + 0 + ] # [n_nodes, 3] + if gradient is None: + return torch.zeros_like(positions) + return -1 * gradient + + +def compute_forces_virials( + energy: torch.Tensor, + positions: torch.Tensor, + displacement: torch.Tensor, + cell: torch.Tensor, + training: bool = True, + compute_stress: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] + forces, virials = torch.autograd.grad( + outputs=[energy], # [n_graphs, ] + inputs=[positions, displacement], # [n_nodes, 3] + grad_outputs=grad_outputs, + retain_graph=training, # Make sure the graph is not destroyed during training + create_graph=training, # Create graph for second derivative + allow_unused=True, + ) + stress = torch.zeros_like(displacement) + if compute_stress and virials is not None: + cell = cell.view(-1, 3, 3) + volume = torch.linalg.det(cell).abs().unsqueeze(-1) + stress = virials / volume.view(-1, 1, 1) + stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) + if forces is None: + forces = torch.zeros_like(positions) + if virials is None: + virials = torch.zeros((1, 3, 3)) + + return -1 * forces, -1 * virials, stress + + +def get_symmetric_displacement( + positions: torch.Tensor, + unit_shifts: torch.Tensor, + cell: Optional[torch.Tensor], + edge_index: torch.Tensor, + num_graphs: int, + batch: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if cell is None: + cell = torch.zeros( + num_graphs * 3, + 3, + dtype=positions.dtype, + device=positions.device, + ) + sender = edge_index[0] + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=positions.dtype, + device=positions.device, + ) + displacement.requires_grad_(True) + symmetric_displacement = 0.5 * ( + displacement + displacement.transpose(-1, -2) + ) # From https://github.com/mir-group/nequip + positions = positions + torch.einsum( + "be,bec->bc", positions, symmetric_displacement[batch] + ) + cell = cell.view(-1, 3, 3) + cell = cell + torch.matmul(cell, symmetric_displacement) + shifts = torch.einsum( + "be,bec->bc", + unit_shifts, + cell[batch[sender]], + ) + return positions, shifts, displacement + + +@torch.jit.unused +def compute_hessians_vmap( + forces: torch.Tensor, + positions: torch.Tensor, +) -> torch.Tensor: + forces_flatten = forces.view(-1) + num_elements = forces_flatten.shape[0] + + def get_vjp(v): + return torch.autograd.grad( + -1 * forces_flatten, + positions, + v, + retain_graph=True, + create_graph=False, + allow_unused=False, + ) + + I_N = torch.eye(num_elements).to(forces.device) + try: + chunk_size = 1 if num_elements < 64 else 16 + gradient = torch.vmap(get_vjp, in_dims=0, out_dims=0, chunk_size=chunk_size)( + I_N + )[0] + except RuntimeError: + gradient = compute_hessians_loop(forces, positions) + if gradient is None: + return torch.zeros((positions.shape[0], forces.shape[0], 3, 3)) + return gradient + + +@torch.jit.unused +def compute_hessians_loop( + forces: torch.Tensor, + positions: torch.Tensor, +) -> torch.Tensor: + hessian = [] + for grad_elem in forces.view(-1): + hess_row = torch.autograd.grad( + outputs=[-1 * grad_elem], + inputs=[positions], + grad_outputs=torch.ones_like(grad_elem), + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + hess_row = hess_row.detach() # this makes it very slow? but needs less memory + if hess_row is None: + hessian.append(torch.zeros_like(positions)) + else: + hessian.append(hess_row) + hessian = torch.stack(hessian) + return hessian + + +def get_outputs( + energy: torch.Tensor, + positions: torch.Tensor, + cell: torch.Tensor, + displacement: Optional[torch.Tensor], + vectors: Optional[torch.Tensor] = None, + training: bool = False, + compute_force: bool = True, + compute_virials: bool = True, + compute_stress: bool = True, + compute_hessian: bool = False, + compute_edge_forces: bool = False, +) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + if (compute_virials or compute_stress) and displacement is not None: + forces, virials, stress = compute_forces_virials( + energy=energy, + positions=positions, + displacement=displacement, + cell=cell, + compute_stress=compute_stress, + training=(training or compute_hessian or compute_edge_forces), + ) + elif compute_force: + forces, virials, stress = ( + compute_forces( + energy=energy, + positions=positions, + training=(training or compute_hessian or compute_edge_forces), + ), + None, + None, + ) + else: + forces, virials, stress = (None, None, None) + if compute_hessian: + assert forces is not None, "Forces must be computed to get the hessian" + hessian = compute_hessians_vmap(forces, positions) + else: + hessian = None + if compute_edge_forces and vectors is not None: + edge_forces = compute_forces( + energy=energy, + positions=vectors, + training=(training or compute_hessian), + ) + if edge_forces is not None: + edge_forces = -1 * edge_forces # Match LAMMPS sign convention + else: + edge_forces = None + return forces, virials, stress, hessian, edge_forces + + +def get_atomic_virials_stresses( + edge_forces: torch.Tensor, # [n_edges, 3] + edge_index: torch.Tensor, # [2, n_edges] + vectors: torch.Tensor, # [n_edges, 3] + num_atoms: int, + batch: torch.Tensor, + cell: torch.Tensor, # [n_graphs, 3, 3] +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Compute atomic virials and optionally atomic stresses from edge forces and vectors. + From pobo95 PR #528. + Returns: + Tuple of: + - Atomic virials [num_atoms, 3, 3] + - Atomic stresses [num_atoms, 3, 3] (None if not computed) + """ + edge_virial = torch.einsum("zi,zj->zij", edge_forces, vectors) + atom_virial_sender = scatter_sum( + src=edge_virial, index=edge_index[0], dim=0, dim_size=num_atoms + ) + atom_virial_receiver = scatter_sum( + src=edge_virial, index=edge_index[1], dim=0, dim_size=num_atoms + ) + atom_virial = (atom_virial_sender + atom_virial_receiver) / 2 + atom_virial = (atom_virial + atom_virial.transpose(-1, -2)) / 2 + atom_stress = None + cell = cell.view(-1, 3, 3) + volume = torch.linalg.det(cell).abs().unsqueeze(-1) + atom_volume = volume[batch].view(-1, 1, 1) + atom_stress = atom_virial / atom_volume + atom_stress = torch.where( + torch.abs(atom_stress) < 1e10, atom_stress, torch.zeros_like(atom_stress) + ) + return -1 * atom_virial, atom_stress + + +def get_edge_vectors_and_lengths( + positions: torch.Tensor, # [n_nodes, 3] + edge_index: torch.Tensor, # [2, n_edges] + shifts: torch.Tensor, # [n_edges, 3] + normalize: bool = False, + eps: float = 1e-9, +) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3] + lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] + if normalize: + vectors_normed = vectors / (lengths + eps) + return vectors_normed, lengths + + return vectors, lengths + + +def _check_non_zero(std): + if np.any(std == 0): + logging.warning( + "Standard deviation of the scaling is zero, Changing to no scaling" + ) + std[std == 0] = 1 + return std + + +def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): + out = [] + out.append(x[:, :num_features]) + for i in range(1, num_layers): + out.append( + x[ + :, + i + * (l_max + 1) ** 2 + * num_features : (i * (l_max + 1) ** 2 + 1) + * num_features, + ] + ) + return torch.cat(out, dim=-1) + + +def compute_mean_std_atomic_inter_energy( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + avg_atom_inter_es_list = [] + head_list = [] + + for batch in data_loader: + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), batch.head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + avg_atom_inter_es_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + head_list.append(batch.head) + + avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] + head = torch.cat(head_list, dim=0) # [total_n_graphs] + # mean = to_numpy(torch.mean(avg_atom_inter_es)).item() + # std = to_numpy(torch.std(avg_atom_inter_es)).item() + mean = to_numpy(scatter_mean(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) + std = to_numpy(scatter_std(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) + std = _check_non_zero(std) + + return mean, std + + +def _compute_mean_std_atomic_inter_energy( + batch: Batch, + atomic_energies_fn: AtomicEnergiesBlock, +) -> Tuple[torch.Tensor, torch.Tensor]: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energies = (batch.energy - graph_e0s) / graph_sizes + return atom_energies + + +def compute_mean_rms_energy_forces( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + atom_energy_list = [] + forces_list = [] + head_list = [] + head_batch = [] + + for batch in data_loader: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energy_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + head_list.append(head) + head_batch.append(head[batch.batch]) + + atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] + forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } + head = torch.cat(head_list, dim=0) # [total_n_graphs] + head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] + + # mean = to_numpy(torch.mean(atom_energies)).item() + # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) + rms = to_numpy( + torch.sqrt( + scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) + ) + ) + rms = _check_non_zero(rms) + + return mean, rms + + +def _compute_mean_rms_energy_forces( + batch: Batch, + atomic_energies_fn: AtomicEnergiesBlock, +) -> Tuple[torch.Tensor, torch.Tensor]: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } + forces = batch.forces # {[n_graphs*n_atoms,3], } + + return atom_energies, forces + + +def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: + num_neighbors = [] + for batch in data_loader: + _, receivers = batch.edge_index + _, counts = torch.unique(receivers, return_counts=True) + num_neighbors.append(counts) + + avg_num_neighbors = torch.mean( + torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) + ) + return to_numpy(avg_num_neighbors).item() + + +def compute_statistics( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float, float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + atom_energy_list = [] + forces_list = [] + num_neighbors = [] + head_list = [] + head_batch = [] + + for batch in data_loader: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energy_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + head_list.append(head) # {[n_graphs], } + head_batch.append(head[batch.batch]) + _, receivers = batch.edge_index + _, counts = torch.unique(receivers, return_counts=True) + num_neighbors.append(counts) + + atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] + forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } + head = torch.cat(head_list, dim=0) # [total_n_graphs] + head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] + + # mean = to_numpy(torch.mean(atom_energies)).item() + mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) + rms = to_numpy( + torch.sqrt( + scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) + ) + ) + + avg_num_neighbors = torch.mean( + torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) + ) + + return to_numpy(avg_num_neighbors).item(), mean, rms + + +def compute_rms_dipoles( + data_loader: torch.utils.data.DataLoader, +) -> Tuple[float, float]: + dipoles_list = [] + for batch in data_loader: + dipoles_list.append(batch.dipole) # {[n_graphs,3], } + + dipoles = torch.cat(dipoles_list, dim=0) # {[total_n_graphs,3], } + rms = to_numpy(torch.sqrt(torch.mean(torch.square(dipoles)))).item() + rms = _check_non_zero(rms) + return rms + + +def compute_fixed_charge_dipole( + charges: torch.Tensor, + positions: torch.Tensor, + batch: torch.Tensor, + num_graphs: int, +) -> torch.Tensor: + mu = positions * charges.unsqueeze(-1) / (1e-11 / c / e) # [N_atoms,3] + return scatter_sum( + src=mu, index=batch.unsqueeze(-1), dim=0, dim_size=num_graphs + ) # [N_graphs,3] + + +class InteractionKwargs(NamedTuple): + lammps_class: Optional[torch.Tensor] + lammps_natoms: Tuple[int, int] = (0, 0) + + +class GraphContext(NamedTuple): + is_lammps: bool + num_graphs: int + num_atoms_arange: torch.Tensor + displacement: Optional[torch.Tensor] + positions: torch.Tensor + vectors: torch.Tensor + lengths: torch.Tensor + cell: torch.Tensor + node_heads: torch.Tensor + interaction_kwargs: InteractionKwargs + + +def prepare_graph( + data: Dict[str, torch.Tensor], + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + lammps_mliap: bool = False, +) -> GraphContext: + if torch.jit.is_scripting(): + lammps_mliap = False + + node_heads = ( + data["head"][data["batch"]] + if "head" in data + else torch.zeros_like(data["batch"]) + ) + + if lammps_mliap: + n_real, n_total = data["natoms"][0], data["natoms"][1] + num_graphs = 2 + num_atoms_arange = torch.arange(n_real, device=data["node_attrs"].device) + displacement = None + positions = torch.zeros( + (int(n_real), 3), + dtype=data["vectors"].dtype, + device=data["vectors"].device, + ) + cell = torch.zeros( + (num_graphs, 3, 3), + dtype=data["vectors"].dtype, + device=data["vectors"].device, + ) + vectors = data["vectors"].requires_grad_(True) + lengths = torch.linalg.vector_norm(vectors, dim=1, keepdim=True) + ikw = InteractionKwargs(data["lammps_class"], (n_real, n_total)) + else: + data["positions"].requires_grad_(True) + positions = data["positions"] + cell = data["cell"] + num_atoms_arange = torch.arange(positions.shape[0], device=positions.device) + num_graphs = int(data["ptr"].numel() - 1) + displacement = torch.zeros( + (num_graphs, 3, 3), dtype=positions.dtype, device=positions.device + ) + if compute_virials or compute_stress or compute_displacement: + p, s, displacement = get_symmetric_displacement( + positions=positions, + unit_shifts=data["unit_shifts"], + cell=cell, + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + data["positions"], data["shifts"] = p, s + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + ikw = InteractionKwargs(None, (0, 0)) + + return GraphContext( + is_lammps=lammps_mliap, + num_graphs=num_graphs, + num_atoms_arange=num_atoms_arange, + displacement=displacement, + positions=positions, + vectors=vectors, + lengths=lengths, + cell=cell, + node_heads=node_heads, + interaction_kwargs=ikw, + ) diff --git a/mace-bench/3rdparty/mace/mace/modules/wrapper_ops.py b/mace-bench/3rdparty/mace/mace/modules/wrapper_ops.py index ee03ef73e8997cc6fb320a0c83923e2a36d3780d..ca05219c775c439b37193a998b48ba490dbd1baf 100644 --- a/mace-bench/3rdparty/mace/mace/modules/wrapper_ops.py +++ b/mace-bench/3rdparty/mace/mace/modules/wrapper_ops.py @@ -1,192 +1,192 @@ -""" -Wrapper class for o3.Linear that optionally uses cuet.Linear -""" - -import dataclasses -from typing import List, Optional - -import torch -from e3nn import o3 - -from mace.modules.symmetric_contraction import SymmetricContraction -from mace.tools.cg import O3_e3nn - -try: - import cuequivariance as cue - import cuequivariance_torch as cuet - - CUET_AVAILABLE = True -except ImportError: - CUET_AVAILABLE = False - - -@dataclasses.dataclass -class CuEquivarianceConfig: - """Configuration for cuequivariance acceleration""" - - enabled: bool = False - layout: str = "mul_ir" # One of: mul_ir, ir_mul - layout_str: str = "mul_ir" - group: str = "O3" - optimize_all: bool = False # Set to True to enable all optimizations - optimize_linear: bool = False - optimize_channelwise: bool = False - optimize_symmetric: bool = False - optimize_fctp: bool = False - - def __post_init__(self): - if self.enabled and CUET_AVAILABLE: - self.layout_str = self.layout - self.layout = getattr(cue, self.layout) - self.group = ( - O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group) - ) - if not CUET_AVAILABLE: - self.enabled = False - - -class Linear: - """Returns either a cuet.Linear or o3.Linear based on config""" - - def __new__( - cls, - irreps_in: o3.Irreps, - irreps_out: o3.Irreps, - shared_weights: bool = True, - internal_weights: bool = True, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - if ( - CUET_AVAILABLE - and cueq_config is not None - and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_linear) - ): - return cuet.Linear( - cue.Irreps(cueq_config.group, irreps_in), - cue.Irreps(cueq_config.group, irreps_out), - layout=cueq_config.layout, - shared_weights=shared_weights, - use_fallback=True, - ) - - return o3.Linear( - irreps_in, - irreps_out, - shared_weights=shared_weights, - internal_weights=internal_weights, - ) - - -class TensorProduct: - """Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct""" - - def __new__( - cls, - irreps_in1: o3.Irreps, - irreps_in2: o3.Irreps, - irreps_out: o3.Irreps, - instructions: Optional[List] = None, - shared_weights: bool = False, - internal_weights: bool = False, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - if ( - CUET_AVAILABLE - and cueq_config is not None - and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_channelwise) - ): - return cuet.ChannelWiseTensorProduct( - cue.Irreps(cueq_config.group, irreps_in1), - cue.Irreps(cueq_config.group, irreps_in2), - cue.Irreps(cueq_config.group, irreps_out), - layout=cueq_config.layout, - shared_weights=shared_weights, - internal_weights=internal_weights, - dtype=torch.get_default_dtype(), - math_dtype=torch.get_default_dtype(), - ) - - return o3.TensorProduct( - irreps_in1, - irreps_in2, - irreps_out, - instructions=instructions, - shared_weights=shared_weights, - internal_weights=internal_weights, - ) - - -class FullyConnectedTensorProduct: - """Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct""" - - def __new__( - cls, - irreps_in1: o3.Irreps, - irreps_in2: o3.Irreps, - irreps_out: o3.Irreps, - shared_weights: bool = True, - internal_weights: bool = True, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - if ( - CUET_AVAILABLE - and cueq_config is not None - and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_fctp) - ): - return cuet.FullyConnectedTensorProduct( - cue.Irreps(cueq_config.group, irreps_in1), - cue.Irreps(cueq_config.group, irreps_in2), - cue.Irreps(cueq_config.group, irreps_out), - layout=cueq_config.layout, - shared_weights=shared_weights, - internal_weights=internal_weights, - use_fallback=True, - ) - - return o3.FullyConnectedTensorProduct( - irreps_in1, - irreps_in2, - irreps_out, - shared_weights=shared_weights, - internal_weights=internal_weights, - ) - - -class SymmetricContractionWrapper: - """Wrapper around SymmetricContraction/cuet.SymmetricContraction""" - - def __new__( - cls, - irreps_in: o3.Irreps, - irreps_out: o3.Irreps, - correlation: int, - num_elements: Optional[int] = None, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - if ( - CUET_AVAILABLE - and cueq_config is not None - and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_symmetric) - ): - return cuet.SymmetricContraction( - cue.Irreps(cueq_config.group, irreps_in), - cue.Irreps(cueq_config.group, irreps_out), - layout_in=cue.ir_mul, - layout_out=cueq_config.layout, - contraction_degree=correlation, - num_elements=num_elements, - original_mace=True, - dtype=torch.get_default_dtype(), - math_dtype=torch.get_default_dtype(), - ) - - return SymmetricContraction( - irreps_in=irreps_in, - irreps_out=irreps_out, - correlation=correlation, - num_elements=num_elements, - ) +""" +Wrapper class for o3.Linear that optionally uses cuet.Linear +""" + +import dataclasses +from typing import List, Optional + +import torch +from e3nn import o3 + +from mace.modules.symmetric_contraction import SymmetricContraction +from mace.tools.cg import O3_e3nn + +try: + import cuequivariance as cue + import cuequivariance_torch as cuet + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + + +@dataclasses.dataclass +class CuEquivarianceConfig: + """Configuration for cuequivariance acceleration""" + + enabled: bool = False + layout: str = "mul_ir" # One of: mul_ir, ir_mul + layout_str: str = "mul_ir" + group: str = "O3" + optimize_all: bool = False # Set to True to enable all optimizations + optimize_linear: bool = False + optimize_channelwise: bool = False + optimize_symmetric: bool = False + optimize_fctp: bool = False + + def __post_init__(self): + if self.enabled and CUET_AVAILABLE: + self.layout_str = self.layout + self.layout = getattr(cue, self.layout) + self.group = ( + O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group) + ) + if not CUET_AVAILABLE: + self.enabled = False + + +class Linear: + """Returns either a cuet.Linear or o3.Linear based on config""" + + def __new__( + cls, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + shared_weights: bool = True, + internal_weights: bool = True, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_linear) + ): + return cuet.Linear( + cue.Irreps(cueq_config.group, irreps_in), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + use_fallback=True, + ) + + return o3.Linear( + irreps_in, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class TensorProduct: + """Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct""" + + def __new__( + cls, + irreps_in1: o3.Irreps, + irreps_in2: o3.Irreps, + irreps_out: o3.Irreps, + instructions: Optional[List] = None, + shared_weights: bool = False, + internal_weights: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_channelwise) + ): + return cuet.ChannelWiseTensorProduct( + cue.Irreps(cueq_config.group, irreps_in1), + cue.Irreps(cueq_config.group, irreps_in2), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + internal_weights=internal_weights, + dtype=torch.get_default_dtype(), + math_dtype=torch.get_default_dtype(), + ) + + return o3.TensorProduct( + irreps_in1, + irreps_in2, + irreps_out, + instructions=instructions, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class FullyConnectedTensorProduct: + """Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct""" + + def __new__( + cls, + irreps_in1: o3.Irreps, + irreps_in2: o3.Irreps, + irreps_out: o3.Irreps, + shared_weights: bool = True, + internal_weights: bool = True, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_fctp) + ): + return cuet.FullyConnectedTensorProduct( + cue.Irreps(cueq_config.group, irreps_in1), + cue.Irreps(cueq_config.group, irreps_in2), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + internal_weights=internal_weights, + use_fallback=True, + ) + + return o3.FullyConnectedTensorProduct( + irreps_in1, + irreps_in2, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class SymmetricContractionWrapper: + """Wrapper around SymmetricContraction/cuet.SymmetricContraction""" + + def __new__( + cls, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: int, + num_elements: Optional[int] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_symmetric) + ): + return cuet.SymmetricContraction( + cue.Irreps(cueq_config.group, irreps_in), + cue.Irreps(cueq_config.group, irreps_out), + layout_in=cue.ir_mul, + layout_out=cueq_config.layout, + contraction_degree=correlation, + num_elements=num_elements, + original_mace=True, + dtype=torch.get_default_dtype(), + math_dtype=torch.get_default_dtype(), + ) + + return SymmetricContraction( + irreps_in=irreps_in, + irreps_out=irreps_out, + correlation=correlation, + num_elements=num_elements, + ) diff --git a/mace-bench/3rdparty/mace/mace/tools/__init__.py b/mace-bench/3rdparty/mace/mace/tools/__init__.py index 5dda7f32fcae41d59e2df62d28f9139dde8b6136..0fa6b0765befce9fab668c2483f7db946a2bfa9d 100644 --- a/mace-bench/3rdparty/mace/mace/tools/__init__.py +++ b/mace-bench/3rdparty/mace/mace/tools/__init__.py @@ -1,73 +1,73 @@ -from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser -from .arg_parser_tools import check_args -from .cg import U_matrix_real -from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState -from .default_keys import DefaultKeys -from .finetuning_utils import load_foundations, load_foundations_elements -from .torch_tools import ( - TensorDict, - cartesian_to_spherical, - count_parameters, - init_device, - init_wandb, - set_default_dtype, - set_seeds, - spherical_to_cartesian, - to_numpy, - to_one_hot, - voigt_to_matrix, -) -from .train import SWAContainer, evaluate, train -from .utils import ( - AtomicNumberTable, - MetricsLogger, - atomic_numbers_to_indices, - compute_c, - compute_mae, - compute_q95, - compute_rel_mae, - compute_rel_rmse, - compute_rmse, - get_atomic_number_table_from_zs, - get_tag, - setup_logger, -) - -__all__ = [ - "TensorDict", - "AtomicNumberTable", - "atomic_numbers_to_indices", - "to_numpy", - "to_one_hot", - "build_default_arg_parser", - "check_args", - "DefaultKeys", - "set_seeds", - "init_device", - "setup_logger", - "get_tag", - "count_parameters", - "MetricsLogger", - "get_atomic_number_table_from_zs", - "train", - "evaluate", - "SWAContainer", - "CheckpointHandler", - "CheckpointIO", - "CheckpointState", - "set_default_dtype", - "compute_mae", - "compute_rel_mae", - "compute_rmse", - "compute_rel_rmse", - "compute_q95", - "compute_c", - "U_matrix_real", - "spherical_to_cartesian", - "cartesian_to_spherical", - "voigt_to_matrix", - "init_wandb", - "load_foundations", - "load_foundations_elements", - "build_preprocess_arg_parser", -] +from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser +from .arg_parser_tools import check_args +from .cg import U_matrix_real +from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState +from .default_keys import DefaultKeys +from .finetuning_utils import load_foundations, load_foundations_elements +from .torch_tools import ( + TensorDict, + cartesian_to_spherical, + count_parameters, + init_device, + init_wandb, + set_default_dtype, + set_seeds, + spherical_to_cartesian, + to_numpy, + to_one_hot, + voigt_to_matrix, +) +from .train import SWAContainer, evaluate, train +from .utils import ( + AtomicNumberTable, + MetricsLogger, + atomic_numbers_to_indices, + compute_c, + compute_mae, + compute_q95, + compute_rel_mae, + compute_rel_rmse, + compute_rmse, + get_atomic_number_table_from_zs, + get_tag, + setup_logger, +) + +__all__ = [ + "TensorDict", + "AtomicNumberTable", + "atomic_numbers_to_indices", + "to_numpy", + "to_one_hot", + "build_default_arg_parser", + "check_args", + "DefaultKeys", + "set_seeds", + "init_device", + "setup_logger", + "get_tag", + "count_parameters", + "MetricsLogger", + "get_atomic_number_table_from_zs", + "train", + "evaluate", + "SWAContainer", + "CheckpointHandler", + "CheckpointIO", + "CheckpointState", + "set_default_dtype", + "compute_mae", + "compute_rel_mae", + "compute_rmse", + "compute_rel_rmse", + "compute_q95", + "compute_c", + "U_matrix_real", + "spherical_to_cartesian", + "cartesian_to_spherical", + "voigt_to_matrix", + "init_wandb", + "load_foundations", + "load_foundations_elements", + "build_preprocess_arg_parser", +] diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 6c34fb5d37ef0a4f1790cd0a65bdb7075a3659c9..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 72b04585fe98c5f2b8301ad85c96ef3f9ee217d7..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser.cpython-310.pyc deleted file mode 100644 index f9591a2a4a813f6ad62488e98974052c57245440..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser.cpython-313.pyc deleted file mode 100644 index 0bbd380f8ea2ced420d809a547baec9d78a4c0b1..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser_tools.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser_tools.cpython-310.pyc deleted file mode 100644 index dedac14ba795f34b778dbb67a1fbe4a526b4e244..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser_tools.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser_tools.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser_tools.cpython-313.pyc deleted file mode 100644 index f9c31366e70dd87292883a23e3582c91354dfef9..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser_tools.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/cg.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/cg.cpython-310.pyc deleted file mode 100644 index aed1c45ed82fd9f734ccab1b96ebf3f3222c85d4..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/cg.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/cg.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/cg.cpython-313.pyc deleted file mode 100644 index 184aec7381b123982bcdab1fd8ef25dff82fa92c..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/cg.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/checkpoint.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/checkpoint.cpython-310.pyc deleted file mode 100644 index 9fa2510d2ce5c86a6e72612729122590aeee668b..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/checkpoint.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/checkpoint.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/checkpoint.cpython-313.pyc deleted file mode 100644 index ca87854b4283d311643682e2bae93aa307f2ae9f..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/checkpoint.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/compile.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/compile.cpython-310.pyc deleted file mode 100644 index a8f96e3d47417f729f402b73d35e551664c28d9f..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/compile.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/compile.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/compile.cpython-313.pyc deleted file mode 100644 index 2f99a2c6731871b7b9e13b2418a019ea0a6fcbae..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/compile.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-310.pyc deleted file mode 100644 index c8990a318593c50705118bb82efa41760871d75f..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-313.pyc deleted file mode 100644 index 03c5c91a09c8d8f7f7930a7dad03a8a7abfe2fd5..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-310.pyc deleted file mode 100644 index afc0e6aa06ed82d55148984ef09aac468911ad00..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-313.pyc deleted file mode 100644 index 29e44bbc813cc85a3bebeabf784cd008ea461caa..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/scatter.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/scatter.cpython-310.pyc deleted file mode 100644 index 5d13180610c5cc85299e17ceae5762f9b51d0239..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/scatter.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/scatter.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/scatter.cpython-313.pyc deleted file mode 100644 index f63f78506c71f4207196428a82b9be8c4c596579..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/scatter.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/scripts_utils.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/scripts_utils.cpython-310.pyc deleted file mode 100644 index 1433c0d728f429493f9745ad5534ef607673ca2f..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/scripts_utils.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/scripts_utils.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/scripts_utils.cpython-313.pyc deleted file mode 100644 index 3a2e0c207675385de131d252027055bfe1a80b11..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/scripts_utils.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/torch_tools.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/torch_tools.cpython-310.pyc deleted file mode 100644 index 8020bd538330553a86befe7e659f0c6835a65c58..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/torch_tools.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/torch_tools.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/torch_tools.cpython-313.pyc deleted file mode 100644 index 3b9401a0b99955e34e19c7c670e3ed94adf502b8..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/torch_tools.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/train.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/train.cpython-310.pyc deleted file mode 100644 index 3145cb7c5a67cfa16b7c0923563c1bba4e1b9a40..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/train.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/train.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/train.cpython-313.pyc deleted file mode 100644 index 49065c314431d44c137190081987cfa160bd2720..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/train.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/utils.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 08bb2211ef145af2d5006262ce0565b88373c7ae..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/utils.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/utils.cpython-313.pyc deleted file mode 100644 index ec8c28f4740a70032a5e0cf7bf3661908309aa0f..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/__pycache__/utils.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/arg_parser.py b/mace-bench/3rdparty/mace/mace/tools/arg_parser.py index 193b1c320672d65bd7feade8150a77be91f70bf4..c9d537e6865303f8bc8ddcadf47732dbc5ca3f83 100644 --- a/mace-bench/3rdparty/mace/mace/tools/arg_parser.py +++ b/mace-bench/3rdparty/mace/mace/tools/arg_parser.py @@ -1,971 +1,971 @@ -########################################################################################### -# Parsing functionalities -# Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import argparse -import os -from typing import Optional - -from .default_keys import DefaultKeys - - -def build_default_arg_parser() -> argparse.ArgumentParser: - try: - import configargparse - - parser = configargparse.ArgumentParser( - config_file_parser_class=configargparse.YAMLConfigFileParser, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add( - "--config", - type=str, - is_config_file=True, - help="config file to aggregate options", - ) - except ImportError: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # Name and seed - parser.add_argument("--name", help="experiment name", required=True) - parser.add_argument("--seed", help="random seed", type=int, default=123) - - # Directories - parser.add_argument( - "--work_dir", - help="set directory for all files and folders", - type=str, - default=".", - ) - parser.add_argument( - "--log_dir", help="directory for log files", type=str, default=None - ) - parser.add_argument( - "--model_dir", help="directory for final model", type=str, default=None - ) - parser.add_argument( - "--checkpoints_dir", - help="directory for checkpoint files", - type=str, - default=None, - ) - parser.add_argument( - "--results_dir", help="directory for results", type=str, default=None - ) - parser.add_argument( - "--downloads_dir", help="directory for downloads", type=str, default=None - ) - - # Device and logging - parser.add_argument( - "--device", - help="select device", - type=str, - choices=["cpu", "cuda", "mps", "xpu"], - default="cpu", - ) - parser.add_argument( - "--default_dtype", - help="set default dtype", - type=str, - choices=["float32", "float64"], - default="float64", - ) - parser.add_argument( - "--distributed", - help="train in multi-GPU data parallel mode", - action="store_true", - default=False, - ) - parser.add_argument("--log_level", help="log level", type=str, default="INFO") - - parser.add_argument( - "--plot", - help="Plot results of training", - type=str2bool, - default=True, - ) - - parser.add_argument( - "--plot_frequency", - help="Set plotting frequency: '0' for only at the end or an integer N to plot every N epochs.", - type=int, - default="0", - ) - - parser.add_argument( - "--error_table", - help="Type of error table produced at the end of the training", - type=str, - choices=[ - "PerAtomRMSE", - "TotalRMSE", - "PerAtomRMSEstressvirials", - "PerAtomMAEstressvirials", - "PerAtomMAE", - "TotalMAE", - "DipoleRMSE", - "DipoleMAE", - "EnergyDipoleRMSE", - ], - default="PerAtomRMSE", - ) - - # Model - parser.add_argument( - "--model", - help="model type", - default="MACE", - choices=[ - "BOTNet", - "MACE", - "ScaleShiftMACE", - "ScaleShiftBOTNet", - "AtomicDipolesMACE", - "EnergyDipolesMACE", - ], - ) - parser.add_argument( - "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 - ) - parser.add_argument( - "--radial_type", - help="type of radial basis functions", - type=str, - default="bessel", - choices=["bessel", "gaussian", "chebyshev"], - ) - parser.add_argument( - "--num_radial_basis", - help="number of radial basis functions", - type=int, - default=8, - ) - parser.add_argument( - "--num_cutoff_basis", - help="number of basis functions for smooth cutoff", - type=int, - default=5, - ) - parser.add_argument( - "--pair_repulsion", - help="use pair repulsion term with ZBL potential", - action="store_true", - default=False, - ) - parser.add_argument( - "--distance_transform", - help="use distance transform for radial basis functions", - default="None", - choices=["None", "Agnesi", "Soft"], - ) - parser.add_argument( - "--interaction", - help="name of interaction block", - type=str, - default="RealAgnosticResidualInteractionBlock", - choices=[ - "RealAgnosticResidualInteractionBlock", - "RealAgnosticAttResidualInteractionBlock", - "RealAgnosticInteractionBlock", - "RealAgnosticDensityInteractionBlock", - "RealAgnosticDensityResidualInteractionBlock", - ], - ) - parser.add_argument( - "--interaction_first", - help="name of interaction block", - type=str, - default="RealAgnosticResidualInteractionBlock", - choices=[ - "RealAgnosticResidualInteractionBlock", - "RealAgnosticInteractionBlock", - "RealAgnosticDensityInteractionBlock", - "RealAgnosticDensityResidualInteractionBlock", - ], - ) - parser.add_argument( - "--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3 - ) - parser.add_argument( - "--correlation", help="correlation order at each layer", type=int, default=3 - ) - parser.add_argument( - "--num_interactions", help="number of interactions", type=int, default=2 - ) - parser.add_argument( - "--MLP_irreps", - help="hidden irreps of the MLP in last readout", - type=str, - default="16x0e", - ) - parser.add_argument( - "--radial_MLP", - help="width of the radial MLP", - type=str, - default="[64, 64, 64]", - ) - parser.add_argument( - "--hidden_irreps", - help="irreps for hidden node states", - type=str, - default=None, - ) - # add option to specify irreps by channel number and max L - parser.add_argument( - "--num_channels", - help="number of embedding channels", - type=int, - default=None, - ) - parser.add_argument( - "--max_L", - help="max L equivariance of the message", - type=int, - default=None, - ) - parser.add_argument( - "--gate", - help="non linearity for last readout", - type=str, - default="silu", - choices=["silu", "tanh", "abs", "None"], - ) - parser.add_argument( - "--scaling", - help="type of scaling to the output", - type=str, - default="rms_forces_scaling", - choices=["std_scaling", "rms_forces_scaling", "no_scaling"], - ) - parser.add_argument( - "--avg_num_neighbors", - help="normalization factor for the message", - type=float, - default=1, - ) - parser.add_argument( - "--compute_avg_num_neighbors", - help="normalization factor for the message", - type=str2bool, - default=True, - ) - parser.add_argument( - "--compute_stress", - help="Select True to compute stress", - type=str2bool, - default=False, - ) - parser.add_argument( - "--compute_forces", - help="Select True to compute forces", - type=str2bool, - default=True, - ) - - # Dataset - parser.add_argument( - "--train_file", - help="Training set file, format is .xyz or .h5", - type=str, - required=False, - ) - parser.add_argument( - "--valid_file", - help="Validation set .xyz or .h5 file", - default=None, - type=str, - required=False, - ) - parser.add_argument( - "--valid_fraction", - help="Fraction of training set used for validation", - type=float, - default=0.1, - required=False, - ) - parser.add_argument( - "--test_file", - help="Test set .xyz pt .h5 file", - type=str, - ) - parser.add_argument( - "--test_dir", - help="Path to directory with test files named as test_*.h5", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--multi_processed_test", - help="Boolean value for whether the test data was multiprocessed", - type=str2bool, - default=False, - required=False, - ) - parser.add_argument( - "--num_workers", - help="Number of workers for data loading", - type=int, - default=0, - ) - parser.add_argument( - "--pin_memory", - help="Pin memory for data loading", - default=True, - type=str2bool, - ) - parser.add_argument( - "--atomic_numbers", - help="List of atomic numbers", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--mean", - help="Mean energy per atom of training set", - type=float, - default=None, - required=False, - ) - parser.add_argument( - "--std", - help="Standard deviation of force components in the training set", - type=float, - default=None, - required=False, - ) - parser.add_argument( - "--statistics_file", - help="json file containing statistics of training set", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--E0s", - help="Dictionary of isolated atom energies", - type=str, - default=None, - required=False, - ) - - # Fine-tuning - parser.add_argument( - "--foundation_filter_elements", - help="Filter element during fine-tuning", - type=str2bool, - default=True, - required=False, - ) - parser.add_argument( - "--heads", - help="Dict of heads: containing individual files and E0s", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--multiheads_finetuning", - help="Boolean value for whether the model is multiheaded", - type=str2bool, - default=True, - ) - parser.add_argument( - "--foundation_head", - help="Name of the head to use for fine-tuning", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--weight_pt_head", - help="Weight of the pretrained head in the loss function", - type=float, - default=1.0, - ) - parser.add_argument( - "--num_samples_pt", - help="Number of samples in the pretrained head", - type=int, - default=10000, - ) - parser.add_argument( - "--force_mh_ft_lr", - help="Force the multiheaded fine-tuning to use arg_parser lr", - type=str2bool, - default=False, - ) - parser.add_argument( - "--subselect_pt", - help="Method to subselect the configurations of the pretraining set", - choices=["fps", "random"], - default="random", - ) - parser.add_argument( - "--filter_type_pt", - help="Filtering method for collecting the pretraining set", - choices=["none", "combinations", "inclusive", "exclusive"], - default="none", - ) - parser.add_argument( - "--pt_train_file", - help="Training set file for the pretrained head", - type=str, - default=None, - ) - parser.add_argument( - "--pt_valid_file", - help="Validation set file for the pretrained head", - type=str, - default=None, - ) - parser.add_argument( - "--foundation_model_elements", - help="Keep all elements of the foundation model during fine-tuning", - type=str2bool, - default=False, - ) - parser.add_argument( - "--keep_isolated_atoms", - help="Keep isolated atoms in the dataset, useful for transfer learning", - type=str2bool, - default=False, - ) - - # Keys - parser.add_argument( - "--energy_key", - help="Key of reference energies in training xyz", - type=str, - default=DefaultKeys.ENERGY.value, - ) - parser.add_argument( - "--forces_key", - help="Key of reference forces in training xyz", - type=str, - default=DefaultKeys.FORCES.value, - ) - parser.add_argument( - "--virials_key", - help="Key of reference virials in training xyz", - type=str, - default=DefaultKeys.VIRIALS.value, - ) - parser.add_argument( - "--stress_key", - help="Key of reference stress in training xyz", - type=str, - default=DefaultKeys.STRESS.value, - ) - parser.add_argument( - "--dipole_key", - help="Key of reference dipoles in training xyz", - type=str, - default=DefaultKeys.DIPOLE.value, - ) - parser.add_argument( - "--head_key", - help="Key of head in training xyz", - type=str, - default=DefaultKeys.HEAD.value, - ) - parser.add_argument( - "--charges_key", - help="Key of atomic charges in training xyz", - type=str, - default=DefaultKeys.CHARGES.value, - ) - parser.add_argument( - "--skip_evaluate_heads", - help="Comma-separated list of heads to skip during final evaluation", - type=str, - default="pt_head", - ) - - # Loss and optimization - parser.add_argument( - "--loss", - help="type of loss", - default="weighted", - choices=[ - "ef", - "weighted", - "forces_only", - "virials", - "stress", - "dipole", - "huber", - "universal", - "energy_forces_dipole", - "l1l2energyforces", - ], - ) - parser.add_argument( - "--forces_weight", help="weight of forces loss", type=float, default=100.0 - ) - parser.add_argument( - "--swa_forces_weight", - "--stage_two_forces_weight", - help="weight of forces loss after starting Stage Two (previously called swa)", - type=float, - default=100.0, - dest="swa_forces_weight", - ) - parser.add_argument( - "--energy_weight", help="weight of energy loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_energy_weight", - "--stage_two_energy_weight", - help="weight of energy loss after starting Stage Two (previously called swa)", - type=float, - default=1000.0, - dest="swa_energy_weight", - ) - parser.add_argument( - "--virials_weight", help="weight of virials loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_virials_weight", - "--stage_two_virials_weight", - help="weight of virials loss after starting Stage Two (previously called swa)", - type=float, - default=10.0, - dest="swa_virials_weight", - ) - parser.add_argument( - "--stress_weight", help="weight of stress loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_stress_weight", - "--stage_two_stress_weight", - help="weight of stress loss after starting Stage Two (previously called swa)", - type=float, - default=10.0, - dest="swa_stress_weight", - ) - parser.add_argument( - "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_dipole_weight", - "--stage_two_dipole_weight", - help="weight of dipoles after starting Stage Two (previously called swa)", - type=float, - default=1.0, - dest="swa_dipole_weight", - ) - parser.add_argument( - "--config_type_weights", - help="String of dictionary containing the weights for each config type", - type=str, - default='{"Default":1.0}', - ) - parser.add_argument( - "--huber_delta", - help="delta parameter for huber loss", - type=float, - default=0.01, - ) - parser.add_argument( - "--optimizer", - help="Optimizer for parameter optimization", - type=str, - default="adam", - choices=["adam", "adamw", "schedulefree"], - ) - parser.add_argument( - "--beta", - help="Beta parameter for the optimizer", - type=float, - default=0.9, - ) - parser.add_argument("--batch_size", help="batch size", type=int, default=10) - parser.add_argument( - "--valid_batch_size", help="Validation batch size", type=int, default=10 - ) - parser.add_argument( - "--lr", help="Learning rate of optimizer", type=float, default=0.01 - ) - parser.add_argument( - "--swa_lr", - "--stage_two_lr", - help="Learning rate of optimizer in Stage Two (previously called swa)", - type=float, - default=1e-3, - dest="swa_lr", - ) - parser.add_argument( - "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 - ) - parser.add_argument( - "--amsgrad", - help="use amsgrad variant of optimizer", - action="store_true", - default=True, - ) - parser.add_argument( - "--scheduler", help="Type of scheduler", type=str, default="ReduceLROnPlateau" - ) - parser.add_argument( - "--lr_factor", help="Learning rate factor", type=float, default=0.8 - ) - parser.add_argument( - "--scheduler_patience", help="Learning rate factor", type=int, default=50 - ) - parser.add_argument( - "--lr_scheduler_gamma", - help="Gamma of learning rate scheduler", - type=float, - default=0.9993, - ) - parser.add_argument( - "--swa", - "--stage_two", - help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", - action="store_true", - default=False, - dest="swa", - ) - parser.add_argument( - "--start_swa", - "--start_stage_two", - help="Number of epochs before changing to Stage Two loss weights", - type=int, - default=None, - dest="start_swa", - ) - parser.add_argument( - "--lbfgs", - help="Switch to L-BFGS optimizer", - action="store_true", - default=False, - ) - parser.add_argument( - "--ema", - help="use Exponential Moving Average", - action="store_true", - default=False, - ) - parser.add_argument( - "--ema_decay", - help="Exponential Moving Average decay", - type=float, - default=0.99, - ) - parser.add_argument( - "--max_num_epochs", help="Maximum number of epochs", type=int, default=2048 - ) - parser.add_argument( - "--patience", - help="Maximum number of consecutive epochs of increasing loss", - type=int, - default=2048, - ) - parser.add_argument( - "--foundation_model", - help="Path to the foundation model for transfer learning", - type=str, - default=None, - ) - parser.add_argument( - "--foundation_model_readout", - help="Use readout of foundation model for transfer learning", - action="store_false", - default=True, - ) - parser.add_argument( - "--eval_interval", help="evaluate model every epochs", type=int, default=1 - ) - parser.add_argument( - "--keep_checkpoints", - help="keep all checkpoints", - action="store_true", - default=False, - ) - parser.add_argument( - "--save_all_checkpoints", - help="save all checkpoints", - action="store_true", - default=False, - ) - parser.add_argument( - "--restart_latest", - help="restart optimizer from latest checkpoint", - action="store_true", - default=False, - ) - parser.add_argument( - "--save_cpu", - help="Save a model to be loaded on cpu", - action="store_true", - default=False, - ) - parser.add_argument( - "--clip_grad", - help="Gradient Clipping Value", - type=check_float_or_none, - default=10.0, - ) - parser.add_argument( - "--dry_run", - help="Run all steps upto training to test settings.", - action="store_true", - default=False, - ) - # option for cuequivariance acceleration - parser.add_argument( - "--enable_cueq", - help="Enable cuequivariance acceleration", - type=str2bool, - default=False, - ) - # options for using Weights and Biases for experiment tracking - # to install see https://wandb.ai - parser.add_argument( - "--wandb", - help="Use Weights and Biases for experiment tracking", - action="store_true", - default=False, - ) - parser.add_argument( - "--wandb_dir", - help="An absolute path to a directory where Weights and Biases metadata will be stored", - type=str, - default=None, - ) - parser.add_argument( - "--wandb_project", - help="Weights and Biases project name", - type=str, - default="", - ) - parser.add_argument( - "--wandb_entity", - help="Weights and Biases entity name", - type=str, - default="", - ) - parser.add_argument( - "--wandb_name", - help="Weights and Biases experiment name", - type=str, - default="", - ) - parser.add_argument( - "--wandb_log_hypers", - help="The hyperparameters to log in Weights and Biases", - type=list, - default=[ - "num_channels", - "max_L", - "correlation", - "lr", - "swa_lr", - "weight_decay", - "batch_size", - "max_num_epochs", - "start_swa", - "energy_weight", - "forces_weight", - ], - ) - return parser - - -def build_preprocess_arg_parser() -> argparse.ArgumentParser: - try: - import configargparse - - parser = configargparse.ArgumentParser( - config_file_parser_class=configargparse.YAMLConfigFileParser, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add( - "--config", - type=str, - is_config_file=True, - help="config file to aggregate options", - ) - except ImportError: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--train_file", - help="Training set h5 file", - type=str, - default=None, - required=True, - ) - parser.add_argument( - "--valid_file", - help="Training set xyz file", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--num_process", - help="The user defined number of processes to use, as well as the number of files created.", - type=int, - default=int(os.cpu_count() / 4), - ) - parser.add_argument( - "--valid_fraction", - help="Fraction of training set used for validation", - type=float, - default=0.1, - required=False, - ) - parser.add_argument( - "--test_file", - help="Test set xyz file", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--work_dir", - help="set directory for all files and folders", - type=str, - default=".", - ) - parser.add_argument( - "--h5_prefix", - help="Prefix for h5 files when saving", - type=str, - default="", - ) - parser.add_argument( - "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 - ) - parser.add_argument( - "--config_type_weights", - help="String of dictionary containing the weights for each config type", - type=str, - default='{"Default":1.0}', - ) - parser.add_argument( - "--energy_key", - help="Key of reference energies in training xyz", - type=str, - default=DefaultKeys.ENERGY.value, - ) - parser.add_argument( - "--forces_key", - help="Key of reference forces in training xyz", - type=str, - default=DefaultKeys.FORCES.value, - ) - parser.add_argument( - "--virials_key", - help="Key of reference virials in training xyz", - type=str, - default=DefaultKeys.VIRIALS.value, - ) - parser.add_argument( - "--stress_key", - help="Key of reference stress in training xyz", - type=str, - default=DefaultKeys.STRESS.value, - ) - parser.add_argument( - "--dipole_key", - help="Key of reference dipoles in training xyz", - type=str, - default=DefaultKeys.DIPOLE.value, - ) - parser.add_argument( - "--charges_key", - help="Key of atomic charges in training xyz", - type=str, - default=DefaultKeys.CHARGES.value, - ) - parser.add_argument( - "--atomic_numbers", - help="List of atomic numbers", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--compute_statistics", - help="Compute statistics for the dataset", - action="store_true", - default=False, - ) - parser.add_argument( - "--batch_size", - help="batch size to compute average number of neighbours", - type=int, - default=16, - ) - - parser.add_argument( - "--scaling", - help="type of scaling to the output", - type=str, - default="rms_forces_scaling", - choices=["std_scaling", "rms_forces_scaling", "no_scaling"], - ) - parser.add_argument( - "--E0s", - help="Dictionary of isolated atom energies", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--shuffle", - help="Shuffle the training dataset", - type=str2bool, - default=True, - ) - parser.add_argument( - "--seed", - help="Random seed for splitting training and validation sets", - type=int, - default=123, - ) - parser.add_argument( - "--head_key", - help="Key of head in training xyz", - type=str, - default=DefaultKeys.HEAD.value, - ) - parser.add_argument( - "--heads", - help="Dict of heads: containing individual files and E0s", - type=str, - default=None, - required=False, - ) - return parser - - -def check_float_or_none(value: str) -> Optional[float]: - try: - return float(value) - except ValueError: - if value != "None": - raise argparse.ArgumentTypeError( - f"{value} is an invalid value (float or None)" - ) from None - return None - - -def str2bool(value): - if isinstance(value, bool): - return value - if value.lower() in ("yes", "true", "t", "y", "1"): - return True - if value.lower() in ("no", "false", "f", "n", "0"): - return False - raise argparse.ArgumentTypeError("Boolean value expected.") +########################################################################################### +# Parsing functionalities +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse +import os +from typing import Optional + +from .default_keys import DefaultKeys + + +def build_default_arg_parser() -> argparse.ArgumentParser: + try: + import configargparse + + parser = configargparse.ArgumentParser( + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add( + "--config", + type=str, + is_config_file=True, + help="config file to aggregate options", + ) + except ImportError: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Name and seed + parser.add_argument("--name", help="experiment name", required=True) + parser.add_argument("--seed", help="random seed", type=int, default=123) + + # Directories + parser.add_argument( + "--work_dir", + help="set directory for all files and folders", + type=str, + default=".", + ) + parser.add_argument( + "--log_dir", help="directory for log files", type=str, default=None + ) + parser.add_argument( + "--model_dir", help="directory for final model", type=str, default=None + ) + parser.add_argument( + "--checkpoints_dir", + help="directory for checkpoint files", + type=str, + default=None, + ) + parser.add_argument( + "--results_dir", help="directory for results", type=str, default=None + ) + parser.add_argument( + "--downloads_dir", help="directory for downloads", type=str, default=None + ) + + # Device and logging + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda", "mps", "xpu"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--distributed", + help="train in multi-GPU data parallel mode", + action="store_true", + default=False, + ) + parser.add_argument("--log_level", help="log level", type=str, default="INFO") + + parser.add_argument( + "--plot", + help="Plot results of training", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--plot_frequency", + help="Set plotting frequency: '0' for only at the end or an integer N to plot every N epochs.", + type=int, + default="0", + ) + + parser.add_argument( + "--error_table", + help="Type of error table produced at the end of the training", + type=str, + choices=[ + "PerAtomRMSE", + "TotalRMSE", + "PerAtomRMSEstressvirials", + "PerAtomMAEstressvirials", + "PerAtomMAE", + "TotalMAE", + "DipoleRMSE", + "DipoleMAE", + "EnergyDipoleRMSE", + ], + default="PerAtomRMSE", + ) + + # Model + parser.add_argument( + "--model", + help="model type", + default="MACE", + choices=[ + "BOTNet", + "MACE", + "ScaleShiftMACE", + "ScaleShiftBOTNet", + "AtomicDipolesMACE", + "EnergyDipolesMACE", + ], + ) + parser.add_argument( + "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 + ) + parser.add_argument( + "--radial_type", + help="type of radial basis functions", + type=str, + default="bessel", + choices=["bessel", "gaussian", "chebyshev"], + ) + parser.add_argument( + "--num_radial_basis", + help="number of radial basis functions", + type=int, + default=8, + ) + parser.add_argument( + "--num_cutoff_basis", + help="number of basis functions for smooth cutoff", + type=int, + default=5, + ) + parser.add_argument( + "--pair_repulsion", + help="use pair repulsion term with ZBL potential", + action="store_true", + default=False, + ) + parser.add_argument( + "--distance_transform", + help="use distance transform for radial basis functions", + default="None", + choices=["None", "Agnesi", "Soft"], + ) + parser.add_argument( + "--interaction", + help="name of interaction block", + type=str, + default="RealAgnosticResidualInteractionBlock", + choices=[ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticAttResidualInteractionBlock", + "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ], + ) + parser.add_argument( + "--interaction_first", + help="name of interaction block", + type=str, + default="RealAgnosticResidualInteractionBlock", + choices=[ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ], + ) + parser.add_argument( + "--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3 + ) + parser.add_argument( + "--correlation", help="correlation order at each layer", type=int, default=3 + ) + parser.add_argument( + "--num_interactions", help="number of interactions", type=int, default=2 + ) + parser.add_argument( + "--MLP_irreps", + help="hidden irreps of the MLP in last readout", + type=str, + default="16x0e", + ) + parser.add_argument( + "--radial_MLP", + help="width of the radial MLP", + type=str, + default="[64, 64, 64]", + ) + parser.add_argument( + "--hidden_irreps", + help="irreps for hidden node states", + type=str, + default=None, + ) + # add option to specify irreps by channel number and max L + parser.add_argument( + "--num_channels", + help="number of embedding channels", + type=int, + default=None, + ) + parser.add_argument( + "--max_L", + help="max L equivariance of the message", + type=int, + default=None, + ) + parser.add_argument( + "--gate", + help="non linearity for last readout", + type=str, + default="silu", + choices=["silu", "tanh", "abs", "None"], + ) + parser.add_argument( + "--scaling", + help="type of scaling to the output", + type=str, + default="rms_forces_scaling", + choices=["std_scaling", "rms_forces_scaling", "no_scaling"], + ) + parser.add_argument( + "--avg_num_neighbors", + help="normalization factor for the message", + type=float, + default=1, + ) + parser.add_argument( + "--compute_avg_num_neighbors", + help="normalization factor for the message", + type=str2bool, + default=True, + ) + parser.add_argument( + "--compute_stress", + help="Select True to compute stress", + type=str2bool, + default=False, + ) + parser.add_argument( + "--compute_forces", + help="Select True to compute forces", + type=str2bool, + default=True, + ) + + # Dataset + parser.add_argument( + "--train_file", + help="Training set file, format is .xyz or .h5", + type=str, + required=False, + ) + parser.add_argument( + "--valid_file", + help="Validation set .xyz or .h5 file", + default=None, + type=str, + required=False, + ) + parser.add_argument( + "--valid_fraction", + help="Fraction of training set used for validation", + type=float, + default=0.1, + required=False, + ) + parser.add_argument( + "--test_file", + help="Test set .xyz pt .h5 file", + type=str, + ) + parser.add_argument( + "--test_dir", + help="Path to directory with test files named as test_*.h5", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--multi_processed_test", + help="Boolean value for whether the test data was multiprocessed", + type=str2bool, + default=False, + required=False, + ) + parser.add_argument( + "--num_workers", + help="Number of workers for data loading", + type=int, + default=0, + ) + parser.add_argument( + "--pin_memory", + help="Pin memory for data loading", + default=True, + type=str2bool, + ) + parser.add_argument( + "--atomic_numbers", + help="List of atomic numbers", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--mean", + help="Mean energy per atom of training set", + type=float, + default=None, + required=False, + ) + parser.add_argument( + "--std", + help="Standard deviation of force components in the training set", + type=float, + default=None, + required=False, + ) + parser.add_argument( + "--statistics_file", + help="json file containing statistics of training set", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--E0s", + help="Dictionary of isolated atom energies", + type=str, + default=None, + required=False, + ) + + # Fine-tuning + parser.add_argument( + "--foundation_filter_elements", + help="Filter element during fine-tuning", + type=str2bool, + default=True, + required=False, + ) + parser.add_argument( + "--heads", + help="Dict of heads: containing individual files and E0s", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--multiheads_finetuning", + help="Boolean value for whether the model is multiheaded", + type=str2bool, + default=True, + ) + parser.add_argument( + "--foundation_head", + help="Name of the head to use for fine-tuning", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--weight_pt_head", + help="Weight of the pretrained head in the loss function", + type=float, + default=1.0, + ) + parser.add_argument( + "--num_samples_pt", + help="Number of samples in the pretrained head", + type=int, + default=10000, + ) + parser.add_argument( + "--force_mh_ft_lr", + help="Force the multiheaded fine-tuning to use arg_parser lr", + type=str2bool, + default=False, + ) + parser.add_argument( + "--subselect_pt", + help="Method to subselect the configurations of the pretraining set", + choices=["fps", "random"], + default="random", + ) + parser.add_argument( + "--filter_type_pt", + help="Filtering method for collecting the pretraining set", + choices=["none", "combinations", "inclusive", "exclusive"], + default="none", + ) + parser.add_argument( + "--pt_train_file", + help="Training set file for the pretrained head", + type=str, + default=None, + ) + parser.add_argument( + "--pt_valid_file", + help="Validation set file for the pretrained head", + type=str, + default=None, + ) + parser.add_argument( + "--foundation_model_elements", + help="Keep all elements of the foundation model during fine-tuning", + type=str2bool, + default=False, + ) + parser.add_argument( + "--keep_isolated_atoms", + help="Keep isolated atoms in the dataset, useful for transfer learning", + type=str2bool, + default=False, + ) + + # Keys + parser.add_argument( + "--energy_key", + help="Key of reference energies in training xyz", + type=str, + default=DefaultKeys.ENERGY.value, + ) + parser.add_argument( + "--forces_key", + help="Key of reference forces in training xyz", + type=str, + default=DefaultKeys.FORCES.value, + ) + parser.add_argument( + "--virials_key", + help="Key of reference virials in training xyz", + type=str, + default=DefaultKeys.VIRIALS.value, + ) + parser.add_argument( + "--stress_key", + help="Key of reference stress in training xyz", + type=str, + default=DefaultKeys.STRESS.value, + ) + parser.add_argument( + "--dipole_key", + help="Key of reference dipoles in training xyz", + type=str, + default=DefaultKeys.DIPOLE.value, + ) + parser.add_argument( + "--head_key", + help="Key of head in training xyz", + type=str, + default=DefaultKeys.HEAD.value, + ) + parser.add_argument( + "--charges_key", + help="Key of atomic charges in training xyz", + type=str, + default=DefaultKeys.CHARGES.value, + ) + parser.add_argument( + "--skip_evaluate_heads", + help="Comma-separated list of heads to skip during final evaluation", + type=str, + default="pt_head", + ) + + # Loss and optimization + parser.add_argument( + "--loss", + help="type of loss", + default="weighted", + choices=[ + "ef", + "weighted", + "forces_only", + "virials", + "stress", + "dipole", + "huber", + "universal", + "energy_forces_dipole", + "l1l2energyforces", + ], + ) + parser.add_argument( + "--forces_weight", help="weight of forces loss", type=float, default=100.0 + ) + parser.add_argument( + "--swa_forces_weight", + "--stage_two_forces_weight", + help="weight of forces loss after starting Stage Two (previously called swa)", + type=float, + default=100.0, + dest="swa_forces_weight", + ) + parser.add_argument( + "--energy_weight", help="weight of energy loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_energy_weight", + "--stage_two_energy_weight", + help="weight of energy loss after starting Stage Two (previously called swa)", + type=float, + default=1000.0, + dest="swa_energy_weight", + ) + parser.add_argument( + "--virials_weight", help="weight of virials loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_virials_weight", + "--stage_two_virials_weight", + help="weight of virials loss after starting Stage Two (previously called swa)", + type=float, + default=10.0, + dest="swa_virials_weight", + ) + parser.add_argument( + "--stress_weight", help="weight of stress loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_stress_weight", + "--stage_two_stress_weight", + help="weight of stress loss after starting Stage Two (previously called swa)", + type=float, + default=10.0, + dest="swa_stress_weight", + ) + parser.add_argument( + "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_dipole_weight", + "--stage_two_dipole_weight", + help="weight of dipoles after starting Stage Two (previously called swa)", + type=float, + default=1.0, + dest="swa_dipole_weight", + ) + parser.add_argument( + "--config_type_weights", + help="String of dictionary containing the weights for each config type", + type=str, + default='{"Default":1.0}', + ) + parser.add_argument( + "--huber_delta", + help="delta parameter for huber loss", + type=float, + default=0.01, + ) + parser.add_argument( + "--optimizer", + help="Optimizer for parameter optimization", + type=str, + default="adam", + choices=["adam", "adamw", "schedulefree"], + ) + parser.add_argument( + "--beta", + help="Beta parameter for the optimizer", + type=float, + default=0.9, + ) + parser.add_argument("--batch_size", help="batch size", type=int, default=10) + parser.add_argument( + "--valid_batch_size", help="Validation batch size", type=int, default=10 + ) + parser.add_argument( + "--lr", help="Learning rate of optimizer", type=float, default=0.01 + ) + parser.add_argument( + "--swa_lr", + "--stage_two_lr", + help="Learning rate of optimizer in Stage Two (previously called swa)", + type=float, + default=1e-3, + dest="swa_lr", + ) + parser.add_argument( + "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 + ) + parser.add_argument( + "--amsgrad", + help="use amsgrad variant of optimizer", + action="store_true", + default=True, + ) + parser.add_argument( + "--scheduler", help="Type of scheduler", type=str, default="ReduceLROnPlateau" + ) + parser.add_argument( + "--lr_factor", help="Learning rate factor", type=float, default=0.8 + ) + parser.add_argument( + "--scheduler_patience", help="Learning rate factor", type=int, default=50 + ) + parser.add_argument( + "--lr_scheduler_gamma", + help="Gamma of learning rate scheduler", + type=float, + default=0.9993, + ) + parser.add_argument( + "--swa", + "--stage_two", + help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", + action="store_true", + default=False, + dest="swa", + ) + parser.add_argument( + "--start_swa", + "--start_stage_two", + help="Number of epochs before changing to Stage Two loss weights", + type=int, + default=None, + dest="start_swa", + ) + parser.add_argument( + "--lbfgs", + help="Switch to L-BFGS optimizer", + action="store_true", + default=False, + ) + parser.add_argument( + "--ema", + help="use Exponential Moving Average", + action="store_true", + default=False, + ) + parser.add_argument( + "--ema_decay", + help="Exponential Moving Average decay", + type=float, + default=0.99, + ) + parser.add_argument( + "--max_num_epochs", help="Maximum number of epochs", type=int, default=2048 + ) + parser.add_argument( + "--patience", + help="Maximum number of consecutive epochs of increasing loss", + type=int, + default=2048, + ) + parser.add_argument( + "--foundation_model", + help="Path to the foundation model for transfer learning", + type=str, + default=None, + ) + parser.add_argument( + "--foundation_model_readout", + help="Use readout of foundation model for transfer learning", + action="store_false", + default=True, + ) + parser.add_argument( + "--eval_interval", help="evaluate model every epochs", type=int, default=1 + ) + parser.add_argument( + "--keep_checkpoints", + help="keep all checkpoints", + action="store_true", + default=False, + ) + parser.add_argument( + "--save_all_checkpoints", + help="save all checkpoints", + action="store_true", + default=False, + ) + parser.add_argument( + "--restart_latest", + help="restart optimizer from latest checkpoint", + action="store_true", + default=False, + ) + parser.add_argument( + "--save_cpu", + help="Save a model to be loaded on cpu", + action="store_true", + default=False, + ) + parser.add_argument( + "--clip_grad", + help="Gradient Clipping Value", + type=check_float_or_none, + default=10.0, + ) + parser.add_argument( + "--dry_run", + help="Run all steps upto training to test settings.", + action="store_true", + default=False, + ) + # option for cuequivariance acceleration + parser.add_argument( + "--enable_cueq", + help="Enable cuequivariance acceleration", + type=str2bool, + default=False, + ) + # options for using Weights and Biases for experiment tracking + # to install see https://wandb.ai + parser.add_argument( + "--wandb", + help="Use Weights and Biases for experiment tracking", + action="store_true", + default=False, + ) + parser.add_argument( + "--wandb_dir", + help="An absolute path to a directory where Weights and Biases metadata will be stored", + type=str, + default=None, + ) + parser.add_argument( + "--wandb_project", + help="Weights and Biases project name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_entity", + help="Weights and Biases entity name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_name", + help="Weights and Biases experiment name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_log_hypers", + help="The hyperparameters to log in Weights and Biases", + type=list, + default=[ + "num_channels", + "max_L", + "correlation", + "lr", + "swa_lr", + "weight_decay", + "batch_size", + "max_num_epochs", + "start_swa", + "energy_weight", + "forces_weight", + ], + ) + return parser + + +def build_preprocess_arg_parser() -> argparse.ArgumentParser: + try: + import configargparse + + parser = configargparse.ArgumentParser( + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add( + "--config", + type=str, + is_config_file=True, + help="config file to aggregate options", + ) + except ImportError: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--train_file", + help="Training set h5 file", + type=str, + default=None, + required=True, + ) + parser.add_argument( + "--valid_file", + help="Training set xyz file", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--num_process", + help="The user defined number of processes to use, as well as the number of files created.", + type=int, + default=int(os.cpu_count() / 4), + ) + parser.add_argument( + "--valid_fraction", + help="Fraction of training set used for validation", + type=float, + default=0.1, + required=False, + ) + parser.add_argument( + "--test_file", + help="Test set xyz file", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--work_dir", + help="set directory for all files and folders", + type=str, + default=".", + ) + parser.add_argument( + "--h5_prefix", + help="Prefix for h5 files when saving", + type=str, + default="", + ) + parser.add_argument( + "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 + ) + parser.add_argument( + "--config_type_weights", + help="String of dictionary containing the weights for each config type", + type=str, + default='{"Default":1.0}', + ) + parser.add_argument( + "--energy_key", + help="Key of reference energies in training xyz", + type=str, + default=DefaultKeys.ENERGY.value, + ) + parser.add_argument( + "--forces_key", + help="Key of reference forces in training xyz", + type=str, + default=DefaultKeys.FORCES.value, + ) + parser.add_argument( + "--virials_key", + help="Key of reference virials in training xyz", + type=str, + default=DefaultKeys.VIRIALS.value, + ) + parser.add_argument( + "--stress_key", + help="Key of reference stress in training xyz", + type=str, + default=DefaultKeys.STRESS.value, + ) + parser.add_argument( + "--dipole_key", + help="Key of reference dipoles in training xyz", + type=str, + default=DefaultKeys.DIPOLE.value, + ) + parser.add_argument( + "--charges_key", + help="Key of atomic charges in training xyz", + type=str, + default=DefaultKeys.CHARGES.value, + ) + parser.add_argument( + "--atomic_numbers", + help="List of atomic numbers", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--compute_statistics", + help="Compute statistics for the dataset", + action="store_true", + default=False, + ) + parser.add_argument( + "--batch_size", + help="batch size to compute average number of neighbours", + type=int, + default=16, + ) + + parser.add_argument( + "--scaling", + help="type of scaling to the output", + type=str, + default="rms_forces_scaling", + choices=["std_scaling", "rms_forces_scaling", "no_scaling"], + ) + parser.add_argument( + "--E0s", + help="Dictionary of isolated atom energies", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--shuffle", + help="Shuffle the training dataset", + type=str2bool, + default=True, + ) + parser.add_argument( + "--seed", + help="Random seed for splitting training and validation sets", + type=int, + default=123, + ) + parser.add_argument( + "--head_key", + help="Key of head in training xyz", + type=str, + default=DefaultKeys.HEAD.value, + ) + parser.add_argument( + "--heads", + help="Dict of heads: containing individual files and E0s", + type=str, + default=None, + required=False, + ) + return parser + + +def check_float_or_none(value: str) -> Optional[float]: + try: + return float(value) + except ValueError: + if value != "None": + raise argparse.ArgumentTypeError( + f"{value} is an invalid value (float or None)" + ) from None + return None + + +def str2bool(value): + if isinstance(value, bool): + return value + if value.lower() in ("yes", "true", "t", "y", "1"): + return True + if value.lower() in ("no", "false", "f", "n", "0"): + return False + raise argparse.ArgumentTypeError("Boolean value expected.") diff --git a/mace-bench/3rdparty/mace/mace/tools/arg_parser_tools.py b/mace-bench/3rdparty/mace/mace/tools/arg_parser_tools.py index 21a23ff8ebcec1d271edf30adf7bbc07bf375a88..be714b26edc2b4d39bb64ccc0c6e270e6c656e5b 100644 --- a/mace-bench/3rdparty/mace/mace/tools/arg_parser_tools.py +++ b/mace-bench/3rdparty/mace/mace/tools/arg_parser_tools.py @@ -1,122 +1,122 @@ -import logging -import os - -from e3nn import o3 - - -def check_args(args): - """ - Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing - the (potentially) modified args and a list of log messages. - """ - log_messages = [] - - # Directories - # Use work_dir for all other directories as well, unless they were specified by the user - if args.log_dir is None: - args.log_dir = os.path.join(args.work_dir, "logs") - if args.model_dir is None: - args.model_dir = args.work_dir - if args.checkpoints_dir is None: - args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") - if args.results_dir is None: - args.results_dir = os.path.join(args.work_dir, "results") - if args.downloads_dir is None: - args.downloads_dir = os.path.join(args.work_dir, "downloads") - - # Model - # Check if hidden_irreps, num_channels and max_L are consistent - if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: - args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 - elif ( - args.hidden_irreps is not None - and args.num_channels is not None - and args.max_L is not None - ): - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - log_messages.append( - ( - "All of hidden_irreps, num_channels and max_L are specified", - logging.WARNING, - ) - ) - log_messages.append( - ( - f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.", - logging.WARNING, - ) - ) - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - elif args.num_channels is not None and args.max_L is not None: - assert args.num_channels > 0, "num_channels must be positive integer" - assert args.max_L >= 0, "max_L must be non-negative integer" - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - elif args.hidden_irreps is not None: - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - - args.num_channels = list( - {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} - )[0] - args.max_L = o3.Irreps(args.hidden_irreps).lmax - elif args.max_L is not None and args.num_channels is None: - assert args.max_L >= 0, "max_L must be non-negative integer" - args.num_channels = 128 - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - elif args.max_L is None and args.num_channels is not None: - assert args.num_channels > 0, "num_channels must be positive integer" - args.max_L = 1 - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - - # Loss and optimization - # Check Stage Two loss start - if args.start_swa is not None: - args.swa = True - log_messages.append( - ( - "Stage Two is activated as start_stage_two was defined", - logging.INFO, - ) - ) - - if args.swa: - if args.start_swa is None: - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - if args.start_swa > args.max_num_epochs: - log_messages.append( - ( - f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", - logging.WARNING, - ) - ) - log_messages.append( - ( - "Stage Two will not start, as start_stage_two > max_num_epochs", - logging.WARNING, - ) - ) - args.swa = False - - return args, log_messages +import logging +import os + +from e3nn import o3 + + +def check_args(args): + """ + Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing + the (potentially) modified args and a list of log messages. + """ + log_messages = [] + + # Directories + # Use work_dir for all other directories as well, unless they were specified by the user + if args.log_dir is None: + args.log_dir = os.path.join(args.work_dir, "logs") + if args.model_dir is None: + args.model_dir = args.work_dir + if args.checkpoints_dir is None: + args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") + if args.results_dir is None: + args.results_dir = os.path.join(args.work_dir, "results") + if args.downloads_dir is None: + args.downloads_dir = os.path.join(args.work_dir, "downloads") + + # Model + # Check if hidden_irreps, num_channels and max_L are consistent + if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: + args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 + elif ( + args.hidden_irreps is not None + and args.num_channels is not None + and args.max_L is not None + ): + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + log_messages.append( + ( + "All of hidden_irreps, num_channels and max_L are specified", + logging.WARNING, + ) + ) + log_messages.append( + ( + f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.", + logging.WARNING, + ) + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.num_channels is not None and args.max_L is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + assert args.max_L >= 0, "max_L must be non-negative integer" + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.hidden_irreps is not None: + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + args.num_channels = list( + {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} + )[0] + args.max_L = o3.Irreps(args.hidden_irreps).lmax + elif args.max_L is not None and args.num_channels is None: + assert args.max_L >= 0, "max_L must be non-negative integer" + args.num_channels = 128 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + elif args.max_L is None and args.num_channels is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + args.max_L = 1 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + + # Loss and optimization + # Check Stage Two loss start + if args.start_swa is not None: + args.swa = True + log_messages.append( + ( + "Stage Two is activated as start_stage_two was defined", + logging.INFO, + ) + ) + + if args.swa: + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + if args.start_swa > args.max_num_epochs: + log_messages.append( + ( + f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", + logging.WARNING, + ) + ) + log_messages.append( + ( + "Stage Two will not start, as start_stage_two > max_num_epochs", + logging.WARNING, + ) + ) + args.swa = False + + return args, log_messages diff --git a/mace-bench/3rdparty/mace/mace/tools/cg.py b/mace-bench/3rdparty/mace/mace/tools/cg.py index 5570be080268d2a191923e5255562a058aa39e9a..471adac786519ea860673687fcff40fbcf08faba 100644 --- a/mace-bench/3rdparty/mace/mace/tools/cg.py +++ b/mace-bench/3rdparty/mace/mace/tools/cg.py @@ -1,211 +1,211 @@ -########################################################################################### -# Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger) -# Authors: Ilyes Batatia -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import collections -import itertools -import os -from typing import Iterator, List, Union - -import numpy as np -import torch -from e3nn import o3 - -try: - import cuequivariance as cue - - CUET_AVAILABLE = True -except ImportError: - CUET_AVAILABLE = False - -USE_CUEQ_CG = os.environ.get("MACE_USE_CUEQ_CG", "0").lower() in ( - "1", - "true", - "yes", - "y", -) - -_TP = collections.namedtuple("_TP", "op, args") -_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") - - -def _wigner_nj( - irrepss: List[o3.Irreps], - normalization: str = "component", - filter_ir_mid=None, - dtype=None, -): - irrepss = [o3.Irreps(irreps) for irreps in irrepss] - if filter_ir_mid is not None: - filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] - - if len(irrepss) == 1: - (irreps,) = irrepss - ret = [] - e = torch.eye(irreps.dim, dtype=dtype) - i = 0 - for mul, ir in irreps: - for _ in range(mul): - sl = slice(i, i + ir.dim) - ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] - i += ir.dim - return ret - - *irrepss_left, irreps_right = irrepss - ret = [] - for ir_left, path_left, C_left in _wigner_nj( - irrepss_left, - normalization=normalization, - filter_ir_mid=filter_ir_mid, - dtype=dtype, - ): - i = 0 - for mul, ir in irreps_right: - for ir_out in ir_left * ir: - if filter_ir_mid is not None and ir_out not in filter_ir_mid: - continue - - C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) - if normalization == "component": - C *= ir_out.dim**0.5 - if normalization == "norm": - C *= ir_left.dim**0.5 * ir.dim**0.5 - - C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) - C = C.reshape( - ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim - ) - for u in range(mul): - E = torch.zeros( - ir_out.dim, - *(irreps.dim for irreps in irrepss_left), - irreps_right.dim, - dtype=dtype, - ) - sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) - E[..., sl] = C - ret += [ - ( - ir_out, - _TP( - op=(ir_left, ir, ir_out), - args=( - path_left, - _INPUT(len(irrepss_left), sl.start, sl.stop), - ), - ), - E, - ) - ] - i += mul * ir.dim - return sorted(ret, key=lambda x: x[0]) - - -def U_matrix_real( - irreps_in: Union[str, o3.Irreps], - irreps_out: Union[str, o3.Irreps], - correlation: int, - normalization: str = "component", - filter_ir_mid=None, - dtype=None, - use_cueq_cg=None, -): - irreps_out = o3.Irreps(irreps_out) - irrepss = [o3.Irreps(irreps_in)] * correlation - - if correlation == 4: - filter_ir_mid = [(i, 1 if i % 2 == 0 else -1) for i in range(12)] - - if use_cueq_cg is None: - use_cueq_cg = USE_CUEQ_CG - if use_cueq_cg and CUET_AVAILABLE: - return compute_U_cueq(irreps_in, irreps_out=irreps_out, correlation=correlation) - - try: - wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) - except NotImplementedError as e: - if CUET_AVAILABLE: - return compute_U_cueq( - irreps_in, irreps_out=irreps_out, correlation=correlation - ) - raise NotImplementedError( - "The requested Clebsch-Gordan coefficients are not implemented, please install cuequivariance; pip install cuequivariance" - ) from e - - current_ir = wigners[0][0] - out = [] - stack = torch.tensor([]) - - for ir, _, base_o3 in wigners: - if ir in irreps_out and ir == current_ir: - stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) - last_ir = current_ir - elif ir in irreps_out and ir != current_ir: - if len(stack) != 0: - out += [last_ir, stack] - stack = base_o3.squeeze().unsqueeze(-1) - current_ir, last_ir = ir, ir - else: - current_ir = ir - out += [last_ir, stack] - return out - - -if CUET_AVAILABLE: - - def compute_U_cueq(irreps_in, irreps_out, correlation=2): - U = [] - irreps_in = cue.Irreps(O3_e3nn, str(irreps_in)) - irreps_out = cue.Irreps(O3_e3nn, str(irreps_out)) - for _, ir in irreps_out: - ir_str = str(ir) - U.append(ir_str) - U_matrix = cue.reduced_symmetric_tensor_product_basis( - irreps_in, correlation, keep_ir=ir, layout=cue.ir_mul - ).array - U_matrix = U_matrix.reshape(ir.dim, *([irreps_in.dim] * correlation), -1) - if ir.dim == 1: - U_matrix = U_matrix[0] - U.append(torch.tensor(U_matrix)) - return U - - class O3_e3nn(cue.O3): - def __mul__( # pylint: disable=no-self-argument - rep1: "O3_e3nn", rep2: "O3_e3nn" - ) -> Iterator["O3_e3nn"]: - return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] - - @classmethod - def clebsch_gordan( - cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn" - ) -> np.ndarray: - rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) - - if rep1.p * rep2.p == rep3.p: - return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( - rep3.dim - ) - return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) - - def __lt__( # pylint: disable=no-self-argument - rep1: "O3_e3nn", rep2: "O3_e3nn" - ) -> bool: - rep2 = rep1._from(rep2) - return (rep1.l, rep1.p) < (rep2.l, rep2.p) - - @classmethod - def iterator(cls) -> Iterator["O3_e3nn"]: - for l in itertools.count(0): - yield O3_e3nn(l=l, p=1 * (-1) ** l) - yield O3_e3nn(l=l, p=-1 * (-1) ** l) - -else: - - class O3_e3nn: - pass - - print( - "cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled." - ) +########################################################################################### +# Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger) +# Authors: Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import collections +import itertools +import os +from typing import Iterator, List, Union + +import numpy as np +import torch +from e3nn import o3 + +try: + import cuequivariance as cue + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +USE_CUEQ_CG = os.environ.get("MACE_USE_CUEQ_CG", "0").lower() in ( + "1", + "true", + "yes", + "y", +) + +_TP = collections.namedtuple("_TP", "op, args") +_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") + + +def _wigner_nj( + irrepss: List[o3.Irreps], + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irrepss = [o3.Irreps(irreps) for irreps in irrepss] + if filter_ir_mid is not None: + filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] + + if len(irrepss) == 1: + (irreps,) = irrepss + ret = [] + e = torch.eye(irreps.dim, dtype=dtype) + i = 0 + for mul, ir in irreps: + for _ in range(mul): + sl = slice(i, i + ir.dim) + ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] + i += ir.dim + return ret + + *irrepss_left, irreps_right = irrepss + ret = [] + for ir_left, path_left, C_left in _wigner_nj( + irrepss_left, + normalization=normalization, + filter_ir_mid=filter_ir_mid, + dtype=dtype, + ): + i = 0 + for mul, ir in irreps_right: + for ir_out in ir_left * ir: + if filter_ir_mid is not None and ir_out not in filter_ir_mid: + continue + + C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) + if normalization == "component": + C *= ir_out.dim**0.5 + if normalization == "norm": + C *= ir_left.dim**0.5 * ir.dim**0.5 + + C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) + C = C.reshape( + ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim + ) + for u in range(mul): + E = torch.zeros( + ir_out.dim, + *(irreps.dim for irreps in irrepss_left), + irreps_right.dim, + dtype=dtype, + ) + sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) + E[..., sl] = C + ret += [ + ( + ir_out, + _TP( + op=(ir_left, ir, ir_out), + args=( + path_left, + _INPUT(len(irrepss_left), sl.start, sl.stop), + ), + ), + E, + ) + ] + i += mul * ir.dim + return sorted(ret, key=lambda x: x[0]) + + +def U_matrix_real( + irreps_in: Union[str, o3.Irreps], + irreps_out: Union[str, o3.Irreps], + correlation: int, + normalization: str = "component", + filter_ir_mid=None, + dtype=None, + use_cueq_cg=None, +): + irreps_out = o3.Irreps(irreps_out) + irrepss = [o3.Irreps(irreps_in)] * correlation + + if correlation == 4: + filter_ir_mid = [(i, 1 if i % 2 == 0 else -1) for i in range(12)] + + if use_cueq_cg is None: + use_cueq_cg = USE_CUEQ_CG + if use_cueq_cg and CUET_AVAILABLE: + return compute_U_cueq(irreps_in, irreps_out=irreps_out, correlation=correlation) + + try: + wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) + except NotImplementedError as e: + if CUET_AVAILABLE: + return compute_U_cueq( + irreps_in, irreps_out=irreps_out, correlation=correlation + ) + raise NotImplementedError( + "The requested Clebsch-Gordan coefficients are not implemented, please install cuequivariance; pip install cuequivariance" + ) from e + + current_ir = wigners[0][0] + out = [] + stack = torch.tensor([]) + + for ir, _, base_o3 in wigners: + if ir in irreps_out and ir == current_ir: + stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) + last_ir = current_ir + elif ir in irreps_out and ir != current_ir: + if len(stack) != 0: + out += [last_ir, stack] + stack = base_o3.squeeze().unsqueeze(-1) + current_ir, last_ir = ir, ir + else: + current_ir = ir + out += [last_ir, stack] + return out + + +if CUET_AVAILABLE: + + def compute_U_cueq(irreps_in, irreps_out, correlation=2): + U = [] + irreps_in = cue.Irreps(O3_e3nn, str(irreps_in)) + irreps_out = cue.Irreps(O3_e3nn, str(irreps_out)) + for _, ir in irreps_out: + ir_str = str(ir) + U.append(ir_str) + U_matrix = cue.reduced_symmetric_tensor_product_basis( + irreps_in, correlation, keep_ir=ir, layout=cue.ir_mul + ).array + U_matrix = U_matrix.reshape(ir.dim, *([irreps_in.dim] * correlation), -1) + if ir.dim == 1: + U_matrix = U_matrix[0] + U.append(torch.tensor(U_matrix)) + return U + + class O3_e3nn(cue.O3): + def __mul__( # pylint: disable=no-self-argument + rep1: "O3_e3nn", rep2: "O3_e3nn" + ) -> Iterator["O3_e3nn"]: + return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] + + @classmethod + def clebsch_gordan( + cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn" + ) -> np.ndarray: + rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) + + if rep1.p * rep2.p == rep3.p: + return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( + rep3.dim + ) + return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) + + def __lt__( # pylint: disable=no-self-argument + rep1: "O3_e3nn", rep2: "O3_e3nn" + ) -> bool: + rep2 = rep1._from(rep2) + return (rep1.l, rep1.p) < (rep2.l, rep2.p) + + @classmethod + def iterator(cls) -> Iterator["O3_e3nn"]: + for l in itertools.count(0): + yield O3_e3nn(l=l, p=1 * (-1) ** l) + yield O3_e3nn(l=l, p=-1 * (-1) ** l) + +else: + + class O3_e3nn: + pass + + print( + "cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled." + ) diff --git a/mace-bench/3rdparty/mace/mace/tools/checkpoint.py b/mace-bench/3rdparty/mace/mace/tools/checkpoint.py index 2925140be6391ed1bbf42de95e277304b37ed3c7..81161cccda0245b8fb75291f38ebb8742fd29a4b 100644 --- a/mace-bench/3rdparty/mace/mace/tools/checkpoint.py +++ b/mace-bench/3rdparty/mace/mace/tools/checkpoint.py @@ -1,227 +1,227 @@ -########################################################################################### -# Checkpointing -# Authors: Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import dataclasses -import logging -import os -import re -from typing import Dict, List, Optional, Tuple - -import torch - -from .torch_tools import TensorDict - -Checkpoint = Dict[str, TensorDict] - - -@dataclasses.dataclass -class CheckpointState: - model: torch.nn.Module - optimizer: torch.optim.Optimizer - lr_scheduler: torch.optim.lr_scheduler.ExponentialLR - - -class CheckpointBuilder: - @staticmethod - def create_checkpoint(state: CheckpointState) -> Checkpoint: - return { - "model": state.model.state_dict(), - "optimizer": state.optimizer.state_dict(), - "lr_scheduler": state.lr_scheduler.state_dict(), - } - - @staticmethod - def load_checkpoint( - state: CheckpointState, checkpoint: Checkpoint, strict: bool - ) -> None: - state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore - state.optimizer.load_state_dict(checkpoint["optimizer"]) - state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) - - -@dataclasses.dataclass -class CheckpointPathInfo: - path: str - tag: str - epochs: int - swa: bool - - -class CheckpointIO: - def __init__( - self, directory: str, tag: str, keep: bool = False, swa_start: int = None - ) -> None: - self.directory = directory - self.tag = tag - self.keep = keep - self.old_path: Optional[str] = None - self.swa_start = swa_start - - self._epochs_string = "_epoch-" - self._filename_extension = "pt" - - def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: - if swa_start is not None and epochs >= swa_start: - return ( - self.tag - + self._epochs_string - + str(epochs) - + "_swa" - + "." - + self._filename_extension - ) - return ( - self.tag - + self._epochs_string - + str(epochs) - + "." - + self._filename_extension - ) - - def _list_file_paths(self) -> List[str]: - if not os.path.isdir(self.directory): - return [] - all_paths = [ - os.path.join(self.directory, f) for f in os.listdir(self.directory) - ] - return [path for path in all_paths if os.path.isfile(path)] - - def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]: - filename = os.path.basename(path) - regex = re.compile( - rf"^(?P.+){self._epochs_string}(?P\d+)\.{self._filename_extension}$" - ) - regex2 = re.compile( - rf"^(?P.+){self._epochs_string}(?P\d+)_swa\.{self._filename_extension}$" - ) - match = regex.match(filename) - match2 = regex2.match(filename) - swa = False - if not match: - if not match2: - return None - match = match2 - swa = True - - return CheckpointPathInfo( - path=path, - tag=match.group("tag"), - epochs=int(match.group("epochs")), - swa=swa, - ) - - def _get_latest_checkpoint_path(self, swa) -> Optional[str]: - all_file_paths = self._list_file_paths() - checkpoint_info_list = [ - self._parse_checkpoint_path(path) for path in all_file_paths - ] - selected_checkpoint_info_list = [ - info for info in checkpoint_info_list if info and info.tag == self.tag - ] - - if len(selected_checkpoint_info_list) == 0: - logging.warning( - f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'" - ) - return None - - selected_checkpoint_info_list_swa = [] - selected_checkpoint_info_list_no_swa = [] - - for ckp in selected_checkpoint_info_list: - if ckp.swa: - selected_checkpoint_info_list_swa.append(ckp) - else: - selected_checkpoint_info_list_no_swa.append(ckp) - if swa: - try: - latest_checkpoint_info = max( - selected_checkpoint_info_list_swa, key=lambda info: info.epochs - ) - except ValueError: - logging.warning( - "No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint." - ) - else: - latest_checkpoint_info = max( - selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs - ) - return latest_checkpoint_info.path - - def save( - self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False - ) -> None: - if not self.keep and self.old_path and not keep_last: - logging.debug(f"Deleting old checkpoint file: {self.old_path}") - os.remove(self.old_path) - - filename = self._get_checkpoint_filename(epochs, self.swa_start) - path = os.path.join(self.directory, filename) - logging.debug(f"Saving checkpoint: {path}") - os.makedirs(self.directory, exist_ok=True) - torch.save(obj=checkpoint, f=path) - self.old_path = path - - def load_latest( - self, swa: Optional[bool] = False, device: Optional[torch.device] = None - ) -> Optional[Tuple[Checkpoint, int]]: - path = self._get_latest_checkpoint_path(swa=swa) - if path is None: - return None - - return self.load(path, device=device) - - def load( - self, path: str, device: Optional[torch.device] = None - ) -> Tuple[Checkpoint, int]: - checkpoint_info = self._parse_checkpoint_path(path) - - if checkpoint_info is None: - raise RuntimeError(f"Cannot find path '{path}'") - - logging.info(f"Loading checkpoint: {checkpoint_info.path}") - return ( - torch.load(f=checkpoint_info.path, map_location=device), - checkpoint_info.epochs, - ) - - -class CheckpointHandler: - def __init__(self, *args, **kwargs) -> None: - self.io = CheckpointIO(*args, **kwargs) - self.builder = CheckpointBuilder() - - def save( - self, state: CheckpointState, epochs: int, keep_last: bool = False - ) -> None: - checkpoint = self.builder.create_checkpoint(state) - self.io.save(checkpoint, epochs, keep_last) - - def load_latest( - self, - state: CheckpointState, - swa: Optional[bool] = False, - device: Optional[torch.device] = None, - strict=False, - ) -> Optional[int]: - result = self.io.load_latest(swa=swa, device=device) - if result is None: - return None - - checkpoint, epochs = result - self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) - return epochs - - def load( - self, - state: CheckpointState, - path: str, - strict=False, - device: Optional[torch.device] = None, - ) -> int: - checkpoint, epochs = self.io.load(path, device=device) - self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) - return epochs +########################################################################################### +# Checkpointing +# Authors: Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import dataclasses +import logging +import os +import re +from typing import Dict, List, Optional, Tuple + +import torch + +from .torch_tools import TensorDict + +Checkpoint = Dict[str, TensorDict] + + +@dataclasses.dataclass +class CheckpointState: + model: torch.nn.Module + optimizer: torch.optim.Optimizer + lr_scheduler: torch.optim.lr_scheduler.ExponentialLR + + +class CheckpointBuilder: + @staticmethod + def create_checkpoint(state: CheckpointState) -> Checkpoint: + return { + "model": state.model.state_dict(), + "optimizer": state.optimizer.state_dict(), + "lr_scheduler": state.lr_scheduler.state_dict(), + } + + @staticmethod + def load_checkpoint( + state: CheckpointState, checkpoint: Checkpoint, strict: bool + ) -> None: + state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore + state.optimizer.load_state_dict(checkpoint["optimizer"]) + state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + + +@dataclasses.dataclass +class CheckpointPathInfo: + path: str + tag: str + epochs: int + swa: bool + + +class CheckpointIO: + def __init__( + self, directory: str, tag: str, keep: bool = False, swa_start: int = None + ) -> None: + self.directory = directory + self.tag = tag + self.keep = keep + self.old_path: Optional[str] = None + self.swa_start = swa_start + + self._epochs_string = "_epoch-" + self._filename_extension = "pt" + + def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: + if swa_start is not None and epochs >= swa_start: + return ( + self.tag + + self._epochs_string + + str(epochs) + + "_swa" + + "." + + self._filename_extension + ) + return ( + self.tag + + self._epochs_string + + str(epochs) + + "." + + self._filename_extension + ) + + def _list_file_paths(self) -> List[str]: + if not os.path.isdir(self.directory): + return [] + all_paths = [ + os.path.join(self.directory, f) for f in os.listdir(self.directory) + ] + return [path for path in all_paths if os.path.isfile(path)] + + def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]: + filename = os.path.basename(path) + regex = re.compile( + rf"^(?P.+){self._epochs_string}(?P\d+)\.{self._filename_extension}$" + ) + regex2 = re.compile( + rf"^(?P.+){self._epochs_string}(?P\d+)_swa\.{self._filename_extension}$" + ) + match = regex.match(filename) + match2 = regex2.match(filename) + swa = False + if not match: + if not match2: + return None + match = match2 + swa = True + + return CheckpointPathInfo( + path=path, + tag=match.group("tag"), + epochs=int(match.group("epochs")), + swa=swa, + ) + + def _get_latest_checkpoint_path(self, swa) -> Optional[str]: + all_file_paths = self._list_file_paths() + checkpoint_info_list = [ + self._parse_checkpoint_path(path) for path in all_file_paths + ] + selected_checkpoint_info_list = [ + info for info in checkpoint_info_list if info and info.tag == self.tag + ] + + if len(selected_checkpoint_info_list) == 0: + logging.warning( + f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'" + ) + return None + + selected_checkpoint_info_list_swa = [] + selected_checkpoint_info_list_no_swa = [] + + for ckp in selected_checkpoint_info_list: + if ckp.swa: + selected_checkpoint_info_list_swa.append(ckp) + else: + selected_checkpoint_info_list_no_swa.append(ckp) + if swa: + try: + latest_checkpoint_info = max( + selected_checkpoint_info_list_swa, key=lambda info: info.epochs + ) + except ValueError: + logging.warning( + "No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint." + ) + else: + latest_checkpoint_info = max( + selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs + ) + return latest_checkpoint_info.path + + def save( + self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False + ) -> None: + if not self.keep and self.old_path and not keep_last: + logging.debug(f"Deleting old checkpoint file: {self.old_path}") + os.remove(self.old_path) + + filename = self._get_checkpoint_filename(epochs, self.swa_start) + path = os.path.join(self.directory, filename) + logging.debug(f"Saving checkpoint: {path}") + os.makedirs(self.directory, exist_ok=True) + torch.save(obj=checkpoint, f=path) + self.old_path = path + + def load_latest( + self, swa: Optional[bool] = False, device: Optional[torch.device] = None + ) -> Optional[Tuple[Checkpoint, int]]: + path = self._get_latest_checkpoint_path(swa=swa) + if path is None: + return None + + return self.load(path, device=device) + + def load( + self, path: str, device: Optional[torch.device] = None + ) -> Tuple[Checkpoint, int]: + checkpoint_info = self._parse_checkpoint_path(path) + + if checkpoint_info is None: + raise RuntimeError(f"Cannot find path '{path}'") + + logging.info(f"Loading checkpoint: {checkpoint_info.path}") + return ( + torch.load(f=checkpoint_info.path, map_location=device), + checkpoint_info.epochs, + ) + + +class CheckpointHandler: + def __init__(self, *args, **kwargs) -> None: + self.io = CheckpointIO(*args, **kwargs) + self.builder = CheckpointBuilder() + + def save( + self, state: CheckpointState, epochs: int, keep_last: bool = False + ) -> None: + checkpoint = self.builder.create_checkpoint(state) + self.io.save(checkpoint, epochs, keep_last) + + def load_latest( + self, + state: CheckpointState, + swa: Optional[bool] = False, + device: Optional[torch.device] = None, + strict=False, + ) -> Optional[int]: + result = self.io.load_latest(swa=swa, device=device) + if result is None: + return None + + checkpoint, epochs = result + self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) + return epochs + + def load( + self, + state: CheckpointState, + path: str, + strict=False, + device: Optional[torch.device] = None, + ) -> int: + checkpoint, epochs = self.io.load(path, device=device) + self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) + return epochs diff --git a/mace-bench/3rdparty/mace/mace/tools/compile.py b/mace-bench/3rdparty/mace/mace/tools/compile.py index 59f7450e69968411a0c7579eaa0898c2812a769b..03282067380d9de08a0af41a66edd38f4ddc973c 100644 --- a/mace-bench/3rdparty/mace/mace/tools/compile.py +++ b/mace-bench/3rdparty/mace/mace/tools/compile.py @@ -1,95 +1,95 @@ -from contextlib import contextmanager -from functools import wraps -from typing import Callable, Tuple - -try: - import torch._dynamo as dynamo -except ImportError: - dynamo = None -from e3nn import get_optimization_defaults, set_optimization_defaults -from torch import autograd, nn -from torch.fx import symbolic_trace - -ModuleFactory = Callable[..., nn.Module] -TypeTuple = Tuple[type, ...] - - -@contextmanager -def disable_e3nn_codegen(): - """Context manager that disables the legacy PyTorch code generation used in e3nn.""" - init_val = get_optimization_defaults()["jit_script_fx"] - set_optimization_defaults(jit_script_fx=False) - yield - set_optimization_defaults(jit_script_fx=init_val) - - -def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: - """Function transform that prepares a MACE module for torch.compile - - Args: - func (ModuleFactory): A function that creates an nn.Module - allow_autograd (bool, optional): Force inductor compiler to inline call to - `torch.autograd.grad`. Defaults to True. - - Returns: - ModuleFactory: Decorated function that creates a torch.compile compatible module - """ - if allow_autograd: - dynamo.allow_in_graph(autograd.grad) - else: - dynamo.disallow_in_graph(autograd.grad) - - @wraps(func) - def wrapper(*args, **kwargs): - with disable_e3nn_codegen(): - model = func(*args, **kwargs) - - model = simplify(model) - return model - - return wrapper - - -_SIMPLIFY_REGISTRY = set() - - -def simplify_if_compile(module: nn.Module) -> nn.Module: - """Decorator to register a module for symbolic simplification - - The decorated module will be simplifed using `torch.fx.symbolic_trace`. - This constrains the module to not have any dynamic control flow, see: - - https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing - - Args: - module (nn.Module): the module to register - - Returns: - nn.Module: registered module - """ - _SIMPLIFY_REGISTRY.add(module) - return module - - -def simplify(module: nn.Module) -> nn.Module: - """Recursively searches for registered modules to simplify with - `torch.fx.symbolic_trace` to support compiling with the PyTorch Dynamo compiler. - - Modules are registered with the `simplify_if_compile` decorator and - - Args: - module (nn.Module): the module to simplify - - Returns: - nn.Module: the simplified module - """ - simplify_types = tuple(_SIMPLIFY_REGISTRY) - - for name, child in module.named_children(): - if isinstance(child, simplify_types): - traced = symbolic_trace(child) - setattr(module, name, traced) - else: - simplify(child) - - return module +from contextlib import contextmanager +from functools import wraps +from typing import Callable, Tuple + +try: + import torch._dynamo as dynamo +except ImportError: + dynamo = None +from e3nn import get_optimization_defaults, set_optimization_defaults +from torch import autograd, nn +from torch.fx import symbolic_trace + +ModuleFactory = Callable[..., nn.Module] +TypeTuple = Tuple[type, ...] + + +@contextmanager +def disable_e3nn_codegen(): + """Context manager that disables the legacy PyTorch code generation used in e3nn.""" + init_val = get_optimization_defaults()["jit_script_fx"] + set_optimization_defaults(jit_script_fx=False) + yield + set_optimization_defaults(jit_script_fx=init_val) + + +def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: + """Function transform that prepares a MACE module for torch.compile + + Args: + func (ModuleFactory): A function that creates an nn.Module + allow_autograd (bool, optional): Force inductor compiler to inline call to + `torch.autograd.grad`. Defaults to True. + + Returns: + ModuleFactory: Decorated function that creates a torch.compile compatible module + """ + if allow_autograd: + dynamo.allow_in_graph(autograd.grad) + else: + dynamo.disallow_in_graph(autograd.grad) + + @wraps(func) + def wrapper(*args, **kwargs): + with disable_e3nn_codegen(): + model = func(*args, **kwargs) + + model = simplify(model) + return model + + return wrapper + + +_SIMPLIFY_REGISTRY = set() + + +def simplify_if_compile(module: nn.Module) -> nn.Module: + """Decorator to register a module for symbolic simplification + + The decorated module will be simplifed using `torch.fx.symbolic_trace`. + This constrains the module to not have any dynamic control flow, see: + + https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing + + Args: + module (nn.Module): the module to register + + Returns: + nn.Module: registered module + """ + _SIMPLIFY_REGISTRY.add(module) + return module + + +def simplify(module: nn.Module) -> nn.Module: + """Recursively searches for registered modules to simplify with + `torch.fx.symbolic_trace` to support compiling with the PyTorch Dynamo compiler. + + Modules are registered with the `simplify_if_compile` decorator and + + Args: + module (nn.Module): the module to simplify + + Returns: + nn.Module: the simplified module + """ + simplify_types = tuple(_SIMPLIFY_REGISTRY) + + for name, child in module.named_children(): + if isinstance(child, simplify_types): + traced = symbolic_trace(child) + setattr(module, name, traced) + else: + simplify(child) + + return module diff --git a/mace-bench/3rdparty/mace/mace/tools/default_keys.py b/mace-bench/3rdparty/mace/mace/tools/default_keys.py index f0629483581f47f4acc960c8f4e91df33e7fab0a..769867dffb95f1ba599944688df7f5470289539e 100644 --- a/mace-bench/3rdparty/mace/mace/tools/default_keys.py +++ b/mace-bench/3rdparty/mace/mace/tools/default_keys.py @@ -1,21 +1,21 @@ -from __future__ import annotations - -from enum import Enum - - -class DefaultKeys(Enum): - ENERGY = "REF_energy" - FORCES = "REF_forces" - STRESS = "REF_stress" - VIRIALS = "REF_virials" - DIPOLE = "dipole" - HEAD = "head" - CHARGES = "REF_charges" - - @staticmethod - def keydict() -> dict[str, str]: - key_dict = {} - for member in DefaultKeys: - key_name = f"{member.name.lower()}_key" - key_dict[key_name] = member.value - return key_dict +from __future__ import annotations + +from enum import Enum + + +class DefaultKeys(Enum): + ENERGY = "REF_energy" + FORCES = "REF_forces" + STRESS = "REF_stress" + VIRIALS = "REF_virials" + DIPOLE = "dipole" + HEAD = "head" + CHARGES = "REF_charges" + + @staticmethod + def keydict() -> dict[str, str]: + key_dict = {} + for member in DefaultKeys: + key_name = f"{member.name.lower()}_key" + key_dict[key_name] = member.value + return key_dict diff --git a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__init__.py b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__init__.py index fb7c72a4cf30e637398526039c74d35750e08a86..5163777df311ba3342680b1f8eec7004b7cb6110 100644 --- a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__init__.py +++ b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__init__.py @@ -1,3 +1,3 @@ -from .lmdb_dataset_tools import AseDBDataset - -__all__ = ["AseDBDataset"] +from .lmdb_dataset_tools import AseDBDataset + +__all__ = ["AseDBDataset"] diff --git a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index e9ea4f9d9d0a340ef61b4b146cfb85cc56ecf628..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 1a9234ca0cdf6d1a333108228eaf3d0fb41ca264..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-310.pyc deleted file mode 100644 index 0a8a893bb9bc8b2bfb1fbb1b3a1c45ac184835e7..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-313.pyc deleted file mode 100644 index 0e2a72d88f0e8e0082f2d81eea9f9e095b525f2e..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/lmdb_dataset_tools.py b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/lmdb_dataset_tools.py index 59dee1878fd594dd288c3103bb9aed343071174a..f0c0ca3e681b44888a1ad421f5fd4a6665ba82fb 100644 --- a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/lmdb_dataset_tools.py +++ b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/lmdb_dataset_tools.py @@ -1,954 +1,954 @@ -""" -This module contains the AseDBDataset class and its dependencies. -It is extracted from the fairchem codebase and adapted to remove dependencies on fairchem. - -Original code copyright: -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -import bisect -import logging -import os -import zlib -from abc import ABC, abstractmethod - -try: - from functools import cache, cached_property -except ImportError: - from functools import cached_property, lru_cache - - cache = lru_cache(maxsize=None) -from glob import glob -from pathlib import Path -from typing import Any, Callable, TypeVar - -import ase -import ase.db.core -import ase.db.row -import ase.io -import lmdb -import numpy as np -import orjson -import torch - -# Type variable for generic dataset return type -T_co = TypeVar("T_co", covariant=True) - - -def rename_data_object_keys(data_object, key_mapping: dict[str, str | list[str]]): - """Rename data object keys - - Args: - data_object: data object - key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key} - - new_key can be a list of new keys, for example, - prev_key: energy - new_key: [common_energy, oc20_energy] - - This is currently required when we use a single target/label for multiple tasks - """ - for _property in key_mapping: - # catch for test data not containing labels - if _property in data_object: - list_of_new_keys = key_mapping[_property] - if isinstance(list_of_new_keys, str): - list_of_new_keys = [list_of_new_keys] - for new_property in list_of_new_keys: - if new_property == _property: - continue - assert new_property not in data_object - data_object[new_property] = data_object[_property] - if _property not in list_of_new_keys: - del data_object[_property] - return data_object - - -def apply_one_tags( - atoms: ase.Atoms, skip_if_nonzero: bool = True, skip_always: bool = False -): - """ - This function will apply tags of 1 to an ASE atoms object. - It is used as an atoms_transform in the datasets contained in this file. - - Certain models will treat atoms differently depending on their tags. - For example, GemNet-OC by default will only compute triplet and quadruplet interactions - for atoms with non-zero tags. This model throws an error if there are no tagged atoms. - For this reason, the default behavior is to tag atoms in structures with no tags. - - args: - skip_if_nonzero (bool): If at least one atom has a nonzero tag, do not tag any atoms - - skip_always (bool): Do not apply any tags. This arg exists so that this function can be disabled - without needing to pass a callable (which is currently difficult to do with main.py) - """ - if skip_always: - return atoms - - if np.all(atoms.get_tags() == 0) or not skip_if_nonzero: - atoms.set_tags(np.ones(len(atoms))) - - return atoms - - -class UnsupportedDatasetError(ValueError): - pass - - -class BaseDataset(ABC): - """Base Dataset class for all ASE datasets.""" - - def __init__(self, config: dict): - """Initialize - - Args: - config (dict): dataset configuration - """ - self.config = config - self.paths = [] - - if "src" in self.config: - if isinstance(config["src"], str): - self.paths = [Path(self.config["src"])] - else: - self.paths = tuple(Path(path) for path in sorted(config["src"])) - - self.lin_ref = None - if self.config.get("lin_ref", False): - lin_ref = torch.tensor( - np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] - ) - self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False) - - def __len__(self) -> int: - return self.num_samples - - def metadata_hasattr(self, attr) -> bool: - return attr in self._metadata - - @cached_property - def indices(self): - return np.arange(self.num_samples, dtype=int) - - @cached_property - def _metadata(self) -> dict[str, np.ndarray]: - # logic to read metadata file here - metadata_npzs = [] - if self.config.get("metadata_path", None) is not None: - metadata_npzs.append( - np.load(self.config["metadata_path"], allow_pickle=True) - ) - - else: - for path in self.paths: - if path.is_file(): - metadata_file = path.parent / "metadata.npz" - else: - metadata_file = path / "metadata.npz" - if metadata_file.is_file(): - metadata_npzs.append(np.load(metadata_file, allow_pickle=True)) - - if len(metadata_npzs) == 0: - logging.warning( - f"Could not find dataset metadata.npz files in '{self.paths}'" - ) - return {} - - metadata = { - field: np.concatenate([metadata[field] for metadata in metadata_npzs]) - for field in metadata_npzs[0] - } - - assert np.issubdtype( - metadata["natoms"].dtype, np.integer - ), f"Metadata natoms must be an integer type! not {metadata['natoms'].dtype}" - assert metadata["natoms"].shape[0] == len( - self - ), "Loaded metadata and dataset size mismatch." - - return metadata - - def get_metadata(self, attr, idx): - if attr in self._metadata: - metadata_attr = self._metadata[attr] - if isinstance(idx, list): - return [metadata_attr[_idx] for _idx in idx] - return metadata_attr[idx] - return None - - -class Subset(BaseDataset): - """A subset that also takes metadata if given.""" - - def __init__( - self, - dataset: BaseDataset, - indices: list[int], - metadata: dict[str, np.ndarray], - ) -> None: - super().__init__(dataset.config) - self.dataset = dataset - self.metadata = metadata - self.indices = indices - self.num_samples = len(indices) - self.config = dataset.config - - @cached_property - def _metadata(self) -> dict[str, np.ndarray]: - return self.dataset._metadata # pylint: disable=protected-access - - def get_metadata(self, attr, idx): - if isinstance(idx, list): - return self.dataset.get_metadata(attr, [[self.indices[i] for i in idx]]) - return self.dataset.get_metadata(attr, self.indices[idx]) - - -class LMDBDatabase(ase.db.core.Database): - """ - This module is modified from the ASE db json backend - and is thus licensed under the corresponding LGPL2.1 license. - - The ASE notice for the LGPL2.1 license is available here: - https://gitlab.com/ase/ase/-/blob/master/LICENSE - """ - - def __init__( # pylint: disable=keyword-arg-before-vararg - self, - filename: str | Path | None = None, - create_indices: bool = True, - use_lock_file: bool = False, - serial: bool = False, - readonly: bool = False, # Moved after *args to make it keyword-only - *args, - **kwargs, - ) -> None: - """ - For the most part, this is identical to the standard ase db initiation - arguments, except that we add a readonly flag. - """ - super().__init__( - Path(filename), - create_indices, - use_lock_file, - serial, - *args, - **kwargs, - ) - - # Add a readonly mode for when we're only training - # to make sure there's no parallel locks - self.readonly = readonly - - if self.readonly: - # Open a new env - self.env = lmdb.open( - str(self.filename), - subdir=False, - meminit=False, - map_async=True, - readonly=True, - lock=False, - ) - - # Open a transaction and keep it open for fast read/writes! - self.txn = self.env.begin(write=False) - - else: - # Open a new env with write access - self.env = lmdb.open( - str(self.filename), - map_size=1099511627776 * 2, - subdir=False, - meminit=False, - map_async=True, - ) - - self.txn = self.env.begin(write=True) - - # Load all ids based on keys in the DB. - self.ids = [] - self.deleted_ids = [] - self._load_ids() - - def __enter__(self) -> "LMDBDatabase": - return self - - def __exit__(self, exc_type, exc_value, tb) -> None: - self.close() - - def close(self) -> None: - # Close the lmdb environment and transaction - self.txn.commit() - self.env.close() - - def _write( - self, - atoms: ase.Atoms | ase.db.row.AtomsRow, - key_value_pairs: dict, - data: dict | None, - id: int | None = None, # pylint: disable=redefined-builtin - ) -> None: - # Call parent method with the original parameter name - super()._write(atoms, key_value_pairs, data) - - mtime = ase.db.core.now() - - if isinstance(atoms, ase.db.row.AtomsRow): - row = atoms - else: - row = ase.db.row.AtomsRow(atoms) - row.ctime = mtime - row.user = os.getenv("USER") - - dct = {} - for key in row.__dict__: - # Use getattr to avoid accessing protected member directly - if key[0] == "_" or key == "id" or key in getattr(row, "_keys", []): - continue - dct[key] = row[key] - - dct["mtime"] = mtime - - if key_value_pairs: - dct["key_value_pairs"] = key_value_pairs - - if data: - dct["data"] = data - - constraints = row.get("constraints") - if constraints: - dct["constraints"] = [constraint.todict() for constraint in constraints] - - # json doesn't like Cell objects, so make it an array - dct["cell"] = np.asarray(dct["cell"]) - - if id is None: - id = self._nextid - nextid = id + 1 - else: - data = self.txn.get(f"{id}".encode("ascii")) - assert data is not None - - # Add the new entry - self.txn.put( - f"{id}".encode("ascii"), - zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - # only append if idx is not in ids - if id not in self.ids: - self.ids.append(id) - self.txn.put( - "nextid".encode("ascii"), - zlib.compress(orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - # check if id is in removed ids and remove accordingly - if id in self.deleted_ids: - self.deleted_ids.remove(id) - self._write_deleted_ids() - - return id - - def _update( - self, - idx: int, - key_value_pairs: dict | None = None, - data: dict | None = None, - ): - # hack this to play nicely with ASE code - row = self._get_row(idx, include_data=True) - if data is not None or key_value_pairs is not None: - self._write( - atoms=row, key_value_pairs=key_value_pairs, data=data, id=idx - ) # Fixed E1123 by using id=idx - - def _write_deleted_ids(self): - self.txn.put( - "deleted_ids".encode("ascii"), - zlib.compress( - orjson.dumps(self.deleted_ids, option=orjson.OPT_SERIALIZE_NUMPY) - ), - ) - - def delete(self, ids: list[int]) -> None: - for idx in ids: - self.txn.delete(f"{idx}".encode("ascii")) - self.ids.remove(idx) - - self.deleted_ids += ids - self._write_deleted_ids() - - def _get_row(self, idx: int, include_data: bool = True): - if idx is None: - assert len(self.ids) == 1 - idx = self.ids[0] - data = self.txn.get(f"{idx}".encode("ascii")) - - if data is not None: - dct = orjson.loads(zlib.decompress(data)) - else: - raise KeyError(f"Id {idx} missing from the database!") - - if not include_data: - dct.pop("data", None) - - dct["id"] = idx - return ase.db.row.AtomsRow(dct) - - def _get_row_by_index(self, index: int, include_data: bool = True): - """Auxiliary function to get the ith entry, rather than a specific id""" - data = self.txn.get(f"{self.ids[index]}".encode("ascii")) - - if data is not None: - dct = orjson.loads(zlib.decompress(data)) - else: - raise KeyError(f"Id {id} missing from the database!") - - if not include_data: - dct.pop("data", None) - - dct["id"] = id - return ase.db.row.AtomsRow(dct) - - def _select( - self, - keys, - cmps: list[tuple[str, str, str]], - explain: bool = False, - _verbosity: int = 0, # Unused parameter marked with underscore - limit: int | None = None, - offset: int = 0, - sort: str | None = None, - include_data: bool = True, - _columns: str = "all", # Unused parameter marked with underscore - ): - if explain: - yield {"explain": (0, 0, 0, "scan table")} - return - - if sort is not None: - if sort[0] == "-": - reverse = True - sort = sort[1:] - else: - reverse = False - - rows = [] - missing = [] - for row in self._select(keys, cmps): - key = row.get(sort) - if key is None: - missing.append((0, row)) - else: - rows.append((key, row)) - - rows.sort(reverse=reverse, key=lambda x: x[0]) - rows += missing - - if limit: - rows = rows[offset : offset + limit] - for _, row in rows: - yield row - return - - if not limit: - limit = -offset - 1 - - cmps = [(key, ase.db.core.ops[op], val) for key, op, val in cmps] - n = 0 - for idx in self.ids: - if n - offset == limit: - return - row = self._get_row(idx, include_data=include_data) - - for key in keys: - if key not in row: - break - else: - for key, op, val in cmps: - if isinstance(key, int): - value = np.equal(row.numbers, key).sum() - else: - value = row.get(key) - if key == "pbc": - assert op in [ase.db.core.ops["="], ase.db.core.ops["!="]] - value = "".join("FT"[x] for x in value) - if value is None or not op(value, val): - break - else: - if n >= offset: - yield row - n += 1 - - @property - def metadata(self): - """Override abstract metadata method from Database class.""" - return self.db_metadata - - @property - def db_metadata(self): - """Load the metadata from the DB if present""" - if self._metadata is None: - metadata = self.txn.get("metadata".encode("ascii")) - if metadata is None: - self._metadata = {} - else: - self._metadata = orjson.loads(zlib.decompress(metadata)) - - return self._metadata.copy() - - @db_metadata.setter - def db_metadata(self, dct): - self._metadata = dct - - # Put the updated metadata dictionary - self.txn.put( - "metadata".encode("ascii"), - zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - - @property - def _nextid(self): - """Get the id of the next row to be written""" - # Get the nextid - nextid_data = self.txn.get("nextid".encode("ascii")) - if nextid_data: - return orjson.loads(zlib.decompress(nextid_data)) - return 1 # Removed unnecessary else (R1705) - - def count(self, selection=None, **kwargs) -> int: - """Count rows. - - See the select() method for the selection syntax. Use db.count() or - len(db) to count all rows. - """ - if selection is not None: - n = 0 - for _row in self.select(selection, **kwargs): - n += 1 - return n - return len(self.ids) - - def _load_ids(self) -> None: - """Load ids from the DB - - Since ASE db ids are mostly 1-N integers, but can be missing entries - if ids have been deleted. To save space and operating under the assumption - that there will probably not be many deletions in most OCP datasets, - we just store the deleted ids. - """ - # Load the deleted ids - deleted_ids_data = self.txn.get("deleted_ids".encode("ascii")) - if deleted_ids_data is not None: - self.deleted_ids = orjson.loads(zlib.decompress(deleted_ids_data)) - - # Reconstruct the full id list - self.ids = [i for i in range(1, self._nextid) if i not in set(self.deleted_ids)] - - -# Placeholder for AtomsToGraphs class -# This is a minimal implementation without the full functionality -class AtomsToGraphs: - """Enhanced AtomsToGraphs implementation with proper property handling.""" - - def __init__( - self, - r_edges=False, - r_pbc=True, - r_energy=False, - r_forces=False, - r_stress=False, - r_data_keys=None, - **kwargs, - ): - self.r_edges = r_edges - self.r_pbc = r_pbc - self.r_energy = r_energy - self.r_forces = r_forces - self.r_stress = r_stress - self.r_data_keys = r_data_keys or {} - self.kwargs = kwargs - - def convert(self, atoms, sid=None): - """ - Convert ASE atoms to graph data format with proper property handling. - """ - from mace.tools.torch_geometric.data import Data - - # Create a minimal data object with required properties - data = Data() - - # Set positions - data.pos = torch.tensor(atoms.get_positions(), dtype=torch.float) - - # Set atomic numbers - data.atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long) - - # Set cell if available - if atoms.cell is not None: - data.cell = torch.tensor(atoms.get_cell(), dtype=torch.float) - - # Set PBC if requested - if self.r_pbc: - data.pbc = torch.tensor(atoms.get_pbc(), dtype=torch.bool) - - # Set energy if requested - if self.r_energy: - energy = self._get_property(atoms, "energy") - if energy is not None: - data.energy = torch.tensor(energy, dtype=torch.float) - - # Set forces if requested - if self.r_forces: - forces = self._get_property(atoms, "forces") - if forces is not None: - data.forces = torch.tensor(forces, dtype=torch.float) - - # Set stress if requested - if self.r_stress: - stress = self._get_property(atoms, "stress") - if stress is not None: - data.stress = torch.tensor(stress, dtype=torch.float) - - # Set sid if provided - if sid is not None: - data.sid = sid - - return data - - def _get_property(self, atoms, prop_name): - """Get property from atoms, checking custom names first then standard methods.""" - # Check if we have a custom name for this property - custom_name = self.r_data_keys.get(prop_name) - - # Try custom name in info dict - if custom_name and custom_name in atoms.info: - return atoms.info[custom_name] - - # Try custom name in arrays dict - if custom_name and custom_name in atoms.arrays: - return atoms.arrays[custom_name] - - # Try standard name in info dict - if prop_name in atoms.info: - return atoms.info[prop_name] - - # Try standard name in arrays dict - if prop_name in atoms.arrays: - return atoms.arrays[prop_name] - - # Try standard ASE methods - method_map = { - "energy": "get_potential_energy", - "forces": "get_forces", - "stress": "get_stress", - } - - if prop_name in method_map and hasattr(atoms, method_map[prop_name]): - try: - method = getattr(atoms, method_map[prop_name]) - return method() - except ( - AttributeError, - RuntimeError, - ) as exc: # Fixed W0718 by specifying exceptions - logging.debug(f"Error getting property {prop_name}: {exc}") - # Removed unnecessary pass (W0107) - - return None - - -# Placeholder for DataTransforms class -class DataTransforms: - """Minimal implementation of DataTransforms to satisfy dependencies.""" - - def __init__(self, transforms_config=None): - self.transforms_config = transforms_config or {} - - def __call__(self, data): - """Apply transforms to data""" - # No transforms applied in this minimal implementation - return data - - -class AseAtomsDataset(BaseDataset, ABC): - """ - This is an abstract Dataset that includes helpful utilities for turning - ASE atoms objects into OCP-usable data objects. This should not be instantiated directly - as get_atoms_object and load_dataset_get_ids are not implemented in this base class. - - Derived classes must add at least two things: - self.get_atoms_object(id): a function that takes an identifier and returns a corresponding atoms object - - self.load_dataset_get_ids(config: dict): This function is responsible for any initialization/loads - of the dataset and importantly must return a list of all possible identifiers that can be passed into - self.get_atoms_object(id) - - Identifiers need not be any particular type. - """ - - def __init__( - self, - config: dict, - atoms_transform: Callable[[ase.Atoms, Any], ase.Atoms] = apply_one_tags, - ) -> None: - super().__init__(config) - - a2g_args = config.get("a2g_args", {}) or {} - - # set default to False if not set by user, assuming otf_graph will be used - if "r_edges" not in a2g_args: - a2g_args["r_edges"] = False - - # Make sure we always include PBC info in the resulting atoms objects - a2g_args["r_pbc"] = True - self.a2g = AtomsToGraphs(**a2g_args) - - self.key_mapping = self.config.get("key_mapping", None) - self.transforms = DataTransforms(self.config.get("transforms", {})) - - self.atoms_transform = atoms_transform - - if self.config.get("keep_in_memory", False): - self.__getitem__ = cache(self.__getitem__) - - self.ids = self._load_dataset_get_ids(config) - self.num_samples = len(self.ids) - - if len(self.ids) == 0: - raise ValueError( - rf"No valid ase data found! \n" - f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" - ) - - def __getitem__(self, idx): # pylint: disable=method-hidden - # Handle slicing - if isinstance(idx, slice): - return [self[i] for i in range(*idx.indices(len(self)))] - - # Get atoms object via derived class method - atoms = self.get_atoms(self.ids[idx]) - - # Transform atoms object - if self.atoms_transform is not None: - atoms = self.atoms_transform( - atoms, **self.config.get("atoms_transform_args", {}) - ) - - sid = atoms.info.get("sid", self.ids[idx]) - fid = atoms.info.get("fid", torch.tensor([0])) - - # Convert to data object - data_object = self.a2g.convert(atoms, sid) - data_object.fid = fid - data_object.natoms = len(atoms) - - # apply linear reference - if self.a2g.r_energy is True and self.lin_ref is not None: - data_object.energy -= sum(self.lin_ref[data_object.atomic_numbers.long()]) - - # Transform data object - data_object = self.transforms(data_object) - - if self.key_mapping is not None: - data_object = rename_data_object_keys(data_object, self.key_mapping) - - if self.config.get("include_relaxed_energy", False): - data_object.energy_relaxed = self.get_relaxed_energy(self.ids[idx]) - - return data_object - - @abstractmethod - def get_atoms(self, idx: str | int) -> ase.Atoms: - # This function should return an ASE atoms object. - raise NotImplementedError( - "Returns an ASE atoms object. Derived classes should implement this function." - ) - - @abstractmethod - def _load_dataset_get_ids(self, config): - # This function should return a list of ids that can be used to index into the database - raise NotImplementedError( - "Every ASE dataset needs to declare a function to load the dataset and return a list of ids." - ) - - def get_relaxed_energy(self, identifier): - raise NotImplementedError( - "Reading relaxed energy from trajectory or file is not implemented with this dataset. " - "If relaxed energies are saved with the atoms info dictionary, they can be used by passing the keys in " - "the r_data_keys argument under a2g_args." - ) - - def get_metadata(self, attr, idx): - # try the parent method - metadata = super().get_metadata(attr, idx) - if metadata is not None: - return metadata - # try to resolve it here - if attr != "natoms": - return None - if isinstance(idx, (list, np.ndarray)): - return np.array([self.get_metadata(attr, i) for i in idx]) - return len(self.get_atoms(idx)) - - -class AseDBDataset(AseAtomsDataset): - """ - This Dataset connects to an ASE Database, allowing the storage of atoms objects - with a variety of backends including JSON, SQLite, and database server options. - """ - - def _load_dataset_get_ids(self, config: dict) -> list[int]: - if isinstance(config["src"], list): - filepaths = [] - for path in sorted(config["src"]): - if os.path.isdir(path): - filepaths.extend(sorted(glob(f"{path}/*"))) - elif os.path.isfile(path): - filepaths.append(path) - else: - raise RuntimeError(f"Error reading dataset in {path}!") - elif os.path.isfile(config["src"]): - filepaths = [config["src"]] - elif os.path.isdir(config["src"]): - filepaths = sorted(glob(f'{config["src"]}/*')) - else: - filepaths = sorted(glob(config["src"])) - - self.dbs = [] - - for path in filepaths: - try: - self.dbs.append(self.connect_db(path, config.get("connect_args", {}))) - except ValueError: - logging.debug( - f"Tried to connect to {path} but it's not an ASE database!" - ) - - self.select_args = config.get("select_args", {}) - if self.select_args is None: - self.select_args = {} - - # Get all unique IDs from the databases - self.db_ids = [] - for db in self.dbs: - if hasattr(db, "ids") and self.select_args == {}: - self.db_ids.append(db.ids) - else: - # this is the slow alternative - self.db_ids.append([row.id for row in db.select(**self.select_args)]) - - idlens = [len(ids) for ids in self.db_ids] - self._idlen_cumulative = np.cumsum(idlens).tolist() - - return list(range(sum(idlens))) - - def get_atoms(self, idx: int) -> ase.Atoms: - """Get atoms object corresponding to datapoint idx. - Args: - idx (int): index in dataset - - Returns: - atoms: ASE atoms corresponding to datapoint idx - """ - # Figure out which db this should be indexed from - db_idx = bisect.bisect(self._idlen_cumulative, idx) - - # Extract index of element within that db - el_idx = idx - if db_idx != 0: - el_idx = idx - self._idlen_cumulative[db_idx - 1] - assert el_idx >= 0 - - # Use a wrapper method to avoid protected access warning - atoms_row = self.get_row_from_db(db_idx, el_idx) - - # Convert to atoms object - atoms = atoms_row.toatoms() - - # Put data back into atoms info - if isinstance(atoms_row.data, dict): - atoms.info.update(atoms_row.data) - - # Add key-value pairs directly to atoms.info - if hasattr(atoms_row, "key_value_pairs") and atoms_row.key_value_pairs: - atoms.info.update(atoms_row.key_value_pairs) - - # Create a SinglePointCalculator to attach energy, forces and stress to atoms - calc_kwargs = {} - - # Check for energy, forces, stress in atoms_row and store in info & calc_kwargs - for prop in ["energy", "forces", "stress", "free_energy"]: - if hasattr(atoms_row, prop) and getattr(atoms_row, prop) is not None: - value = getattr(atoms_row, prop) - calc_kwargs[prop] = value - atoms.info[prop] = value - - # If we have custom data mappings, copy the standard properties to the custom names - a2g_args = self.config.get("a2g_args", {}) or {} - r_data_keys = a2g_args.get("r_data_keys", {}) - if r_data_keys: - # Map from standard names to custom names (in reverse of how they'll be used) - for custom_key, standard_key in r_data_keys.items(): - if standard_key in atoms.info: - atoms.info[custom_key] = atoms.info[standard_key] - elif standard_key in atoms.arrays: - atoms.arrays[custom_key] = atoms.arrays[standard_key] - - # Create calculator if we have any properties - if calc_kwargs: - from ase.calculators.singlepoint import SinglePointCalculator - - calc = SinglePointCalculator(atoms, **calc_kwargs) - atoms.calc = calc - - return atoms - - def get_row_from_db(self, db_idx, el_idx): - """Get a row from the database at the given indices.""" - db = self.dbs[db_idx] - row_id = self.db_ids[db_idx][el_idx] - if isinstance(db, LMDBDatabase): - return db._get_row(row_id) # pylint: disable=protected-access - return db.get(row_id) - - @staticmethod - def connect_db( - address: str | Path, connect_args: dict | None = None - ) -> ase.db.core.Database: - if connect_args is None: - connect_args = {} - db_type = connect_args.get("type", "extract_from_name") - if db_type in ("lmdb", "aselmdb") or ( - db_type == "extract_from_name" - and str(address).rsplit(".", maxsplit=1)[-1] in ("lmdb", "aselmdb") - ): - return LMDBDatabase(address, readonly=True, **connect_args) - - return ase.db.connect(address, **connect_args) - - def __del__(self): - for db in self.dbs: - if hasattr(db, "close"): - db.close() - - def sample_property_metadata( - self, - ) -> dict: # Removed unused argument num_samples (W0613) - """ - Sample property metadata from the database. - - This method was previously using the copy module which is now removed. - """ - logging.warning( - "You specified a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!" - ) - if self.dbs[0].metadata == {}: - return {} - - # Fixed unnecessary comprehension (R1721) - return dict(self.dbs[0].metadata.items()) +""" +This module contains the AseDBDataset class and its dependencies. +It is extracted from the fairchem codebase and adapted to remove dependencies on fairchem. + +Original code copyright: +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import bisect +import logging +import os +import zlib +from abc import ABC, abstractmethod + +try: + from functools import cache, cached_property +except ImportError: + from functools import cached_property, lru_cache + + cache = lru_cache(maxsize=None) +from glob import glob +from pathlib import Path +from typing import Any, Callable, TypeVar + +import ase +import ase.db.core +import ase.db.row +import ase.io +import lmdb +import numpy as np +import orjson +import torch + +# Type variable for generic dataset return type +T_co = TypeVar("T_co", covariant=True) + + +def rename_data_object_keys(data_object, key_mapping: dict[str, str | list[str]]): + """Rename data object keys + + Args: + data_object: data object + key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key} + + new_key can be a list of new keys, for example, + prev_key: energy + new_key: [common_energy, oc20_energy] + + This is currently required when we use a single target/label for multiple tasks + """ + for _property in key_mapping: + # catch for test data not containing labels + if _property in data_object: + list_of_new_keys = key_mapping[_property] + if isinstance(list_of_new_keys, str): + list_of_new_keys = [list_of_new_keys] + for new_property in list_of_new_keys: + if new_property == _property: + continue + assert new_property not in data_object + data_object[new_property] = data_object[_property] + if _property not in list_of_new_keys: + del data_object[_property] + return data_object + + +def apply_one_tags( + atoms: ase.Atoms, skip_if_nonzero: bool = True, skip_always: bool = False +): + """ + This function will apply tags of 1 to an ASE atoms object. + It is used as an atoms_transform in the datasets contained in this file. + + Certain models will treat atoms differently depending on their tags. + For example, GemNet-OC by default will only compute triplet and quadruplet interactions + for atoms with non-zero tags. This model throws an error if there are no tagged atoms. + For this reason, the default behavior is to tag atoms in structures with no tags. + + args: + skip_if_nonzero (bool): If at least one atom has a nonzero tag, do not tag any atoms + + skip_always (bool): Do not apply any tags. This arg exists so that this function can be disabled + without needing to pass a callable (which is currently difficult to do with main.py) + """ + if skip_always: + return atoms + + if np.all(atoms.get_tags() == 0) or not skip_if_nonzero: + atoms.set_tags(np.ones(len(atoms))) + + return atoms + + +class UnsupportedDatasetError(ValueError): + pass + + +class BaseDataset(ABC): + """Base Dataset class for all ASE datasets.""" + + def __init__(self, config: dict): + """Initialize + + Args: + config (dict): dataset configuration + """ + self.config = config + self.paths = [] + + if "src" in self.config: + if isinstance(config["src"], str): + self.paths = [Path(self.config["src"])] + else: + self.paths = tuple(Path(path) for path in sorted(config["src"])) + + self.lin_ref = None + if self.config.get("lin_ref", False): + lin_ref = torch.tensor( + np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] + ) + self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False) + + def __len__(self) -> int: + return self.num_samples + + def metadata_hasattr(self, attr) -> bool: + return attr in self._metadata + + @cached_property + def indices(self): + return np.arange(self.num_samples, dtype=int) + + @cached_property + def _metadata(self) -> dict[str, np.ndarray]: + # logic to read metadata file here + metadata_npzs = [] + if self.config.get("metadata_path", None) is not None: + metadata_npzs.append( + np.load(self.config["metadata_path"], allow_pickle=True) + ) + + else: + for path in self.paths: + if path.is_file(): + metadata_file = path.parent / "metadata.npz" + else: + metadata_file = path / "metadata.npz" + if metadata_file.is_file(): + metadata_npzs.append(np.load(metadata_file, allow_pickle=True)) + + if len(metadata_npzs) == 0: + logging.warning( + f"Could not find dataset metadata.npz files in '{self.paths}'" + ) + return {} + + metadata = { + field: np.concatenate([metadata[field] for metadata in metadata_npzs]) + for field in metadata_npzs[0] + } + + assert np.issubdtype( + metadata["natoms"].dtype, np.integer + ), f"Metadata natoms must be an integer type! not {metadata['natoms'].dtype}" + assert metadata["natoms"].shape[0] == len( + self + ), "Loaded metadata and dataset size mismatch." + + return metadata + + def get_metadata(self, attr, idx): + if attr in self._metadata: + metadata_attr = self._metadata[attr] + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] + return None + + +class Subset(BaseDataset): + """A subset that also takes metadata if given.""" + + def __init__( + self, + dataset: BaseDataset, + indices: list[int], + metadata: dict[str, np.ndarray], + ) -> None: + super().__init__(dataset.config) + self.dataset = dataset + self.metadata = metadata + self.indices = indices + self.num_samples = len(indices) + self.config = dataset.config + + @cached_property + def _metadata(self) -> dict[str, np.ndarray]: + return self.dataset._metadata # pylint: disable=protected-access + + def get_metadata(self, attr, idx): + if isinstance(idx, list): + return self.dataset.get_metadata(attr, [[self.indices[i] for i in idx]]) + return self.dataset.get_metadata(attr, self.indices[idx]) + + +class LMDBDatabase(ase.db.core.Database): + """ + This module is modified from the ASE db json backend + and is thus licensed under the corresponding LGPL2.1 license. + + The ASE notice for the LGPL2.1 license is available here: + https://gitlab.com/ase/ase/-/blob/master/LICENSE + """ + + def __init__( # pylint: disable=keyword-arg-before-vararg + self, + filename: str | Path | None = None, + create_indices: bool = True, + use_lock_file: bool = False, + serial: bool = False, + readonly: bool = False, # Moved after *args to make it keyword-only + *args, + **kwargs, + ) -> None: + """ + For the most part, this is identical to the standard ase db initiation + arguments, except that we add a readonly flag. + """ + super().__init__( + Path(filename), + create_indices, + use_lock_file, + serial, + *args, + **kwargs, + ) + + # Add a readonly mode for when we're only training + # to make sure there's no parallel locks + self.readonly = readonly + + if self.readonly: + # Open a new env + self.env = lmdb.open( + str(self.filename), + subdir=False, + meminit=False, + map_async=True, + readonly=True, + lock=False, + ) + + # Open a transaction and keep it open for fast read/writes! + self.txn = self.env.begin(write=False) + + else: + # Open a new env with write access + self.env = lmdb.open( + str(self.filename), + map_size=1099511627776 * 2, + subdir=False, + meminit=False, + map_async=True, + ) + + self.txn = self.env.begin(write=True) + + # Load all ids based on keys in the DB. + self.ids = [] + self.deleted_ids = [] + self._load_ids() + + def __enter__(self) -> "LMDBDatabase": + return self + + def __exit__(self, exc_type, exc_value, tb) -> None: + self.close() + + def close(self) -> None: + # Close the lmdb environment and transaction + self.txn.commit() + self.env.close() + + def _write( + self, + atoms: ase.Atoms | ase.db.row.AtomsRow, + key_value_pairs: dict, + data: dict | None, + id: int | None = None, # pylint: disable=redefined-builtin + ) -> None: + # Call parent method with the original parameter name + super()._write(atoms, key_value_pairs, data) + + mtime = ase.db.core.now() + + if isinstance(atoms, ase.db.row.AtomsRow): + row = atoms + else: + row = ase.db.row.AtomsRow(atoms) + row.ctime = mtime + row.user = os.getenv("USER") + + dct = {} + for key in row.__dict__: + # Use getattr to avoid accessing protected member directly + if key[0] == "_" or key == "id" or key in getattr(row, "_keys", []): + continue + dct[key] = row[key] + + dct["mtime"] = mtime + + if key_value_pairs: + dct["key_value_pairs"] = key_value_pairs + + if data: + dct["data"] = data + + constraints = row.get("constraints") + if constraints: + dct["constraints"] = [constraint.todict() for constraint in constraints] + + # json doesn't like Cell objects, so make it an array + dct["cell"] = np.asarray(dct["cell"]) + + if id is None: + id = self._nextid + nextid = id + 1 + else: + data = self.txn.get(f"{id}".encode("ascii")) + assert data is not None + + # Add the new entry + self.txn.put( + f"{id}".encode("ascii"), + zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + # only append if idx is not in ids + if id not in self.ids: + self.ids.append(id) + self.txn.put( + "nextid".encode("ascii"), + zlib.compress(orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + # check if id is in removed ids and remove accordingly + if id in self.deleted_ids: + self.deleted_ids.remove(id) + self._write_deleted_ids() + + return id + + def _update( + self, + idx: int, + key_value_pairs: dict | None = None, + data: dict | None = None, + ): + # hack this to play nicely with ASE code + row = self._get_row(idx, include_data=True) + if data is not None or key_value_pairs is not None: + self._write( + atoms=row, key_value_pairs=key_value_pairs, data=data, id=idx + ) # Fixed E1123 by using id=idx + + def _write_deleted_ids(self): + self.txn.put( + "deleted_ids".encode("ascii"), + zlib.compress( + orjson.dumps(self.deleted_ids, option=orjson.OPT_SERIALIZE_NUMPY) + ), + ) + + def delete(self, ids: list[int]) -> None: + for idx in ids: + self.txn.delete(f"{idx}".encode("ascii")) + self.ids.remove(idx) + + self.deleted_ids += ids + self._write_deleted_ids() + + def _get_row(self, idx: int, include_data: bool = True): + if idx is None: + assert len(self.ids) == 1 + idx = self.ids[0] + data = self.txn.get(f"{idx}".encode("ascii")) + + if data is not None: + dct = orjson.loads(zlib.decompress(data)) + else: + raise KeyError(f"Id {idx} missing from the database!") + + if not include_data: + dct.pop("data", None) + + dct["id"] = idx + return ase.db.row.AtomsRow(dct) + + def _get_row_by_index(self, index: int, include_data: bool = True): + """Auxiliary function to get the ith entry, rather than a specific id""" + data = self.txn.get(f"{self.ids[index]}".encode("ascii")) + + if data is not None: + dct = orjson.loads(zlib.decompress(data)) + else: + raise KeyError(f"Id {id} missing from the database!") + + if not include_data: + dct.pop("data", None) + + dct["id"] = id + return ase.db.row.AtomsRow(dct) + + def _select( + self, + keys, + cmps: list[tuple[str, str, str]], + explain: bool = False, + _verbosity: int = 0, # Unused parameter marked with underscore + limit: int | None = None, + offset: int = 0, + sort: str | None = None, + include_data: bool = True, + _columns: str = "all", # Unused parameter marked with underscore + ): + if explain: + yield {"explain": (0, 0, 0, "scan table")} + return + + if sort is not None: + if sort[0] == "-": + reverse = True + sort = sort[1:] + else: + reverse = False + + rows = [] + missing = [] + for row in self._select(keys, cmps): + key = row.get(sort) + if key is None: + missing.append((0, row)) + else: + rows.append((key, row)) + + rows.sort(reverse=reverse, key=lambda x: x[0]) + rows += missing + + if limit: + rows = rows[offset : offset + limit] + for _, row in rows: + yield row + return + + if not limit: + limit = -offset - 1 + + cmps = [(key, ase.db.core.ops[op], val) for key, op, val in cmps] + n = 0 + for idx in self.ids: + if n - offset == limit: + return + row = self._get_row(idx, include_data=include_data) + + for key in keys: + if key not in row: + break + else: + for key, op, val in cmps: + if isinstance(key, int): + value = np.equal(row.numbers, key).sum() + else: + value = row.get(key) + if key == "pbc": + assert op in [ase.db.core.ops["="], ase.db.core.ops["!="]] + value = "".join("FT"[x] for x in value) + if value is None or not op(value, val): + break + else: + if n >= offset: + yield row + n += 1 + + @property + def metadata(self): + """Override abstract metadata method from Database class.""" + return self.db_metadata + + @property + def db_metadata(self): + """Load the metadata from the DB if present""" + if self._metadata is None: + metadata = self.txn.get("metadata".encode("ascii")) + if metadata is None: + self._metadata = {} + else: + self._metadata = orjson.loads(zlib.decompress(metadata)) + + return self._metadata.copy() + + @db_metadata.setter + def db_metadata(self, dct): + self._metadata = dct + + # Put the updated metadata dictionary + self.txn.put( + "metadata".encode("ascii"), + zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + + @property + def _nextid(self): + """Get the id of the next row to be written""" + # Get the nextid + nextid_data = self.txn.get("nextid".encode("ascii")) + if nextid_data: + return orjson.loads(zlib.decompress(nextid_data)) + return 1 # Removed unnecessary else (R1705) + + def count(self, selection=None, **kwargs) -> int: + """Count rows. + + See the select() method for the selection syntax. Use db.count() or + len(db) to count all rows. + """ + if selection is not None: + n = 0 + for _row in self.select(selection, **kwargs): + n += 1 + return n + return len(self.ids) + + def _load_ids(self) -> None: + """Load ids from the DB + + Since ASE db ids are mostly 1-N integers, but can be missing entries + if ids have been deleted. To save space and operating under the assumption + that there will probably not be many deletions in most OCP datasets, + we just store the deleted ids. + """ + # Load the deleted ids + deleted_ids_data = self.txn.get("deleted_ids".encode("ascii")) + if deleted_ids_data is not None: + self.deleted_ids = orjson.loads(zlib.decompress(deleted_ids_data)) + + # Reconstruct the full id list + self.ids = [i for i in range(1, self._nextid) if i not in set(self.deleted_ids)] + + +# Placeholder for AtomsToGraphs class +# This is a minimal implementation without the full functionality +class AtomsToGraphs: + """Enhanced AtomsToGraphs implementation with proper property handling.""" + + def __init__( + self, + r_edges=False, + r_pbc=True, + r_energy=False, + r_forces=False, + r_stress=False, + r_data_keys=None, + **kwargs, + ): + self.r_edges = r_edges + self.r_pbc = r_pbc + self.r_energy = r_energy + self.r_forces = r_forces + self.r_stress = r_stress + self.r_data_keys = r_data_keys or {} + self.kwargs = kwargs + + def convert(self, atoms, sid=None): + """ + Convert ASE atoms to graph data format with proper property handling. + """ + from mace.tools.torch_geometric.data import Data + + # Create a minimal data object with required properties + data = Data() + + # Set positions + data.pos = torch.tensor(atoms.get_positions(), dtype=torch.float) + + # Set atomic numbers + data.atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long) + + # Set cell if available + if atoms.cell is not None: + data.cell = torch.tensor(atoms.get_cell(), dtype=torch.float) + + # Set PBC if requested + if self.r_pbc: + data.pbc = torch.tensor(atoms.get_pbc(), dtype=torch.bool) + + # Set energy if requested + if self.r_energy: + energy = self._get_property(atoms, "energy") + if energy is not None: + data.energy = torch.tensor(energy, dtype=torch.float) + + # Set forces if requested + if self.r_forces: + forces = self._get_property(atoms, "forces") + if forces is not None: + data.forces = torch.tensor(forces, dtype=torch.float) + + # Set stress if requested + if self.r_stress: + stress = self._get_property(atoms, "stress") + if stress is not None: + data.stress = torch.tensor(stress, dtype=torch.float) + + # Set sid if provided + if sid is not None: + data.sid = sid + + return data + + def _get_property(self, atoms, prop_name): + """Get property from atoms, checking custom names first then standard methods.""" + # Check if we have a custom name for this property + custom_name = self.r_data_keys.get(prop_name) + + # Try custom name in info dict + if custom_name and custom_name in atoms.info: + return atoms.info[custom_name] + + # Try custom name in arrays dict + if custom_name and custom_name in atoms.arrays: + return atoms.arrays[custom_name] + + # Try standard name in info dict + if prop_name in atoms.info: + return atoms.info[prop_name] + + # Try standard name in arrays dict + if prop_name in atoms.arrays: + return atoms.arrays[prop_name] + + # Try standard ASE methods + method_map = { + "energy": "get_potential_energy", + "forces": "get_forces", + "stress": "get_stress", + } + + if prop_name in method_map and hasattr(atoms, method_map[prop_name]): + try: + method = getattr(atoms, method_map[prop_name]) + return method() + except ( + AttributeError, + RuntimeError, + ) as exc: # Fixed W0718 by specifying exceptions + logging.debug(f"Error getting property {prop_name}: {exc}") + # Removed unnecessary pass (W0107) + + return None + + +# Placeholder for DataTransforms class +class DataTransforms: + """Minimal implementation of DataTransforms to satisfy dependencies.""" + + def __init__(self, transforms_config=None): + self.transforms_config = transforms_config or {} + + def __call__(self, data): + """Apply transforms to data""" + # No transforms applied in this minimal implementation + return data + + +class AseAtomsDataset(BaseDataset, ABC): + """ + This is an abstract Dataset that includes helpful utilities for turning + ASE atoms objects into OCP-usable data objects. This should not be instantiated directly + as get_atoms_object and load_dataset_get_ids are not implemented in this base class. + + Derived classes must add at least two things: + self.get_atoms_object(id): a function that takes an identifier and returns a corresponding atoms object + + self.load_dataset_get_ids(config: dict): This function is responsible for any initialization/loads + of the dataset and importantly must return a list of all possible identifiers that can be passed into + self.get_atoms_object(id) + + Identifiers need not be any particular type. + """ + + def __init__( + self, + config: dict, + atoms_transform: Callable[[ase.Atoms, Any], ase.Atoms] = apply_one_tags, + ) -> None: + super().__init__(config) + + a2g_args = config.get("a2g_args", {}) or {} + + # set default to False if not set by user, assuming otf_graph will be used + if "r_edges" not in a2g_args: + a2g_args["r_edges"] = False + + # Make sure we always include PBC info in the resulting atoms objects + a2g_args["r_pbc"] = True + self.a2g = AtomsToGraphs(**a2g_args) + + self.key_mapping = self.config.get("key_mapping", None) + self.transforms = DataTransforms(self.config.get("transforms", {})) + + self.atoms_transform = atoms_transform + + if self.config.get("keep_in_memory", False): + self.__getitem__ = cache(self.__getitem__) + + self.ids = self._load_dataset_get_ids(config) + self.num_samples = len(self.ids) + + if len(self.ids) == 0: + raise ValueError( + rf"No valid ase data found! \n" + f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" + ) + + def __getitem__(self, idx): # pylint: disable=method-hidden + # Handle slicing + if isinstance(idx, slice): + return [self[i] for i in range(*idx.indices(len(self)))] + + # Get atoms object via derived class method + atoms = self.get_atoms(self.ids[idx]) + + # Transform atoms object + if self.atoms_transform is not None: + atoms = self.atoms_transform( + atoms, **self.config.get("atoms_transform_args", {}) + ) + + sid = atoms.info.get("sid", self.ids[idx]) + fid = atoms.info.get("fid", torch.tensor([0])) + + # Convert to data object + data_object = self.a2g.convert(atoms, sid) + data_object.fid = fid + data_object.natoms = len(atoms) + + # apply linear reference + if self.a2g.r_energy is True and self.lin_ref is not None: + data_object.energy -= sum(self.lin_ref[data_object.atomic_numbers.long()]) + + # Transform data object + data_object = self.transforms(data_object) + + if self.key_mapping is not None: + data_object = rename_data_object_keys(data_object, self.key_mapping) + + if self.config.get("include_relaxed_energy", False): + data_object.energy_relaxed = self.get_relaxed_energy(self.ids[idx]) + + return data_object + + @abstractmethod + def get_atoms(self, idx: str | int) -> ase.Atoms: + # This function should return an ASE atoms object. + raise NotImplementedError( + "Returns an ASE atoms object. Derived classes should implement this function." + ) + + @abstractmethod + def _load_dataset_get_ids(self, config): + # This function should return a list of ids that can be used to index into the database + raise NotImplementedError( + "Every ASE dataset needs to declare a function to load the dataset and return a list of ids." + ) + + def get_relaxed_energy(self, identifier): + raise NotImplementedError( + "Reading relaxed energy from trajectory or file is not implemented with this dataset. " + "If relaxed energies are saved with the atoms info dictionary, they can be used by passing the keys in " + "the r_data_keys argument under a2g_args." + ) + + def get_metadata(self, attr, idx): + # try the parent method + metadata = super().get_metadata(attr, idx) + if metadata is not None: + return metadata + # try to resolve it here + if attr != "natoms": + return None + if isinstance(idx, (list, np.ndarray)): + return np.array([self.get_metadata(attr, i) for i in idx]) + return len(self.get_atoms(idx)) + + +class AseDBDataset(AseAtomsDataset): + """ + This Dataset connects to an ASE Database, allowing the storage of atoms objects + with a variety of backends including JSON, SQLite, and database server options. + """ + + def _load_dataset_get_ids(self, config: dict) -> list[int]: + if isinstance(config["src"], list): + filepaths = [] + for path in sorted(config["src"]): + if os.path.isdir(path): + filepaths.extend(sorted(glob(f"{path}/*"))) + elif os.path.isfile(path): + filepaths.append(path) + else: + raise RuntimeError(f"Error reading dataset in {path}!") + elif os.path.isfile(config["src"]): + filepaths = [config["src"]] + elif os.path.isdir(config["src"]): + filepaths = sorted(glob(f'{config["src"]}/*')) + else: + filepaths = sorted(glob(config["src"])) + + self.dbs = [] + + for path in filepaths: + try: + self.dbs.append(self.connect_db(path, config.get("connect_args", {}))) + except ValueError: + logging.debug( + f"Tried to connect to {path} but it's not an ASE database!" + ) + + self.select_args = config.get("select_args", {}) + if self.select_args is None: + self.select_args = {} + + # Get all unique IDs from the databases + self.db_ids = [] + for db in self.dbs: + if hasattr(db, "ids") and self.select_args == {}: + self.db_ids.append(db.ids) + else: + # this is the slow alternative + self.db_ids.append([row.id for row in db.select(**self.select_args)]) + + idlens = [len(ids) for ids in self.db_ids] + self._idlen_cumulative = np.cumsum(idlens).tolist() + + return list(range(sum(idlens))) + + def get_atoms(self, idx: int) -> ase.Atoms: + """Get atoms object corresponding to datapoint idx. + Args: + idx (int): index in dataset + + Returns: + atoms: ASE atoms corresponding to datapoint idx + """ + # Figure out which db this should be indexed from + db_idx = bisect.bisect(self._idlen_cumulative, idx) + + # Extract index of element within that db + el_idx = idx + if db_idx != 0: + el_idx = idx - self._idlen_cumulative[db_idx - 1] + assert el_idx >= 0 + + # Use a wrapper method to avoid protected access warning + atoms_row = self.get_row_from_db(db_idx, el_idx) + + # Convert to atoms object + atoms = atoms_row.toatoms() + + # Put data back into atoms info + if isinstance(atoms_row.data, dict): + atoms.info.update(atoms_row.data) + + # Add key-value pairs directly to atoms.info + if hasattr(atoms_row, "key_value_pairs") and atoms_row.key_value_pairs: + atoms.info.update(atoms_row.key_value_pairs) + + # Create a SinglePointCalculator to attach energy, forces and stress to atoms + calc_kwargs = {} + + # Check for energy, forces, stress in atoms_row and store in info & calc_kwargs + for prop in ["energy", "forces", "stress", "free_energy"]: + if hasattr(atoms_row, prop) and getattr(atoms_row, prop) is not None: + value = getattr(atoms_row, prop) + calc_kwargs[prop] = value + atoms.info[prop] = value + + # If we have custom data mappings, copy the standard properties to the custom names + a2g_args = self.config.get("a2g_args", {}) or {} + r_data_keys = a2g_args.get("r_data_keys", {}) + if r_data_keys: + # Map from standard names to custom names (in reverse of how they'll be used) + for custom_key, standard_key in r_data_keys.items(): + if standard_key in atoms.info: + atoms.info[custom_key] = atoms.info[standard_key] + elif standard_key in atoms.arrays: + atoms.arrays[custom_key] = atoms.arrays[standard_key] + + # Create calculator if we have any properties + if calc_kwargs: + from ase.calculators.singlepoint import SinglePointCalculator + + calc = SinglePointCalculator(atoms, **calc_kwargs) + atoms.calc = calc + + return atoms + + def get_row_from_db(self, db_idx, el_idx): + """Get a row from the database at the given indices.""" + db = self.dbs[db_idx] + row_id = self.db_ids[db_idx][el_idx] + if isinstance(db, LMDBDatabase): + return db._get_row(row_id) # pylint: disable=protected-access + return db.get(row_id) + + @staticmethod + def connect_db( + address: str | Path, connect_args: dict | None = None + ) -> ase.db.core.Database: + if connect_args is None: + connect_args = {} + db_type = connect_args.get("type", "extract_from_name") + if db_type in ("lmdb", "aselmdb") or ( + db_type == "extract_from_name" + and str(address).rsplit(".", maxsplit=1)[-1] in ("lmdb", "aselmdb") + ): + return LMDBDatabase(address, readonly=True, **connect_args) + + return ase.db.connect(address, **connect_args) + + def __del__(self): + for db in self.dbs: + if hasattr(db, "close"): + db.close() + + def sample_property_metadata( + self, + ) -> dict: # Removed unused argument num_samples (W0613) + """ + Sample property metadata from the database. + + This method was previously using the copy module which is now removed. + """ + logging.warning( + "You specified a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!" + ) + if self.dbs[0].metadata == {}: + return {} + + # Fixed unnecessary comprehension (R1721) + return dict(self.dbs[0].metadata.items()) diff --git a/mace-bench/3rdparty/mace/mace/tools/finetuning_utils.py b/mace-bench/3rdparty/mace/mace/tools/finetuning_utils.py index f76aa90c67f71786a0e96814602755315e58f8cf..8df0b0d1f7309995ea1417264190d5f7749c2acd 100644 --- a/mace-bench/3rdparty/mace/mace/tools/finetuning_utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/finetuning_utils.py @@ -1,204 +1,204 @@ -import torch - -from mace.tools.utils import AtomicNumberTable - - -def load_foundations_elements( - model: torch.nn.Module, - model_foundations: torch.nn.Module, - table: AtomicNumberTable, - load_readout=False, - use_shift=True, - use_scale=True, - max_L=2, -): - """ - Load the foundations of a model into a model for fine-tuning. - """ - assert model_foundations.r_max == model.r_max - z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) - model_heads = model.heads - new_z_table = table - num_species_foundations = len(z_table.zs) - num_channels_foundation = ( - model_foundations.node_embedding.linear.weight.shape[0] - // num_species_foundations - ) - indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs] - num_radial = model.radial_embedding.out_dim - num_species = len(indices_weights) - max_ell = model.spherical_harmonics._lmax # pylint: disable=protected-access - model.node_embedding.linear.weight = torch.nn.Parameter( - model_foundations.node_embedding.linear.weight.view( - num_species_foundations, -1 - )[indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis": - model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( - model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() - ) - for i in range(int(model.num_interactions)): - model.interactions[i].linear_up.weight = torch.nn.Parameter( - model_foundations.interactions[i].linear_up.weight.clone() - ) - model.interactions[i].avg_num_neighbors = model_foundations.interactions[ - i - ].avg_num_neighbors - for j in range(4): # Assuming 4 layers in conv_tp_weights, - layer_name = f"layer{j}" - if j == 0: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ) - .weight[:num_radial, :] - .clone() - ) - ) - else: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ).weight.clone() - ) - ) - - model.interactions[i].linear.weight = torch.nn.Parameter( - model_foundations.interactions[i].linear.weight.clone() - ) - if model.interactions[i].__class__.__name__ in [ - "RealAgnosticResidualInteractionBlock", - "RealAgnosticDensityResidualInteractionBlock", - ]: - model.interactions[i].skip_tp.weight = torch.nn.Parameter( - model_foundations.interactions[i] - .skip_tp.weight.reshape( - num_channels_foundation, - num_species_foundations, - num_channels_foundation, - )[:, indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - else: - model.interactions[i].skip_tp.weight = torch.nn.Parameter( - model_foundations.interactions[i] - .skip_tp.weight.reshape( - num_channels_foundation, - (max_ell + 1), - num_species_foundations, - num_channels_foundation, - )[:, :, indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - if model.interactions[i].__class__.__name__ in [ - "RealAgnosticDensityInteractionBlock", - "RealAgnosticDensityResidualInteractionBlock", - ]: - # Assuming only 1 layer in density_fn - getattr(model.interactions[i].density_fn, "layer0").weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].density_fn, - "layer0", - ).weight.clone() - ) - ) - # Transferring products - for i in range(2): # Assuming 2 products modules - max_range = max_L + 1 if i == 0 else 1 - for j in range(max_range): # Assuming 3 contractions in symmetric_contractions - model.products[i].symmetric_contractions.contractions[j].weights_max = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights_max[indices_weights, :, :] - .clone() - ) - ) - - for k in range(2): # Assuming 2 weights in each contraction - model.products[i].symmetric_contractions.contractions[j].weights[k] = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights[k][indices_weights, :, :] - .clone() - ) - ) - - model.products[i].linear.weight = torch.nn.Parameter( - model_foundations.products[i].linear.weight.clone() - ) - - if load_readout: - # Transferring readouts - model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone() - model_readouts_zero_linear_weight = ( - model_foundations.readouts[0] - .linear.weight.view(num_channels_foundation, -1) - .repeat(1, len(model_heads)) - .flatten() - .clone() - ) - model.readouts[0].linear.weight = torch.nn.Parameter( - model_readouts_zero_linear_weight - ) - - shape_input_1 = ( - model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps - ) - shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps - model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone() - model_readouts_one_linear_1_weight = ( - model_foundations.readouts[1] - .linear_1.weight.view(num_channels_foundation, -1) - .repeat(1, len(model_heads)) - .flatten() - .clone() - ) - model.readouts[1].linear_1.weight = torch.nn.Parameter( - model_readouts_one_linear_1_weight - ) - model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone() - model_readouts_one_linear_2_weight = model_foundations.readouts[ - 1 - ].linear_2.weight.view(shape_input_1, -1).repeat( - len(model_heads), len(model_heads) - ).flatten().clone() / ( - ((shape_input_1) / (shape_output_1)) ** 0.5 - ) - model.readouts[1].linear_2.weight = torch.nn.Parameter( - model_readouts_one_linear_2_weight - ) - if model_foundations.scale_shift is not None: - if use_scale: - model.scale_shift.scale = model_foundations.scale_shift.scale.repeat( - len(model_heads) - ).clone() - if use_shift: - model.scale_shift.shift = model_foundations.scale_shift.shift.repeat( - len(model_heads) - ).clone() - return model - - -def load_foundations( - model, - model_foundations, -): - for name, param in model_foundations.named_parameters(): - if name in model.state_dict().keys(): - if "readouts" not in name: - model.state_dict()[name].copy_(param) - return model +import torch + +from mace.tools.utils import AtomicNumberTable + + +def load_foundations_elements( + model: torch.nn.Module, + model_foundations: torch.nn.Module, + table: AtomicNumberTable, + load_readout=False, + use_shift=True, + use_scale=True, + max_L=2, +): + """ + Load the foundations of a model into a model for fine-tuning. + """ + assert model_foundations.r_max == model.r_max + z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) + model_heads = model.heads + new_z_table = table + num_species_foundations = len(z_table.zs) + num_channels_foundation = ( + model_foundations.node_embedding.linear.weight.shape[0] + // num_species_foundations + ) + indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs] + num_radial = model.radial_embedding.out_dim + num_species = len(indices_weights) + max_ell = model.spherical_harmonics._lmax # pylint: disable=protected-access + model.node_embedding.linear.weight = torch.nn.Parameter( + model_foundations.node_embedding.linear.weight.view( + num_species_foundations, -1 + )[indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis": + model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( + model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() + ) + for i in range(int(model.num_interactions)): + model.interactions[i].linear_up.weight = torch.nn.Parameter( + model_foundations.interactions[i].linear_up.weight.clone() + ) + model.interactions[i].avg_num_neighbors = model_foundations.interactions[ + i + ].avg_num_neighbors + for j in range(4): # Assuming 4 layers in conv_tp_weights, + layer_name = f"layer{j}" + if j == 0: + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ) + .weight[:num_radial, :] + .clone() + ) + ) + else: + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ).weight.clone() + ) + ) + + model.interactions[i].linear.weight = torch.nn.Parameter( + model_foundations.interactions[i].linear.weight.clone() + ) + if model.interactions[i].__class__.__name__ in [ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ]: + model.interactions[i].skip_tp.weight = torch.nn.Parameter( + model_foundations.interactions[i] + .skip_tp.weight.reshape( + num_channels_foundation, + num_species_foundations, + num_channels_foundation, + )[:, indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + else: + model.interactions[i].skip_tp.weight = torch.nn.Parameter( + model_foundations.interactions[i] + .skip_tp.weight.reshape( + num_channels_foundation, + (max_ell + 1), + num_species_foundations, + num_channels_foundation, + )[:, :, indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + if model.interactions[i].__class__.__name__ in [ + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ]: + # Assuming only 1 layer in density_fn + getattr(model.interactions[i].density_fn, "layer0").weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].density_fn, + "layer0", + ).weight.clone() + ) + ) + # Transferring products + for i in range(2): # Assuming 2 products modules + max_range = max_L + 1 if i == 0 else 1 + for j in range(max_range): # Assuming 3 contractions in symmetric_contractions + model.products[i].symmetric_contractions.contractions[j].weights_max = ( + torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights_max[indices_weights, :, :] + .clone() + ) + ) + + for k in range(2): # Assuming 2 weights in each contraction + model.products[i].symmetric_contractions.contractions[j].weights[k] = ( + torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights[k][indices_weights, :, :] + .clone() + ) + ) + + model.products[i].linear.weight = torch.nn.Parameter( + model_foundations.products[i].linear.weight.clone() + ) + + if load_readout: + # Transferring readouts + model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone() + model_readouts_zero_linear_weight = ( + model_foundations.readouts[0] + .linear.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_heads)) + .flatten() + .clone() + ) + model.readouts[0].linear.weight = torch.nn.Parameter( + model_readouts_zero_linear_weight + ) + + shape_input_1 = ( + model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + ) + shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone() + model_readouts_one_linear_1_weight = ( + model_foundations.readouts[1] + .linear_1.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_heads)) + .flatten() + .clone() + ) + model.readouts[1].linear_1.weight = torch.nn.Parameter( + model_readouts_one_linear_1_weight + ) + model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone() + model_readouts_one_linear_2_weight = model_foundations.readouts[ + 1 + ].linear_2.weight.view(shape_input_1, -1).repeat( + len(model_heads), len(model_heads) + ).flatten().clone() / ( + ((shape_input_1) / (shape_output_1)) ** 0.5 + ) + model.readouts[1].linear_2.weight = torch.nn.Parameter( + model_readouts_one_linear_2_weight + ) + if model_foundations.scale_shift is not None: + if use_scale: + model.scale_shift.scale = model_foundations.scale_shift.scale.repeat( + len(model_heads) + ).clone() + if use_shift: + model.scale_shift.shift = model_foundations.scale_shift.shift.repeat( + len(model_heads) + ).clone() + return model + + +def load_foundations( + model, + model_foundations, +): + for name, param in model_foundations.named_parameters(): + if name in model.state_dict().keys(): + if "readouts" not in name: + model.state_dict()[name].copy_(param) + return model diff --git a/mace-bench/3rdparty/mace/mace/tools/model_script_utils.py b/mace-bench/3rdparty/mace/mace/tools/model_script_utils.py index e5775245a1ebf96f6c53cc2f28208da9d565e513..c9de08b9f34abb957f3407fedd32e13ea68779d7 100644 --- a/mace-bench/3rdparty/mace/mace/tools/model_script_utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/model_script_utils.py @@ -1,265 +1,265 @@ -import ast -import logging - -import numpy as np -from e3nn import o3 - -from mace import modules -from mace.tools.finetuning_utils import load_foundations_elements -from mace.tools.scripts_utils import extract_config_mace_model -from mace.tools.utils import AtomicNumberTable - - -def configure_model( - args, - train_loader, - atomic_energies, - model_foundation=None, - heads=None, - z_table=None, - head_configs=None, -): - # Selecting outputs - compute_virials = args.loss == "virials" - compute_stress = args.loss in ("stress", "huber", "universal") - - if compute_virials: - args.compute_virials = True - args.error_table = "PerAtomRMSEstressvirials" - elif compute_stress: - args.compute_stress = True - args.error_table = "PerAtomRMSEstressvirials" - - output_args = { - "energy": args.compute_energy, - "forces": args.compute_forces, - "virials": compute_virials, - "stress": compute_stress, - "dipoles": args.compute_dipole, - } - logging.info( - f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" - ) - logging.info("===========MODEL DETAILS===========") - - if args.scaling == "no_scaling": - args.std = 1.0 - if head_configs is not None: - for head_config in head_configs: - head_config.std = 1.0 - logging.info("No scaling selected") - - if ( - head_configs is not None - and args.std is not None - and not isinstance(args.std, list) - ): - atomic_inter_scale = [] - for head_config in head_configs: - if hasattr(head_config, "std") and head_config.std is not None: - atomic_inter_scale.append(head_config.std) - elif args.std is not None: - atomic_inter_scale.append( - args.std if isinstance(args.std, float) else 1.0 - ) - args.std = atomic_inter_scale - - elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE": - args.mean, args.std = modules.scaling_classes[args.scaling]( - train_loader, atomic_energies - ) - - # Build model - if model_foundation is not None and args.model in ["MACE", "ScaleShiftMACE"]: - logging.info("Loading FOUNDATION model") - model_config_foundation = extract_config_mace_model(model_foundation) - model_config_foundation["atomic_energies"] = atomic_energies - - if args.foundation_model_elements: - foundation_z_table = AtomicNumberTable( - [int(z) for z in model_foundation.atomic_numbers] - ) - model_config_foundation["atomic_numbers"] = foundation_z_table.zs - model_config_foundation["num_elements"] = len(foundation_z_table) - z_table = foundation_z_table - logging.info( - f"Using all elements from foundation model: {foundation_z_table.zs}" - ) - else: - model_config_foundation["atomic_numbers"] = z_table.zs - model_config_foundation["num_elements"] = len(z_table) - logging.info(f"Using filtered elements: {z_table.zs}") - - args.max_L = model_config_foundation["hidden_irreps"].lmax - - if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": - model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) - else: - model_config_foundation["atomic_inter_shift"] = ( - _determine_atomic_inter_shift(args.mean, heads) - ) - model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) - args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] - args.model = "FoundationMACE" - model_config_foundation["heads"] = heads - model_config = model_config_foundation - - logging.info("Model configuration extracted from foundation model") - logging.info("Using universal loss function for fine-tuning") - logging.info( - f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})" - ) - logging.info( - f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" - ) - logging.info( - f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)" - ) - logging.info( - f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" - ) - else: - logging.info("Building model") - logging.info( - f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" - ) - logging.info( - f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" - ) - logging.info( - f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" - ) - logging.info( - f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)" - ) - logging.info( - f"Distance transform for radial basis functions: {args.distance_transform}" - ) - - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - - logging.info(f"Hidden irreps: {args.hidden_irreps}") - - model_config = dict( - r_max=args.r_max, - num_bessel=args.num_radial_basis, - num_polynomial_cutoff=args.num_cutoff_basis, - max_ell=args.max_ell, - interaction_cls=modules.interaction_classes[args.interaction], - num_interactions=args.num_interactions, - num_elements=len(z_table), - hidden_irreps=o3.Irreps(args.hidden_irreps), - atomic_energies=atomic_energies, - avg_num_neighbors=args.avg_num_neighbors, - atomic_numbers=z_table.zs, - ) - model_config_foundation = None - - model = _build_model(args, model_config, model_config_foundation, heads) - - if model_foundation is not None: - model = load_foundations_elements( - model, - model_foundation, - z_table, - load_readout=args.foundation_filter_elements, - max_L=args.max_L, - ) - - return model, output_args - - -def _determine_atomic_inter_shift(mean, heads): - if isinstance(mean, np.ndarray): - if mean.size == 1: - return mean.item() - if mean.size == len(heads): - return mean.tolist() - logging.info("Mean not in correct format, using default value of 0.0") - return [0.0] * len(heads) - if isinstance(mean, list) and len(mean) == len(heads): - return mean - if isinstance(mean, float): - return [mean] * len(heads) - logging.info("Mean not in correct format, using default value of 0.0") - return [0.0] * len(heads) - - -def _build_model( - args, model_config, model_config_foundation, heads -): # pylint: disable=too-many-return-statements - if args.model == "MACE": - if args.interaction_first not in [ - "RealAgnosticInteractionBlock", - "RealAgnosticDensityInteractionBlock", - ]: - args.interaction_first = "RealAgnosticInteractionBlock" - return modules.ScaleShiftMACE( - **model_config, - pair_repulsion=args.pair_repulsion, - distance_transform=args.distance_transform, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[args.interaction_first], - MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=args.std, - atomic_inter_shift=[0.0] * len(heads), - radial_MLP=ast.literal_eval(args.radial_MLP), - radial_type=args.radial_type, - heads=heads, - ) - if args.model == "ScaleShiftMACE": - return modules.ScaleShiftMACE( - **model_config, - pair_repulsion=args.pair_repulsion, - distance_transform=args.distance_transform, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[args.interaction_first], - MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=args.std, - atomic_inter_shift=args.mean, - radial_MLP=ast.literal_eval(args.radial_MLP), - radial_type=args.radial_type, - heads=heads, - ) - if args.model == "FoundationMACE": - return modules.ScaleShiftMACE(**model_config_foundation) - if args.model == "ScaleShiftBOTNet": - # say it is deprecated - raise RuntimeError("ScaleShiftBOTNet is deprecated, use MACE instead") - if args.model == "BOTNet": - raise RuntimeError("BOTNet is deprecated, use MACE instead") - if args.model == "AtomicDipolesMACE": - assert args.loss == "dipole", "Use dipole loss with AtomicDipolesMACE model" - assert ( - args.error_table == "DipoleRMSE" - ), "Use error_table DipoleRMSE with AtomicDipolesMACE model" - return modules.AtomicDipolesMACE( - **model_config, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticInteractionBlock" - ], - MLP_irreps=o3.Irreps(args.MLP_irreps), - ) - if args.model == "EnergyDipolesMACE": - assert ( - args.loss == "energy_forces_dipole" - ), "Use energy_forces_dipole loss with EnergyDipolesMACE model" - assert ( - args.error_table == "EnergyDipoleRMSE" - ), "Use error_table EnergyDipoleRMSE with AtomicDipolesMACE model" - return modules.EnergyDipolesMACE( - **model_config, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticInteractionBlock" - ], - MLP_irreps=o3.Irreps(args.MLP_irreps), - ) - raise RuntimeError(f"Unknown model: '{args.model}'") +import ast +import logging + +import numpy as np +from e3nn import o3 + +from mace import modules +from mace.tools.finetuning_utils import load_foundations_elements +from mace.tools.scripts_utils import extract_config_mace_model +from mace.tools.utils import AtomicNumberTable + + +def configure_model( + args, + train_loader, + atomic_energies, + model_foundation=None, + heads=None, + z_table=None, + head_configs=None, +): + # Selecting outputs + compute_virials = args.loss == "virials" + compute_stress = args.loss in ("stress", "huber", "universal") + + if compute_virials: + args.compute_virials = True + args.error_table = "PerAtomRMSEstressvirials" + elif compute_stress: + args.compute_stress = True + args.error_table = "PerAtomRMSEstressvirials" + + output_args = { + "energy": args.compute_energy, + "forces": args.compute_forces, + "virials": compute_virials, + "stress": compute_stress, + "dipoles": args.compute_dipole, + } + logging.info( + f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" + ) + logging.info("===========MODEL DETAILS===========") + + if args.scaling == "no_scaling": + args.std = 1.0 + if head_configs is not None: + for head_config in head_configs: + head_config.std = 1.0 + logging.info("No scaling selected") + + if ( + head_configs is not None + and args.std is not None + and not isinstance(args.std, list) + ): + atomic_inter_scale = [] + for head_config in head_configs: + if hasattr(head_config, "std") and head_config.std is not None: + atomic_inter_scale.append(head_config.std) + elif args.std is not None: + atomic_inter_scale.append( + args.std if isinstance(args.std, float) else 1.0 + ) + args.std = atomic_inter_scale + + elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE": + args.mean, args.std = modules.scaling_classes[args.scaling]( + train_loader, atomic_energies + ) + + # Build model + if model_foundation is not None and args.model in ["MACE", "ScaleShiftMACE"]: + logging.info("Loading FOUNDATION model") + model_config_foundation = extract_config_mace_model(model_foundation) + model_config_foundation["atomic_energies"] = atomic_energies + + if args.foundation_model_elements: + foundation_z_table = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + model_config_foundation["atomic_numbers"] = foundation_z_table.zs + model_config_foundation["num_elements"] = len(foundation_z_table) + z_table = foundation_z_table + logging.info( + f"Using all elements from foundation model: {foundation_z_table.zs}" + ) + else: + model_config_foundation["atomic_numbers"] = z_table.zs + model_config_foundation["num_elements"] = len(z_table) + logging.info(f"Using filtered elements: {z_table.zs}") + + args.max_L = model_config_foundation["hidden_irreps"].lmax + + if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": + model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) + else: + model_config_foundation["atomic_inter_shift"] = ( + _determine_atomic_inter_shift(args.mean, heads) + ) + model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) + args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] + args.model = "FoundationMACE" + model_config_foundation["heads"] = heads + model_config = model_config_foundation + + logging.info("Model configuration extracted from foundation model") + logging.info("Using universal loss function for fine-tuning") + logging.info( + f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})" + ) + logging.info( + f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" + ) + logging.info( + f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)" + ) + logging.info( + f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" + ) + else: + logging.info("Building model") + logging.info( + f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" + ) + logging.info( + f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" + ) + logging.info( + f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" + ) + logging.info( + f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)" + ) + logging.info( + f"Distance transform for radial basis functions: {args.distance_transform}" + ) + + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + logging.info(f"Hidden irreps: {args.hidden_irreps}") + + model_config = dict( + r_max=args.r_max, + num_bessel=args.num_radial_basis, + num_polynomial_cutoff=args.num_cutoff_basis, + max_ell=args.max_ell, + interaction_cls=modules.interaction_classes[args.interaction], + num_interactions=args.num_interactions, + num_elements=len(z_table), + hidden_irreps=o3.Irreps(args.hidden_irreps), + atomic_energies=atomic_energies, + avg_num_neighbors=args.avg_num_neighbors, + atomic_numbers=z_table.zs, + ) + model_config_foundation = None + + model = _build_model(args, model_config, model_config_foundation, heads) + + if model_foundation is not None: + model = load_foundations_elements( + model, + model_foundation, + z_table, + load_readout=args.foundation_filter_elements, + max_L=args.max_L, + ) + + return model, output_args + + +def _determine_atomic_inter_shift(mean, heads): + if isinstance(mean, np.ndarray): + if mean.size == 1: + return mean.item() + if mean.size == len(heads): + return mean.tolist() + logging.info("Mean not in correct format, using default value of 0.0") + return [0.0] * len(heads) + if isinstance(mean, list) and len(mean) == len(heads): + return mean + if isinstance(mean, float): + return [mean] * len(heads) + logging.info("Mean not in correct format, using default value of 0.0") + return [0.0] * len(heads) + + +def _build_model( + args, model_config, model_config_foundation, heads +): # pylint: disable=too-many-return-statements + if args.model == "MACE": + if args.interaction_first not in [ + "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + ]: + args.interaction_first = "RealAgnosticInteractionBlock" + return modules.ScaleShiftMACE( + **model_config, + pair_repulsion=args.pair_repulsion, + distance_transform=args.distance_transform, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=[0.0] * len(heads), + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + heads=heads, + ) + if args.model == "ScaleShiftMACE": + return modules.ScaleShiftMACE( + **model_config, + pair_repulsion=args.pair_repulsion, + distance_transform=args.distance_transform, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=args.mean, + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + heads=heads, + ) + if args.model == "FoundationMACE": + return modules.ScaleShiftMACE(**model_config_foundation) + if args.model == "ScaleShiftBOTNet": + # say it is deprecated + raise RuntimeError("ScaleShiftBOTNet is deprecated, use MACE instead") + if args.model == "BOTNet": + raise RuntimeError("BOTNet is deprecated, use MACE instead") + if args.model == "AtomicDipolesMACE": + assert args.loss == "dipole", "Use dipole loss with AtomicDipolesMACE model" + assert ( + args.error_table == "DipoleRMSE" + ), "Use error_table DipoleRMSE with AtomicDipolesMACE model" + return modules.AtomicDipolesMACE( + **model_config, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + if args.model == "EnergyDipolesMACE": + assert ( + args.loss == "energy_forces_dipole" + ), "Use energy_forces_dipole loss with EnergyDipolesMACE model" + assert ( + args.error_table == "EnergyDipoleRMSE" + ), "Use error_table EnergyDipoleRMSE with AtomicDipolesMACE model" + return modules.EnergyDipolesMACE( + **model_config, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + raise RuntimeError(f"Unknown model: '{args.model}'") diff --git a/mace-bench/3rdparty/mace/mace/tools/multihead_tools.py b/mace-bench/3rdparty/mace/mace/tools/multihead_tools.py index 1a12416e9170471dabc49fea82d4cde2a2b239ae..f321af390f8d45a63b68a62c1a95f49d88fa7756 100644 --- a/mace-bench/3rdparty/mace/mace/tools/multihead_tools.py +++ b/mace-bench/3rdparty/mace/mace/tools/multihead_tools.py @@ -1,200 +1,200 @@ -import argparse -import ast -import dataclasses -import logging -import os -import urllib.request -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import torch - -from mace.cli.fine_tuning_select import ( - FilteringType, - SelectionSettings, - SubselectType, - select_samples, -) -from mace.data import KeySpecification -from mace.tools.scripts_utils import SubsetCollection, get_dataset_from_xyz - - -@dataclasses.dataclass -class HeadConfig: - head_name: str - key_specification: KeySpecification - train_file: Optional[Union[str, List[str]]] = None - valid_file: Optional[Union[str, List[str]]] = None - test_file: Optional[str] = None - test_dir: Optional[str] = None - E0s: Optional[Any] = None - statistics_file: Optional[str] = None - valid_fraction: Optional[float] = None - config_type_weights: Optional[Dict[str, float]] = None - keep_isolated_atoms: Optional[bool] = None - atomic_numbers: Optional[Union[List[int], List[str]]] = None - mean: Optional[float] = None - std: Optional[float] = None - avg_num_neighbors: Optional[float] = None - compute_avg_num_neighbors: Optional[bool] = None - collections: Optional[SubsetCollection] = None - train_loader: Optional[torch.utils.data.DataLoader] = None - z_table: Optional[Any] = None - atomic_energies_dict: Optional[Dict[str, float]] = None - - -def dict_head_to_dataclass( - head: Dict[str, Any], head_name: str, args: argparse.Namespace -) -> HeadConfig: - """Convert head dictionary to HeadConfig dataclass.""" - # parser+head args that have no defaults but are required - if (args.train_file is None) and (head.get("train_file", None) is None): - raise ValueError( - "train file is not set in the head config yaml or via command line args" - ) - - return HeadConfig( - head_name=head_name, - train_file=head.get("train_file", args.train_file), - valid_file=head.get("valid_file", args.valid_file), - test_file=head.get("test_file", None), - test_dir=head.get("test_dir", None), - E0s=head.get("E0s", args.E0s), - statistics_file=head.get("statistics_file", args.statistics_file), - valid_fraction=head.get("valid_fraction", args.valid_fraction), - config_type_weights=head.get("config_type_weights", args.config_type_weights), - compute_avg_num_neighbors=head.get( - "compute_avg_num_neighbors", args.compute_avg_num_neighbors - ), - atomic_numbers=head.get("atomic_numbers", args.atomic_numbers), - mean=head.get("mean", args.mean), - std=head.get("std", args.std), - avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors), - key_specification=head["key_specification"], - keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), - ) - - -def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: - """Prepare a default head from args.""" - return { - "Default": { - "train_file": args.train_file, - "valid_file": args.valid_file, - "test_file": args.test_file, - "test_dir": args.test_dir, - "E0s": args.E0s, - "statistics_file": args.statistics_file, - "key_specification": args.key_specification, - "valid_fraction": args.valid_fraction, - "config_type_weights": args.config_type_weights, - "keep_isolated_atoms": args.keep_isolated_atoms, - } - } - - -def prepare_pt_head( - args: argparse.Namespace, - pt_keyspec: KeySpecification, - foundation_model_num_neighbours: float, -) -> Dict[str, Any]: - """Prepare a pretraining head from args.""" - if ( - args.foundation_model in ["small", "medium", "large"] - or args.pt_train_file == "mp" - ): - logging.info( - "Using foundation model for multiheads finetuning with Materials Project data" - ) - pt_keyspec.update( - info_keys={"energy": "energy", "stress": "stress"}, - arrays_keys={"forces": "forces"}, - ) - pt_head = { - "train_file": "mp", - "E0s": "foundation", - "statistics_file": None, - "key_specification": pt_keyspec, - "avg_num_neighbors": foundation_model_num_neighbours, - "compute_avg_num_neighbors": False, - } - else: - pt_head = { - "train_file": args.pt_train_file, - "valid_file": args.pt_valid_file, - "E0s": "foundation", - "statistics_file": args.statistics_file, - "valid_fraction": args.valid_fraction, - "key_specification": pt_keyspec, - "avg_num_neighbors": foundation_model_num_neighbours, - "keep_isolated_atoms": args.keep_isolated_atoms, - "compute_avg_num_neighbors": False, - } - - return pt_head - - -def assemble_mp_data( - args: argparse.Namespace, - head_config_pt: HeadConfig, - tag: str, -) -> SubsetCollection: - """Assemble Materials Project data for fine-tuning.""" - try: - checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" - cache_dir = ( - Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser() / ".cache/mace" - ) - checkpoint_url_name = "".join( - c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" - ) - cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" - if not os.path.isfile(cached_dataset_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - logging.info("Downloading MP structures for finetuning") - _, http_msg = urllib.request.urlretrieve( - checkpoint_url, cached_dataset_path - ) - if "Content-Type: text/html" in http_msg: - raise RuntimeError( - f"Dataset download failed, please check the URL {checkpoint_url}" - ) - logging.info(f"Materials Project dataset to {cached_dataset_path}") - output = f"mp_finetuning-{tag}.xyz" - atomic_numbers = ( - ast.literal_eval(args.atomic_numbers) - if args.atomic_numbers is not None - else None - ) - settings = SelectionSettings( - configs_pt=cached_dataset_path, - output=f"mp_finetuning-{tag}.xyz", - atomic_numbers=atomic_numbers, - num_samples=args.num_samples_pt, - seed=args.seed, - head_pt="pbe_mp", - weight_pt=args.weight_pt_head, - filtering_type=FilteringType(args.filter_type_pt), - subselect=SubselectType(args.subselect_pt), - default_dtype=args.default_dtype, - ) - select_samples(settings) - head_config_pt.train_file = [output] - collections_mp, _ = get_dataset_from_xyz( - work_dir=args.work_dir, - train_path=output, - valid_path=None, - valid_fraction=args.valid_fraction, - config_type_weights=None, - test_path=None, - seed=args.seed, - key_specification=head_config_pt.key_specification, - head_name="pt_head", - keep_isolated_atoms=args.keep_isolated_atoms, - ) - return collections_mp - except Exception as exc: - raise RuntimeError( - "Model or descriptors download failed and no local model found" - ) from exc +import argparse +import ast +import dataclasses +import logging +import os +import urllib.request +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch + +from mace.cli.fine_tuning_select import ( + FilteringType, + SelectionSettings, + SubselectType, + select_samples, +) +from mace.data import KeySpecification +from mace.tools.scripts_utils import SubsetCollection, get_dataset_from_xyz + + +@dataclasses.dataclass +class HeadConfig: + head_name: str + key_specification: KeySpecification + train_file: Optional[Union[str, List[str]]] = None + valid_file: Optional[Union[str, List[str]]] = None + test_file: Optional[str] = None + test_dir: Optional[str] = None + E0s: Optional[Any] = None + statistics_file: Optional[str] = None + valid_fraction: Optional[float] = None + config_type_weights: Optional[Dict[str, float]] = None + keep_isolated_atoms: Optional[bool] = None + atomic_numbers: Optional[Union[List[int], List[str]]] = None + mean: Optional[float] = None + std: Optional[float] = None + avg_num_neighbors: Optional[float] = None + compute_avg_num_neighbors: Optional[bool] = None + collections: Optional[SubsetCollection] = None + train_loader: Optional[torch.utils.data.DataLoader] = None + z_table: Optional[Any] = None + atomic_energies_dict: Optional[Dict[str, float]] = None + + +def dict_head_to_dataclass( + head: Dict[str, Any], head_name: str, args: argparse.Namespace +) -> HeadConfig: + """Convert head dictionary to HeadConfig dataclass.""" + # parser+head args that have no defaults but are required + if (args.train_file is None) and (head.get("train_file", None) is None): + raise ValueError( + "train file is not set in the head config yaml or via command line args" + ) + + return HeadConfig( + head_name=head_name, + train_file=head.get("train_file", args.train_file), + valid_file=head.get("valid_file", args.valid_file), + test_file=head.get("test_file", None), + test_dir=head.get("test_dir", None), + E0s=head.get("E0s", args.E0s), + statistics_file=head.get("statistics_file", args.statistics_file), + valid_fraction=head.get("valid_fraction", args.valid_fraction), + config_type_weights=head.get("config_type_weights", args.config_type_weights), + compute_avg_num_neighbors=head.get( + "compute_avg_num_neighbors", args.compute_avg_num_neighbors + ), + atomic_numbers=head.get("atomic_numbers", args.atomic_numbers), + mean=head.get("mean", args.mean), + std=head.get("std", args.std), + avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors), + key_specification=head["key_specification"], + keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), + ) + + +def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: + """Prepare a default head from args.""" + return { + "Default": { + "train_file": args.train_file, + "valid_file": args.valid_file, + "test_file": args.test_file, + "test_dir": args.test_dir, + "E0s": args.E0s, + "statistics_file": args.statistics_file, + "key_specification": args.key_specification, + "valid_fraction": args.valid_fraction, + "config_type_weights": args.config_type_weights, + "keep_isolated_atoms": args.keep_isolated_atoms, + } + } + + +def prepare_pt_head( + args: argparse.Namespace, + pt_keyspec: KeySpecification, + foundation_model_num_neighbours: float, +) -> Dict[str, Any]: + """Prepare a pretraining head from args.""" + if ( + args.foundation_model in ["small", "medium", "large"] + or args.pt_train_file == "mp" + ): + logging.info( + "Using foundation model for multiheads finetuning with Materials Project data" + ) + pt_keyspec.update( + info_keys={"energy": "energy", "stress": "stress"}, + arrays_keys={"forces": "forces"}, + ) + pt_head = { + "train_file": "mp", + "E0s": "foundation", + "statistics_file": None, + "key_specification": pt_keyspec, + "avg_num_neighbors": foundation_model_num_neighbours, + "compute_avg_num_neighbors": False, + } + else: + pt_head = { + "train_file": args.pt_train_file, + "valid_file": args.pt_valid_file, + "E0s": "foundation", + "statistics_file": args.statistics_file, + "valid_fraction": args.valid_fraction, + "key_specification": pt_keyspec, + "avg_num_neighbors": foundation_model_num_neighbours, + "keep_isolated_atoms": args.keep_isolated_atoms, + "compute_avg_num_neighbors": False, + } + + return pt_head + + +def assemble_mp_data( + args: argparse.Namespace, + head_config_pt: HeadConfig, + tag: str, +) -> SubsetCollection: + """Assemble Materials Project data for fine-tuning.""" + try: + checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" + cache_dir = ( + Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser() / ".cache/mace" + ) + checkpoint_url_name = "".join( + c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" + ) + cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" + if not os.path.isfile(cached_dataset_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP structures for finetuning") + _, http_msg = urllib.request.urlretrieve( + checkpoint_url, cached_dataset_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Dataset download failed, please check the URL {checkpoint_url}" + ) + logging.info(f"Materials Project dataset to {cached_dataset_path}") + output = f"mp_finetuning-{tag}.xyz" + atomic_numbers = ( + ast.literal_eval(args.atomic_numbers) + if args.atomic_numbers is not None + else None + ) + settings = SelectionSettings( + configs_pt=cached_dataset_path, + output=f"mp_finetuning-{tag}.xyz", + atomic_numbers=atomic_numbers, + num_samples=args.num_samples_pt, + seed=args.seed, + head_pt="pbe_mp", + weight_pt=args.weight_pt_head, + filtering_type=FilteringType(args.filter_type_pt), + subselect=SubselectType(args.subselect_pt), + default_dtype=args.default_dtype, + ) + select_samples(settings) + head_config_pt.train_file = [output] + collections_mp, _ = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=output, + valid_path=None, + valid_fraction=args.valid_fraction, + config_type_weights=None, + test_path=None, + seed=args.seed, + key_specification=head_config_pt.key_specification, + head_name="pt_head", + keep_isolated_atoms=args.keep_isolated_atoms, + ) + return collections_mp + except Exception as exc: + raise RuntimeError( + "Model or descriptors download failed and no local model found" + ) from exc diff --git a/mace-bench/3rdparty/mace/mace/tools/run_train_utils.py b/mace-bench/3rdparty/mace/mace/tools/run_train_utils.py index cb1d5683d87fbb75144ec2b6ef10fc2888f54093..ce37e0edc345f7367eac550af822e6ee03c7e9f4 100644 --- a/mace-bench/3rdparty/mace/mace/tools/run_train_utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/run_train_utils.py @@ -1,217 +1,217 @@ -import logging -import os -from pathlib import Path -from typing import Any, List, Optional, Union - -import torch -from torch.utils.data import ConcatDataset - -from mace import data -from mace.tools.scripts_utils import check_path_ase_read -from mace.tools.torch_geometric.dataset import Dataset -from mace.tools.utils import AtomicNumberTable - - -def normalize_file_paths(file_paths: Union[str, List[str]]) -> List[str]: - """ - Normalize file paths to a list format. - - Args: - file_paths: Either a string or a list of strings representing file paths - - Returns: - A list of file paths - """ - if isinstance(file_paths, str): - return [file_paths] - if isinstance(file_paths, list): - return file_paths - raise ValueError(f"Unexpected file paths format: {type(file_paths)}") - - -def load_dataset_for_path( - file_path: Union[str, Path, List[str]], - r_max: float, - z_table: AtomicNumberTable, - heads: List[str], - head_config: Any, - collection: Optional[Any] = None, -) -> Union[Dataset, List]: - """ - Load a dataset from a file path based on its format. - - Args: - file_path: Path to the dataset file - r_max: Cutoff radius - z_table: Atomic number table - heads: List of head names - head_name: Current head name - **kwargs: Additional arguments - - Returns: - Loaded dataset - """ - if isinstance(file_path, list): - if len(file_path) == 1: - file_path = file_path[0] - if isinstance(file_path, list): - is_ase_readable = all(check_path_ase_read(p) for p in file_path) - if not is_ase_readable: - raise ValueError( - "Not all paths in the list are ASE readable, not supported" - ) - if isinstance(file_path, str): - is_ase_readable = check_path_ase_read(file_path) - - if is_ase_readable: - assert ( - collection is not None - ), "Collection must be provided for ASE readable files" - return [ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=r_max, heads=heads - ) - for config in collection - ] - - filepath = Path(file_path) - if filepath.is_dir(): - - if filepath.name.endswith("_lmdb") or any( - f.endswith(".lmdb") or f.endswith(".aselmdb") for f in os.listdir(filepath) - ): - logging.info(f"Loading LMDB dataset from {file_path}") - return data.LMDBDataset( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - - h5_files = list(filepath.glob("*.h5")) + list(filepath.glob("*.hdf5")) - if h5_files: - logging.info(f"Loading HDF5 dataset from directory {file_path}") - try: - return data.dataset_from_sharded_hdf5( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - except Exception as e: - logging.error(f"Error loading sharded HDF5 dataset: {e}") - raise - - if "lmdb" in str(filepath).lower() or "aselmdb" in str(filepath).lower(): - logging.info(f"Loading LMDB dataset based on path name: {file_path}") - return data.LMDBDataset( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - - logging.info(f"Attempting to load directory as HDF5 dataset: {file_path}") - try: - return data.dataset_from_sharded_hdf5( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - except Exception as e: - logging.error(f"Error loading as sharded HDF5: {e}") - raise - - suffix = filepath.suffix.lower() - if suffix in (".h5", ".hdf5"): - logging.info(f"Loading single HDF5 file: {file_path}") - return data.HDF5Dataset( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - - if suffix in (".lmdb", ".aselmdb", ".db"): - logging.info(f"Loading single LMDB file: {file_path}") - return data.LMDBDataset( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - - logging.info(f"Attempting to load as LMDB: {file_path}") - return data.LMDBDataset( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - - -def combine_datasets(datasets, head_name): - """ - Combine multiple datasets which might be of different types. - - Args: - datasets: List of datasets (can be mixed types) - head_name: Name of the current head - - Returns: - Combined dataset - """ - if not datasets: - return [] - - if all(isinstance(ds, list) for ds in datasets): - logging.info(f"Combining {len(datasets)} list datasets for head '{head_name}'") - return [item for sublist in datasets for item in sublist] - - if all(not isinstance(ds, list) for ds in datasets): - logging.info( - f"Combining {len(datasets)} Dataset objects for head '{head_name}'" - ) - return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] - - logging.info(f"Converting mixed dataset types for head '{head_name}'") - - try: - all_items = [] - for ds in datasets: - if isinstance(ds, list): - all_items.extend(ds) - else: - all_items.extend([ds[i] for i in range(len(ds))]) - return all_items - except Exception as e: # pylint: disable=W0703 - logging.warning(f"Failed to convert mixed datasets to list: {e}") - - try: - dataset_objects = [] - for ds in datasets: - if isinstance(ds, list): - from torch.utils.data import TensorDataset - - # Convert list to a Dataset - dataset_objects.append( - TensorDataset(*[torch.tensor([i]) for i in range(len(ds))]) - ) - else: - dataset_objects.append(ds) - return ConcatDataset(dataset_objects) - except Exception as e: # pylint: disable=W0703 - logging.warning(f"Failed to convert mixed datasets to ConcatDataset: {e}") - - logging.warning( - "Could not combine datasets of different types. Using only the first dataset." - ) - return datasets[0] +import logging +import os +from pathlib import Path +from typing import Any, List, Optional, Union + +import torch +from torch.utils.data import ConcatDataset + +from mace import data +from mace.tools.scripts_utils import check_path_ase_read +from mace.tools.torch_geometric.dataset import Dataset +from mace.tools.utils import AtomicNumberTable + + +def normalize_file_paths(file_paths: Union[str, List[str]]) -> List[str]: + """ + Normalize file paths to a list format. + + Args: + file_paths: Either a string or a list of strings representing file paths + + Returns: + A list of file paths + """ + if isinstance(file_paths, str): + return [file_paths] + if isinstance(file_paths, list): + return file_paths + raise ValueError(f"Unexpected file paths format: {type(file_paths)}") + + +def load_dataset_for_path( + file_path: Union[str, Path, List[str]], + r_max: float, + z_table: AtomicNumberTable, + heads: List[str], + head_config: Any, + collection: Optional[Any] = None, +) -> Union[Dataset, List]: + """ + Load a dataset from a file path based on its format. + + Args: + file_path: Path to the dataset file + r_max: Cutoff radius + z_table: Atomic number table + heads: List of head names + head_name: Current head name + **kwargs: Additional arguments + + Returns: + Loaded dataset + """ + if isinstance(file_path, list): + if len(file_path) == 1: + file_path = file_path[0] + if isinstance(file_path, list): + is_ase_readable = all(check_path_ase_read(p) for p in file_path) + if not is_ase_readable: + raise ValueError( + "Not all paths in the list are ASE readable, not supported" + ) + if isinstance(file_path, str): + is_ase_readable = check_path_ase_read(file_path) + + if is_ase_readable: + assert ( + collection is not None + ), "Collection must be provided for ASE readable files" + return [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=r_max, heads=heads + ) + for config in collection + ] + + filepath = Path(file_path) + if filepath.is_dir(): + + if filepath.name.endswith("_lmdb") or any( + f.endswith(".lmdb") or f.endswith(".aselmdb") for f in os.listdir(filepath) + ): + logging.info(f"Loading LMDB dataset from {file_path}") + return data.LMDBDataset( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + + h5_files = list(filepath.glob("*.h5")) + list(filepath.glob("*.hdf5")) + if h5_files: + logging.info(f"Loading HDF5 dataset from directory {file_path}") + try: + return data.dataset_from_sharded_hdf5( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + except Exception as e: + logging.error(f"Error loading sharded HDF5 dataset: {e}") + raise + + if "lmdb" in str(filepath).lower() or "aselmdb" in str(filepath).lower(): + logging.info(f"Loading LMDB dataset based on path name: {file_path}") + return data.LMDBDataset( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + + logging.info(f"Attempting to load directory as HDF5 dataset: {file_path}") + try: + return data.dataset_from_sharded_hdf5( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + except Exception as e: + logging.error(f"Error loading as sharded HDF5: {e}") + raise + + suffix = filepath.suffix.lower() + if suffix in (".h5", ".hdf5"): + logging.info(f"Loading single HDF5 file: {file_path}") + return data.HDF5Dataset( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + + if suffix in (".lmdb", ".aselmdb", ".db"): + logging.info(f"Loading single LMDB file: {file_path}") + return data.LMDBDataset( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + + logging.info(f"Attempting to load as LMDB: {file_path}") + return data.LMDBDataset( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + + +def combine_datasets(datasets, head_name): + """ + Combine multiple datasets which might be of different types. + + Args: + datasets: List of datasets (can be mixed types) + head_name: Name of the current head + + Returns: + Combined dataset + """ + if not datasets: + return [] + + if all(isinstance(ds, list) for ds in datasets): + logging.info(f"Combining {len(datasets)} list datasets for head '{head_name}'") + return [item for sublist in datasets for item in sublist] + + if all(not isinstance(ds, list) for ds in datasets): + logging.info( + f"Combining {len(datasets)} Dataset objects for head '{head_name}'" + ) + return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] + + logging.info(f"Converting mixed dataset types for head '{head_name}'") + + try: + all_items = [] + for ds in datasets: + if isinstance(ds, list): + all_items.extend(ds) + else: + all_items.extend([ds[i] for i in range(len(ds))]) + return all_items + except Exception as e: # pylint: disable=W0703 + logging.warning(f"Failed to convert mixed datasets to list: {e}") + + try: + dataset_objects = [] + for ds in datasets: + if isinstance(ds, list): + from torch.utils.data import TensorDataset + + # Convert list to a Dataset + dataset_objects.append( + TensorDataset(*[torch.tensor([i]) for i in range(len(ds))]) + ) + else: + dataset_objects.append(ds) + return ConcatDataset(dataset_objects) + except Exception as e: # pylint: disable=W0703 + logging.warning(f"Failed to convert mixed datasets to ConcatDataset: {e}") + + logging.warning( + "Could not combine datasets of different types. Using only the first dataset." + ) + return datasets[0] diff --git a/mace-bench/3rdparty/mace/mace/tools/scatter.py b/mace-bench/3rdparty/mace/mace/tools/scatter.py index cf7a5ec7c4bf362738bd5d8be1e2dab48b7ee35e..7e1139a999e5d741b0b57f500d1d349a10092db9 100644 --- a/mace-bench/3rdparty/mace/mace/tools/scatter.py +++ b/mace-bench/3rdparty/mace/mace/tools/scatter.py @@ -1,112 +1,112 @@ -"""basic scatter_sum operations from torch_scatter from -https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py -Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. -PyTorch plans to move these features into the main repo, but until then, -to make installation simpler, we need this pure python set of wrappers -that don't require installing PyTorch C++ extensions. -See https://github.com/pytorch/pytorch/issues/63780. -""" - -from typing import Optional - -import torch - - -def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): - if dim < 0: - dim = other.dim() + dim - if src.dim() == 1: - for _ in range(0, dim): - src = src.unsqueeze(0) - for _ in range(src.dim(), other.dim()): - src = src.unsqueeze(-1) - src = src.expand_as(other) - return src - - -def scatter_sum( - src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, - reduce: str = "sum", -) -> torch.Tensor: - assert reduce == "sum" # for now, TODO - index = _broadcast(index, src, dim) - if out is None: - size = list(src.size()) - if dim_size is not None: - size[dim] = dim_size - elif index.numel() == 0: - size[dim] = 0 - else: - size[dim] = int(index.max()) + 1 - out = torch.zeros(size, dtype=src.dtype, device=src.device) - return out.scatter_add_(dim, index, src) - else: - return out.scatter_add_(dim, index, src) - - -def scatter_std( - src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, - unbiased: bool = True, -) -> torch.Tensor: - if out is not None: - dim_size = out.size(dim) - - if dim < 0: - dim = src.dim() + dim - - count_dim = dim - if index.dim() <= dim: - count_dim = index.dim() - 1 - - ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) - count = scatter_sum(ones, index, count_dim, dim_size=dim_size) - - index = _broadcast(index, src, dim) - tmp = scatter_sum(src, index, dim, dim_size=dim_size) - count = _broadcast(count, tmp, dim).clamp(1) - mean = tmp.div(count) - - var = src - mean.gather(dim, index) - var = var * var - out = scatter_sum(var, index, dim, out, dim_size) - - if unbiased: - count = count.sub(1).clamp_(1) - out = out.div(count + 1e-6).sqrt() - - return out - - -def scatter_mean( - src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, -) -> torch.Tensor: - out = scatter_sum(src, index, dim, out, dim_size) - dim_size = out.size(dim) - - index_dim = dim - if index_dim < 0: - index_dim = index_dim + src.dim() - if index.dim() <= index_dim: - index_dim = index.dim() - 1 - - ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) - count = scatter_sum(ones, index, index_dim, None, dim_size) - count[count < 1] = 1 - count = _broadcast(count, out, dim) - if out.is_floating_point(): - out.true_divide_(count) - else: - out.div_(count, rounding_mode="floor") - return out +"""basic scatter_sum operations from torch_scatter from +https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py +Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. +PyTorch plans to move these features into the main repo, but until then, +to make installation simpler, we need this pure python set of wrappers +that don't require installing PyTorch C++ extensions. +See https://github.com/pytorch/pytorch/issues/63780. +""" + +from typing import Optional + +import torch + + +def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src + + +def scatter_sum( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + reduce: str = "sum", +) -> torch.Tensor: + assert reduce == "sum" # for now, TODO + index = _broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +def scatter_std( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + unbiased: bool = True, +) -> torch.Tensor: + if out is not None: + dim_size = out.size(dim) + + if dim < 0: + dim = src.dim() + dim + + count_dim = dim + if index.dim() <= dim: + count_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = _broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = _broadcast(count, tmp, dim).clamp(1) + mean = tmp.div(count) + + var = src - mean.gather(dim, index) + var = var * var + out = scatter_sum(var, index, dim, out, dim_size) + + if unbiased: + count = count.sub(1).clamp_(1) + out = out.div(count + 1e-6).sqrt() + + return out + + +def scatter_mean( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, +) -> torch.Tensor: + out = scatter_sum(src, index, dim, out, dim_size) + dim_size = out.size(dim) + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = _broadcast(count, out, dim) + if out.is_floating_point(): + out.true_divide_(count) + else: + out.div_(count, rounding_mode="floor") + return out diff --git a/mace-bench/3rdparty/mace/mace/tools/scripts_utils.py b/mace-bench/3rdparty/mace/mace/tools/scripts_utils.py index 91fc674142817a0d741208552a1a6de776afa5d8..bb7f79efe84cf8e3c81237bdb34e31f28fc9ca13 100644 --- a/mace-bench/3rdparty/mace/mace/tools/scripts_utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/scripts_utils.py @@ -1,888 +1,888 @@ -########################################################################################### -# Training utils -# Authors: David Kovacs, Ilyes Batatia -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import argparse -import ast -import dataclasses -import json -import logging -import os -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.distributed -from e3nn import o3 -from torch.optim.swa_utils import SWALR, AveragedModel - -from mace import data, modules, tools -from mace.data import KeySpecification -from mace.tools.train import SWAContainer - - -@dataclasses.dataclass -class SubsetCollection: - train: data.Configurations - valid: data.Configurations - tests: List[Tuple[str, data.Configurations]] - - -def log_dataset_contents(dataset: data.Configurations, dataset_name: str) -> None: - log_string = f"{dataset_name} [" - for prop_name in dataset[0].properties.keys(): - if prop_name == "dipole": - log_string += f"{prop_name} components: {int(np.sum([np.sum(config.property_weights[prop_name]) for config in dataset]))}, " - else: - log_string += f"{prop_name}: {int(np.sum([config.property_weights[prop_name] for config in dataset]))}, " - log_string = log_string[:-2] + "]" - logging.info(log_string) - - -def get_dataset_from_xyz( - work_dir: str, - train_path: Union[str, List[str]], - valid_path: Optional[Union[str, List[str]]], - valid_fraction: float, - key_specification: KeySpecification, - config_type_weights: Optional[Dict] = None, - test_path: Optional[Union[str, List[str]]] = None, - seed: int = 1234, - keep_isolated_atoms: bool = False, - head_name: str = "Default", -) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: - """ - Load training, validation, and test datasets from xyz files. - - Args: - work_dir: Working directory for saving split information - train_path: Path or list of paths to training xyz files - valid_path: Path or list of paths to validation xyz files - valid_fraction: Fraction of training data to use for validation if valid_path is None - config_type_weights: Dictionary of weights for each configuration type - key_specification: KeySpecification object for loading data - test_path: Path or list of paths to test xyz files - seed: Random seed for train/validation split - keep_isolated_atoms: Whether to keep isolated atoms in the dataset - head_name: Name of the head for multi-head models - - Returns: - Tuple containing: - - SubsetCollection with train, valid, and test configurations - - Dictionary of atomic energies (or None if not available) - """ - # Convert input paths to lists if they're not already - train_paths = [train_path] if isinstance(train_path, str) else train_path - valid_paths = ( - [valid_path] - if isinstance(valid_path, str) and valid_path is not None - else valid_path - ) - test_paths = ( - [test_path] - if isinstance(test_path, str) and test_path is not None - else test_path - ) - - # Initialize collections and atomic energies tracking - all_train_configs = [] - all_valid_configs = [] - all_test_configs = [] - - # For tracking atomic energies across files - atomic_energies_values = {} # Element Z -> list of energy values - atomic_energies_counts = {} # Element Z -> count of files with this element - - # Process training files - for i, path in enumerate(train_paths): - logging.debug(f"Loading training file: {path}") - ae_dict, train_configs = data.load_from_xyz( - file_path=path, - config_type_weights=config_type_weights, - key_specification=key_specification, - extract_atomic_energies=True, # Extract from all files to average - keep_isolated_atoms=keep_isolated_atoms, - head_name=head_name, - ) - all_train_configs.extend(train_configs) - - # Track atomic energies from each file for averaging - if ae_dict: - for element, energy in ae_dict.items(): - if element not in atomic_energies_values: - atomic_energies_values[element] = [] - atomic_energies_counts[element] = 0 - - atomic_energies_values[element].append(energy) - atomic_energies_counts[element] += 1 - - log_dataset_contents(train_configs, f"Training set {i+1}/{len(train_paths)}") - - # Log total training set info - log_dataset_contents(all_train_configs, "Total Training set") - - # Process validation files if provided - if valid_paths: - for i, path in enumerate(valid_paths): - _, valid_configs = data.load_from_xyz( - file_path=path, - config_type_weights=config_type_weights, - key_specification=key_specification, - extract_atomic_energies=False, - head_name=head_name, - ) - all_valid_configs.extend(valid_configs) - log_dataset_contents( - valid_configs, f"Validation set {i+1}/{len(valid_paths)}" - ) - - # Log total validation set info - log_dataset_contents(all_valid_configs, "Total Validation set") - train_configs = all_train_configs - valid_configs = all_valid_configs - else: - # Split training data if no validation files are provided - logging.info("No validation set provided, splitting training data instead.") - train_configs, valid_configs = data.random_train_valid_split( - all_train_configs, valid_fraction, seed, work_dir - ) - log_dataset_contents(train_configs, "Random Split Training set") - log_dataset_contents(valid_configs, "Random Split Validation set") - - test_configs_by_type = [] - if test_paths: - for i, path in enumerate(test_paths): - _, test_configs = data.load_from_xyz( - file_path=path, - config_type_weights=config_type_weights, - key_specification=key_specification, - extract_atomic_energies=False, - head_name=head_name, - ) - all_test_configs.extend(test_configs) - - log_dataset_contents(test_configs, f"Test set {i+1}/{len(test_paths)}") - - # Create list of tuples (config_type, list(Atoms)) - test_configs_by_type = data.test_config_types(all_test_configs) - log_dataset_contents(all_test_configs, "Total Test set") - - atomic_energies_dict = {} - for element, values in atomic_energies_values.items(): - if atomic_energies_counts[element] > 1: - atomic_energies_dict[element] = sum(values) / len(values) - logging.debug( - f"Element {element} found in {atomic_energies_counts[element]} files. Using average E0: {atomic_energies_dict[element]:.6f} eV" - ) - else: - atomic_energies_dict[element] = values[0] - logging.debug( - f"Element {element} found in 1 file. Using E0: {atomic_energies_dict[element]:.6f} eV" - ) - - return ( - SubsetCollection( - train=train_configs, valid=valid_configs, tests=test_configs_by_type - ), - atomic_energies_dict if atomic_energies_dict else None, - ) - - -def get_config_type_weights(ct_weights): - """ - Get config type weights from command line argument - """ - try: - config_type_weights = ast.literal_eval(ct_weights) - assert isinstance(config_type_weights, dict) - except Exception as e: # pylint: disable=W0703 - logging.warning( - f"Config type weights not specified correctly ({e}), using Default" - ) - config_type_weights = {"Default": 1.0} - return config_type_weights - - -def print_git_commit(): - try: - import git - - repo = git.Repo(search_parent_directories=True) - commit = repo.head.commit.hexsha - logging.debug(f"Current Git commit: {commit}") - return commit - except Exception as e: # pylint: disable=W0703 - logging.debug(f"Error accessing Git repository: {e}") - return "None" - - -def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]: - if model.__class__.__name__ != "ScaleShiftMACE": - return {"error": "Model is not a ScaleShiftMACE model"} - - def radial_to_name(radial_type): - if radial_type == "BesselBasis": - return "bessel" - if radial_type == "GaussianBasis": - return "gaussian" - if radial_type == "ChebychevBasis": - return "chebyshev" - return radial_type - - def radial_to_transform(radial): - if not hasattr(radial, "distance_transform"): - return None - if radial.distance_transform.__class__.__name__ == "AgnesiTransform": - return "Agnesi" - if radial.distance_transform.__class__.__name__ == "SoftTransform": - return "Soft" - return radial.distance_transform.__class__.__name__ - - scale = model.scale_shift.scale - shift = model.scale_shift.shift - heads = model.heads if hasattr(model, "heads") else ["default"] - model_mlp_irreps = ( - o3.Irreps(str(model.readouts[-1].hidden_irreps)) - if model.num_interactions.item() > 1 - else 1 - ) - mlp_irreps = o3.Irreps(f"{model_mlp_irreps.count((0, 1)) // len(heads)}x0e") - try: - correlation = ( - len(model.products[0].symmetric_contractions.contractions[0].weights) + 1 - ) - except AttributeError: - correlation = model.products[0].symmetric_contractions.contraction_degree - config = { - "r_max": model.r_max.item(), - "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), - "num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(), - "max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access - "interaction_cls": model.interactions[-1].__class__, - "interaction_cls_first": model.interactions[0].__class__, - "num_interactions": model.num_interactions.item(), - "num_elements": len(model.atomic_numbers), - "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), - "MLP_irreps": (mlp_irreps if model.num_interactions.item() > 1 else 1), - "gate": ( - model.readouts[-1] # pylint: disable=protected-access - .non_linearity._modules["acts"][0] - .f - if model.num_interactions.item() > 1 - else None - ), - "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), - "avg_num_neighbors": model.interactions[0].avg_num_neighbors, - "atomic_numbers": model.atomic_numbers, - "correlation": correlation, - "radial_type": radial_to_name( - model.radial_embedding.bessel_fn.__class__.__name__ - ), - "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], - "pair_repulsion": hasattr(model, "pair_repulsion_fn"), - "distance_transform": radial_to_transform(model.radial_embedding), - "atomic_inter_scale": scale.cpu().numpy(), - "atomic_inter_shift": shift.cpu().numpy(), - "heads": heads, - } - return config - - -def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: - return extract_model( - torch.load(f=f, map_location=map_location), map_location=map_location - ) - - -def remove_pt_head( - model: torch.nn.Module, head_to_keep: Optional[str] = None -) -> torch.nn.Module: - """Converts a multihead MACE model to a single head model by removing the pretraining head. - - Args: - model (ScaleShiftMACE): The multihead MACE model to convert - head_to_keep (Optional[str]): The name of the head to keep. If None, keeps the first non-PT head. - - Returns: - ScaleShiftMACE: A new MACE model with only the specified head - - Raises: - ValueError: If the model is not a multihead model or if the specified head is not found - """ - if not hasattr(model, "heads") or len(model.heads) <= 1: - raise ValueError("Model must be a multihead model with more than one head") - - # Get index of head to keep - if head_to_keep is None: - # Find first non-PT head - try: - head_idx = next(i for i, h in enumerate(model.heads) if h != "pt_head") - except StopIteration as e: - raise ValueError("No non-PT head found in model") from e - else: - try: - head_idx = model.heads.index(head_to_keep) - except ValueError as e: - raise ValueError(f"Head {head_to_keep} not found in model") from e - - # Extract config and modify for single head - model_config = extract_config_mace_model(model) - model_config["heads"] = [model.heads[head_idx]] - model_config["atomic_energies"] = ( - model.atomic_energies_fn.atomic_energies[head_idx] - .unsqueeze(0) - .detach() - .cpu() - .numpy() - ) - model_config["atomic_inter_scale"] = model.scale_shift.scale[head_idx].item() - model_config["atomic_inter_shift"] = model.scale_shift.shift[head_idx].item() - mlp_count_irreps = model_config["MLP_irreps"].count((0, 1)) - # model_config["MLP_irreps"] = o3.Irreps(f"{mlp_count_irreps}x0e") - - new_model = model.__class__(**model_config) - state_dict = model.state_dict() - new_state_dict = {} - - for name, param in state_dict.items(): - if "atomic_energies" in name: - new_state_dict[name] = param[head_idx : head_idx + 1] - elif "scale" in name or "shift" in name: - new_state_dict[name] = param[head_idx : head_idx + 1] - elif "readouts" in name: - channels_per_head = param.shape[0] // len(model.heads) - start_idx = head_idx * channels_per_head - end_idx = start_idx + channels_per_head - if "linear_2.weight" in name: - end_idx = start_idx + channels_per_head // 2 - # if ( - # "readouts.0.linear.weight" in name - # or "readouts.1.linear_2.weight" in name - # ): - # new_state_dict[name] = param[start_idx:end_idx] / ( - # len(model.heads) ** 0.5 - # ) - if "readouts.0.linear.weight" in name: - new_state_dict[name] = param.reshape(-1, len(model.heads))[ - :, head_idx - ].flatten() - elif "readouts.1.linear_1.weight" in name: - new_state_dict[name] = param.reshape( - -1, len(model.heads), mlp_count_irreps - )[:, head_idx, :].flatten() - elif "readouts.1.linear_2.weight" in name: - new_state_dict[name] = param.reshape( - len(model.heads), -1, len(model.heads) - )[head_idx, :, head_idx].flatten() / (len(model.heads) ** 0.5) - else: - new_state_dict[name] = param[start_idx:end_idx] - - else: - new_state_dict[name] = param - - # Load state dict into new model - new_model.load_state_dict(new_state_dict) - - return new_model - - -def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: - model_copy = model.__class__(**extract_config_mace_model(model)) - model_copy.load_state_dict(model.state_dict()) - return model_copy.to(map_location) - - -def convert_to_json_format(dict_input): - for key, value in dict_input.items(): - if isinstance(value, (np.ndarray, torch.Tensor)): - dict_input[key] = value.tolist() - # # check if the value is a class and convert it to a string - elif hasattr(value, "__class__"): - dict_input[key] = str(value) - return dict_input - - -def convert_from_json_format(dict_input): - dict_output = dict_input.copy() - if ( - dict_input["interaction_cls"] - == "" - ): - dict_output["interaction_cls"] = ( - modules.blocks.RealAgnosticResidualInteractionBlock - ) - if ( - dict_input["interaction_cls"] - == "" - ): - dict_output["interaction_cls"] = modules.blocks.RealAgnosticInteractionBlock - if ( - dict_input["interaction_cls_first"] - == "" - ): - dict_output["interaction_cls_first"] = ( - modules.blocks.RealAgnosticResidualInteractionBlock - ) - if ( - dict_input["interaction_cls_first"] - == "" - ): - dict_output["interaction_cls_first"] = ( - modules.blocks.RealAgnosticInteractionBlock - ) - dict_output["r_max"] = float(dict_input["r_max"]) - dict_output["num_bessel"] = int(dict_input["num_bessel"]) - dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) - dict_output["max_ell"] = int(dict_input["max_ell"]) - dict_output["num_interactions"] = int(dict_input["num_interactions"]) - dict_output["num_elements"] = int(dict_input["num_elements"]) - dict_output["hidden_irreps"] = o3.Irreps(dict_input["hidden_irreps"]) - dict_output["MLP_irreps"] = o3.Irreps(dict_input["MLP_irreps"]) - dict_output["avg_num_neighbors"] = float(dict_input["avg_num_neighbors"]) - dict_output["gate"] = torch.nn.functional.silu - dict_output["atomic_energies"] = np.array(dict_input["atomic_energies"]) - dict_output["atomic_numbers"] = dict_input["atomic_numbers"] - dict_output["correlation"] = int(dict_input["correlation"]) - dict_output["radial_type"] = dict_input["radial_type"] - dict_output["radial_MLP"] = ast.literal_eval(dict_input["radial_MLP"]) - dict_output["pair_repulsion"] = ast.literal_eval(dict_input["pair_repulsion"]) - dict_output["distance_transform"] = dict_input["distance_transform"] - dict_output["atomic_inter_scale"] = float(dict_input["atomic_inter_scale"]) - dict_output["atomic_inter_shift"] = float(dict_input["atomic_inter_shift"]) - - return dict_output - - -def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: - extra_files_extract = {"commit.txt": None, "config.json": None} - model_jit_load = torch.jit.load( - f, _extra_files=extra_files_extract, map_location=map_location - ) - model_load_yaml = modules.ScaleShiftMACE( - **convert_from_json_format(json.loads(extra_files_extract["config.json"])) - ) - model_load_yaml.load_state_dict(model_jit_load.state_dict()) - return model_load_yaml.to(map_location) - - -def get_atomic_energies(E0s, train_collection, z_table) -> dict: - if E0s is not None: - logging.info( - "Isolated Atomic Energies (E0s) not in training file, using command line argument" - ) - if E0s.lower() == "average": - logging.info( - "Computing average Atomic Energies using least squares regression" - ) - # catch if colections.train not defined above - try: - assert train_collection is not None - atomic_energies_dict = data.compute_average_E0s( - train_collection, z_table - ) - except Exception as e: - raise RuntimeError( - f"Could not compute average E0s if no training xyz given, error {e} occured" - ) from e - else: - if E0s.endswith(".json"): - logging.info(f"Loading atomic energies from {E0s}") - with open(E0s, "r", encoding="utf-8") as f: - atomic_energies_dict = json.load(f) - atomic_energies_dict = { - int(key): value for key, value in atomic_energies_dict.items() - } - else: - try: - atomic_energies_eval = ast.literal_eval(E0s) - if not all( - isinstance(value, dict) - for value in atomic_energies_eval.values() - ): - atomic_energies_dict = atomic_energies_eval - else: - atomic_energies_dict = atomic_energies_eval - assert isinstance(atomic_energies_dict, dict) - except Exception as e: - raise RuntimeError( - f"E0s specified invalidly, error {e} occured" - ) from e - else: - raise RuntimeError( - "E0s not found in training file and not specified in command line" - ) - return atomic_energies_dict - - -def get_avg_num_neighbors(head_configs, args, train_loader, device): - if all(head_config.compute_avg_num_neighbors for head_config in head_configs): - logging.info("Computing average number of neighbors") - avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) - if args.distributed: - num_graphs = torch.tensor(len(train_loader.dataset)).to(device) - num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) - torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce( - num_neighbors, op=torch.distributed.ReduceOp.SUM - ) - avg_num_neighbors_out = (num_neighbors / num_graphs).item() - else: - avg_num_neighbors_out = avg_num_neighbors - else: - assert any( - head_config.avg_num_neighbors is not None for head_config in head_configs - ), "Average number of neighbors must be provided in the configuration" - avg_num_neighbors_out = max( - head_config.avg_num_neighbors - for head_config in head_configs - if head_config.avg_num_neighbors is not None - ) - if avg_num_neighbors_out < 2 or avg_num_neighbors_out > 100: - logging.warning( - f"Unusual average number of neighbors: {avg_num_neighbors_out:.1f}" - ) - else: - logging.info(f"Average number of neighbors: {avg_num_neighbors_out}") - return avg_num_neighbors_out - - -def get_loss_fn( - args: argparse.Namespace, - dipole_only: bool, - compute_dipole: bool, -) -> torch.nn.Module: - if args.loss == "weighted": - loss_fn = modules.WeightedEnergyForcesLoss( - energy_weight=args.energy_weight, forces_weight=args.forces_weight - ) - elif args.loss == "forces_only": - loss_fn = modules.WeightedForcesLoss(forces_weight=args.forces_weight) - elif args.loss == "virials": - loss_fn = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - virials_weight=args.virials_weight, - ) - elif args.loss == "stress": - loss_fn = modules.WeightedEnergyForcesStressLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - ) - elif args.loss == "huber": - loss_fn = modules.WeightedHuberEnergyForcesStressLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - huber_delta=args.huber_delta, - ) - elif args.loss == "universal": - loss_fn = modules.UniversalLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - huber_delta=args.huber_delta, - ) - elif args.loss == "l1l2energyforces": - loss_fn = modules.WeightedEnergyForcesL1L2Loss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - ) - elif args.loss == "dipole": - assert ( - dipole_only is True - ), "dipole loss can only be used with AtomicDipolesMACE model" - loss_fn = modules.DipoleSingleLoss( - dipole_weight=args.dipole_weight, - ) - elif args.loss == "energy_forces_dipole": - assert dipole_only is False and compute_dipole is True - loss_fn = modules.WeightedEnergyForcesDipoleLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - dipole_weight=args.dipole_weight, - ) - else: - loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) - return loss_fn - - -def get_swa( - args: argparse.Namespace, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - swas: List[bool], - dipole_only: bool = False, -): - assert dipole_only is False, "Stage Two for dipole fitting not implemented" - swas.append(True) - if args.start_swa is None: - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - else: - if args.start_swa >= args.max_num_epochs: - logging.warning( - f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" - ) - swas[-1] = False - if args.loss == "forces_only": - raise ValueError("Can not select Stage Two with forces only loss.") - if args.loss == "virials": - loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - virials_weight=args.swa_virials_weight, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, virials_weight: {args.swa_virials_weight} and learning rate : {args.swa_lr}" - ) - elif args.loss == "stress": - loss_fn_energy = modules.WeightedEnergyForcesStressLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - stress_weight=args.swa_stress_weight, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" - ) - elif args.loss == "energy_forces_dipole": - loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss( - args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - dipole_weight=args.swa_dipole_weight, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" - ) - elif args.loss == "universal": - loss_fn_energy = modules.UniversalLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - stress_weight=args.swa_stress_weight, - huber_delta=args.huber_delta, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" - ) - else: - loss_fn_energy = modules.WeightedEnergyForcesLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" - ) - swa = SWAContainer( - model=AveragedModel(model), - scheduler=SWALR( - optimizer=optimizer, - swa_lr=args.swa_lr, - anneal_epochs=1, - anneal_strategy="linear", - ), - start=args.start_swa, - loss_fn=loss_fn_energy, - ) - return swa, swas - - -def get_params_options( - args: argparse.Namespace, model: torch.nn.Module -) -> Dict[str, Any]: - decay_interactions = {} - no_decay_interactions = {} - for name, param in model.interactions.named_parameters(): - if "linear.weight" in name or "skip_tp_full.weight" in name: - decay_interactions[name] = param - else: - no_decay_interactions[name] = param - - param_options = dict( - params=[ - { - "name": "embedding", - "params": model.node_embedding.parameters(), - "weight_decay": 0.0, - }, - { - "name": "interactions_decay", - "params": list(decay_interactions.values()), - "weight_decay": args.weight_decay, - }, - { - "name": "interactions_no_decay", - "params": list(no_decay_interactions.values()), - "weight_decay": 0.0, - }, - { - "name": "products", - "params": model.products.parameters(), - "weight_decay": args.weight_decay, - }, - { - "name": "readouts", - "params": model.readouts.parameters(), - "weight_decay": 0.0, - }, - ], - lr=args.lr, - amsgrad=args.amsgrad, - betas=(args.beta, 0.999), - ) - return param_options - - -def get_optimizer( - args: argparse.Namespace, param_options: Dict[str, Any] -) -> torch.optim.Optimizer: - if args.optimizer == "adamw": - optimizer = torch.optim.AdamW(**param_options) - elif args.optimizer == "schedulefree": - try: - from schedulefree import adamw_schedulefree - except ImportError as exc: - raise ImportError( - "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" - ) from exc - _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} - optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) - else: - optimizer = torch.optim.Adam(**param_options) - return optimizer - - -def setup_wandb(args: argparse.Namespace): - logging.info("Using Weights and Biases for logging") - import wandb - - wandb_config = {} - args_dict = vars(args) - - for key, value in args_dict.items(): - if isinstance(value, np.ndarray): - args_dict[key] = value.tolist() - - class CustomEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, KeySpecification): - return o.__dict__ - return super().default(o) - - args_dict_json = json.dumps(args_dict, cls=CustomEncoder) - for key in args.wandb_log_hypers: - wandb_config[key] = args_dict[key] - tools.init_wandb( - project=args.wandb_project, - entity=args.wandb_entity, - name=args.wandb_name, - config=wandb_config, - directory=args.wandb_dir, - ) - wandb.run.summary["params"] = args_dict_json - - -def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: - return [ - os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) - ] - - -def dict_to_array(input_data, heads): - if all(isinstance(value, np.ndarray) for value in input_data.values()): - return np.array([input_data[head] for head in heads]) - if not all(isinstance(value, dict) for value in input_data.values()): - return np.array([[input_data[head]] for head in heads]) - unique_keys = set() - for inner_dict in input_data.values(): - unique_keys.update(inner_dict.keys()) - unique_keys = list(unique_keys) - sorted_keys = sorted([int(key) for key in unique_keys]) - result_array = np.zeros((len(input_data), len(sorted_keys))) - for _, (head_name, inner_dict) in enumerate(input_data.items()): - for key, value in inner_dict.items(): - key_index = sorted_keys.index(int(key)) - head_index = heads.index(head_name) - result_array[head_index][key_index] = value - return result_array - - -class LRScheduler: - def __init__(self, optimizer, args) -> None: - self.scheduler = args.scheduler - self._optimizer_type = ( - args.optimizer - ) # Schedulefree does not need an optimizer but checkpoint handler does. - if args.scheduler == "ExponentialLR": - self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( - optimizer=optimizer, gamma=args.lr_scheduler_gamma - ) - elif args.scheduler == "ReduceLROnPlateau": - self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer=optimizer, - factor=args.lr_factor, - patience=args.scheduler_patience, - ) - else: - raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'") - - def step(self, metrics=None, epoch=None): # pylint: disable=E1123 - if self._optimizer_type == "schedulefree": - return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary - if self.scheduler == "ExponentialLR": - self.lr_scheduler.step(epoch=epoch) - elif self.scheduler == "ReduceLROnPlateau": - self.lr_scheduler.step( # pylint: disable=E1123 - metrics=metrics, epoch=epoch - ) - - def __getattr__(self, name): - if name == "step": - return self.step - return getattr(self.lr_scheduler, name) - - -def check_folder_subfolder(folder_path): - entries = os.listdir(folder_path) - for entry in entries: - full_path = os.path.join(folder_path, entry) - if os.path.isdir(full_path): - return True - return False - - -def check_path_ase_read(filename: Optional[str]) -> bool: - if filename is None: - return False - filepath = Path(filename) - if filepath.is_dir(): - num_h5_files = len(list(filepath.glob("*.h5"))) - num_hdf5_files = len(list(filepath.glob("*.hdf5"))) - num_ldb_files = len(list(filepath.glob("*.lmdb"))) - num_aselmbd_files = len(list(filepath.glob("*.aselmdb"))) - num_mdb_files = len(list(filepath.glob("*.mdb"))) - if ( - num_h5_files - + num_hdf5_files - + num_ldb_files - + num_aselmbd_files - + num_mdb_files - == 0 - ): - # print all the files in the directory extension in the directory for debugging - for file in os.listdir(filepath): - print(file) - raise RuntimeError(f"No supported files found in directory '{filename}'") - return False - if filepath.suffix in (".h5", ".hdf5", ".lmdb", ".aselmdb", ".mdb"): - return False - return True - - -def dict_to_namespace(dictionary): - # Convert the dictionary into an argparse.Namespace - namespace = argparse.Namespace() - for key, value in dictionary.items(): - setattr(namespace, key, value) - return namespace +########################################################################################### +# Training utils +# Authors: David Kovacs, Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse +import ast +import dataclasses +import json +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed +from e3nn import o3 +from torch.optim.swa_utils import SWALR, AveragedModel + +from mace import data, modules, tools +from mace.data import KeySpecification +from mace.tools.train import SWAContainer + + +@dataclasses.dataclass +class SubsetCollection: + train: data.Configurations + valid: data.Configurations + tests: List[Tuple[str, data.Configurations]] + + +def log_dataset_contents(dataset: data.Configurations, dataset_name: str) -> None: + log_string = f"{dataset_name} [" + for prop_name in dataset[0].properties.keys(): + if prop_name == "dipole": + log_string += f"{prop_name} components: {int(np.sum([np.sum(config.property_weights[prop_name]) for config in dataset]))}, " + else: + log_string += f"{prop_name}: {int(np.sum([config.property_weights[prop_name] for config in dataset]))}, " + log_string = log_string[:-2] + "]" + logging.info(log_string) + + +def get_dataset_from_xyz( + work_dir: str, + train_path: Union[str, List[str]], + valid_path: Optional[Union[str, List[str]]], + valid_fraction: float, + key_specification: KeySpecification, + config_type_weights: Optional[Dict] = None, + test_path: Optional[Union[str, List[str]]] = None, + seed: int = 1234, + keep_isolated_atoms: bool = False, + head_name: str = "Default", +) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: + """ + Load training, validation, and test datasets from xyz files. + + Args: + work_dir: Working directory for saving split information + train_path: Path or list of paths to training xyz files + valid_path: Path or list of paths to validation xyz files + valid_fraction: Fraction of training data to use for validation if valid_path is None + config_type_weights: Dictionary of weights for each configuration type + key_specification: KeySpecification object for loading data + test_path: Path or list of paths to test xyz files + seed: Random seed for train/validation split + keep_isolated_atoms: Whether to keep isolated atoms in the dataset + head_name: Name of the head for multi-head models + + Returns: + Tuple containing: + - SubsetCollection with train, valid, and test configurations + - Dictionary of atomic energies (or None if not available) + """ + # Convert input paths to lists if they're not already + train_paths = [train_path] if isinstance(train_path, str) else train_path + valid_paths = ( + [valid_path] + if isinstance(valid_path, str) and valid_path is not None + else valid_path + ) + test_paths = ( + [test_path] + if isinstance(test_path, str) and test_path is not None + else test_path + ) + + # Initialize collections and atomic energies tracking + all_train_configs = [] + all_valid_configs = [] + all_test_configs = [] + + # For tracking atomic energies across files + atomic_energies_values = {} # Element Z -> list of energy values + atomic_energies_counts = {} # Element Z -> count of files with this element + + # Process training files + for i, path in enumerate(train_paths): + logging.debug(f"Loading training file: {path}") + ae_dict, train_configs = data.load_from_xyz( + file_path=path, + config_type_weights=config_type_weights, + key_specification=key_specification, + extract_atomic_energies=True, # Extract from all files to average + keep_isolated_atoms=keep_isolated_atoms, + head_name=head_name, + ) + all_train_configs.extend(train_configs) + + # Track atomic energies from each file for averaging + if ae_dict: + for element, energy in ae_dict.items(): + if element not in atomic_energies_values: + atomic_energies_values[element] = [] + atomic_energies_counts[element] = 0 + + atomic_energies_values[element].append(energy) + atomic_energies_counts[element] += 1 + + log_dataset_contents(train_configs, f"Training set {i+1}/{len(train_paths)}") + + # Log total training set info + log_dataset_contents(all_train_configs, "Total Training set") + + # Process validation files if provided + if valid_paths: + for i, path in enumerate(valid_paths): + _, valid_configs = data.load_from_xyz( + file_path=path, + config_type_weights=config_type_weights, + key_specification=key_specification, + extract_atomic_energies=False, + head_name=head_name, + ) + all_valid_configs.extend(valid_configs) + log_dataset_contents( + valid_configs, f"Validation set {i+1}/{len(valid_paths)}" + ) + + # Log total validation set info + log_dataset_contents(all_valid_configs, "Total Validation set") + train_configs = all_train_configs + valid_configs = all_valid_configs + else: + # Split training data if no validation files are provided + logging.info("No validation set provided, splitting training data instead.") + train_configs, valid_configs = data.random_train_valid_split( + all_train_configs, valid_fraction, seed, work_dir + ) + log_dataset_contents(train_configs, "Random Split Training set") + log_dataset_contents(valid_configs, "Random Split Validation set") + + test_configs_by_type = [] + if test_paths: + for i, path in enumerate(test_paths): + _, test_configs = data.load_from_xyz( + file_path=path, + config_type_weights=config_type_weights, + key_specification=key_specification, + extract_atomic_energies=False, + head_name=head_name, + ) + all_test_configs.extend(test_configs) + + log_dataset_contents(test_configs, f"Test set {i+1}/{len(test_paths)}") + + # Create list of tuples (config_type, list(Atoms)) + test_configs_by_type = data.test_config_types(all_test_configs) + log_dataset_contents(all_test_configs, "Total Test set") + + atomic_energies_dict = {} + for element, values in atomic_energies_values.items(): + if atomic_energies_counts[element] > 1: + atomic_energies_dict[element] = sum(values) / len(values) + logging.debug( + f"Element {element} found in {atomic_energies_counts[element]} files. Using average E0: {atomic_energies_dict[element]:.6f} eV" + ) + else: + atomic_energies_dict[element] = values[0] + logging.debug( + f"Element {element} found in 1 file. Using E0: {atomic_energies_dict[element]:.6f} eV" + ) + + return ( + SubsetCollection( + train=train_configs, valid=valid_configs, tests=test_configs_by_type + ), + atomic_energies_dict if atomic_energies_dict else None, + ) + + +def get_config_type_weights(ct_weights): + """ + Get config type weights from command line argument + """ + try: + config_type_weights = ast.literal_eval(ct_weights) + assert isinstance(config_type_weights, dict) + except Exception as e: # pylint: disable=W0703 + logging.warning( + f"Config type weights not specified correctly ({e}), using Default" + ) + config_type_weights = {"Default": 1.0} + return config_type_weights + + +def print_git_commit(): + try: + import git + + repo = git.Repo(search_parent_directories=True) + commit = repo.head.commit.hexsha + logging.debug(f"Current Git commit: {commit}") + return commit + except Exception as e: # pylint: disable=W0703 + logging.debug(f"Error accessing Git repository: {e}") + return "None" + + +def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]: + if model.__class__.__name__ != "ScaleShiftMACE": + return {"error": "Model is not a ScaleShiftMACE model"} + + def radial_to_name(radial_type): + if radial_type == "BesselBasis": + return "bessel" + if radial_type == "GaussianBasis": + return "gaussian" + if radial_type == "ChebychevBasis": + return "chebyshev" + return radial_type + + def radial_to_transform(radial): + if not hasattr(radial, "distance_transform"): + return None + if radial.distance_transform.__class__.__name__ == "AgnesiTransform": + return "Agnesi" + if radial.distance_transform.__class__.__name__ == "SoftTransform": + return "Soft" + return radial.distance_transform.__class__.__name__ + + scale = model.scale_shift.scale + shift = model.scale_shift.shift + heads = model.heads if hasattr(model, "heads") else ["default"] + model_mlp_irreps = ( + o3.Irreps(str(model.readouts[-1].hidden_irreps)) + if model.num_interactions.item() > 1 + else 1 + ) + mlp_irreps = o3.Irreps(f"{model_mlp_irreps.count((0, 1)) // len(heads)}x0e") + try: + correlation = ( + len(model.products[0].symmetric_contractions.contractions[0].weights) + 1 + ) + except AttributeError: + correlation = model.products[0].symmetric_contractions.contraction_degree + config = { + "r_max": model.r_max.item(), + "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), + "num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(), + "max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access + "interaction_cls": model.interactions[-1].__class__, + "interaction_cls_first": model.interactions[0].__class__, + "num_interactions": model.num_interactions.item(), + "num_elements": len(model.atomic_numbers), + "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), + "MLP_irreps": (mlp_irreps if model.num_interactions.item() > 1 else 1), + "gate": ( + model.readouts[-1] # pylint: disable=protected-access + .non_linearity._modules["acts"][0] + .f + if model.num_interactions.item() > 1 + else None + ), + "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), + "avg_num_neighbors": model.interactions[0].avg_num_neighbors, + "atomic_numbers": model.atomic_numbers, + "correlation": correlation, + "radial_type": radial_to_name( + model.radial_embedding.bessel_fn.__class__.__name__ + ), + "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], + "pair_repulsion": hasattr(model, "pair_repulsion_fn"), + "distance_transform": radial_to_transform(model.radial_embedding), + "atomic_inter_scale": scale.cpu().numpy(), + "atomic_inter_shift": shift.cpu().numpy(), + "heads": heads, + } + return config + + +def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: + return extract_model( + torch.load(f=f, map_location=map_location), map_location=map_location + ) + + +def remove_pt_head( + model: torch.nn.Module, head_to_keep: Optional[str] = None +) -> torch.nn.Module: + """Converts a multihead MACE model to a single head model by removing the pretraining head. + + Args: + model (ScaleShiftMACE): The multihead MACE model to convert + head_to_keep (Optional[str]): The name of the head to keep. If None, keeps the first non-PT head. + + Returns: + ScaleShiftMACE: A new MACE model with only the specified head + + Raises: + ValueError: If the model is not a multihead model or if the specified head is not found + """ + if not hasattr(model, "heads") or len(model.heads) <= 1: + raise ValueError("Model must be a multihead model with more than one head") + + # Get index of head to keep + if head_to_keep is None: + # Find first non-PT head + try: + head_idx = next(i for i, h in enumerate(model.heads) if h != "pt_head") + except StopIteration as e: + raise ValueError("No non-PT head found in model") from e + else: + try: + head_idx = model.heads.index(head_to_keep) + except ValueError as e: + raise ValueError(f"Head {head_to_keep} not found in model") from e + + # Extract config and modify for single head + model_config = extract_config_mace_model(model) + model_config["heads"] = [model.heads[head_idx]] + model_config["atomic_energies"] = ( + model.atomic_energies_fn.atomic_energies[head_idx] + .unsqueeze(0) + .detach() + .cpu() + .numpy() + ) + model_config["atomic_inter_scale"] = model.scale_shift.scale[head_idx].item() + model_config["atomic_inter_shift"] = model.scale_shift.shift[head_idx].item() + mlp_count_irreps = model_config["MLP_irreps"].count((0, 1)) + # model_config["MLP_irreps"] = o3.Irreps(f"{mlp_count_irreps}x0e") + + new_model = model.__class__(**model_config) + state_dict = model.state_dict() + new_state_dict = {} + + for name, param in state_dict.items(): + if "atomic_energies" in name: + new_state_dict[name] = param[head_idx : head_idx + 1] + elif "scale" in name or "shift" in name: + new_state_dict[name] = param[head_idx : head_idx + 1] + elif "readouts" in name: + channels_per_head = param.shape[0] // len(model.heads) + start_idx = head_idx * channels_per_head + end_idx = start_idx + channels_per_head + if "linear_2.weight" in name: + end_idx = start_idx + channels_per_head // 2 + # if ( + # "readouts.0.linear.weight" in name + # or "readouts.1.linear_2.weight" in name + # ): + # new_state_dict[name] = param[start_idx:end_idx] / ( + # len(model.heads) ** 0.5 + # ) + if "readouts.0.linear.weight" in name: + new_state_dict[name] = param.reshape(-1, len(model.heads))[ + :, head_idx + ].flatten() + elif "readouts.1.linear_1.weight" in name: + new_state_dict[name] = param.reshape( + -1, len(model.heads), mlp_count_irreps + )[:, head_idx, :].flatten() + elif "readouts.1.linear_2.weight" in name: + new_state_dict[name] = param.reshape( + len(model.heads), -1, len(model.heads) + )[head_idx, :, head_idx].flatten() / (len(model.heads) ** 0.5) + else: + new_state_dict[name] = param[start_idx:end_idx] + + else: + new_state_dict[name] = param + + # Load state dict into new model + new_model.load_state_dict(new_state_dict) + + return new_model + + +def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: + model_copy = model.__class__(**extract_config_mace_model(model)) + model_copy.load_state_dict(model.state_dict()) + return model_copy.to(map_location) + + +def convert_to_json_format(dict_input): + for key, value in dict_input.items(): + if isinstance(value, (np.ndarray, torch.Tensor)): + dict_input[key] = value.tolist() + # # check if the value is a class and convert it to a string + elif hasattr(value, "__class__"): + dict_input[key] = str(value) + return dict_input + + +def convert_from_json_format(dict_input): + dict_output = dict_input.copy() + if ( + dict_input["interaction_cls"] + == "" + ): + dict_output["interaction_cls"] = ( + modules.blocks.RealAgnosticResidualInteractionBlock + ) + if ( + dict_input["interaction_cls"] + == "" + ): + dict_output["interaction_cls"] = modules.blocks.RealAgnosticInteractionBlock + if ( + dict_input["interaction_cls_first"] + == "" + ): + dict_output["interaction_cls_first"] = ( + modules.blocks.RealAgnosticResidualInteractionBlock + ) + if ( + dict_input["interaction_cls_first"] + == "" + ): + dict_output["interaction_cls_first"] = ( + modules.blocks.RealAgnosticInteractionBlock + ) + dict_output["r_max"] = float(dict_input["r_max"]) + dict_output["num_bessel"] = int(dict_input["num_bessel"]) + dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) + dict_output["max_ell"] = int(dict_input["max_ell"]) + dict_output["num_interactions"] = int(dict_input["num_interactions"]) + dict_output["num_elements"] = int(dict_input["num_elements"]) + dict_output["hidden_irreps"] = o3.Irreps(dict_input["hidden_irreps"]) + dict_output["MLP_irreps"] = o3.Irreps(dict_input["MLP_irreps"]) + dict_output["avg_num_neighbors"] = float(dict_input["avg_num_neighbors"]) + dict_output["gate"] = torch.nn.functional.silu + dict_output["atomic_energies"] = np.array(dict_input["atomic_energies"]) + dict_output["atomic_numbers"] = dict_input["atomic_numbers"] + dict_output["correlation"] = int(dict_input["correlation"]) + dict_output["radial_type"] = dict_input["radial_type"] + dict_output["radial_MLP"] = ast.literal_eval(dict_input["radial_MLP"]) + dict_output["pair_repulsion"] = ast.literal_eval(dict_input["pair_repulsion"]) + dict_output["distance_transform"] = dict_input["distance_transform"] + dict_output["atomic_inter_scale"] = float(dict_input["atomic_inter_scale"]) + dict_output["atomic_inter_shift"] = float(dict_input["atomic_inter_shift"]) + + return dict_output + + +def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: + extra_files_extract = {"commit.txt": None, "config.json": None} + model_jit_load = torch.jit.load( + f, _extra_files=extra_files_extract, map_location=map_location + ) + model_load_yaml = modules.ScaleShiftMACE( + **convert_from_json_format(json.loads(extra_files_extract["config.json"])) + ) + model_load_yaml.load_state_dict(model_jit_load.state_dict()) + return model_load_yaml.to(map_location) + + +def get_atomic_energies(E0s, train_collection, z_table) -> dict: + if E0s is not None: + logging.info( + "Isolated Atomic Energies (E0s) not in training file, using command line argument" + ) + if E0s.lower() == "average": + logging.info( + "Computing average Atomic Energies using least squares regression" + ) + # catch if colections.train not defined above + try: + assert train_collection is not None + atomic_energies_dict = data.compute_average_E0s( + train_collection, z_table + ) + except Exception as e: + raise RuntimeError( + f"Could not compute average E0s if no training xyz given, error {e} occured" + ) from e + else: + if E0s.endswith(".json"): + logging.info(f"Loading atomic energies from {E0s}") + with open(E0s, "r", encoding="utf-8") as f: + atomic_energies_dict = json.load(f) + atomic_energies_dict = { + int(key): value for key, value in atomic_energies_dict.items() + } + else: + try: + atomic_energies_eval = ast.literal_eval(E0s) + if not all( + isinstance(value, dict) + for value in atomic_energies_eval.values() + ): + atomic_energies_dict = atomic_energies_eval + else: + atomic_energies_dict = atomic_energies_eval + assert isinstance(atomic_energies_dict, dict) + except Exception as e: + raise RuntimeError( + f"E0s specified invalidly, error {e} occured" + ) from e + else: + raise RuntimeError( + "E0s not found in training file and not specified in command line" + ) + return atomic_energies_dict + + +def get_avg_num_neighbors(head_configs, args, train_loader, device): + if all(head_config.compute_avg_num_neighbors for head_config in head_configs): + logging.info("Computing average number of neighbors") + avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) + if args.distributed: + num_graphs = torch.tensor(len(train_loader.dataset)).to(device) + num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) + torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce( + num_neighbors, op=torch.distributed.ReduceOp.SUM + ) + avg_num_neighbors_out = (num_neighbors / num_graphs).item() + else: + avg_num_neighbors_out = avg_num_neighbors + else: + assert any( + head_config.avg_num_neighbors is not None for head_config in head_configs + ), "Average number of neighbors must be provided in the configuration" + avg_num_neighbors_out = max( + head_config.avg_num_neighbors + for head_config in head_configs + if head_config.avg_num_neighbors is not None + ) + if avg_num_neighbors_out < 2 or avg_num_neighbors_out > 100: + logging.warning( + f"Unusual average number of neighbors: {avg_num_neighbors_out:.1f}" + ) + else: + logging.info(f"Average number of neighbors: {avg_num_neighbors_out}") + return avg_num_neighbors_out + + +def get_loss_fn( + args: argparse.Namespace, + dipole_only: bool, + compute_dipole: bool, +) -> torch.nn.Module: + if args.loss == "weighted": + loss_fn = modules.WeightedEnergyForcesLoss( + energy_weight=args.energy_weight, forces_weight=args.forces_weight + ) + elif args.loss == "forces_only": + loss_fn = modules.WeightedForcesLoss(forces_weight=args.forces_weight) + elif args.loss == "virials": + loss_fn = modules.WeightedEnergyForcesVirialsLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + virials_weight=args.virials_weight, + ) + elif args.loss == "stress": + loss_fn = modules.WeightedEnergyForcesStressLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + ) + elif args.loss == "huber": + loss_fn = modules.WeightedHuberEnergyForcesStressLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + huber_delta=args.huber_delta, + ) + elif args.loss == "universal": + loss_fn = modules.UniversalLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + huber_delta=args.huber_delta, + ) + elif args.loss == "l1l2energyforces": + loss_fn = modules.WeightedEnergyForcesL1L2Loss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + ) + elif args.loss == "dipole": + assert ( + dipole_only is True + ), "dipole loss can only be used with AtomicDipolesMACE model" + loss_fn = modules.DipoleSingleLoss( + dipole_weight=args.dipole_weight, + ) + elif args.loss == "energy_forces_dipole": + assert dipole_only is False and compute_dipole is True + loss_fn = modules.WeightedEnergyForcesDipoleLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + dipole_weight=args.dipole_weight, + ) + else: + loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) + return loss_fn + + +def get_swa( + args: argparse.Namespace, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + swas: List[bool], + dipole_only: bool = False, +): + assert dipole_only is False, "Stage Two for dipole fitting not implemented" + swas.append(True) + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + else: + if args.start_swa >= args.max_num_epochs: + logging.warning( + f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" + ) + swas[-1] = False + if args.loss == "forces_only": + raise ValueError("Can not select Stage Two with forces only loss.") + if args.loss == "virials": + loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + virials_weight=args.swa_virials_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, virials_weight: {args.swa_virials_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "stress": + loss_fn_energy = modules.WeightedEnergyForcesStressLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + stress_weight=args.swa_stress_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "energy_forces_dipole": + loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss( + args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + dipole_weight=args.swa_dipole_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "universal": + loss_fn_energy = modules.UniversalLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + stress_weight=args.swa_stress_weight, + huber_delta=args.huber_delta, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" + ) + else: + loss_fn_energy = modules.WeightedEnergyForcesLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" + ) + swa = SWAContainer( + model=AveragedModel(model), + scheduler=SWALR( + optimizer=optimizer, + swa_lr=args.swa_lr, + anneal_epochs=1, + anneal_strategy="linear", + ), + start=args.start_swa, + loss_fn=loss_fn_energy, + ) + return swa, swas + + +def get_params_options( + args: argparse.Namespace, model: torch.nn.Module +) -> Dict[str, Any]: + decay_interactions = {} + no_decay_interactions = {} + for name, param in model.interactions.named_parameters(): + if "linear.weight" in name or "skip_tp_full.weight" in name: + decay_interactions[name] = param + else: + no_decay_interactions[name] = param + + param_options = dict( + params=[ + { + "name": "embedding", + "params": model.node_embedding.parameters(), + "weight_decay": 0.0, + }, + { + "name": "interactions_decay", + "params": list(decay_interactions.values()), + "weight_decay": args.weight_decay, + }, + { + "name": "interactions_no_decay", + "params": list(no_decay_interactions.values()), + "weight_decay": 0.0, + }, + { + "name": "products", + "params": model.products.parameters(), + "weight_decay": args.weight_decay, + }, + { + "name": "readouts", + "params": model.readouts.parameters(), + "weight_decay": 0.0, + }, + ], + lr=args.lr, + amsgrad=args.amsgrad, + betas=(args.beta, 0.999), + ) + return param_options + + +def get_optimizer( + args: argparse.Namespace, param_options: Dict[str, Any] +) -> torch.optim.Optimizer: + if args.optimizer == "adamw": + optimizer = torch.optim.AdamW(**param_options) + elif args.optimizer == "schedulefree": + try: + from schedulefree import adamw_schedulefree + except ImportError as exc: + raise ImportError( + "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" + ) from exc + _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} + optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) + else: + optimizer = torch.optim.Adam(**param_options) + return optimizer + + +def setup_wandb(args: argparse.Namespace): + logging.info("Using Weights and Biases for logging") + import wandb + + wandb_config = {} + args_dict = vars(args) + + for key, value in args_dict.items(): + if isinstance(value, np.ndarray): + args_dict[key] = value.tolist() + + class CustomEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, KeySpecification): + return o.__dict__ + return super().default(o) + + args_dict_json = json.dumps(args_dict, cls=CustomEncoder) + for key in args.wandb_log_hypers: + wandb_config[key] = args_dict[key] + tools.init_wandb( + project=args.wandb_project, + entity=args.wandb_entity, + name=args.wandb_name, + config=wandb_config, + directory=args.wandb_dir, + ) + wandb.run.summary["params"] = args_dict_json + + +def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: + return [ + os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) + ] + + +def dict_to_array(input_data, heads): + if all(isinstance(value, np.ndarray) for value in input_data.values()): + return np.array([input_data[head] for head in heads]) + if not all(isinstance(value, dict) for value in input_data.values()): + return np.array([[input_data[head]] for head in heads]) + unique_keys = set() + for inner_dict in input_data.values(): + unique_keys.update(inner_dict.keys()) + unique_keys = list(unique_keys) + sorted_keys = sorted([int(key) for key in unique_keys]) + result_array = np.zeros((len(input_data), len(sorted_keys))) + for _, (head_name, inner_dict) in enumerate(input_data.items()): + for key, value in inner_dict.items(): + key_index = sorted_keys.index(int(key)) + head_index = heads.index(head_name) + result_array[head_index][key_index] = value + return result_array + + +class LRScheduler: + def __init__(self, optimizer, args) -> None: + self.scheduler = args.scheduler + self._optimizer_type = ( + args.optimizer + ) # Schedulefree does not need an optimizer but checkpoint handler does. + if args.scheduler == "ExponentialLR": + self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer=optimizer, gamma=args.lr_scheduler_gamma + ) + elif args.scheduler == "ReduceLROnPlateau": + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer=optimizer, + factor=args.lr_factor, + patience=args.scheduler_patience, + ) + else: + raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'") + + def step(self, metrics=None, epoch=None): # pylint: disable=E1123 + if self._optimizer_type == "schedulefree": + return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary + if self.scheduler == "ExponentialLR": + self.lr_scheduler.step(epoch=epoch) + elif self.scheduler == "ReduceLROnPlateau": + self.lr_scheduler.step( # pylint: disable=E1123 + metrics=metrics, epoch=epoch + ) + + def __getattr__(self, name): + if name == "step": + return self.step + return getattr(self.lr_scheduler, name) + + +def check_folder_subfolder(folder_path): + entries = os.listdir(folder_path) + for entry in entries: + full_path = os.path.join(folder_path, entry) + if os.path.isdir(full_path): + return True + return False + + +def check_path_ase_read(filename: Optional[str]) -> bool: + if filename is None: + return False + filepath = Path(filename) + if filepath.is_dir(): + num_h5_files = len(list(filepath.glob("*.h5"))) + num_hdf5_files = len(list(filepath.glob("*.hdf5"))) + num_ldb_files = len(list(filepath.glob("*.lmdb"))) + num_aselmbd_files = len(list(filepath.glob("*.aselmdb"))) + num_mdb_files = len(list(filepath.glob("*.mdb"))) + if ( + num_h5_files + + num_hdf5_files + + num_ldb_files + + num_aselmbd_files + + num_mdb_files + == 0 + ): + # print all the files in the directory extension in the directory for debugging + for file in os.listdir(filepath): + print(file) + raise RuntimeError(f"No supported files found in directory '{filename}'") + return False + if filepath.suffix in (".h5", ".hdf5", ".lmdb", ".aselmdb", ".mdb"): + return False + return True + + +def dict_to_namespace(dictionary): + # Convert the dictionary into an argparse.Namespace + namespace = argparse.Namespace() + for key, value in dictionary.items(): + setattr(namespace, key, value) + return namespace diff --git a/mace-bench/3rdparty/mace/mace/tools/slurm_distributed.py b/mace-bench/3rdparty/mace/mace/tools/slurm_distributed.py index 578f5a01ea2cc69dff97dcba645879005ec2e640..9b7c77b3b2ef354fee3a896a00f1eceed1520dd0 100644 --- a/mace-bench/3rdparty/mace/mace/tools/slurm_distributed.py +++ b/mace-bench/3rdparty/mace/mace/tools/slurm_distributed.py @@ -1,40 +1,40 @@ -########################################################################################### -# Slurm environment setup for distributed training. -# This code is refactored from rsarm's contribution at: -# https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import os - -import hostlist - - -class DistributedEnvironment: - def __init__(self): - self._setup_distr_env() - self.master_addr = os.environ["MASTER_ADDR"] - self.master_port = os.environ["MASTER_PORT"] - self.world_size = int(os.environ["WORLD_SIZE"]) - self.local_rank = int(os.environ["LOCAL_RANK"]) - self.rank = int(os.environ["RANK"]) - - def _setup_distr_env(self): - hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0] - os.environ["MASTER_ADDR"] = hostname - os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33333") - os.environ["WORLD_SIZE"] = os.environ.get( - "SLURM_NTASKS", - str( - int(os.environ["SLURM_NTASKS_PER_NODE"]) - * int(os.environ["SLURM_NNODES"]) - ), - ) - os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] - os.environ["RANK"] = os.environ["SLURM_PROCID"] - - def __repr__(self): - return ( - f"DistributedEnvironment(master_addr={self.master_addr}, master_port={self.master_port}, " - f"world_size={self.world_size}, local_rank={self.local_rank}, rank={self.rank})" - ) +########################################################################################### +# Slurm environment setup for distributed training. +# This code is refactored from rsarm's contribution at: +# https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import os + +import hostlist + + +class DistributedEnvironment: + def __init__(self): + self._setup_distr_env() + self.master_addr = os.environ["MASTER_ADDR"] + self.master_port = os.environ["MASTER_PORT"] + self.world_size = int(os.environ["WORLD_SIZE"]) + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.rank = int(os.environ["RANK"]) + + def _setup_distr_env(self): + hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0] + os.environ["MASTER_ADDR"] = hostname + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33333") + os.environ["WORLD_SIZE"] = os.environ.get( + "SLURM_NTASKS", + str( + int(os.environ["SLURM_NTASKS_PER_NODE"]) + * int(os.environ["SLURM_NNODES"]) + ), + ) + os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] + os.environ["RANK"] = os.environ["SLURM_PROCID"] + + def __repr__(self): + return ( + f"DistributedEnvironment(master_addr={self.master_addr}, master_port={self.master_port}, " + f"world_size={self.world_size}, local_rank={self.local_rank}, rank={self.rank})" + ) diff --git a/mace-bench/3rdparty/mace/mace/tools/tables_utils.py b/mace-bench/3rdparty/mace/mace/tools/tables_utils.py index bbd581bcc740932354f35ec54714d382b0ca4899..04ff64014f4a0d0505a66372a61c23a50df75815 100644 --- a/mace-bench/3rdparty/mace/mace/tools/tables_utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/tables_utils.py @@ -1,246 +1,246 @@ -import logging -from typing import Dict, List, Optional - -import torch -from prettytable import PrettyTable - -from mace.tools import evaluate - - -def custom_key(key): - """ - Helper function to sort the keys of the data loader dictionary - to ensure that the training set, and validation set - are evaluated first - """ - if key == "train": - return (0, key) - if key == "valid": - return (1, key) - return (2, key) - - -def create_error_table( - table_type: str, - all_data_loaders: dict, - model: torch.nn.Module, - loss_fn: torch.nn.Module, - output_args: Dict[str, bool], - log_wandb: bool, - device: str, - distributed: bool = False, - skip_heads: Optional[List[str]] = None, -) -> PrettyTable: - if log_wandb: - import wandb - skip_heads = skip_heads or [] - table = PrettyTable() - if table_type == "TotalRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV", - "RMSE F / meV / A", - "relative F RMSE %", - ] - elif table_type == "PerAtomRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "relative F RMSE %", - ] - elif table_type == "PerAtomRMSEstressvirials": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "relative F RMSE %", - "RMSE Stress (Virials) / meV / A (A^3)", - ] - elif table_type == "PerAtomMAEstressvirials": - table.field_names = [ - "config_type", - "MAE E / meV / atom", - "MAE F / meV / A", - "relative F MAE %", - "MAE Stress (Virials) / meV / A (A^3)", - ] - elif table_type == "TotalMAE": - table.field_names = [ - "config_type", - "MAE E / meV", - "MAE F / meV / A", - "relative F MAE %", - ] - elif table_type == "PerAtomMAE": - table.field_names = [ - "config_type", - "MAE E / meV / atom", - "MAE F / meV / A", - "relative F MAE %", - ] - elif table_type == "DipoleRMSE": - table.field_names = [ - "config_type", - "RMSE MU / mDebye / atom", - "relative MU RMSE %", - ] - elif table_type == "DipoleMAE": - table.field_names = [ - "config_type", - "MAE MU / mDebye / atom", - "relative MU MAE %", - ] - elif table_type == "EnergyDipoleRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "rel F RMSE %", - "RMSE MU / mDebye / atom", - "rel MU RMSE %", - ] - - for name in sorted(all_data_loaders, key=custom_key): - if any(skip_head in name for skip_head in skip_heads): - logging.info(f"Skipping evaluation of {name} (in skip_heads list)") - continue - data_loader = all_data_loaders[name] - logging.info(f"Evaluating {name} ...") - _, metrics = evaluate( - model, - loss_fn=loss_fn, - data_loader=data_loader, - output_args=output_args, - device=device, - ) - if distributed: - torch.distributed.barrier() - - del data_loader - torch.cuda.empty_cache() - if log_wandb: - wandb_log_dict = { - name - + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] - * 1e3, # meV / atom - name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A - name + "_final_rel_rmse_f": metrics["rel_rmse_f"], - } - wandb.log(wandb_log_dict) - if table_type == "TotalRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - ] - ) - elif table_type == "PerAtomRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - ] - ) - elif ( - table_type == "PerAtomRMSEstressvirials" - and metrics["rmse_stress"] is not None - ): - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - f"{metrics['rmse_stress'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomRMSEstressvirials" - and metrics["rmse_virials"] is not None - ): - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - f"{metrics['rmse_virials'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_stress"] is not None - ): - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - f"{metrics['mae_stress'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_virials"] is not None - ): - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - f"{metrics['mae_virials'] * 1000:8.1f}", - ] - ) - elif table_type == "TotalMAE": - table.add_row( - [ - name, - f"{metrics['mae_e'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - ] - ) - elif table_type == "PerAtomMAE": - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - ] - ) - elif table_type == "DipoleRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", - f"{metrics['rel_rmse_mu']:8.1f}", - ] - ) - elif table_type == "DipoleMAE": - table.add_row( - [ - name, - f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", - f"{metrics['rel_mae_mu']:8.1f}", - ] - ) - elif table_type == "EnergyDipoleRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.1f}", - f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", - f"{metrics['rel_rmse_mu']:8.1f}", - ] - ) - return table +import logging +from typing import Dict, List, Optional + +import torch +from prettytable import PrettyTable + +from mace.tools import evaluate + + +def custom_key(key): + """ + Helper function to sort the keys of the data loader dictionary + to ensure that the training set, and validation set + are evaluated first + """ + if key == "train": + return (0, key) + if key == "valid": + return (1, key) + return (2, key) + + +def create_error_table( + table_type: str, + all_data_loaders: dict, + model: torch.nn.Module, + loss_fn: torch.nn.Module, + output_args: Dict[str, bool], + log_wandb: bool, + device: str, + distributed: bool = False, + skip_heads: Optional[List[str]] = None, +) -> PrettyTable: + if log_wandb: + import wandb + skip_heads = skip_heads or [] + table = PrettyTable() + if table_type == "TotalRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSEstressvirials": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + "RMSE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "PerAtomMAEstressvirials": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + "MAE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "TotalMAE": + table.field_names = [ + "config_type", + "MAE E / meV", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "PerAtomMAE": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "DipoleRMSE": + table.field_names = [ + "config_type", + "RMSE MU / mDebye / atom", + "relative MU RMSE %", + ] + elif table_type == "DipoleMAE": + table.field_names = [ + "config_type", + "MAE MU / mDebye / atom", + "relative MU MAE %", + ] + elif table_type == "EnergyDipoleRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "rel F RMSE %", + "RMSE MU / mDebye / atom", + "rel MU RMSE %", + ] + + for name in sorted(all_data_loaders, key=custom_key): + if any(skip_head in name for skip_head in skip_heads): + logging.info(f"Skipping evaluation of {name} (in skip_heads list)") + continue + data_loader = all_data_loaders[name] + logging.info(f"Evaluating {name} ...") + _, metrics = evaluate( + model, + loss_fn=loss_fn, + data_loader=data_loader, + output_args=output_args, + device=device, + ) + if distributed: + torch.distributed.barrier() + + del data_loader + torch.cuda.empty_cache() + if log_wandb: + wandb_log_dict = { + name + + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] + * 1e3, # meV / atom + name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A + name + "_final_rel_rmse_f": metrics["rel_rmse_f"], + } + wandb.log(wandb_log_dict) + if table_type == "TotalRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif table_type == "PerAtomRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_virials'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_virials'] * 1000:8.1f}", + ] + ) + elif table_type == "TotalMAE": + table.add_row( + [ + name, + f"{metrics['mae_e'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "PerAtomMAE": + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "DipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + elif table_type == "DipoleMAE": + table.add_row( + [ + name, + f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_mae_mu']:8.1f}", + ] + ) + elif table_type == "EnergyDipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.1f}", + f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + return table diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__init__.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__init__.py index 329f8dda3f2511bdaa26fe8c1e267f1b692032ff..486f0d09d41acfdffba0c1d6e29828bc4fe9ba75 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__init__.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__init__.py @@ -1,7 +1,7 @@ -from .batch import Batch -from .data import Data -from .dataloader import DataLoader -from .dataset import Dataset -from .seed import seed_everything - -__all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"] +from .batch import Batch +from .data import Data +from .dataloader import DataLoader +from .dataset import Dataset +from .seed import seed_everything + +__all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"] diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 735bf63e361ff3a6ac40b3983fa5225113bd74ef..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index b98d9d71e3b8a1c87147e2c555a75e1c6fe9cbd1..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/batch.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/batch.cpython-310.pyc deleted file mode 100644 index 3a5c79e765070028f8db7ebd0eb9c6292dff25a4..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/batch.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/batch.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/batch.cpython-313.pyc deleted file mode 100644 index 80351e8e3b9f1df6c8fc4f456b2c589e8dbf2351..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/batch.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/data.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/data.cpython-310.pyc deleted file mode 100644 index 3ea13e80bc0509875c06d354c2ee883da1f2fccd..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/data.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/data.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/data.cpython-313.pyc deleted file mode 100644 index a0ef969ca973a60fa900aefae0179abf01228386..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/data.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataloader.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataloader.cpython-310.pyc deleted file mode 100644 index 8ff3e9f617cedb4d5f31935e99d2cb53ae55dabf..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataloader.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataloader.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataloader.cpython-313.pyc deleted file mode 100644 index b26fde472f4457c90cbe798e07f29bc14519d5f7..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataloader.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataset.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataset.cpython-310.pyc deleted file mode 100644 index 466c5e1e552f982edc5717131a772edba00120af..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataset.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataset.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataset.cpython-313.pyc deleted file mode 100644 index 0b629ef4c3d69d9afd66b4d5dd9c41b5a5026dc1..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataset.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-310.pyc deleted file mode 100644 index e81ebced3f665c683e5d8030ba76fe8291990903..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-313.pyc deleted file mode 100644 index 3bbd4145a0f0b10e6d2c204c5d12eb3b00eaf83d..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/utils.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 7174e35a1bbad93078d43d8ff4302b8a4e3bc459..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/utils.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/utils.cpython-313.pyc deleted file mode 100644 index 617b8af0098e232c09c39ef68f2aa170f9c3618f..0000000000000000000000000000000000000000 Binary files a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/utils.cpython-313.pyc and /dev/null differ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/batch.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/batch.py index 93ff7e2500ac4f7b1a3551ad8845f26143b34436..be5ec9d0cf418c9edc79fc503f143065e6450434 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/batch.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/batch.py @@ -1,257 +1,257 @@ -from collections.abc import Sequence -from typing import List - -import numpy as np -import torch -from torch import Tensor - -from .data import Data -from .dataset import IndexType - - -class Batch(Data): - r"""A plain old python object modeling a batch of graphs as one big - (disconnected) graph. With :class:`torch_geometric.data.Data` being the - base class, all its methods can also be used here. - In addition, single graphs can be reconstructed via the assignment vector - :obj:`batch`, which maps each node to its respective graph identifier. - """ - - def __init__(self, batch=None, ptr=None, **kwargs): - super(Batch, self).__init__(**kwargs) - - for key, item in kwargs.items(): - if key == "num_nodes": - self.__num_nodes__ = item - else: - self[key] = item - - self.batch = batch - self.ptr = ptr - self.__data_class__ = Data - self.__slices__ = None - self.__cumsum__ = None - self.__cat_dims__ = None - self.__num_nodes_list__ = None - self.__num_graphs__ = None - - @classmethod - def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): - r"""Constructs a batch object from a python list holding - :class:`torch_geometric.data.Data` objects. - The assignment vector :obj:`batch` is created on the fly. - Additionally, creates assignment batch vectors for each key in - :obj:`follow_batch`. - Will exclude any keys given in :obj:`exclude_keys`.""" - - keys = list(set(data_list[0].keys) - set(exclude_keys)) - assert "batch" not in keys and "ptr" not in keys - - batch = cls() - for key in data_list[0].__dict__.keys(): - if key[:2] != "__" and key[-2:] != "__": - batch[key] = None - - batch.__num_graphs__ = len(data_list) - batch.__data_class__ = data_list[0].__class__ - for key in keys + ["batch"]: - batch[key] = [] - batch["ptr"] = [0] - - device = None - slices = {key: [0] for key in keys} - cumsum = {key: [0] for key in keys} - cat_dims = {} - num_nodes_list = [] - for i, data in enumerate(data_list): - for key in keys: - item = data[key] - - # Increase values by `cumsum` value. - cum = cumsum[key][-1] - if isinstance(item, Tensor) and item.dtype != torch.bool: - if not isinstance(cum, int) or cum != 0: - item = item + cum - elif isinstance(item, (int, float)): - item = item + cum - - # Gather the size of the `cat` dimension. - size = 1 - cat_dim = data.__cat_dim__(key, data[key]) - # 0-dimensional tensors have no dimension along which to - # concatenate, so we set `cat_dim` to `None`. - if isinstance(item, Tensor) and item.dim() == 0: - cat_dim = None - cat_dims[key] = cat_dim - - # Add a batch dimension to items whose `cat_dim` is `None`: - if isinstance(item, Tensor) and cat_dim is None: - cat_dim = 0 # Concatenate along this new batch dimension. - item = item.unsqueeze(0) - device = item.device - elif isinstance(item, Tensor): - size = item.size(cat_dim) - device = item.device - - batch[key].append(item) # Append item to the attribute list. - - slices[key].append(size + slices[key][-1]) - inc = data.__inc__(key, item) - if isinstance(inc, (tuple, list)): - inc = torch.tensor(inc) - cumsum[key].append(inc + cumsum[key][-1]) - - if key in follow_batch: - if isinstance(size, Tensor): - for j, size in enumerate(size.tolist()): - tmp = f"{key}_{j}_batch" - batch[tmp] = [] if i == 0 else batch[tmp] - batch[tmp].append( - torch.full((size,), i, dtype=torch.long, device=device) - ) - else: - tmp = f"{key}_batch" - batch[tmp] = [] if i == 0 else batch[tmp] - batch[tmp].append( - torch.full((size,), i, dtype=torch.long, device=device) - ) - - if hasattr(data, "__num_nodes__"): - num_nodes_list.append(data.__num_nodes__) - else: - num_nodes_list.append(None) - - num_nodes = data.num_nodes - if num_nodes is not None: - item = torch.full((num_nodes,), i, dtype=torch.long, device=device) - batch.batch.append(item) - batch.ptr.append(batch.ptr[-1] + num_nodes) - - batch.batch = None if len(batch.batch) == 0 else batch.batch - batch.ptr = None if len(batch.ptr) == 1 else batch.ptr - batch.__slices__ = slices - batch.__cumsum__ = cumsum - batch.__cat_dims__ = cat_dims - batch.__num_nodes_list__ = num_nodes_list - - ref_data = data_list[0] - for key in batch.keys: - items = batch[key] - item = items[0] - cat_dim = ref_data.__cat_dim__(key, item) - cat_dim = 0 if cat_dim is None else cat_dim - if isinstance(item, Tensor): - batch[key] = torch.cat(items, cat_dim) - elif isinstance(item, (int, float)): - batch[key] = torch.tensor(items) - - # if torch_geometric.is_debug_enabled(): - # batch.debug() - - return batch.contiguous() - - def get_example(self, idx: int) -> Data: - r"""Reconstructs the :class:`torch_geometric.data.Data` object at index - :obj:`idx` from the batch object. - The batch object must have been created via :meth:`from_data_list` in - order to be able to reconstruct the initial objects.""" - - if self.__slices__ is None: - raise RuntimeError( - ( - "Cannot reconstruct data list from batch because the batch " - "object was not created using `Batch.from_data_list()`." - ) - ) - - data = self.__data_class__() - idx = self.num_graphs + idx if idx < 0 else idx - - for key in self.__slices__.keys(): - item = self[key] - if self.__cat_dims__[key] is None: - # The item was concatenated along a new batch dimension, - # so just index in that dimension: - item = item[idx] - else: - # Narrow the item based on the values in `__slices__`. - if isinstance(item, Tensor): - dim = self.__cat_dims__[key] - start = self.__slices__[key][idx] - end = self.__slices__[key][idx + 1] - item = item.narrow(dim, start, end - start) - else: - start = self.__slices__[key][idx] - end = self.__slices__[key][idx + 1] - item = item[start:end] - item = item[0] if len(item) == 1 else item - - # Decrease its value by `cumsum` value: - cum = self.__cumsum__[key][idx] - if isinstance(item, Tensor): - if not isinstance(cum, int) or cum != 0: - item = item - cum - elif isinstance(item, (int, float)): - item = item - cum - - data[key] = item - - if self.__num_nodes_list__[idx] is not None: - data.num_nodes = self.__num_nodes_list__[idx] - - return data - - def index_select(self, idx: IndexType) -> List[Data]: - if isinstance(idx, slice): - idx = list(range(self.num_graphs)[idx]) - - elif isinstance(idx, Tensor) and idx.dtype == torch.long: - idx = idx.flatten().tolist() - - elif isinstance(idx, Tensor) and idx.dtype == torch.bool: - idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() - - elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: - idx = idx.flatten().tolist() - - elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: - idx = idx.flatten().nonzero()[0].flatten().tolist() - - elif isinstance(idx, Sequence) and not isinstance(idx, str): - pass - - else: - raise IndexError( - f"Only integers, slices (':'), list, tuples, torch.tensor and " - f"np.ndarray of dtype long or bool are valid indices (got " - f"'{type(idx).__name__}')" - ) - - return [self.get_example(i) for i in idx] - - def __getitem__(self, idx): - if isinstance(idx, str): - return super(Batch, self).__getitem__(idx) - elif isinstance(idx, (int, np.integer)): - return self.get_example(idx) - else: - return self.index_select(idx) - - def to_data_list(self) -> List[Data]: - r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects - from the batch object. - The batch object must have been created via :meth:`from_data_list` in - order to be able to reconstruct the initial objects.""" - return [self.get_example(i) for i in range(self.num_graphs)] - - @property - def num_graphs(self) -> int: - """Returns the number of graphs in the batch.""" - if self.__num_graphs__ is not None: - return self.__num_graphs__ - elif self.ptr is not None: - return self.ptr.numel() - 1 - elif self.batch is not None: - return int(self.batch.max()) + 1 - else: - raise ValueError +from collections.abc import Sequence +from typing import List + +import numpy as np +import torch +from torch import Tensor + +from .data import Data +from .dataset import IndexType + + +class Batch(Data): + r"""A plain old python object modeling a batch of graphs as one big + (disconnected) graph. With :class:`torch_geometric.data.Data` being the + base class, all its methods can also be used here. + In addition, single graphs can be reconstructed via the assignment vector + :obj:`batch`, which maps each node to its respective graph identifier. + """ + + def __init__(self, batch=None, ptr=None, **kwargs): + super(Batch, self).__init__(**kwargs) + + for key, item in kwargs.items(): + if key == "num_nodes": + self.__num_nodes__ = item + else: + self[key] = item + + self.batch = batch + self.ptr = ptr + self.__data_class__ = Data + self.__slices__ = None + self.__cumsum__ = None + self.__cat_dims__ = None + self.__num_nodes_list__ = None + self.__num_graphs__ = None + + @classmethod + def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): + r"""Constructs a batch object from a python list holding + :class:`torch_geometric.data.Data` objects. + The assignment vector :obj:`batch` is created on the fly. + Additionally, creates assignment batch vectors for each key in + :obj:`follow_batch`. + Will exclude any keys given in :obj:`exclude_keys`.""" + + keys = list(set(data_list[0].keys) - set(exclude_keys)) + assert "batch" not in keys and "ptr" not in keys + + batch = cls() + for key in data_list[0].__dict__.keys(): + if key[:2] != "__" and key[-2:] != "__": + batch[key] = None + + batch.__num_graphs__ = len(data_list) + batch.__data_class__ = data_list[0].__class__ + for key in keys + ["batch"]: + batch[key] = [] + batch["ptr"] = [0] + + device = None + slices = {key: [0] for key in keys} + cumsum = {key: [0] for key in keys} + cat_dims = {} + num_nodes_list = [] + for i, data in enumerate(data_list): + for key in keys: + item = data[key] + + # Increase values by `cumsum` value. + cum = cumsum[key][-1] + if isinstance(item, Tensor) and item.dtype != torch.bool: + if not isinstance(cum, int) or cum != 0: + item = item + cum + elif isinstance(item, (int, float)): + item = item + cum + + # Gather the size of the `cat` dimension. + size = 1 + cat_dim = data.__cat_dim__(key, data[key]) + # 0-dimensional tensors have no dimension along which to + # concatenate, so we set `cat_dim` to `None`. + if isinstance(item, Tensor) and item.dim() == 0: + cat_dim = None + cat_dims[key] = cat_dim + + # Add a batch dimension to items whose `cat_dim` is `None`: + if isinstance(item, Tensor) and cat_dim is None: + cat_dim = 0 # Concatenate along this new batch dimension. + item = item.unsqueeze(0) + device = item.device + elif isinstance(item, Tensor): + size = item.size(cat_dim) + device = item.device + + batch[key].append(item) # Append item to the attribute list. + + slices[key].append(size + slices[key][-1]) + inc = data.__inc__(key, item) + if isinstance(inc, (tuple, list)): + inc = torch.tensor(inc) + cumsum[key].append(inc + cumsum[key][-1]) + + if key in follow_batch: + if isinstance(size, Tensor): + for j, size in enumerate(size.tolist()): + tmp = f"{key}_{j}_batch" + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size,), i, dtype=torch.long, device=device) + ) + else: + tmp = f"{key}_batch" + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size,), i, dtype=torch.long, device=device) + ) + + if hasattr(data, "__num_nodes__"): + num_nodes_list.append(data.__num_nodes__) + else: + num_nodes_list.append(None) + + num_nodes = data.num_nodes + if num_nodes is not None: + item = torch.full((num_nodes,), i, dtype=torch.long, device=device) + batch.batch.append(item) + batch.ptr.append(batch.ptr[-1] + num_nodes) + + batch.batch = None if len(batch.batch) == 0 else batch.batch + batch.ptr = None if len(batch.ptr) == 1 else batch.ptr + batch.__slices__ = slices + batch.__cumsum__ = cumsum + batch.__cat_dims__ = cat_dims + batch.__num_nodes_list__ = num_nodes_list + + ref_data = data_list[0] + for key in batch.keys: + items = batch[key] + item = items[0] + cat_dim = ref_data.__cat_dim__(key, item) + cat_dim = 0 if cat_dim is None else cat_dim + if isinstance(item, Tensor): + batch[key] = torch.cat(items, cat_dim) + elif isinstance(item, (int, float)): + batch[key] = torch.tensor(items) + + # if torch_geometric.is_debug_enabled(): + # batch.debug() + + return batch.contiguous() + + def get_example(self, idx: int) -> Data: + r"""Reconstructs the :class:`torch_geometric.data.Data` object at index + :obj:`idx` from the batch object. + The batch object must have been created via :meth:`from_data_list` in + order to be able to reconstruct the initial objects.""" + + if self.__slices__ is None: + raise RuntimeError( + ( + "Cannot reconstruct data list from batch because the batch " + "object was not created using `Batch.from_data_list()`." + ) + ) + + data = self.__data_class__() + idx = self.num_graphs + idx if idx < 0 else idx + + for key in self.__slices__.keys(): + item = self[key] + if self.__cat_dims__[key] is None: + # The item was concatenated along a new batch dimension, + # so just index in that dimension: + item = item[idx] + else: + # Narrow the item based on the values in `__slices__`. + if isinstance(item, Tensor): + dim = self.__cat_dims__[key] + start = self.__slices__[key][idx] + end = self.__slices__[key][idx + 1] + item = item.narrow(dim, start, end - start) + else: + start = self.__slices__[key][idx] + end = self.__slices__[key][idx + 1] + item = item[start:end] + item = item[0] if len(item) == 1 else item + + # Decrease its value by `cumsum` value: + cum = self.__cumsum__[key][idx] + if isinstance(item, Tensor): + if not isinstance(cum, int) or cum != 0: + item = item - cum + elif isinstance(item, (int, float)): + item = item - cum + + data[key] = item + + if self.__num_nodes_list__[idx] is not None: + data.num_nodes = self.__num_nodes_list__[idx] + + return data + + def index_select(self, idx: IndexType) -> List[Data]: + if isinstance(idx, slice): + idx = list(range(self.num_graphs)[idx]) + + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + idx = idx.flatten().tolist() + + elif isinstance(idx, Tensor) and idx.dtype == torch.bool: + idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + idx = idx.flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + idx = idx.flatten().nonzero()[0].flatten().tolist() + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + pass + + else: + raise IndexError( + f"Only integers, slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + return [self.get_example(i) for i in idx] + + def __getitem__(self, idx): + if isinstance(idx, str): + return super(Batch, self).__getitem__(idx) + elif isinstance(idx, (int, np.integer)): + return self.get_example(idx) + else: + return self.index_select(idx) + + def to_data_list(self) -> List[Data]: + r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects + from the batch object. + The batch object must have been created via :meth:`from_data_list` in + order to be able to reconstruct the initial objects.""" + return [self.get_example(i) for i in range(self.num_graphs)] + + @property + def num_graphs(self) -> int: + """Returns the number of graphs in the batch.""" + if self.__num_graphs__ is not None: + return self.__num_graphs__ + elif self.ptr is not None: + return self.ptr.numel() - 1 + elif self.batch is not None: + return int(self.batch.max()) + 1 + else: + raise ValueError diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/data.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/data.py index 0c41726ad39bd303295a8ac26cdab061169ece8e..4e1ab3084d584dcf5af7289e3159c353327ec539 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/data.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/data.py @@ -1,441 +1,441 @@ -import collections -import copy -import re - -import torch - -# from ..utils.num_nodes import maybe_num_nodes - -__num_nodes_warn_msg__ = ( - "The number of nodes in your data object can only be inferred by its {} " - "indices, and hence may result in unexpected batch-wise behavior, e.g., " - "in case there exists isolated nodes. Please consider explicitly setting " - "the number of nodes for this data object by assigning it to " - "data.num_nodes." -) - - -def size_repr(key, item, indent=0): - indent_str = " " * indent - if torch.is_tensor(item) and item.dim() == 0: - out = item.item() - elif torch.is_tensor(item): - out = str(list(item.size())) - elif isinstance(item, list) or isinstance(item, tuple): - out = str([len(item)]) - elif isinstance(item, dict): - lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] - out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}" - elif isinstance(item, str): - out = f'"{item}"' - else: - out = str(item) - - return f"{indent_str}{key}={out}" - - -class Data(object): - r"""A plain old python object modeling a single graph with various - (optional) attributes: - - Args: - x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, - num_node_features]`. (default: :obj:`None`) - edge_index (LongTensor, optional): Graph connectivity in COO format - with shape :obj:`[2, num_edges]`. (default: :obj:`None`) - edge_attr (Tensor, optional): Edge feature matrix with shape - :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) - y (Tensor, optional): Graph or node targets with arbitrary shape. - (default: :obj:`None`) - pos (Tensor, optional): Node position matrix with shape - :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) - normal (Tensor, optional): Normal vector matrix with shape - :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) - face (LongTensor, optional): Face adjacency matrix with shape - :obj:`[3, num_faces]`. (default: :obj:`None`) - - The data object is not restricted to these attributes and can be extended - by any other additional data. - - Example:: - - data = Data(x=x, edge_index=edge_index) - data.train_idx = torch.tensor([...], dtype=torch.long) - data.test_mask = torch.tensor([...], dtype=torch.bool) - """ - - def __init__( - self, - x=None, - edge_index=None, - edge_attr=None, - y=None, - pos=None, - normal=None, - face=None, - **kwargs, - ): - self.x = x - self.edge_index = edge_index - self.edge_attr = edge_attr - self.y = y - self.pos = pos - self.normal = normal - self.face = face - for key, item in kwargs.items(): - if key == "num_nodes": - self.__num_nodes__ = item - else: - self[key] = item - - if edge_index is not None and edge_index.dtype != torch.long: - raise ValueError( - ( - f"Argument `edge_index` needs to be of type `torch.long` but " - f"found type `{edge_index.dtype}`." - ) - ) - - if face is not None and face.dtype != torch.long: - raise ValueError( - ( - f"Argument `face` needs to be of type `torch.long` but found " - f"type `{face.dtype}`." - ) - ) - - @classmethod - def from_dict(cls, dictionary): - r"""Creates a data object from a python dictionary.""" - data = cls() - - for key, item in dictionary.items(): - data[key] = item - - return data - - def to_dict(self): - return {key: item for key, item in self} - - def to_namedtuple(self): - keys = self.keys - DataTuple = collections.namedtuple("DataTuple", keys) - return DataTuple(*[self[key] for key in keys]) - - def __getitem__(self, key): - r"""Gets the data of the attribute :obj:`key`.""" - return getattr(self, key, None) - - def __setitem__(self, key, value): - """Sets the attribute :obj:`key` to :obj:`value`.""" - setattr(self, key, value) - - def __delitem__(self, key): - r"""Delete the data of the attribute :obj:`key`.""" - return delattr(self, key) - - @property - def keys(self): - r"""Returns all names of graph attributes.""" - keys = [key for key in self.__dict__.keys() if self[key] is not None] - keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"] - return keys - - def __len__(self): - r"""Returns the number of all present attributes.""" - return len(self.keys) - - def __contains__(self, key): - r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the - data.""" - return key in self.keys - - def __iter__(self): - r"""Iterates over all present attributes in the data, yielding their - attribute names and content.""" - for key in sorted(self.keys): - yield key, self[key] - - def __call__(self, *keys): - r"""Iterates over all attributes :obj:`*keys` in the data, yielding - their attribute names and content. - If :obj:`*keys` is not given this method will iterative over all - present attributes.""" - for key in sorted(self.keys) if not keys else keys: - if key in self: - yield key, self[key] - - def __cat_dim__(self, key, value): - r"""Returns the dimension for which :obj:`value` of attribute - :obj:`key` will get concatenated when creating batches. - - .. note:: - - This method is for internal use only, and should only be overridden - if the batch concatenation process is corrupted for a specific data - attribute. - """ - if bool(re.search("(index|face)", key)): - return -1 - return 0 - - def __inc__(self, key, value): - r"""Returns the incremental count to cumulatively increase the value - of the next attribute of :obj:`key` when creating batches. - - .. note:: - - This method is for internal use only, and should only be overridden - if the batch concatenation process is corrupted for a specific data - attribute. - """ - # Only `*index*` and `*face*` attributes should be cumulatively summed - # up when creating batches. - return self.num_nodes if bool(re.search("(index|face)", key)) else 0 - - @property - def num_nodes(self): - r"""Returns or sets the number of nodes in the graph. - - .. note:: - The number of nodes in your data object is typically automatically - inferred, *e.g.*, when node features :obj:`x` are present. - In some cases however, a graph may only be given by its edge - indices :obj:`edge_index`. - PyTorch Geometric then *guesses* the number of nodes - according to :obj:`edge_index.max().item() + 1`, but in case there - exists isolated nodes, this number has not to be correct and can - therefore result in unexpected batch-wise behavior. - Thus, we recommend to set the number of nodes in your data object - explicitly via :obj:`data.num_nodes = ...`. - You will be given a warning that requests you to do so. - """ - if hasattr(self, "__num_nodes__"): - return self.__num_nodes__ - for key, item in self("x", "pos", "normal", "batch"): - return item.size(self.__cat_dim__(key, item)) - if hasattr(self, "adj"): - return self.adj.size(0) - if hasattr(self, "adj_t"): - return self.adj_t.size(1) - # if self.face is not None: - # logging.warning(__num_nodes_warn_msg__.format("face")) - # return maybe_num_nodes(self.face) - # if self.edge_index is not None: - # logging.warning(__num_nodes_warn_msg__.format("edge")) - # return maybe_num_nodes(self.edge_index) - return None - - @num_nodes.setter - def num_nodes(self, num_nodes): - self.__num_nodes__ = num_nodes - - @property - def num_edges(self): - """ - Returns the number of edges in the graph. - For undirected graphs, this will return the number of bi-directional - edges, which is double the amount of unique edges. - """ - for key, item in self("edge_index", "edge_attr"): - return item.size(self.__cat_dim__(key, item)) - for key, item in self("adj", "adj_t"): - return item.nnz() - return None - - @property - def num_faces(self): - r"""Returns the number of faces in the mesh.""" - if self.face is not None: - return self.face.size(self.__cat_dim__("face", self.face)) - return None - - @property - def num_node_features(self): - r"""Returns the number of features per node in the graph.""" - if self.x is None: - return 0 - return 1 if self.x.dim() == 1 else self.x.size(1) - - @property - def num_features(self): - r"""Alias for :py:attr:`~num_node_features`.""" - return self.num_node_features - - @property - def num_edge_features(self): - r"""Returns the number of features per edge in the graph.""" - if self.edge_attr is None: - return 0 - return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1) - - def __apply__(self, item, func): - if torch.is_tensor(item): - return func(item) - elif isinstance(item, (tuple, list)): - return [self.__apply__(v, func) for v in item] - elif isinstance(item, dict): - return {k: self.__apply__(v, func) for k, v in item.items()} - else: - return item - - def apply(self, func, *keys): - r"""Applies the function :obj:`func` to all tensor attributes - :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to - all present attributes. - """ - for key, item in self(*keys): - self[key] = self.__apply__(item, func) - return self - - def contiguous(self, *keys): - r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. - If :obj:`*keys` is not given, all present attributes are ensured to - have a contiguous memory layout.""" - return self.apply(lambda x: x.contiguous(), *keys) - - def to(self, device, *keys, **kwargs): - r"""Performs tensor dtype and/or device conversion to all attributes - :obj:`*keys`. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply(lambda x: x.to(device, **kwargs), *keys) - - def cpu(self, *keys): - r"""Copies all attributes :obj:`*keys` to CPU memory. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply(lambda x: x.cpu(), *keys) - - def cuda(self, device=None, non_blocking=False, *keys): - r"""Copies all attributes :obj:`*keys` to CUDA memory. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply( - lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys - ) - - def clone(self): - r"""Performs a deep-copy of the data object.""" - return self.__class__.from_dict( - { - k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v) - for k, v in self.__dict__.items() - } - ) - - def pin_memory(self, *keys): - r"""Copies all attributes :obj:`*keys` to pinned memory. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply(lambda x: x.pin_memory(), *keys) - - def debug(self): - if self.edge_index is not None: - if self.edge_index.dtype != torch.long: - raise RuntimeError( - ( - "Expected edge indices of dtype {}, but found dtype " " {}" - ).format(torch.long, self.edge_index.dtype) - ) - - if self.face is not None: - if self.face.dtype != torch.long: - raise RuntimeError( - ( - "Expected face indices of dtype {}, but found dtype " " {}" - ).format(torch.long, self.face.dtype) - ) - - if self.edge_index is not None: - if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: - raise RuntimeError( - ( - "Edge indices should have shape [2, num_edges] but found" - " shape {}" - ).format(self.edge_index.size()) - ) - - if self.edge_index is not None and self.num_nodes is not None: - if self.edge_index.numel() > 0: - min_index = self.edge_index.min() - max_index = self.edge_index.max() - else: - min_index = max_index = 0 - if min_index < 0 or max_index > self.num_nodes - 1: - raise RuntimeError( - ( - "Edge indices must lay in the interval [0, {}]" - " but found them in the interval [{}, {}]" - ).format(self.num_nodes - 1, min_index, max_index) - ) - - if self.face is not None: - if self.face.dim() != 2 or self.face.size(0) != 3: - raise RuntimeError( - ( - "Face indices should have shape [3, num_faces] but found" - " shape {}" - ).format(self.face.size()) - ) - - if self.face is not None and self.num_nodes is not None: - if self.face.numel() > 0: - min_index = self.face.min() - max_index = self.face.max() - else: - min_index = max_index = 0 - if min_index < 0 or max_index > self.num_nodes - 1: - raise RuntimeError( - ( - "Face indices must lay in the interval [0, {}]" - " but found them in the interval [{}, {}]" - ).format(self.num_nodes - 1, min_index, max_index) - ) - - if self.edge_index is not None and self.edge_attr is not None: - if self.edge_index.size(1) != self.edge_attr.size(0): - raise RuntimeError( - ( - "Edge indices and edge attributes hold a differing " - "number of edges, found {} and {}" - ).format(self.edge_index.size(), self.edge_attr.size()) - ) - - if self.x is not None and self.num_nodes is not None: - if self.x.size(0) != self.num_nodes: - raise RuntimeError( - ( - "Node features should hold {} elements in the first " - "dimension but found {}" - ).format(self.num_nodes, self.x.size(0)) - ) - - if self.pos is not None and self.num_nodes is not None: - if self.pos.size(0) != self.num_nodes: - raise RuntimeError( - ( - "Node positions should hold {} elements in the first " - "dimension but found {}" - ).format(self.num_nodes, self.pos.size(0)) - ) - - if self.normal is not None and self.num_nodes is not None: - if self.normal.size(0) != self.num_nodes: - raise RuntimeError( - ( - "Node normals should hold {} elements in the first " - "dimension but found {}" - ).format(self.num_nodes, self.normal.size(0)) - ) - - def __repr__(self): - cls = str(self.__class__.__name__) - has_dict = any([isinstance(item, dict) for _, item in self]) - - if not has_dict: - info = [size_repr(key, item) for key, item in self] - return "{}({})".format(cls, ", ".join(info)) - else: - info = [size_repr(key, item, indent=2) for key, item in self] - return "{}(\n{}\n)".format(cls, ",\n".join(info)) +import collections +import copy +import re + +import torch + +# from ..utils.num_nodes import maybe_num_nodes + +__num_nodes_warn_msg__ = ( + "The number of nodes in your data object can only be inferred by its {} " + "indices, and hence may result in unexpected batch-wise behavior, e.g., " + "in case there exists isolated nodes. Please consider explicitly setting " + "the number of nodes for this data object by assigning it to " + "data.num_nodes." +) + + +def size_repr(key, item, indent=0): + indent_str = " " * indent + if torch.is_tensor(item) and item.dim() == 0: + out = item.item() + elif torch.is_tensor(item): + out = str(list(item.size())) + elif isinstance(item, list) or isinstance(item, tuple): + out = str([len(item)]) + elif isinstance(item, dict): + lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] + out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}" + elif isinstance(item, str): + out = f'"{item}"' + else: + out = str(item) + + return f"{indent_str}{key}={out}" + + +class Data(object): + r"""A plain old python object modeling a single graph with various + (optional) attributes: + + Args: + x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, + num_node_features]`. (default: :obj:`None`) + edge_index (LongTensor, optional): Graph connectivity in COO format + with shape :obj:`[2, num_edges]`. (default: :obj:`None`) + edge_attr (Tensor, optional): Edge feature matrix with shape + :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) + y (Tensor, optional): Graph or node targets with arbitrary shape. + (default: :obj:`None`) + pos (Tensor, optional): Node position matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + normal (Tensor, optional): Normal vector matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + face (LongTensor, optional): Face adjacency matrix with shape + :obj:`[3, num_faces]`. (default: :obj:`None`) + + The data object is not restricted to these attributes and can be extended + by any other additional data. + + Example:: + + data = Data(x=x, edge_index=edge_index) + data.train_idx = torch.tensor([...], dtype=torch.long) + data.test_mask = torch.tensor([...], dtype=torch.bool) + """ + + def __init__( + self, + x=None, + edge_index=None, + edge_attr=None, + y=None, + pos=None, + normal=None, + face=None, + **kwargs, + ): + self.x = x + self.edge_index = edge_index + self.edge_attr = edge_attr + self.y = y + self.pos = pos + self.normal = normal + self.face = face + for key, item in kwargs.items(): + if key == "num_nodes": + self.__num_nodes__ = item + else: + self[key] = item + + if edge_index is not None and edge_index.dtype != torch.long: + raise ValueError( + ( + f"Argument `edge_index` needs to be of type `torch.long` but " + f"found type `{edge_index.dtype}`." + ) + ) + + if face is not None and face.dtype != torch.long: + raise ValueError( + ( + f"Argument `face` needs to be of type `torch.long` but found " + f"type `{face.dtype}`." + ) + ) + + @classmethod + def from_dict(cls, dictionary): + r"""Creates a data object from a python dictionary.""" + data = cls() + + for key, item in dictionary.items(): + data[key] = item + + return data + + def to_dict(self): + return {key: item for key, item in self} + + def to_namedtuple(self): + keys = self.keys + DataTuple = collections.namedtuple("DataTuple", keys) + return DataTuple(*[self[key] for key in keys]) + + def __getitem__(self, key): + r"""Gets the data of the attribute :obj:`key`.""" + return getattr(self, key, None) + + def __setitem__(self, key, value): + """Sets the attribute :obj:`key` to :obj:`value`.""" + setattr(self, key, value) + + def __delitem__(self, key): + r"""Delete the data of the attribute :obj:`key`.""" + return delattr(self, key) + + @property + def keys(self): + r"""Returns all names of graph attributes.""" + keys = [key for key in self.__dict__.keys() if self[key] is not None] + keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"] + return keys + + def __len__(self): + r"""Returns the number of all present attributes.""" + return len(self.keys) + + def __contains__(self, key): + r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the + data.""" + return key in self.keys + + def __iter__(self): + r"""Iterates over all present attributes in the data, yielding their + attribute names and content.""" + for key in sorted(self.keys): + yield key, self[key] + + def __call__(self, *keys): + r"""Iterates over all attributes :obj:`*keys` in the data, yielding + their attribute names and content. + If :obj:`*keys` is not given this method will iterative over all + present attributes.""" + for key in sorted(self.keys) if not keys else keys: + if key in self: + yield key, self[key] + + def __cat_dim__(self, key, value): + r"""Returns the dimension for which :obj:`value` of attribute + :obj:`key` will get concatenated when creating batches. + + .. note:: + + This method is for internal use only, and should only be overridden + if the batch concatenation process is corrupted for a specific data + attribute. + """ + if bool(re.search("(index|face)", key)): + return -1 + return 0 + + def __inc__(self, key, value): + r"""Returns the incremental count to cumulatively increase the value + of the next attribute of :obj:`key` when creating batches. + + .. note:: + + This method is for internal use only, and should only be overridden + if the batch concatenation process is corrupted for a specific data + attribute. + """ + # Only `*index*` and `*face*` attributes should be cumulatively summed + # up when creating batches. + return self.num_nodes if bool(re.search("(index|face)", key)) else 0 + + @property + def num_nodes(self): + r"""Returns or sets the number of nodes in the graph. + + .. note:: + The number of nodes in your data object is typically automatically + inferred, *e.g.*, when node features :obj:`x` are present. + In some cases however, a graph may only be given by its edge + indices :obj:`edge_index`. + PyTorch Geometric then *guesses* the number of nodes + according to :obj:`edge_index.max().item() + 1`, but in case there + exists isolated nodes, this number has not to be correct and can + therefore result in unexpected batch-wise behavior. + Thus, we recommend to set the number of nodes in your data object + explicitly via :obj:`data.num_nodes = ...`. + You will be given a warning that requests you to do so. + """ + if hasattr(self, "__num_nodes__"): + return self.__num_nodes__ + for key, item in self("x", "pos", "normal", "batch"): + return item.size(self.__cat_dim__(key, item)) + if hasattr(self, "adj"): + return self.adj.size(0) + if hasattr(self, "adj_t"): + return self.adj_t.size(1) + # if self.face is not None: + # logging.warning(__num_nodes_warn_msg__.format("face")) + # return maybe_num_nodes(self.face) + # if self.edge_index is not None: + # logging.warning(__num_nodes_warn_msg__.format("edge")) + # return maybe_num_nodes(self.edge_index) + return None + + @num_nodes.setter + def num_nodes(self, num_nodes): + self.__num_nodes__ = num_nodes + + @property + def num_edges(self): + """ + Returns the number of edges in the graph. + For undirected graphs, this will return the number of bi-directional + edges, which is double the amount of unique edges. + """ + for key, item in self("edge_index", "edge_attr"): + return item.size(self.__cat_dim__(key, item)) + for key, item in self("adj", "adj_t"): + return item.nnz() + return None + + @property + def num_faces(self): + r"""Returns the number of faces in the mesh.""" + if self.face is not None: + return self.face.size(self.__cat_dim__("face", self.face)) + return None + + @property + def num_node_features(self): + r"""Returns the number of features per node in the graph.""" + if self.x is None: + return 0 + return 1 if self.x.dim() == 1 else self.x.size(1) + + @property + def num_features(self): + r"""Alias for :py:attr:`~num_node_features`.""" + return self.num_node_features + + @property + def num_edge_features(self): + r"""Returns the number of features per edge in the graph.""" + if self.edge_attr is None: + return 0 + return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1) + + def __apply__(self, item, func): + if torch.is_tensor(item): + return func(item) + elif isinstance(item, (tuple, list)): + return [self.__apply__(v, func) for v in item] + elif isinstance(item, dict): + return {k: self.__apply__(v, func) for k, v in item.items()} + else: + return item + + def apply(self, func, *keys): + r"""Applies the function :obj:`func` to all tensor attributes + :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to + all present attributes. + """ + for key, item in self(*keys): + self[key] = self.__apply__(item, func) + return self + + def contiguous(self, *keys): + r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. + If :obj:`*keys` is not given, all present attributes are ensured to + have a contiguous memory layout.""" + return self.apply(lambda x: x.contiguous(), *keys) + + def to(self, device, *keys, **kwargs): + r"""Performs tensor dtype and/or device conversion to all attributes + :obj:`*keys`. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.to(device, **kwargs), *keys) + + def cpu(self, *keys): + r"""Copies all attributes :obj:`*keys` to CPU memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.cpu(), *keys) + + def cuda(self, device=None, non_blocking=False, *keys): + r"""Copies all attributes :obj:`*keys` to CUDA memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply( + lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys + ) + + def clone(self): + r"""Performs a deep-copy of the data object.""" + return self.__class__.from_dict( + { + k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v) + for k, v in self.__dict__.items() + } + ) + + def pin_memory(self, *keys): + r"""Copies all attributes :obj:`*keys` to pinned memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.pin_memory(), *keys) + + def debug(self): + if self.edge_index is not None: + if self.edge_index.dtype != torch.long: + raise RuntimeError( + ( + "Expected edge indices of dtype {}, but found dtype " " {}" + ).format(torch.long, self.edge_index.dtype) + ) + + if self.face is not None: + if self.face.dtype != torch.long: + raise RuntimeError( + ( + "Expected face indices of dtype {}, but found dtype " " {}" + ).format(torch.long, self.face.dtype) + ) + + if self.edge_index is not None: + if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: + raise RuntimeError( + ( + "Edge indices should have shape [2, num_edges] but found" + " shape {}" + ).format(self.edge_index.size()) + ) + + if self.edge_index is not None and self.num_nodes is not None: + if self.edge_index.numel() > 0: + min_index = self.edge_index.min() + max_index = self.edge_index.max() + else: + min_index = max_index = 0 + if min_index < 0 or max_index > self.num_nodes - 1: + raise RuntimeError( + ( + "Edge indices must lay in the interval [0, {}]" + " but found them in the interval [{}, {}]" + ).format(self.num_nodes - 1, min_index, max_index) + ) + + if self.face is not None: + if self.face.dim() != 2 or self.face.size(0) != 3: + raise RuntimeError( + ( + "Face indices should have shape [3, num_faces] but found" + " shape {}" + ).format(self.face.size()) + ) + + if self.face is not None and self.num_nodes is not None: + if self.face.numel() > 0: + min_index = self.face.min() + max_index = self.face.max() + else: + min_index = max_index = 0 + if min_index < 0 or max_index > self.num_nodes - 1: + raise RuntimeError( + ( + "Face indices must lay in the interval [0, {}]" + " but found them in the interval [{}, {}]" + ).format(self.num_nodes - 1, min_index, max_index) + ) + + if self.edge_index is not None and self.edge_attr is not None: + if self.edge_index.size(1) != self.edge_attr.size(0): + raise RuntimeError( + ( + "Edge indices and edge attributes hold a differing " + "number of edges, found {} and {}" + ).format(self.edge_index.size(), self.edge_attr.size()) + ) + + if self.x is not None and self.num_nodes is not None: + if self.x.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node features should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.x.size(0)) + ) + + if self.pos is not None and self.num_nodes is not None: + if self.pos.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node positions should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.pos.size(0)) + ) + + if self.normal is not None and self.num_nodes is not None: + if self.normal.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node normals should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.normal.size(0)) + ) + + def __repr__(self): + cls = str(self.__class__.__name__) + has_dict = any([isinstance(item, dict) for _, item in self]) + + if not has_dict: + info = [size_repr(key, item) for key, item in self] + return "{}({})".format(cls, ", ".join(info)) + else: + info = [size_repr(key, item, indent=2) for key, item in self] + return "{}(\n{}\n)".format(cls, ",\n".join(info)) diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataloader.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataloader.py index 9953c1421593a780702c2631fdf28e926d8617bf..396b7e7285ac192cb8d6d5e26f686321734d132b 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataloader.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataloader.py @@ -1,87 +1,87 @@ -from collections.abc import Mapping, Sequence -from typing import List, Optional, Union - -import torch.utils.data -from torch.utils.data.dataloader import default_collate - -from .batch import Batch -from .data import Data -from .dataset import Dataset - - -class Collater: - def __init__(self, follow_batch, exclude_keys): - self.follow_batch = follow_batch - self.exclude_keys = exclude_keys - - def __call__(self, batch): - elem = batch[0] - if isinstance(elem, Data): - return Batch.from_data_list( - batch, - follow_batch=self.follow_batch, - exclude_keys=self.exclude_keys, - ) - elif isinstance(elem, torch.Tensor): - return default_collate(batch) - elif isinstance(elem, float): - return torch.tensor(batch, dtype=torch.float) - elif isinstance(elem, int): - return torch.tensor(batch) - elif isinstance(elem, str): - return batch - elif isinstance(elem, Mapping): - return {key: self([data[key] for data in batch]) for key in elem} - elif isinstance(elem, tuple) and hasattr(elem, "_fields"): - return type(elem)(*(self(s) for s in zip(*batch))) - elif isinstance(elem, Sequence) and not isinstance(elem, str): - return [self(s) for s in zip(*batch)] - - raise TypeError(f"DataLoader found invalid type: {type(elem)}") - - def collate(self, batch): # Deprecated... - return self(batch) - - -class DataLoader(torch.utils.data.DataLoader): - r"""A data loader which merges data objects from a - :class:`torch_geometric.data.Dataset` to a mini-batch. - Data objects can be either of type :class:`~torch_geometric.data.Data` or - :class:`~torch_geometric.data.HeteroData`. - Args: - dataset (Dataset): The dataset from which to load the data. - batch_size (int, optional): How many samples per batch to load. - (default: :obj:`1`) - shuffle (bool, optional): If set to :obj:`True`, the data will be - reshuffled at every epoch. (default: :obj:`False`) - follow_batch (List[str], optional): Creates assignment batch - vectors for each key in the list. (default: :obj:`None`) - exclude_keys (List[str], optional): Will exclude each key in the - list. (default: :obj:`None`) - **kwargs (optional): Additional arguments of - :class:`torch.utils.data.DataLoader`. - """ - - def __init__( - self, - dataset: Dataset, - batch_size: int = 1, - shuffle: bool = False, - follow_batch: Optional[List[str]] = [None], - exclude_keys: Optional[List[str]] = [None], - **kwargs, - ): - if "collate_fn" in kwargs: - del kwargs["collate_fn"] - - # Save for PyTorch Lightning < 1.6: - self.follow_batch = follow_batch - self.exclude_keys = exclude_keys - - super().__init__( - dataset, - batch_size, - shuffle, - collate_fn=Collater(follow_batch, exclude_keys), - **kwargs, - ) +from collections.abc import Mapping, Sequence +from typing import List, Optional, Union + +import torch.utils.data +from torch.utils.data.dataloader import default_collate + +from .batch import Batch +from .data import Data +from .dataset import Dataset + + +class Collater: + def __init__(self, follow_batch, exclude_keys): + self.follow_batch = follow_batch + self.exclude_keys = exclude_keys + + def __call__(self, batch): + elem = batch[0] + if isinstance(elem, Data): + return Batch.from_data_list( + batch, + follow_batch=self.follow_batch, + exclude_keys=self.exclude_keys, + ) + elif isinstance(elem, torch.Tensor): + return default_collate(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, str): + return batch + elif isinstance(elem, Mapping): + return {key: self([data[key] for data in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): + return type(elem)(*(self(s) for s in zip(*batch))) + elif isinstance(elem, Sequence) and not isinstance(elem, str): + return [self(s) for s in zip(*batch)] + + raise TypeError(f"DataLoader found invalid type: {type(elem)}") + + def collate(self, batch): # Deprecated... + return self(batch) + + +class DataLoader(torch.utils.data.DataLoader): + r"""A data loader which merges data objects from a + :class:`torch_geometric.data.Dataset` to a mini-batch. + Data objects can be either of type :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData`. + Args: + dataset (Dataset): The dataset from which to load the data. + batch_size (int, optional): How many samples per batch to load. + (default: :obj:`1`) + shuffle (bool, optional): If set to :obj:`True`, the data will be + reshuffled at every epoch. (default: :obj:`False`) + follow_batch (List[str], optional): Creates assignment batch + vectors for each key in the list. (default: :obj:`None`) + exclude_keys (List[str], optional): Will exclude each key in the + list. (default: :obj:`None`) + **kwargs (optional): Additional arguments of + :class:`torch.utils.data.DataLoader`. + """ + + def __init__( + self, + dataset: Dataset, + batch_size: int = 1, + shuffle: bool = False, + follow_batch: Optional[List[str]] = [None], + exclude_keys: Optional[List[str]] = [None], + **kwargs, + ): + if "collate_fn" in kwargs: + del kwargs["collate_fn"] + + # Save for PyTorch Lightning < 1.6: + self.follow_batch = follow_batch + self.exclude_keys = exclude_keys + + super().__init__( + dataset, + batch_size, + shuffle, + collate_fn=Collater(follow_batch, exclude_keys), + **kwargs, + ) diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataset.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataset.py index 7b4db34bac5cc7904c2ee551f5704427db315213..b4aeb2be9149ed68be67cc9009531ac809ac6787 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataset.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataset.py @@ -1,280 +1,280 @@ -import copy -import os.path as osp -import re -import warnings -from collections.abc import Sequence -from typing import Any, Callable, List, Optional, Tuple, Union - -import numpy as np -import torch.utils.data -from torch import Tensor - -from .data import Data -from .utils import makedirs - -IndexType = Union[slice, Tensor, np.ndarray, Sequence] - - -class Dataset(torch.utils.data.Dataset): - r"""Dataset base class for creating graph datasets. - See `here `__ for the accompanying tutorial. - - Args: - root (string, optional): Root directory where the dataset should be - saved. (optional: :obj:`None`) - transform (callable, optional): A function/transform that takes in an - :obj:`torch_geometric.data.Data` object and returns a transformed - version. The data object will be transformed before every access. - (default: :obj:`None`) - pre_transform (callable, optional): A function/transform that takes in - an :obj:`torch_geometric.data.Data` object and returns a - transformed version. The data object will be transformed before - being saved to disk. (default: :obj:`None`) - pre_filter (callable, optional): A function that takes in an - :obj:`torch_geometric.data.Data` object and returns a boolean - value, indicating whether the data object should be included in the - final dataset. (default: :obj:`None`) - """ - - @property - def raw_file_names(self) -> Union[str, List[str], Tuple]: - r"""The name of the files to find in the :obj:`self.raw_dir` folder in - order to skip the download.""" - raise NotImplementedError - - @property - def processed_file_names(self) -> Union[str, List[str], Tuple]: - r"""The name of the files to find in the :obj:`self.processed_dir` - folder in order to skip the processing.""" - raise NotImplementedError - - def download(self): - r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" - raise NotImplementedError - - def process(self): - r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" - raise NotImplementedError - - def len(self) -> int: - raise NotImplementedError - - def get(self, idx: int) -> Data: - r"""Gets the data object at index :obj:`idx`.""" - raise NotImplementedError - - def __init__( - self, - root: Optional[str] = None, - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None, - ): - super().__init__() - - if isinstance(root, str): - root = osp.expanduser(osp.normpath(root)) - - self.root = root - self.transform = transform - self.pre_transform = pre_transform - self.pre_filter = pre_filter - self._indices: Optional[Sequence] = None - - if "download" in self.__class__.__dict__.keys(): - self._download() - - if "process" in self.__class__.__dict__.keys(): - self._process() - - def indices(self) -> Sequence: - return range(self.len()) if self._indices is None else self._indices - - @property - def raw_dir(self) -> str: - return osp.join(self.root, "raw") - - @property - def processed_dir(self) -> str: - return osp.join(self.root, "processed") - - @property - def num_node_features(self) -> int: - r"""Returns the number of features per node in the dataset.""" - data = self[0] - if hasattr(data, "num_node_features"): - return data.num_node_features - raise AttributeError( - f"'{data.__class__.__name__}' object has no " - f"attribute 'num_node_features'" - ) - - @property - def num_features(self) -> int: - r"""Alias for :py:attr:`~num_node_features`.""" - return self.num_node_features - - @property - def num_edge_features(self) -> int: - r"""Returns the number of features per edge in the dataset.""" - data = self[0] - if hasattr(data, "num_edge_features"): - return data.num_edge_features - raise AttributeError( - f"'{data.__class__.__name__}' object has no " - f"attribute 'num_edge_features'" - ) - - @property - def raw_paths(self) -> List[str]: - r"""The filepaths to find in order to skip the download.""" - files = to_list(self.raw_file_names) - return [osp.join(self.raw_dir, f) for f in files] - - @property - def processed_paths(self) -> List[str]: - r"""The filepaths to find in the :obj:`self.processed_dir` - folder in order to skip the processing.""" - files = to_list(self.processed_file_names) - return [osp.join(self.processed_dir, f) for f in files] - - def _download(self): - if files_exist(self.raw_paths): # pragma: no cover - return - - makedirs(self.raw_dir) - self.download() - - def _process(self): - f = osp.join(self.processed_dir, "pre_transform.pt") - if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): - warnings.warn( - f"The `pre_transform` argument differs from the one used in " - f"the pre-processed version of this dataset. If you want to " - f"make use of another pre-processing technique, make sure to " - f"sure to delete '{self.processed_dir}' first" - ) - - f = osp.join(self.processed_dir, "pre_filter.pt") - if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): - warnings.warn( - "The `pre_filter` argument differs from the one used in the " - "pre-processed version of this dataset. If you want to make " - "use of another pre-fitering technique, make sure to delete " - "'{self.processed_dir}' first" - ) - - if files_exist(self.processed_paths): # pragma: no cover - return - - print("Processing...") - - makedirs(self.processed_dir) - self.process() - - path = osp.join(self.processed_dir, "pre_transform.pt") - torch.save(_repr(self.pre_transform), path) - path = osp.join(self.processed_dir, "pre_filter.pt") - torch.save(_repr(self.pre_filter), path) - - print("Done!") - - def __len__(self) -> int: - r"""The number of examples in the dataset.""" - return len(self.indices()) - - def __getitem__( - self, - idx: Union[int, np.integer, IndexType], - ) -> Union["Dataset", Data]: - r"""In case :obj:`idx` is of type integer, will return the data object - at index :obj:`idx` (and transforms it in case :obj:`transform` is - present). - In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a - tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy - :obj:`np.array`, will return a subset of the dataset at the specified - indices.""" - if ( - isinstance(idx, (int, np.integer)) - or (isinstance(idx, Tensor) and idx.dim() == 0) - or (isinstance(idx, np.ndarray) and np.isscalar(idx)) - ): - data = self.get(self.indices()[idx]) - data = data if self.transform is None else self.transform(data) - return data - - else: - return self.index_select(idx) - - def index_select(self, idx: IndexType) -> "Dataset": - indices = self.indices() - - if isinstance(idx, slice): - indices = indices[idx] - - elif isinstance(idx, Tensor) and idx.dtype == torch.long: - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, Tensor) and idx.dtype == torch.bool: - idx = idx.flatten().nonzero(as_tuple=False) - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: - idx = idx.flatten().nonzero()[0] - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, Sequence) and not isinstance(idx, str): - indices = [indices[i] for i in idx] - - else: - raise IndexError( - f"Only integers, slices (':'), list, tuples, torch.tensor and " - f"np.ndarray of dtype long or bool are valid indices (got " - f"'{type(idx).__name__}')" - ) - - dataset = copy.copy(self) - dataset._indices = indices - return dataset - - def shuffle( - self, - return_perm: bool = False, - ) -> Union["Dataset", Tuple["Dataset", Tensor]]: - r"""Randomly shuffles the examples in the dataset. - - Args: - return_perm (bool, optional): If set to :obj:`True`, will return - the random permutation used to shuffle the dataset in addition. - (default: :obj:`False`) - """ - perm = torch.randperm(len(self)) - dataset = self.index_select(perm) - return (dataset, perm) if return_perm is True else dataset - - def __repr__(self) -> str: - arg_repr = str(len(self)) if len(self) > 1 else "" - return f"{self.__class__.__name__}({arg_repr})" - - -def to_list(value: Any) -> Sequence: - if isinstance(value, Sequence) and not isinstance(value, str): - return value - else: - return [value] - - -def files_exist(files: List[str]) -> bool: - # NOTE: We return `False` in case `files` is empty, leading to a - # re-processing of files on every instantiation. - return len(files) != 0 and all([osp.exists(f) for f in files]) - - -def _repr(obj: Any) -> str: - if obj is None: - return "None" - return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) +import copy +import os.path as osp +import re +import warnings +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import torch.utils.data +from torch import Tensor + +from .data import Data +from .utils import makedirs + +IndexType = Union[slice, Tensor, np.ndarray, Sequence] + + +class Dataset(torch.utils.data.Dataset): + r"""Dataset base class for creating graph datasets. + See `here `__ for the accompanying tutorial. + + Args: + root (string, optional): Root directory where the dataset should be + saved. (optional: :obj:`None`) + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + """ + + @property + def raw_file_names(self) -> Union[str, List[str], Tuple]: + r"""The name of the files to find in the :obj:`self.raw_dir` folder in + order to skip the download.""" + raise NotImplementedError + + @property + def processed_file_names(self) -> Union[str, List[str], Tuple]: + r"""The name of the files to find in the :obj:`self.processed_dir` + folder in order to skip the processing.""" + raise NotImplementedError + + def download(self): + r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" + raise NotImplementedError + + def process(self): + r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" + raise NotImplementedError + + def len(self) -> int: + raise NotImplementedError + + def get(self, idx: int) -> Data: + r"""Gets the data object at index :obj:`idx`.""" + raise NotImplementedError + + def __init__( + self, + root: Optional[str] = None, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + ): + super().__init__() + + if isinstance(root, str): + root = osp.expanduser(osp.normpath(root)) + + self.root = root + self.transform = transform + self.pre_transform = pre_transform + self.pre_filter = pre_filter + self._indices: Optional[Sequence] = None + + if "download" in self.__class__.__dict__.keys(): + self._download() + + if "process" in self.__class__.__dict__.keys(): + self._process() + + def indices(self) -> Sequence: + return range(self.len()) if self._indices is None else self._indices + + @property + def raw_dir(self) -> str: + return osp.join(self.root, "raw") + + @property + def processed_dir(self) -> str: + return osp.join(self.root, "processed") + + @property + def num_node_features(self) -> int: + r"""Returns the number of features per node in the dataset.""" + data = self[0] + if hasattr(data, "num_node_features"): + return data.num_node_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_node_features'" + ) + + @property + def num_features(self) -> int: + r"""Alias for :py:attr:`~num_node_features`.""" + return self.num_node_features + + @property + def num_edge_features(self) -> int: + r"""Returns the number of features per edge in the dataset.""" + data = self[0] + if hasattr(data, "num_edge_features"): + return data.num_edge_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_edge_features'" + ) + + @property + def raw_paths(self) -> List[str]: + r"""The filepaths to find in order to skip the download.""" + files = to_list(self.raw_file_names) + return [osp.join(self.raw_dir, f) for f in files] + + @property + def processed_paths(self) -> List[str]: + r"""The filepaths to find in the :obj:`self.processed_dir` + folder in order to skip the processing.""" + files = to_list(self.processed_file_names) + return [osp.join(self.processed_dir, f) for f in files] + + def _download(self): + if files_exist(self.raw_paths): # pragma: no cover + return + + makedirs(self.raw_dir) + self.download() + + def _process(self): + f = osp.join(self.processed_dir, "pre_transform.pt") + if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): + warnings.warn( + f"The `pre_transform` argument differs from the one used in " + f"the pre-processed version of this dataset. If you want to " + f"make use of another pre-processing technique, make sure to " + f"sure to delete '{self.processed_dir}' first" + ) + + f = osp.join(self.processed_dir, "pre_filter.pt") + if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): + warnings.warn( + "The `pre_filter` argument differs from the one used in the " + "pre-processed version of this dataset. If you want to make " + "use of another pre-fitering technique, make sure to delete " + "'{self.processed_dir}' first" + ) + + if files_exist(self.processed_paths): # pragma: no cover + return + + print("Processing...") + + makedirs(self.processed_dir) + self.process() + + path = osp.join(self.processed_dir, "pre_transform.pt") + torch.save(_repr(self.pre_transform), path) + path = osp.join(self.processed_dir, "pre_filter.pt") + torch.save(_repr(self.pre_filter), path) + + print("Done!") + + def __len__(self) -> int: + r"""The number of examples in the dataset.""" + return len(self.indices()) + + def __getitem__( + self, + idx: Union[int, np.integer, IndexType], + ) -> Union["Dataset", Data]: + r"""In case :obj:`idx` is of type integer, will return the data object + at index :obj:`idx` (and transforms it in case :obj:`transform` is + present). + In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a + tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy + :obj:`np.array`, will return a subset of the dataset at the specified + indices.""" + if ( + isinstance(idx, (int, np.integer)) + or (isinstance(idx, Tensor) and idx.dim() == 0) + or (isinstance(idx, np.ndarray) and np.isscalar(idx)) + ): + data = self.get(self.indices()[idx]) + data = data if self.transform is None else self.transform(data) + return data + + else: + return self.index_select(idx) + + def index_select(self, idx: IndexType) -> "Dataset": + indices = self.indices() + + if isinstance(idx, slice): + indices = indices[idx] + + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, Tensor) and idx.dtype == torch.bool: + idx = idx.flatten().nonzero(as_tuple=False) + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + idx = idx.flatten().nonzero()[0] + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + indices = [indices[i] for i in idx] + + else: + raise IndexError( + f"Only integers, slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + dataset = copy.copy(self) + dataset._indices = indices + return dataset + + def shuffle( + self, + return_perm: bool = False, + ) -> Union["Dataset", Tuple["Dataset", Tensor]]: + r"""Randomly shuffles the examples in the dataset. + + Args: + return_perm (bool, optional): If set to :obj:`True`, will return + the random permutation used to shuffle the dataset in addition. + (default: :obj:`False`) + """ + perm = torch.randperm(len(self)) + dataset = self.index_select(perm) + return (dataset, perm) if return_perm is True else dataset + + def __repr__(self) -> str: + arg_repr = str(len(self)) if len(self) > 1 else "" + return f"{self.__class__.__name__}({arg_repr})" + + +def to_list(value: Any) -> Sequence: + if isinstance(value, Sequence) and not isinstance(value, str): + return value + else: + return [value] + + +def files_exist(files: List[str]) -> bool: + # NOTE: We return `False` in case `files` is empty, leading to a + # re-processing of files on every instantiation. + return len(files) != 0 and all([osp.exists(f) for f in files]) + + +def _repr(obj: Any) -> str: + if obj is None: + return "None" + return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/seed.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/seed.py index 6819fda8c7feaeb8bc799051ab0dbd2670c1d8b4..be27fcaa1636632a9c95022990b4d9a9ac21744d 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/seed.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/seed.py @@ -1,17 +1,17 @@ -import random - -import numpy as np -import torch - - -def seed_everything(seed: int): - r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, - :obj:`numpy` and Python. - - Args: - seed (int): The desired seed. - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) +import random + +import numpy as np +import torch + + +def seed_everything(seed: int): + r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, + :obj:`numpy` and Python. + + Args: + seed (int): The desired seed. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/utils.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/utils.py index 3efc3540d3be5d154062a3c3ee483fbba0754fab..f53b8f8098a1efcd39669b9dbe94dc0399a2190f 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/utils.py @@ -1,54 +1,54 @@ -import os -import os.path as osp -import ssl -import urllib -import zipfile - - -def makedirs(dir): - os.makedirs(dir, exist_ok=True) - - -def download_url(url, folder, log=True): - r"""Downloads the content of an URL to a specific folder. - - Args: - url (string): The url. - folder (string): The folder. - log (bool, optional): If :obj:`False`, will not print anything to the - console. (default: :obj:`True`) - """ - - filename = url.rpartition("/")[2].split("?")[0] - path = osp.join(folder, filename) - - if osp.exists(path): # pragma: no cover - if log: - print("Using exist file", filename) - return path - - if log: - print("Downloading", url) - - makedirs(folder) - - context = ssl._create_unverified_context() - data = urllib.request.urlopen(url, context=context) - - with open(path, "wb") as f: - f.write(data.read()) - - return path - - -def extract_zip(path, folder, log=True): - r"""Extracts a zip archive to a specific folder. - - Args: - path (string): The path to the tar archive. - folder (string): The folder. - log (bool, optional): If :obj:`False`, will not print anything to the - console. (default: :obj:`True`) - """ - with zipfile.ZipFile(path, "r") as f: - f.extractall(folder) +import os +import os.path as osp +import ssl +import urllib +import zipfile + + +def makedirs(dir): + os.makedirs(dir, exist_ok=True) + + +def download_url(url, folder, log=True): + r"""Downloads the content of an URL to a specific folder. + + Args: + url (string): The url. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + + filename = url.rpartition("/")[2].split("?")[0] + path = osp.join(folder, filename) + + if osp.exists(path): # pragma: no cover + if log: + print("Using exist file", filename) + return path + + if log: + print("Downloading", url) + + makedirs(folder) + + context = ssl._create_unverified_context() + data = urllib.request.urlopen(url, context=context) + + with open(path, "wb") as f: + f.write(data.read()) + + return path + + +def extract_zip(path, folder, log=True): + r"""Extracts a zip archive to a specific folder. + + Args: + path (string): The path to the tar archive. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + with zipfile.ZipFile(path, "r") as f: + f.extractall(folder) diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_tools.py b/mace-bench/3rdparty/mace/mace/tools/torch_tools.py index 380a1da63d223a432d7693ee71e57a1fd35c4743..2ab339ef81808a44c6b26b8daa59ccc204eee57d 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_tools.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_tools.py @@ -1,153 +1,153 @@ -########################################################################################### -# Tools for torch -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging -from contextlib import contextmanager -from typing import Dict, Union - -import numpy as np -import torch -from e3nn.io import CartesianTensor - -TensorDict = Dict[str, torch.Tensor] - - -def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: - """ - Generates one-hot encoding with classes from - :param indices: (N x 1) tensor - :param num_classes: number of classes - :param device: torch device - :return: (N x num_classes) tensor - """ - shape = indices.shape[:-1] + (num_classes,) - oh = torch.zeros(shape, device=indices.device).view(shape) - - # scatter_ is the in-place version of scatter - oh.scatter_(dim=-1, index=indices, value=1) - - return oh.view(*shape) - - -def count_parameters(module: torch.nn.Module) -> int: - return int(sum(np.prod(p.shape) for p in module.parameters())) - - -def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict: - return {k: v.to(device) if v is not None else None for k, v in td.items()} - - -def set_seeds(seed: int) -> None: - np.random.seed(seed) - torch.manual_seed(seed) - - -def to_numpy(t: torch.Tensor) -> np.ndarray: - return t.cpu().detach().numpy() - - -def init_device(device_str: str) -> torch.device: - if "cuda" in device_str: - assert torch.cuda.is_available(), "No CUDA device available!" - if ":" in device_str: - # Check if the desired device is available - assert int(device_str.split(":")[-1]) < torch.cuda.device_count() - logging.info( - f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}" - ) - torch.cuda.init() - return torch.device(device_str) - if device_str == "mps": - assert torch.backends.mps.is_available(), "No MPS backend is available!" - logging.info("Using MPS GPU acceleration") - return torch.device("mps") - if device_str == "xpu": - torch.xpu.is_available() - return torch.device("xpu") - - logging.info("Using CPU") - return torch.device("cpu") - - -dtype_dict = {"float32": torch.float32, "float64": torch.float64} - - -def set_default_dtype(dtype: str) -> None: - torch.set_default_dtype(dtype_dict[dtype]) - - -def spherical_to_cartesian(t: torch.Tensor): - """ - Convert spherical notation to cartesian notation - """ - stress_cart_tensor = CartesianTensor("ij=ji") - stress_rtp = stress_cart_tensor.reduced_tensor_products() - return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) - - -def cartesian_to_spherical(t: torch.Tensor): - """ - Convert cartesian notation to spherical notation - """ - stress_cart_tensor = CartesianTensor("ij=ji") - stress_rtp = stress_cart_tensor.reduced_tensor_products() - return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) - - -def voigt_to_matrix(t: torch.Tensor): - """ - Convert voigt notation to matrix notation - :param t: (6,) tensor or (3, 3) tensor or (9,) tensor - :return: (3, 3) tensor - """ - if t.shape == (3, 3): - return t - if t.shape == (6,): - return torch.tensor( - [ - [t[0], t[5], t[4]], - [t[5], t[1], t[3]], - [t[4], t[3], t[2]], - ], - dtype=t.dtype, - ) - if t.shape == (9,): - return t.view(3, 3) - - raise ValueError( - f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}" - ) - - -def init_wandb(project: str, entity: str, name: str, config: dict, directory: str): - import wandb - - wandb.init( - project=project, - entity=entity, - name=name, - config=config, - dir=directory, - resume="allow", - ) - - -@contextmanager -def default_dtype(dtype: Union[torch.dtype, str]): - """Context manager for configuring the default_dtype used by torch - - Args: - dtype (torch.dtype|str): the default dtype to use within this context manager - """ - init = torch.get_default_dtype() - if isinstance(dtype, str): - set_default_dtype(dtype) - else: - torch.set_default_dtype(dtype) - - yield - - torch.set_default_dtype(init) +########################################################################################### +# Tools for torch +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from contextlib import contextmanager +from typing import Dict, Union + +import numpy as np +import torch +from e3nn.io import CartesianTensor + +TensorDict = Dict[str, torch.Tensor] + + +def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: + """ + Generates one-hot encoding with classes from + :param indices: (N x 1) tensor + :param num_classes: number of classes + :param device: torch device + :return: (N x num_classes) tensor + """ + shape = indices.shape[:-1] + (num_classes,) + oh = torch.zeros(shape, device=indices.device).view(shape) + + # scatter_ is the in-place version of scatter + oh.scatter_(dim=-1, index=indices, value=1) + + return oh.view(*shape) + + +def count_parameters(module: torch.nn.Module) -> int: + return int(sum(np.prod(p.shape) for p in module.parameters())) + + +def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict: + return {k: v.to(device) if v is not None else None for k, v in td.items()} + + +def set_seeds(seed: int) -> None: + np.random.seed(seed) + torch.manual_seed(seed) + + +def to_numpy(t: torch.Tensor) -> np.ndarray: + return t.cpu().detach().numpy() + + +def init_device(device_str: str) -> torch.device: + if "cuda" in device_str: + assert torch.cuda.is_available(), "No CUDA device available!" + if ":" in device_str: + # Check if the desired device is available + assert int(device_str.split(":")[-1]) < torch.cuda.device_count() + logging.info( + f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}" + ) + torch.cuda.init() + return torch.device(device_str) + if device_str == "mps": + assert torch.backends.mps.is_available(), "No MPS backend is available!" + logging.info("Using MPS GPU acceleration") + return torch.device("mps") + if device_str == "xpu": + torch.xpu.is_available() + return torch.device("xpu") + + logging.info("Using CPU") + return torch.device("cpu") + + +dtype_dict = {"float32": torch.float32, "float64": torch.float64} + + +def set_default_dtype(dtype: str) -> None: + torch.set_default_dtype(dtype_dict[dtype]) + + +def spherical_to_cartesian(t: torch.Tensor): + """ + Convert spherical notation to cartesian notation + """ + stress_cart_tensor = CartesianTensor("ij=ji") + stress_rtp = stress_cart_tensor.reduced_tensor_products() + return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) + + +def cartesian_to_spherical(t: torch.Tensor): + """ + Convert cartesian notation to spherical notation + """ + stress_cart_tensor = CartesianTensor("ij=ji") + stress_rtp = stress_cart_tensor.reduced_tensor_products() + return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) + + +def voigt_to_matrix(t: torch.Tensor): + """ + Convert voigt notation to matrix notation + :param t: (6,) tensor or (3, 3) tensor or (9,) tensor + :return: (3, 3) tensor + """ + if t.shape == (3, 3): + return t + if t.shape == (6,): + return torch.tensor( + [ + [t[0], t[5], t[4]], + [t[5], t[1], t[3]], + [t[4], t[3], t[2]], + ], + dtype=t.dtype, + ) + if t.shape == (9,): + return t.view(3, 3) + + raise ValueError( + f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}" + ) + + +def init_wandb(project: str, entity: str, name: str, config: dict, directory: str): + import wandb + + wandb.init( + project=project, + entity=entity, + name=name, + config=config, + dir=directory, + resume="allow", + ) + + +@contextmanager +def default_dtype(dtype: Union[torch.dtype, str]): + """Context manager for configuring the default_dtype used by torch + + Args: + dtype (torch.dtype|str): the default dtype to use within this context manager + """ + init = torch.get_default_dtype() + if isinstance(dtype, str): + set_default_dtype(dtype) + else: + torch.set_default_dtype(dtype) + + yield + + torch.set_default_dtype(init) diff --git a/mace-bench/3rdparty/mace/mace/tools/train.py b/mace-bench/3rdparty/mace/mace/tools/train.py index 0c3916b358e0966a58b1a53999a5f9932d63c2aa..c7c17e136952d3408ac9914d6bb6498b7b048ae0 100644 --- a/mace-bench/3rdparty/mace/mace/tools/train.py +++ b/mace-bench/3rdparty/mace/mace/tools/train.py @@ -1,669 +1,669 @@ -########################################################################################### -# Training script -# Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import dataclasses -import logging -import time -from contextlib import nullcontext -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.distributed -from torch.nn.parallel import DistributedDataParallel -from torch.optim import LBFGS -from torch.optim.swa_utils import SWALR, AveragedModel -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from torch_ema import ExponentialMovingAverage -from torchmetrics import Metric - -from mace.cli.visualise_train import TrainingPlotter - -from . import torch_geometric -from .checkpoint import CheckpointHandler, CheckpointState -from .torch_tools import to_numpy -from .utils import ( - MetricsLogger, - compute_mae, - compute_q95, - compute_rel_mae, - compute_rel_rmse, - compute_rmse, -) - - -@dataclasses.dataclass -class SWAContainer: - model: AveragedModel - scheduler: SWALR - start: int - loss_fn: torch.nn.Module - - -def valid_err_log( - valid_loss, - eval_metrics, - logger, - log_errors, - epoch=None, - valid_loader_name="Default", -): - eval_metrics["mode"] = "eval" - eval_metrics["epoch"] = epoch - eval_metrics["head"] = valid_loader_name - logger.log(eval_metrics) - if epoch is None: - inintial_phrase = "Initial" - else: - inintial_phrase = f"Epoch {epoch}" - if log_errors == "PerAtomRMSE": - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A" - ) - elif ( - log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_stress"] is not None - ): - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_stress = eval_metrics["rmse_stress"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_stress={error_stress:8.2f} meV / A^3", - ) - elif ( - log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_virials_per_atom"] is not None - ): - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_virials_per_atom={error_virials:8.2f} meV", - ) - elif ( - log_errors == "PerAtomMAEstressvirials" - and eval_metrics["mae_stress_per_atom"] is not None - ): - error_e = eval_metrics["mae_e_per_atom"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - error_stress = eval_metrics["mae_stress"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_stress={error_stress:8.2f} meV / A^3" - ) - elif ( - log_errors == "PerAtomMAEstressvirials" - and eval_metrics["mae_virials_per_atom"] is not None - ): - error_e = eval_metrics["mae_e_per_atom"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - error_virials = eval_metrics["mae_virials"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_virials={error_virials:8.2f} meV" - ) - elif log_errors == "TotalRMSE": - error_e = eval_metrics["rmse_e"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A", - ) - elif log_errors == "PerAtomMAE": - error_e = eval_metrics["mae_e_per_atom"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", - ) - elif log_errors == "TotalMAE": - error_e = eval_metrics["mae_e"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", - ) - elif log_errors == "DipoleRMSE": - error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", - ) - elif log_errors == "EnergyDipoleRMSE": - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", - ) - - -def train( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - train_loader: DataLoader, - valid_loaders: Dict[str, DataLoader], - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.ExponentialLR, - start_epoch: int, - max_num_epochs: int, - patience: int, - checkpoint_handler: CheckpointHandler, - logger: MetricsLogger, - eval_interval: int, - output_args: Dict[str, bool], - device: torch.device, - log_errors: str, - swa: Optional[SWAContainer] = None, - ema: Optional[ExponentialMovingAverage] = None, - max_grad_norm: Optional[float] = 10.0, - log_wandb: bool = False, - distributed: bool = False, - save_all_checkpoints: bool = False, - plotter: TrainingPlotter = None, - distributed_model: Optional[DistributedDataParallel] = None, - train_sampler: Optional[DistributedSampler] = None, - rank: Optional[int] = 0, -): - lowest_loss = np.inf - valid_loss = np.inf - patience_counter = 0 - swa_start = True - keep_last = False - if log_wandb: - import wandb - - if max_grad_norm is not None: - logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") - - logging.info("") - logging.info("===========TRAINING===========") - logging.info("Started training, reporting errors on validation set") - logging.info("Loss metrics on validation set") - epoch = start_epoch - - # log validation loss before _any_ training - for valid_loader_name, valid_loader in valid_loaders.items(): - valid_loss_head, eval_metrics = evaluate( - model=model, - loss_fn=loss_fn, - data_loader=valid_loader, - output_args=output_args, - device=device, - ) - valid_err_log( - valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name - ) - valid_loss = valid_loss_head # consider only the last head for the checkpoint - - while epoch < max_num_epochs: - # LR scheduler and SWA update - if swa is None or epoch < swa.start: - if epoch > start_epoch: - lr_scheduler.step( - metrics=valid_loss - ) # Can break if exponential LR, TODO fix that! - else: - if swa_start: - logging.info("Changing loss based on Stage Two Weights") - lowest_loss = np.inf - swa_start = False - keep_last = True - loss_fn = swa.loss_fn - swa.model.update_parameters(model) - if epoch > start_epoch: - swa.scheduler.step() - - # Train - if distributed: - train_sampler.set_epoch(epoch) - if "ScheduleFree" in type(optimizer).__name__: - optimizer.train() - train_one_epoch( - model=model, - loss_fn=loss_fn, - data_loader=train_loader, - optimizer=optimizer, - epoch=epoch, - output_args=output_args, - max_grad_norm=max_grad_norm, - ema=ema, - logger=logger, - device=device, - distributed=distributed, - distributed_model=distributed_model, - rank=rank, - ) - if distributed: - torch.distributed.barrier() - - # Validate - if epoch % eval_interval == 0: - model_to_evaluate = ( - model if distributed_model is None else distributed_model - ) - param_context = ( - ema.average_parameters() if ema is not None else nullcontext() - ) - if "ScheduleFree" in type(optimizer).__name__: - optimizer.eval() - with param_context: - wandb_log_dict = {} - for valid_loader_name, valid_loader in valid_loaders.items(): - valid_loss_head, eval_metrics = evaluate( - model=model_to_evaluate, - loss_fn=loss_fn, - data_loader=valid_loader, - output_args=output_args, - device=device, - ) - if rank == 0: - valid_err_log( - valid_loss_head, - eval_metrics, - logger, - log_errors, - epoch, - valid_loader_name, - ) - if log_wandb: - wandb_log_dict[valid_loader_name] = { - "epoch": epoch, - "valid_loss": valid_loss_head, - "valid_rmse_e_per_atom": eval_metrics[ - "rmse_e_per_atom" - ], - "valid_rmse_f": eval_metrics["rmse_f"], - } - if plotter and epoch % plotter.plot_frequency == 0: - try: - plotter.plot(epoch, model_to_evaluate, rank) - except Exception as e: # pylint: disable=broad-except - logging.debug(f"Plotting failed: {e}") - valid_loss = ( - valid_loss_head # consider only the last head for the checkpoint - ) - if log_wandb: - wandb.log(wandb_log_dict) - if rank == 0: - if valid_loss >= lowest_loss: - patience_counter += 1 - if patience_counter >= patience: - if swa is not None and epoch < swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" - ) - epoch = swa.start - else: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement" - ) - break - if save_all_checkpoints: - param_context = ( - ema.average_parameters() - if ema is not None - else nullcontext() - ) - with param_context: - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=True, - ) - else: - lowest_loss = valid_loss - patience_counter = 0 - param_context = ( - ema.average_parameters() if ema is not None else nullcontext() - ) - with param_context: - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=keep_last, - ) - keep_last = False or save_all_checkpoints - if distributed: - torch.distributed.barrier() - epoch += 1 - - logging.info("Training complete") - - -def train_one_epoch( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - data_loader: DataLoader, - optimizer: torch.optim.Optimizer, - epoch: int, - output_args: Dict[str, bool], - max_grad_norm: Optional[float], - ema: Optional[ExponentialMovingAverage], - logger: MetricsLogger, - device: torch.device, - distributed: bool, - distributed_model: Optional[DistributedDataParallel] = None, - rank: Optional[int] = 0, -) -> None: - model_to_train = model if distributed_model is None else distributed_model - - if isinstance(optimizer, LBFGS): - _, opt_metrics = take_step_lbfgs( - model=model_to_train, - loss_fn=loss_fn, - data_loader=data_loader, - optimizer=optimizer, - ema=ema, - output_args=output_args, - max_grad_norm=max_grad_norm, - device=device, - distributed=distributed, - rank=rank, - ) - opt_metrics["mode"] = "opt" - opt_metrics["epoch"] = epoch - if rank == 0: - logger.log(opt_metrics) - else: - for batch in data_loader: - _, opt_metrics = take_step( - model=model_to_train, - loss_fn=loss_fn, - batch=batch, - optimizer=optimizer, - ema=ema, - output_args=output_args, - max_grad_norm=max_grad_norm, - device=device, - ) - opt_metrics["mode"] = "opt" - opt_metrics["epoch"] = epoch - if rank == 0: - logger.log(opt_metrics) - - -def take_step( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - batch: torch_geometric.batch.Batch, - optimizer: torch.optim.Optimizer, - ema: Optional[ExponentialMovingAverage], - output_args: Dict[str, bool], - max_grad_norm: Optional[float], - device: torch.device, -) -> Tuple[float, Dict[str, Any]]: - start_time = time.time() - batch = batch.to(device) - batch_dict = batch.to_dict() - - def closure(): - optimizer.zero_grad(set_to_none=True) - output = model( - batch_dict, - training=True, - compute_force=output_args["forces"], - compute_virials=output_args["virials"], - compute_stress=output_args["stress"], - ) - loss = loss_fn(pred=output, ref=batch) - loss.backward() - if max_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) - - return loss - - loss = closure() - optimizer.step() - - if ema is not None: - ema.update() - - loss_dict = { - "loss": to_numpy(loss), - "time": time.time() - start_time, - } - - return loss, loss_dict - - -def take_step_lbfgs( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - data_loader: DataLoader, - optimizer: torch.optim.Optimizer, - ema: Optional[ExponentialMovingAverage], - output_args: Dict[str, bool], - max_grad_norm: Optional[float], - device: torch.device, - distributed: bool, - rank: int, -) -> Tuple[float, Dict[str, Any]]: - start_time = time.time() - logging.debug( - f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB" - ) - - total_sample_count = 0 - for batch in data_loader: - total_sample_count += batch.num_graphs - - if distributed: - global_sample_count = torch.tensor(total_sample_count, device=device) - torch.distributed.all_reduce( - global_sample_count, op=torch.distributed.ReduceOp.SUM - ) - total_sample_count = global_sample_count.item() - - signal = torch.zeros(1, device=device) if distributed else None - - def closure(): - if distributed: - if rank == 0: - signal.fill_(1) - torch.distributed.broadcast(signal, src=0) - - for param in model.parameters(): - torch.distributed.broadcast(param.data, src=0) - - optimizer.zero_grad(set_to_none=True) - total_loss = torch.tensor(0.0, device=device) - - # Process each batch and then collect the results we pass to the optimizer - for batch in data_loader: - batch = batch.to(device) - batch_dict = batch.to_dict() - output = model( - batch_dict, - training=True, - compute_force=output_args["forces"], - compute_virials=output_args["virials"], - compute_stress=output_args["stress"], - ) - batch_loss = loss_fn(pred=output, ref=batch) - batch_loss = batch_loss * (batch.num_graphs / total_sample_count) - - batch_loss.backward() - total_loss += batch_loss - - if max_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) - - if distributed: - torch.distributed.all_reduce(total_loss, op=torch.distributed.ReduceOp.SUM) - return total_loss - - if distributed: - if rank == 0: - loss = optimizer.step(closure) - signal.fill_(0) - torch.distributed.broadcast(signal, src=0) - else: - while True: - # Other ranks wait for signals from rank 0 - torch.distributed.broadcast(signal, src=0) - if signal.item() == 0: - break - if signal.item() == 1: - loss = closure() - - for param in model.parameters(): - torch.distributed.broadcast(param.data, src=0) - else: - loss = optimizer.step(closure) - - if ema is not None: - ema.update() - - loss_dict = { - "loss": to_numpy(loss), - "time": time.time() - start_time, - } - - return loss, loss_dict - - -def evaluate( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - data_loader: DataLoader, - output_args: Dict[str, bool], - device: torch.device, -) -> Tuple[float, Dict[str, Any]]: - for param in model.parameters(): - param.requires_grad = False - - metrics = MACELoss(loss_fn=loss_fn).to(device) - - start_time = time.time() - for batch in data_loader: - batch = batch.to(device) - batch_dict = batch.to_dict() - output = model( - batch_dict, - training=False, - compute_force=output_args["forces"], - compute_virials=output_args["virials"], - compute_stress=output_args["stress"], - ) - avg_loss, aux = metrics(batch, output) - - avg_loss, aux = metrics.compute() - aux["time"] = time.time() - start_time - metrics.reset() - - for param in model.parameters(): - param.requires_grad = True - - return avg_loss, aux - - -class MACELoss(Metric): - def __init__(self, loss_fn: torch.nn.Module): - super().__init__() - self.loss_fn = loss_fn - self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("num_data", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("E_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("delta_es", default=[], dist_reduce_fx="cat") - self.add_state("delta_es_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("Fs_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("fs", default=[], dist_reduce_fx="cat") - self.add_state("delta_fs", default=[], dist_reduce_fx="cat") - self.add_state( - "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" - ) - self.add_state("delta_stress", default=[], dist_reduce_fx="cat") - self.add_state( - "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" - ) - self.add_state("delta_virials", default=[], dist_reduce_fx="cat") - self.add_state("delta_virials_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("Mus_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("mus", default=[], dist_reduce_fx="cat") - self.add_state("delta_mus", default=[], dist_reduce_fx="cat") - self.add_state("delta_mus_per_atom", default=[], dist_reduce_fx="cat") - - def update(self, batch, output): # pylint: disable=arguments-differ - loss = self.loss_fn(pred=output, ref=batch) - self.total_loss += loss - self.num_data += batch.num_graphs - - if output.get("energy") is not None and batch.energy is not None: - self.E_computed += 1.0 - self.delta_es.append(batch.energy - output["energy"]) - self.delta_es_per_atom.append( - (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1]) - ) - if output.get("forces") is not None and batch.forces is not None: - self.Fs_computed += 1.0 - self.fs.append(batch.forces) - self.delta_fs.append(batch.forces - output["forces"]) - if output.get("stress") is not None and batch.stress is not None: - self.stress_computed += 1.0 - self.delta_stress.append(batch.stress - output["stress"]) - if output.get("virials") is not None and batch.virials is not None: - self.virials_computed += 1.0 - self.delta_virials.append(batch.virials - output["virials"]) - self.delta_virials_per_atom.append( - (batch.virials - output["virials"]) - / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) - ) - if output.get("dipole") is not None and batch.dipole is not None: - self.Mus_computed += 1.0 - self.mus.append(batch.dipole) - self.delta_mus.append(batch.dipole - output["dipole"]) - self.delta_mus_per_atom.append( - (batch.dipole - output["dipole"]) - / (batch.ptr[1:] - batch.ptr[:-1]).unsqueeze(-1) - ) - - def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray: - if isinstance(delta, list): - delta = torch.cat(delta) - return to_numpy(delta) - - def compute(self): - aux = {} - aux["loss"] = to_numpy(self.total_loss / self.num_data).item() - if self.E_computed: - delta_es = self.convert(self.delta_es) - delta_es_per_atom = self.convert(self.delta_es_per_atom) - aux["mae_e"] = compute_mae(delta_es) - aux["mae_e_per_atom"] = compute_mae(delta_es_per_atom) - aux["rmse_e"] = compute_rmse(delta_es) - aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom) - aux["q95_e"] = compute_q95(delta_es) - if self.Fs_computed: - fs = self.convert(self.fs) - delta_fs = self.convert(self.delta_fs) - aux["mae_f"] = compute_mae(delta_fs) - aux["rel_mae_f"] = compute_rel_mae(delta_fs, fs) - aux["rmse_f"] = compute_rmse(delta_fs) - aux["rel_rmse_f"] = compute_rel_rmse(delta_fs, fs) - aux["q95_f"] = compute_q95(delta_fs) - if self.stress_computed: - delta_stress = self.convert(self.delta_stress) - aux["mae_stress"] = compute_mae(delta_stress) - aux["rmse_stress"] = compute_rmse(delta_stress) - aux["q95_stress"] = compute_q95(delta_stress) - if self.virials_computed: - delta_virials = self.convert(self.delta_virials) - delta_virials_per_atom = self.convert(self.delta_virials_per_atom) - aux["mae_virials"] = compute_mae(delta_virials) - aux["rmse_virials"] = compute_rmse(delta_virials) - aux["rmse_virials_per_atom"] = compute_rmse(delta_virials_per_atom) - aux["q95_virials"] = compute_q95(delta_virials) - if self.Mus_computed: - mus = self.convert(self.mus) - delta_mus = self.convert(self.delta_mus) - delta_mus_per_atom = self.convert(self.delta_mus_per_atom) - aux["mae_mu"] = compute_mae(delta_mus) - aux["mae_mu_per_atom"] = compute_mae(delta_mus_per_atom) - aux["rel_mae_mu"] = compute_rel_mae(delta_mus, mus) - aux["rmse_mu"] = compute_rmse(delta_mus) - aux["rmse_mu_per_atom"] = compute_rmse(delta_mus_per_atom) - aux["rel_rmse_mu"] = compute_rel_rmse(delta_mus, mus) - aux["q95_mu"] = compute_q95(delta_mus) - - return aux["loss"], aux +########################################################################################### +# Training script +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import dataclasses +import logging +import time +from contextlib import nullcontext +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed +from torch.nn.parallel import DistributedDataParallel +from torch.optim import LBFGS +from torch.optim.swa_utils import SWALR, AveragedModel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch_ema import ExponentialMovingAverage +from torchmetrics import Metric + +from mace.cli.visualise_train import TrainingPlotter + +from . import torch_geometric +from .checkpoint import CheckpointHandler, CheckpointState +from .torch_tools import to_numpy +from .utils import ( + MetricsLogger, + compute_mae, + compute_q95, + compute_rel_mae, + compute_rel_rmse, + compute_rmse, +) + + +@dataclasses.dataclass +class SWAContainer: + model: AveragedModel + scheduler: SWALR + start: int + loss_fn: torch.nn.Module + + +def valid_err_log( + valid_loss, + eval_metrics, + logger, + log_errors, + epoch=None, + valid_loader_name="Default", +): + eval_metrics["mode"] = "eval" + eval_metrics["epoch"] = epoch + eval_metrics["head"] = valid_loader_name + logger.log(eval_metrics) + if epoch is None: + inintial_phrase = "Initial" + else: + inintial_phrase = f"Epoch {epoch}" + if log_errors == "PerAtomRMSE": + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A" + ) + elif ( + log_errors == "PerAtomRMSEstressvirials" + and eval_metrics["rmse_stress"] is not None + ): + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_stress = eval_metrics["rmse_stress"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_stress={error_stress:8.2f} meV / A^3", + ) + elif ( + log_errors == "PerAtomRMSEstressvirials" + and eval_metrics["rmse_virials_per_atom"] is not None + ): + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_virials_per_atom={error_virials:8.2f} meV", + ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_stress_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_stress = eval_metrics["mae_stress"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_stress={error_stress:8.2f} meV / A^3" + ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_virials_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_virials = eval_metrics["mae_virials"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_virials={error_virials:8.2f} meV" + ) + elif log_errors == "TotalRMSE": + error_e = eval_metrics["rmse_e"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A", + ) + elif log_errors == "PerAtomMAE": + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", + ) + elif log_errors == "TotalMAE": + error_e = eval_metrics["mae_e"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", + ) + elif log_errors == "DipoleRMSE": + error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", + ) + elif log_errors == "EnergyDipoleRMSE": + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", + ) + + +def train( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + train_loader: DataLoader, + valid_loaders: Dict[str, DataLoader], + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.ExponentialLR, + start_epoch: int, + max_num_epochs: int, + patience: int, + checkpoint_handler: CheckpointHandler, + logger: MetricsLogger, + eval_interval: int, + output_args: Dict[str, bool], + device: torch.device, + log_errors: str, + swa: Optional[SWAContainer] = None, + ema: Optional[ExponentialMovingAverage] = None, + max_grad_norm: Optional[float] = 10.0, + log_wandb: bool = False, + distributed: bool = False, + save_all_checkpoints: bool = False, + plotter: TrainingPlotter = None, + distributed_model: Optional[DistributedDataParallel] = None, + train_sampler: Optional[DistributedSampler] = None, + rank: Optional[int] = 0, +): + lowest_loss = np.inf + valid_loss = np.inf + patience_counter = 0 + swa_start = True + keep_last = False + if log_wandb: + import wandb + + if max_grad_norm is not None: + logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") + + logging.info("") + logging.info("===========TRAINING===========") + logging.info("Started training, reporting errors on validation set") + logging.info("Loss metrics on validation set") + epoch = start_epoch + + # log validation loss before _any_ training + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_head, eval_metrics = evaluate( + model=model, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + valid_err_log( + valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name + ) + valid_loss = valid_loss_head # consider only the last head for the checkpoint + + while epoch < max_num_epochs: + # LR scheduler and SWA update + if swa is None or epoch < swa.start: + if epoch > start_epoch: + lr_scheduler.step( + metrics=valid_loss + ) # Can break if exponential LR, TODO fix that! + else: + if swa_start: + logging.info("Changing loss based on Stage Two Weights") + lowest_loss = np.inf + swa_start = False + keep_last = True + loss_fn = swa.loss_fn + swa.model.update_parameters(model) + if epoch > start_epoch: + swa.scheduler.step() + + # Train + if distributed: + train_sampler.set_epoch(epoch) + if "ScheduleFree" in type(optimizer).__name__: + optimizer.train() + train_one_epoch( + model=model, + loss_fn=loss_fn, + data_loader=train_loader, + optimizer=optimizer, + epoch=epoch, + output_args=output_args, + max_grad_norm=max_grad_norm, + ema=ema, + logger=logger, + device=device, + distributed=distributed, + distributed_model=distributed_model, + rank=rank, + ) + if distributed: + torch.distributed.barrier() + + # Validate + if epoch % eval_interval == 0: + model_to_evaluate = ( + model if distributed_model is None else distributed_model + ) + param_context = ( + ema.average_parameters() if ema is not None else nullcontext() + ) + if "ScheduleFree" in type(optimizer).__name__: + optimizer.eval() + with param_context: + wandb_log_dict = {} + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_head, eval_metrics = evaluate( + model=model_to_evaluate, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + if rank == 0: + valid_err_log( + valid_loss_head, + eval_metrics, + logger, + log_errors, + epoch, + valid_loader_name, + ) + if log_wandb: + wandb_log_dict[valid_loader_name] = { + "epoch": epoch, + "valid_loss": valid_loss_head, + "valid_rmse_e_per_atom": eval_metrics[ + "rmse_e_per_atom" + ], + "valid_rmse_f": eval_metrics["rmse_f"], + } + if plotter and epoch % plotter.plot_frequency == 0: + try: + plotter.plot(epoch, model_to_evaluate, rank) + except Exception as e: # pylint: disable=broad-except + logging.debug(f"Plotting failed: {e}") + valid_loss = ( + valid_loss_head # consider only the last head for the checkpoint + ) + if log_wandb: + wandb.log(wandb_log_dict) + if rank == 0: + if valid_loss >= lowest_loss: + patience_counter += 1 + if patience_counter >= patience: + if swa is not None and epoch < swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" + ) + epoch = swa.start + else: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement" + ) + break + if save_all_checkpoints: + param_context = ( + ema.average_parameters() + if ema is not None + else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=True, + ) + else: + lowest_loss = valid_loss + patience_counter = 0 + param_context = ( + ema.average_parameters() if ema is not None else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=keep_last, + ) + keep_last = False or save_all_checkpoints + if distributed: + torch.distributed.barrier() + epoch += 1 + + logging.info("Training complete") + + +def train_one_epoch( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + data_loader: DataLoader, + optimizer: torch.optim.Optimizer, + epoch: int, + output_args: Dict[str, bool], + max_grad_norm: Optional[float], + ema: Optional[ExponentialMovingAverage], + logger: MetricsLogger, + device: torch.device, + distributed: bool, + distributed_model: Optional[DistributedDataParallel] = None, + rank: Optional[int] = 0, +) -> None: + model_to_train = model if distributed_model is None else distributed_model + + if isinstance(optimizer, LBFGS): + _, opt_metrics = take_step_lbfgs( + model=model_to_train, + loss_fn=loss_fn, + data_loader=data_loader, + optimizer=optimizer, + ema=ema, + output_args=output_args, + max_grad_norm=max_grad_norm, + device=device, + distributed=distributed, + rank=rank, + ) + opt_metrics["mode"] = "opt" + opt_metrics["epoch"] = epoch + if rank == 0: + logger.log(opt_metrics) + else: + for batch in data_loader: + _, opt_metrics = take_step( + model=model_to_train, + loss_fn=loss_fn, + batch=batch, + optimizer=optimizer, + ema=ema, + output_args=output_args, + max_grad_norm=max_grad_norm, + device=device, + ) + opt_metrics["mode"] = "opt" + opt_metrics["epoch"] = epoch + if rank == 0: + logger.log(opt_metrics) + + +def take_step( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + batch: torch_geometric.batch.Batch, + optimizer: torch.optim.Optimizer, + ema: Optional[ExponentialMovingAverage], + output_args: Dict[str, bool], + max_grad_norm: Optional[float], + device: torch.device, +) -> Tuple[float, Dict[str, Any]]: + start_time = time.time() + batch = batch.to(device) + batch_dict = batch.to_dict() + + def closure(): + optimizer.zero_grad(set_to_none=True) + output = model( + batch_dict, + training=True, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + loss = loss_fn(pred=output, ref=batch) + loss.backward() + if max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) + + return loss + + loss = closure() + optimizer.step() + + if ema is not None: + ema.update() + + loss_dict = { + "loss": to_numpy(loss), + "time": time.time() - start_time, + } + + return loss, loss_dict + + +def take_step_lbfgs( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + data_loader: DataLoader, + optimizer: torch.optim.Optimizer, + ema: Optional[ExponentialMovingAverage], + output_args: Dict[str, bool], + max_grad_norm: Optional[float], + device: torch.device, + distributed: bool, + rank: int, +) -> Tuple[float, Dict[str, Any]]: + start_time = time.time() + logging.debug( + f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB" + ) + + total_sample_count = 0 + for batch in data_loader: + total_sample_count += batch.num_graphs + + if distributed: + global_sample_count = torch.tensor(total_sample_count, device=device) + torch.distributed.all_reduce( + global_sample_count, op=torch.distributed.ReduceOp.SUM + ) + total_sample_count = global_sample_count.item() + + signal = torch.zeros(1, device=device) if distributed else None + + def closure(): + if distributed: + if rank == 0: + signal.fill_(1) + torch.distributed.broadcast(signal, src=0) + + for param in model.parameters(): + torch.distributed.broadcast(param.data, src=0) + + optimizer.zero_grad(set_to_none=True) + total_loss = torch.tensor(0.0, device=device) + + # Process each batch and then collect the results we pass to the optimizer + for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=True, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + batch_loss = loss_fn(pred=output, ref=batch) + batch_loss = batch_loss * (batch.num_graphs / total_sample_count) + + batch_loss.backward() + total_loss += batch_loss + + if max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) + + if distributed: + torch.distributed.all_reduce(total_loss, op=torch.distributed.ReduceOp.SUM) + return total_loss + + if distributed: + if rank == 0: + loss = optimizer.step(closure) + signal.fill_(0) + torch.distributed.broadcast(signal, src=0) + else: + while True: + # Other ranks wait for signals from rank 0 + torch.distributed.broadcast(signal, src=0) + if signal.item() == 0: + break + if signal.item() == 1: + loss = closure() + + for param in model.parameters(): + torch.distributed.broadcast(param.data, src=0) + else: + loss = optimizer.step(closure) + + if ema is not None: + ema.update() + + loss_dict = { + "loss": to_numpy(loss), + "time": time.time() - start_time, + } + + return loss, loss_dict + + +def evaluate( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + data_loader: DataLoader, + output_args: Dict[str, bool], + device: torch.device, +) -> Tuple[float, Dict[str, Any]]: + for param in model.parameters(): + param.requires_grad = False + + metrics = MACELoss(loss_fn=loss_fn).to(device) + + start_time = time.time() + for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=False, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + avg_loss, aux = metrics(batch, output) + + avg_loss, aux = metrics.compute() + aux["time"] = time.time() - start_time + metrics.reset() + + for param in model.parameters(): + param.requires_grad = True + + return avg_loss, aux + + +class MACELoss(Metric): + def __init__(self, loss_fn: torch.nn.Module): + super().__init__() + self.loss_fn = loss_fn + self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("num_data", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("E_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("delta_es", default=[], dist_reduce_fx="cat") + self.add_state("delta_es_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("Fs_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("fs", default=[], dist_reduce_fx="cat") + self.add_state("delta_fs", default=[], dist_reduce_fx="cat") + self.add_state( + "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("delta_stress", default=[], dist_reduce_fx="cat") + self.add_state( + "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("delta_virials", default=[], dist_reduce_fx="cat") + self.add_state("delta_virials_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("Mus_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("mus", default=[], dist_reduce_fx="cat") + self.add_state("delta_mus", default=[], dist_reduce_fx="cat") + self.add_state("delta_mus_per_atom", default=[], dist_reduce_fx="cat") + + def update(self, batch, output): # pylint: disable=arguments-differ + loss = self.loss_fn(pred=output, ref=batch) + self.total_loss += loss + self.num_data += batch.num_graphs + + if output.get("energy") is not None and batch.energy is not None: + self.E_computed += 1.0 + self.delta_es.append(batch.energy - output["energy"]) + self.delta_es_per_atom.append( + (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1]) + ) + if output.get("forces") is not None and batch.forces is not None: + self.Fs_computed += 1.0 + self.fs.append(batch.forces) + self.delta_fs.append(batch.forces - output["forces"]) + if output.get("stress") is not None and batch.stress is not None: + self.stress_computed += 1.0 + self.delta_stress.append(batch.stress - output["stress"]) + if output.get("virials") is not None and batch.virials is not None: + self.virials_computed += 1.0 + self.delta_virials.append(batch.virials - output["virials"]) + self.delta_virials_per_atom.append( + (batch.virials - output["virials"]) + / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) + ) + if output.get("dipole") is not None and batch.dipole is not None: + self.Mus_computed += 1.0 + self.mus.append(batch.dipole) + self.delta_mus.append(batch.dipole - output["dipole"]) + self.delta_mus_per_atom.append( + (batch.dipole - output["dipole"]) + / (batch.ptr[1:] - batch.ptr[:-1]).unsqueeze(-1) + ) + + def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray: + if isinstance(delta, list): + delta = torch.cat(delta) + return to_numpy(delta) + + def compute(self): + aux = {} + aux["loss"] = to_numpy(self.total_loss / self.num_data).item() + if self.E_computed: + delta_es = self.convert(self.delta_es) + delta_es_per_atom = self.convert(self.delta_es_per_atom) + aux["mae_e"] = compute_mae(delta_es) + aux["mae_e_per_atom"] = compute_mae(delta_es_per_atom) + aux["rmse_e"] = compute_rmse(delta_es) + aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom) + aux["q95_e"] = compute_q95(delta_es) + if self.Fs_computed: + fs = self.convert(self.fs) + delta_fs = self.convert(self.delta_fs) + aux["mae_f"] = compute_mae(delta_fs) + aux["rel_mae_f"] = compute_rel_mae(delta_fs, fs) + aux["rmse_f"] = compute_rmse(delta_fs) + aux["rel_rmse_f"] = compute_rel_rmse(delta_fs, fs) + aux["q95_f"] = compute_q95(delta_fs) + if self.stress_computed: + delta_stress = self.convert(self.delta_stress) + aux["mae_stress"] = compute_mae(delta_stress) + aux["rmse_stress"] = compute_rmse(delta_stress) + aux["q95_stress"] = compute_q95(delta_stress) + if self.virials_computed: + delta_virials = self.convert(self.delta_virials) + delta_virials_per_atom = self.convert(self.delta_virials_per_atom) + aux["mae_virials"] = compute_mae(delta_virials) + aux["rmse_virials"] = compute_rmse(delta_virials) + aux["rmse_virials_per_atom"] = compute_rmse(delta_virials_per_atom) + aux["q95_virials"] = compute_q95(delta_virials) + if self.Mus_computed: + mus = self.convert(self.mus) + delta_mus = self.convert(self.delta_mus) + delta_mus_per_atom = self.convert(self.delta_mus_per_atom) + aux["mae_mu"] = compute_mae(delta_mus) + aux["mae_mu_per_atom"] = compute_mae(delta_mus_per_atom) + aux["rel_mae_mu"] = compute_rel_mae(delta_mus, mus) + aux["rmse_mu"] = compute_rmse(delta_mus) + aux["rmse_mu_per_atom"] = compute_rmse(delta_mus_per_atom) + aux["rel_rmse_mu"] = compute_rel_rmse(delta_mus, mus) + aux["q95_mu"] = compute_q95(delta_mus) + + return aux["loss"], aux diff --git a/mace-bench/3rdparty/mace/mace/tools/utils.py b/mace-bench/3rdparty/mace/mace/tools/utils.py index 12878e1d36ad5b47e65a787ebed0489a310842cf..1b1b55b1862687f782af27edae278054af6f372a 100644 --- a/mace-bench/3rdparty/mace/mace/tools/utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/utils.py @@ -1,166 +1,166 @@ -########################################################################################### -# Statistics utilities -# Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import json -import logging -import os -import sys -from typing import Any, Dict, Iterable, Optional, Sequence, Union - -import numpy as np -import torch - -from .torch_tools import to_numpy - - -def compute_mae(delta: np.ndarray) -> float: - return np.mean(np.abs(delta)).item() - - -def compute_rel_mae(delta: np.ndarray, target_val: np.ndarray) -> float: - target_norm = np.mean(np.abs(target_val)) - return np.mean(np.abs(delta)).item() / (target_norm + 1e-9) * 100 - - -def compute_rmse(delta: np.ndarray) -> float: - return np.sqrt(np.mean(np.square(delta))).item() - - -def compute_rel_rmse(delta: np.ndarray, target_val: np.ndarray) -> float: - target_norm = np.sqrt(np.mean(np.square(target_val))).item() - return np.sqrt(np.mean(np.square(delta))).item() / (target_norm + 1e-9) * 100 - - -def compute_q95(delta: np.ndarray) -> float: - return np.percentile(np.abs(delta), q=95) - - -def compute_c(delta: np.ndarray, eta: float) -> float: - return np.mean(np.abs(delta) < eta).item() - - -def get_tag(name: str, seed: int) -> str: - return f"{name}_run-{seed}" - - -def setup_logger( - level: Union[int, str] = logging.INFO, - tag: Optional[str] = None, - directory: Optional[str] = None, - rank: Optional[int] = 0, -): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.DEBUG) # Set to DEBUG to capture all levels - - # Create formatters - formatter = logging.Formatter( - "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - - # Add filter for rank - logger.addFilter(lambda _: rank == 0) - - # Create console handler - ch = logging.StreamHandler(stream=sys.stdout) - ch.setLevel(level) - ch.setFormatter(formatter) - logger.addHandler(ch) - - if directory is not None and tag is not None: - os.makedirs(name=directory, exist_ok=True) - - # Create file handler for non-debug logs - main_log_path = os.path.join(directory, f"{tag}.log") - fh_main = logging.FileHandler(main_log_path) - fh_main.setLevel(level) - fh_main.setFormatter(formatter) - logger.addHandler(fh_main) - - # Create file handler for debug logs - debug_log_path = os.path.join(directory, f"{tag}_debug.log") - fh_debug = logging.FileHandler(debug_log_path) - fh_debug.setLevel(logging.DEBUG) - fh_debug.setFormatter(formatter) - fh_debug.addFilter(lambda record: record.levelno >= logging.DEBUG) - logger.addHandler(fh_debug) - - -class AtomicNumberTable: - def __init__(self, zs: Sequence[int]): - self.zs = zs - - def __len__(self) -> int: - return len(self.zs) - - def __str__(self): - return f"AtomicNumberTable: {tuple(s for s in self.zs)}" - - def index_to_z(self, index: int) -> int: - return self.zs[index] - - def z_to_index(self, atomic_number: str) -> int: - return self.zs.index(atomic_number) - - -def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable: - z_set = set() - for z in zs: - z_set.add(z) - return AtomicNumberTable(sorted(list(z_set))) - - -def atomic_numbers_to_indices( - atomic_numbers: np.ndarray, z_table: AtomicNumberTable -) -> np.ndarray: - to_index_fn = np.vectorize(z_table.z_to_index) - return to_index_fn(atomic_numbers) - - -class UniversalEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, np.integer): - return int(o) - if isinstance(o, np.floating): - return float(o) - if isinstance(o, np.ndarray): - return o.tolist() - if isinstance(o, torch.Tensor): - return to_numpy(o) - return json.JSONEncoder.default(self, o) - - -class MetricsLogger: - def __init__(self, directory: str, tag: str) -> None: - self.directory = directory - self.filename = tag + ".txt" - self.path = os.path.join(self.directory, self.filename) - - def log(self, d: Dict[str, Any]) -> None: - os.makedirs(name=self.directory, exist_ok=True) - with open(self.path, mode="a", encoding="utf-8") as f: - f.write(json.dumps(d, cls=UniversalEncoder)) - f.write("\n") - - -# pylint: disable=abstract-method, arguments-differ -class LAMMPS_MP(torch.autograd.Function): - @staticmethod - def forward(ctx, *args): - feats, data = args # unpack - ctx.vec_len = feats.shape[-1] - ctx.data = data - out = torch.empty_like(feats) - data.forward_exchange(feats, out, ctx.vec_len) - return out - - @staticmethod - def backward(ctx, *grad_outputs): - (grad,) = grad_outputs # unpack - gout = torch.empty_like(grad) - ctx.data.reverse_exchange(grad, gout, ctx.vec_len) - return gout, None +########################################################################################### +# Statistics utilities +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import json +import logging +import os +import sys +from typing import Any, Dict, Iterable, Optional, Sequence, Union + +import numpy as np +import torch + +from .torch_tools import to_numpy + + +def compute_mae(delta: np.ndarray) -> float: + return np.mean(np.abs(delta)).item() + + +def compute_rel_mae(delta: np.ndarray, target_val: np.ndarray) -> float: + target_norm = np.mean(np.abs(target_val)) + return np.mean(np.abs(delta)).item() / (target_norm + 1e-9) * 100 + + +def compute_rmse(delta: np.ndarray) -> float: + return np.sqrt(np.mean(np.square(delta))).item() + + +def compute_rel_rmse(delta: np.ndarray, target_val: np.ndarray) -> float: + target_norm = np.sqrt(np.mean(np.square(target_val))).item() + return np.sqrt(np.mean(np.square(delta))).item() / (target_norm + 1e-9) * 100 + + +def compute_q95(delta: np.ndarray) -> float: + return np.percentile(np.abs(delta), q=95) + + +def compute_c(delta: np.ndarray, eta: float) -> float: + return np.mean(np.abs(delta) < eta).item() + + +def get_tag(name: str, seed: int) -> str: + return f"{name}_run-{seed}" + + +def setup_logger( + level: Union[int, str] = logging.INFO, + tag: Optional[str] = None, + directory: Optional[str] = None, + rank: Optional[int] = 0, +): + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) # Set to DEBUG to capture all levels + + # Create formatters + formatter = logging.Formatter( + "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Add filter for rank + logger.addFilter(lambda _: rank == 0) + + # Create console handler + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + if directory is not None and tag is not None: + os.makedirs(name=directory, exist_ok=True) + + # Create file handler for non-debug logs + main_log_path = os.path.join(directory, f"{tag}.log") + fh_main = logging.FileHandler(main_log_path) + fh_main.setLevel(level) + fh_main.setFormatter(formatter) + logger.addHandler(fh_main) + + # Create file handler for debug logs + debug_log_path = os.path.join(directory, f"{tag}_debug.log") + fh_debug = logging.FileHandler(debug_log_path) + fh_debug.setLevel(logging.DEBUG) + fh_debug.setFormatter(formatter) + fh_debug.addFilter(lambda record: record.levelno >= logging.DEBUG) + logger.addHandler(fh_debug) + + +class AtomicNumberTable: + def __init__(self, zs: Sequence[int]): + self.zs = zs + + def __len__(self) -> int: + return len(self.zs) + + def __str__(self): + return f"AtomicNumberTable: {tuple(s for s in self.zs)}" + + def index_to_z(self, index: int) -> int: + return self.zs[index] + + def z_to_index(self, atomic_number: str) -> int: + return self.zs.index(atomic_number) + + +def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable: + z_set = set() + for z in zs: + z_set.add(z) + return AtomicNumberTable(sorted(list(z_set))) + + +def atomic_numbers_to_indices( + atomic_numbers: np.ndarray, z_table: AtomicNumberTable +) -> np.ndarray: + to_index_fn = np.vectorize(z_table.z_to_index) + return to_index_fn(atomic_numbers) + + +class UniversalEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + if isinstance(o, np.ndarray): + return o.tolist() + if isinstance(o, torch.Tensor): + return to_numpy(o) + return json.JSONEncoder.default(self, o) + + +class MetricsLogger: + def __init__(self, directory: str, tag: str) -> None: + self.directory = directory + self.filename = tag + ".txt" + self.path = os.path.join(self.directory, self.filename) + + def log(self, d: Dict[str, Any]) -> None: + os.makedirs(name=self.directory, exist_ok=True) + with open(self.path, mode="a", encoding="utf-8") as f: + f.write(json.dumps(d, cls=UniversalEncoder)) + f.write("\n") + + +# pylint: disable=abstract-method, arguments-differ +class LAMMPS_MP(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + feats, data = args # unpack + ctx.vec_len = feats.shape[-1] + ctx.data = data + out = torch.empty_like(feats) + data.forward_exchange(feats, out, ctx.vec_len) + return out + + @staticmethod + def backward(ctx, *grad_outputs): + (grad,) = grad_outputs # unpack + gout = torch.empty_like(grad) + ctx.data.reverse_exchange(grad, gout, ctx.vec_len) + return gout, None diff --git a/mace-bench/3rdparty/mace/scripts/eval_configs.py b/mace-bench/3rdparty/mace/scripts/eval_configs.py index 350d804363823cf38f5920d99999fd054eafff89..d2f4e217bc50c0eb411f7809c539d6a79f88a80b 100644 --- a/mace-bench/3rdparty/mace/scripts/eval_configs.py +++ b/mace-bench/3rdparty/mace/scripts/eval_configs.py @@ -1,6 +1,6 @@ -## Wrapper for mace.cli.eval_configs.main ## - -from mace.cli.eval_configs import main - -if __name__ == "__main__": - main() +## Wrapper for mace.cli.eval_configs.main ## + +from mace.cli.eval_configs import main + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/scripts/preprocess_data.py b/mace-bench/3rdparty/mace/scripts/preprocess_data.py index be11345d1fd0ed5f0e4df3370fdb208c4091d4f6..3c2c288c78424f2ca198a46e785a9769ce7fab52 100644 --- a/mace-bench/3rdparty/mace/scripts/preprocess_data.py +++ b/mace-bench/3rdparty/mace/scripts/preprocess_data.py @@ -1,6 +1,6 @@ -## Wrapper for mace.cli.run_train.main ## - -from mace.cli.preprocess_data import main - -if __name__ == "__main__": - main() +## Wrapper for mace.cli.run_train.main ## + +from mace.cli.preprocess_data import main + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/scripts/run_checks.sh b/mace-bench/3rdparty/mace/scripts/run_checks.sh index e2e073bd98c56ba2f99caaccb3fbccee6ca34f52..bd1214a403cbc6f9af62d108dd4e16f8c83b5107 100644 --- a/mace-bench/3rdparty/mace/scripts/run_checks.sh +++ b/mace-bench/3rdparty/mace/scripts/run_checks.sh @@ -1,9 +1,9 @@ -# Format -python -m black . -python -m isort . - -# Check -python -m pylint --rcfile=pyproject.toml mace tests scripts - -# Tests -python -m pytest tests +# Format +python -m black . +python -m isort . + +# Check +python -m pylint --rcfile=pyproject.toml mace tests scripts + +# Tests +python -m pytest tests diff --git a/mace-bench/3rdparty/mace/scripts/run_train.py b/mace-bench/3rdparty/mace/scripts/run_train.py index 77d53b0d3d256763b323f69450fafa82285768a7..d14952db45aa264ae640e2dde5323e48159c752d 100644 --- a/mace-bench/3rdparty/mace/scripts/run_train.py +++ b/mace-bench/3rdparty/mace/scripts/run_train.py @@ -1,6 +1,6 @@ -## Wrapper for mace.cli.run_train.main ## - -from mace.cli.run_train import main - -if __name__ == "__main__": - main() +## Wrapper for mace.cli.run_train.main ## + +from mace.cli.run_train import main + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/tests/__init__.py b/mace-bench/3rdparty/mace/tests/__init__.py index 9ff3a03c70a38a8e4cab498ee316095ce84f5d4f..96ae7770bf014e5be3ba779f49eda5e9b8304e3d 100644 --- a/mace-bench/3rdparty/mace/tests/__init__.py +++ b/mace-bench/3rdparty/mace/tests/__init__.py @@ -1,3 +1,3 @@ -import os - -os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" +import os + +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" diff --git a/mace-bench/3rdparty/mace/tests/modules/test_radial.py b/mace-bench/3rdparty/mace/tests/modules/test_radial.py index 402dc46e867b813d3791883310630f1aa19773c3..3aef254bcdfa3ffa1242da3a94d94c864f8fe50c 100644 --- a/mace-bench/3rdparty/mace/tests/modules/test_radial.py +++ b/mace-bench/3rdparty/mace/tests/modules/test_radial.py @@ -1,95 +1,95 @@ -import pytest -import torch - -from mace.modules.radial import AgnesiTransform, ZBLBasis - - -@pytest.fixture -def zbl_basis(): - return ZBLBasis(p=6, trainable=False) - - -def test_zbl_basis_initialization(zbl_basis): - assert zbl_basis.p == torch.tensor(6.0) - assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) - - assert zbl_basis.a_exp == torch.tensor(0.300) - assert zbl_basis.a_prefactor == torch.tensor(0.4543) - assert not zbl_basis.a_exp.requires_grad - assert not zbl_basis.a_prefactor.requires_grad - - -def test_trainable_zbl_basis_initialization(zbl_basis): - zbl_basis = ZBLBasis(p=6, trainable=True) - assert zbl_basis.p == torch.tensor(6.0) - assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) - - assert zbl_basis.a_exp == torch.tensor(0.300) - assert zbl_basis.a_prefactor == torch.tensor(0.4543) - assert zbl_basis.a_exp.requires_grad - assert zbl_basis.a_prefactor.requires_grad - - -def test_forward(zbl_basis): - x = torch.tensor([1.0, 1.0, 2.0]).unsqueeze(-1) # [n_edges] - node_attrs = torch.tensor( - [[1, 0], [0, 1]] - ) # [n_nodes, n_node_features] - one_hot encoding of atomic numbers - edge_index = torch.tensor([[0, 1, 1], [1, 0, 1]]) # [2, n_edges] - atomic_numbers = torch.tensor([1, 6]) # [n_nodes] - output = zbl_basis(x, node_attrs, edge_index, atomic_numbers) - - assert output.shape == torch.Size([node_attrs.shape[0]]) - assert torch.is_tensor(output) - assert torch.allclose( - output, - torch.tensor([0.0031, 0.0031], dtype=torch.get_default_dtype()), - rtol=1e-2, - ) - - -@pytest.fixture -def agnesi(): - return AgnesiTransform(trainable=False) - - -def test_agnesi_transform_initialization(agnesi: AgnesiTransform): - assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) - assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) - assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) - assert not agnesi.a.requires_grad - assert not agnesi.q.requires_grad - assert not agnesi.p.requires_grad - - -def test_trainable_agnesi_transform_initialization(): - agnesi = AgnesiTransform(trainable=True) - - assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) - assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) - assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) - assert agnesi.a.requires_grad - assert agnesi.q.requires_grad - assert agnesi.p.requires_grad - - -def test_agnesi_transform_forward(): - agnesi = AgnesiTransform() - x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.get_default_dtype()).unsqueeze(-1) - node_attrs = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.get_default_dtype()) - edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) - atomic_numbers = torch.tensor([1, 6, 8]) - output = agnesi(x, node_attrs, edge_index, atomic_numbers) - assert output.shape == x.shape - assert torch.is_tensor(output) - assert torch.allclose( - output, - torch.tensor( - [0.3646, 0.2175, 0.2089], dtype=torch.get_default_dtype() - ).unsqueeze(-1), - rtol=1e-2, - ) - - -if __name__ == "__main__": - pytest.main([__file__]) +import pytest +import torch + +from mace.modules.radial import AgnesiTransform, ZBLBasis + + +@pytest.fixture +def zbl_basis(): + return ZBLBasis(p=6, trainable=False) + + +def test_zbl_basis_initialization(zbl_basis): + assert zbl_basis.p == torch.tensor(6.0) + assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) + + assert zbl_basis.a_exp == torch.tensor(0.300) + assert zbl_basis.a_prefactor == torch.tensor(0.4543) + assert not zbl_basis.a_exp.requires_grad + assert not zbl_basis.a_prefactor.requires_grad + + +def test_trainable_zbl_basis_initialization(zbl_basis): + zbl_basis = ZBLBasis(p=6, trainable=True) + assert zbl_basis.p == torch.tensor(6.0) + assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) + + assert zbl_basis.a_exp == torch.tensor(0.300) + assert zbl_basis.a_prefactor == torch.tensor(0.4543) + assert zbl_basis.a_exp.requires_grad + assert zbl_basis.a_prefactor.requires_grad + + +def test_forward(zbl_basis): + x = torch.tensor([1.0, 1.0, 2.0]).unsqueeze(-1) # [n_edges] + node_attrs = torch.tensor( + [[1, 0], [0, 1]] + ) # [n_nodes, n_node_features] - one_hot encoding of atomic numbers + edge_index = torch.tensor([[0, 1, 1], [1, 0, 1]]) # [2, n_edges] + atomic_numbers = torch.tensor([1, 6]) # [n_nodes] + output = zbl_basis(x, node_attrs, edge_index, atomic_numbers) + + assert output.shape == torch.Size([node_attrs.shape[0]]) + assert torch.is_tensor(output) + assert torch.allclose( + output, + torch.tensor([0.0031, 0.0031], dtype=torch.get_default_dtype()), + rtol=1e-2, + ) + + +@pytest.fixture +def agnesi(): + return AgnesiTransform(trainable=False) + + +def test_agnesi_transform_initialization(agnesi: AgnesiTransform): + assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) + assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) + assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) + assert not agnesi.a.requires_grad + assert not agnesi.q.requires_grad + assert not agnesi.p.requires_grad + + +def test_trainable_agnesi_transform_initialization(): + agnesi = AgnesiTransform(trainable=True) + + assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) + assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) + assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) + assert agnesi.a.requires_grad + assert agnesi.q.requires_grad + assert agnesi.p.requires_grad + + +def test_agnesi_transform_forward(): + agnesi = AgnesiTransform() + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.get_default_dtype()).unsqueeze(-1) + node_attrs = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.get_default_dtype()) + edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) + atomic_numbers = torch.tensor([1, 6, 8]) + output = agnesi(x, node_attrs, edge_index, atomic_numbers) + assert output.shape == x.shape + assert torch.is_tensor(output) + assert torch.allclose( + output, + torch.tensor( + [0.3646, 0.2175, 0.2089], dtype=torch.get_default_dtype() + ).unsqueeze(-1), + rtol=1e-2, + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/mace-bench/3rdparty/mace/tests/test_benchmark.py b/mace-bench/3rdparty/mace/tests/test_benchmark.py index 6e5f11c3f2b22d09a3de71ae0229f32656515691..fae104b41c2fae7f105a1cb1940939f53b240662 100644 --- a/mace-bench/3rdparty/mace/tests/test_benchmark.py +++ b/mace-bench/3rdparty/mace/tests/test_benchmark.py @@ -1,121 +1,121 @@ -import json -import os -from pathlib import Path -from typing import List, Optional - -import pandas as pd -import pytest -import torch -from ase import build - -from mace import data as mace_data -from mace.calculators.foundations_models import mace_mp -from mace.tools import AtomicNumberTable, torch_geometric, torch_tools - - -def is_mace_full_bench(): - return os.environ.get("MACE_FULL_BENCH", "0") == "1" - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -@pytest.mark.benchmark(warmup=True, warmup_iterations=4, min_rounds=8) -@pytest.mark.parametrize("size", (3, 5, 7, 9)) -@pytest.mark.parametrize("dtype", ["float32", "float64"]) -@pytest.mark.parametrize("compile_mode", [None, "default"]) -def test_inference( - benchmark, size: int, dtype: str, compile_mode: Optional[str], device: str = "cuda" -): - if not is_mace_full_bench() and compile_mode is not None: - pytest.skip("Skipping long running benchmark, set MACE_FULL_BENCH=1 to execute") - - with torch_tools.default_dtype(dtype): - model = load_mace_mp_medium(dtype, compile_mode, device) - batch = create_batch(size, model, device) - log_bench_info(benchmark, dtype, compile_mode, batch) - - def func(): - torch.cuda.synchronize() - model(batch, training=compile_mode is not None, compute_force=True) - - torch.cuda.empty_cache() - benchmark(func) - - -def load_mace_mp_medium(dtype, compile_mode, device): - calc = mace_mp( - model="medium", - default_dtype=dtype, - device=device, - compile_mode=compile_mode, - fullgraph=False, - ) - model = calc.models[0].to(device) - return model - - -def create_batch(size: int, model: torch.nn.Module, device: str) -> dict: - cutoff = model.r_max.item() - z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) - atoms = build.bulk("C", "diamond", a=3.567, cubic=True) - atoms = atoms.repeat((size, size, size)) - config = mace_data.config_from_atoms(atoms) - dataset = [mace_data.AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=dataset, - batch_size=1, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - batch.to(device) - return batch.to_dict() - - -def log_bench_info(benchmark, dtype, compile_mode, batch): - benchmark.extra_info["num_atoms"] = int(batch["positions"].shape[0]) - benchmark.extra_info["num_edges"] = int(batch["edge_index"].shape[1]) - benchmark.extra_info["dtype"] = dtype - benchmark.extra_info["is_compiled"] = compile_mode is not None - benchmark.extra_info["device_name"] = torch.cuda.get_device_name() - - -def process_benchmark_file(bench_file: Path) -> pd.DataFrame: - with open(bench_file, "r", encoding="utf-8") as f: - bench_data = json.load(f) - - records = [] - for bench in bench_data["benchmarks"]: - record = {**bench["extra_info"], **bench["stats"]} - records.append(record) - - result_df = pd.DataFrame(records) - result_df["ns/day (1 fs/step)"] = 0.086400 / result_df["median"] - result_df["Steps per day"] = result_df["ops"] * 86400 - columns = [ - "num_atoms", - "num_edges", - "dtype", - "is_compiled", - "device_name", - "median", - "Steps per day", - "ns/day (1 fs/step)", - ] - return result_df[columns] - - -def read_bench_results(result_files: List[str]) -> pd.DataFrame: - return pd.concat([process_benchmark_file(Path(f)) for f in result_files]) - - -if __name__ == "__main__": - # Print to stdout a csv of the benchmark metrics - import subprocess - - result = subprocess.run( - ["pytest-benchmark", "list"], capture_output=True, text=True, check=True - ) - - bench_files = result.stdout.strip().split("\n") - bench_results = read_bench_results(bench_files) - print(bench_results.to_csv(index=False)) +import json +import os +from pathlib import Path +from typing import List, Optional + +import pandas as pd +import pytest +import torch +from ase import build + +from mace import data as mace_data +from mace.calculators.foundations_models import mace_mp +from mace.tools import AtomicNumberTable, torch_geometric, torch_tools + + +def is_mace_full_bench(): + return os.environ.get("MACE_FULL_BENCH", "0") == "1" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +@pytest.mark.benchmark(warmup=True, warmup_iterations=4, min_rounds=8) +@pytest.mark.parametrize("size", (3, 5, 7, 9)) +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +@pytest.mark.parametrize("compile_mode", [None, "default"]) +def test_inference( + benchmark, size: int, dtype: str, compile_mode: Optional[str], device: str = "cuda" +): + if not is_mace_full_bench() and compile_mode is not None: + pytest.skip("Skipping long running benchmark, set MACE_FULL_BENCH=1 to execute") + + with torch_tools.default_dtype(dtype): + model = load_mace_mp_medium(dtype, compile_mode, device) + batch = create_batch(size, model, device) + log_bench_info(benchmark, dtype, compile_mode, batch) + + def func(): + torch.cuda.synchronize() + model(batch, training=compile_mode is not None, compute_force=True) + + torch.cuda.empty_cache() + benchmark(func) + + +def load_mace_mp_medium(dtype, compile_mode, device): + calc = mace_mp( + model="medium", + default_dtype=dtype, + device=device, + compile_mode=compile_mode, + fullgraph=False, + ) + model = calc.models[0].to(device) + return model + + +def create_batch(size: int, model: torch.nn.Module, device: str) -> dict: + cutoff = model.r_max.item() + z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms = atoms.repeat((size, size, size)) + config = mace_data.config_from_atoms(atoms) + dataset = [mace_data.AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch.to(device) + return batch.to_dict() + + +def log_bench_info(benchmark, dtype, compile_mode, batch): + benchmark.extra_info["num_atoms"] = int(batch["positions"].shape[0]) + benchmark.extra_info["num_edges"] = int(batch["edge_index"].shape[1]) + benchmark.extra_info["dtype"] = dtype + benchmark.extra_info["is_compiled"] = compile_mode is not None + benchmark.extra_info["device_name"] = torch.cuda.get_device_name() + + +def process_benchmark_file(bench_file: Path) -> pd.DataFrame: + with open(bench_file, "r", encoding="utf-8") as f: + bench_data = json.load(f) + + records = [] + for bench in bench_data["benchmarks"]: + record = {**bench["extra_info"], **bench["stats"]} + records.append(record) + + result_df = pd.DataFrame(records) + result_df["ns/day (1 fs/step)"] = 0.086400 / result_df["median"] + result_df["Steps per day"] = result_df["ops"] * 86400 + columns = [ + "num_atoms", + "num_edges", + "dtype", + "is_compiled", + "device_name", + "median", + "Steps per day", + "ns/day (1 fs/step)", + ] + return result_df[columns] + + +def read_bench_results(result_files: List[str]) -> pd.DataFrame: + return pd.concat([process_benchmark_file(Path(f)) for f in result_files]) + + +if __name__ == "__main__": + # Print to stdout a csv of the benchmark metrics + import subprocess + + result = subprocess.run( + ["pytest-benchmark", "list"], capture_output=True, text=True, check=True + ) + + bench_files = result.stdout.strip().split("\n") + bench_results = read_bench_results(bench_files) + print(bench_results.to_csv(index=False)) diff --git a/mace-bench/3rdparty/mace/tests/test_calculator.py b/mace-bench/3rdparty/mace/tests/test_calculator.py index 8491538ac4cf296fa99d5298fc7dbc2589731d20..e9a654613010e3681d23e1b70797ff2c90d38ae6 100644 --- a/mace-bench/3rdparty/mace/tests/test_calculator.py +++ b/mace-bench/3rdparty/mace/tests/test_calculator.py @@ -1,689 +1,689 @@ -import os -import subprocess -import sys -from pathlib import Path - -import ase.io -import numpy as np -import pytest -import torch -from ase import build -from ase.atoms import Atoms -from ase.calculators.test import gradient_test -from ase.constraints import ExpCellFilter - -from mace.calculators import mace_mp, mace_off -from mace.calculators.mace import MACECalculator -from mace.modules.models import ScaleShiftMACE - -try: - import cuequivariance as cue # pylint: disable=unused-import - - CUET_AVAILABLE = True -except ImportError: - CUET_AVAILABLE = False - -pytest_mace_dir = Path(__file__).parent.parent -run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - - -@pytest.fixture(scope="module", name="fitting_configs") -def fitting_configs_fixture(): - water = Atoms( - numbers=[8, 1, 1], - positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], - cell=[4] * 3, - pbc=[True] * 3, - ) - fit_configs = [ - Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), - Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), - ] - fit_configs[0].info["REF_energy"] = 1.0 - fit_configs[0].info["config_type"] = "IsolatedAtom" - fit_configs[1].info["REF_energy"] = -0.5 - fit_configs[1].info["config_type"] = "IsolatedAtom" - - np.random.seed(5) - for _ in range(20): - c = water.copy() - c.positions += np.random.normal(0.1, size=c.positions.shape) - c.info["REF_energy"] = np.random.normal(0.1) - c.info["REF_dipole"] = np.random.normal(0.1, size=3) - c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) - c.new_array("Qs", np.random.normal(0.1, size=c.positions.shape[0])) - c.info["REF_stress"] = np.random.normal(0.1, size=6) - fit_configs.append(c) - - return fit_configs - - -@pytest.fixture(scope="module", name="trained_model") -def trained_model_fixture(tmp_path_factory, fitting_configs): - _mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "128x0e", - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "stress", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp("run_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - -@pytest.fixture(scope="module", name="trained_equivariant_model") -def trained_model_equivariant_fixture(tmp_path_factory, fitting_configs): - _mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "16x0e+16x1o", - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "stress", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp("run_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - -@pytest.fixture(scope="module", name="trained_equivariant_model_cueq") -def trained_model_equivariant_fixture_cueq(tmp_path_factory, fitting_configs): - _mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "16x0e+16x1o", - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "stress", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp("run_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - return MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", enable_cueq=True - ) - - -@pytest.fixture(scope="module", name="trained_dipole_model") -def trained_dipole_fixture(tmp_path_factory, fitting_configs): - _mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "AtomicDipolesMACE", - "num_channels": 8, - "max_L": 2, - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "dipole", - "energy_key": "", - "forces_key": "", - "stress_key": "", - "dipole_key": "REF_dipole", - "error_table": "DipoleRMSE", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp("run_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - return MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", model_type="DipoleMACE" - ) - - -@pytest.fixture(scope="module", name="trained_energy_dipole_model") -def trained_energy_dipole_fixture(tmp_path_factory, fitting_configs): - _mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "EnergyDipolesMACE", - "num_channels": 32, - "max_L": 1, - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "energy_forces_dipole", - "energy_key": "REF_energy", - "forces_key": "", - "stress_key": "", - "dipole_key": "REF_dipole", - "error_table": "EnergyDipoleRMSE", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp("run_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - return MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", model_type="EnergyDipoleMACE" - ) - - -@pytest.fixture(scope="module", name="trained_committee") -def trained_committee_fixture(tmp_path_factory, fitting_configs): - _seeds = [5, 6, 7] - _model_paths = [] - for seed in _seeds: - _mace_params = { - "name": f"MACE{seed}", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "16x0e", - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": seed, - "loss": "stress", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp(f"run{seed}_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - _model_paths.append(tmp_path / f"MACE{seed}.model") - - return MACECalculator(model_paths=_model_paths, device="cpu") - - -def test_calculator_node_energy(fitting_configs, trained_model): - for at in fitting_configs: - trained_model.calculate(at) - node_energies = trained_model.results["node_energy"] - batch = trained_model._atoms_to_batch(at) # pylint: disable=protected-access - node_heads = batch["head"][batch["batch"]] - num_atoms_arange = torch.arange(batch["positions"].shape[0]) - node_e0 = ( - trained_model.models[0].atomic_energies_fn(batch["node_attrs"]).detach() - ) - node_e0 = node_e0[num_atoms_arange, node_heads].cpu().numpy() - energy_via_nodes = np.sum(node_energies + node_e0) - energy = trained_model.results["energy"] - np.testing.assert_allclose(energy, energy_via_nodes, atol=1e-6) - - -def test_calculator_forces(fitting_configs, trained_model): - at = fitting_configs[2].copy() - at.calc = trained_model - - # test just forces - grads = gradient_test(at) - - assert np.allclose(grads[0], grads[1]) - - -def test_calculator_stress(fitting_configs, trained_model): - at = fitting_configs[2].copy() - at.calc = trained_model - - # test forces and stress - at_wrapped = ExpCellFilter(at) - grads = gradient_test(at_wrapped) - - assert np.allclose(grads[0], grads[1]) - - -def test_calculator_committee(fitting_configs, trained_committee): - at = fitting_configs[2].copy() - at.calc = trained_committee - - # test just forces - grads = gradient_test(at) - - assert np.allclose(grads[0], grads[1]) - - E = at.get_potential_energy() - energies = at.calc.results["energies"] - energies_var = at.calc.results["energy_var"] - forces_var = np.var(at.calc.results["forces_comm"], axis=0) - assert np.allclose(E, np.mean(energies)) - assert np.allclose(energies_var, np.var(energies)) - assert forces_var.shape == at.calc.results["forces"].shape - - -def test_calculator_from_model(fitting_configs, trained_committee): - # test single model - test_calculator_forces( - fitting_configs, - trained_model=MACECalculator(models=trained_committee.models[0], device="cpu"), - ) - - # test committee model - test_calculator_committee( - fitting_configs, - trained_committee=MACECalculator(models=trained_committee.models, device="cpu"), - ) - - -def test_calculator_dipole(fitting_configs, trained_dipole_model): - at = fitting_configs[2].copy() - at.calc = trained_dipole_model - - dip = at.get_dipole_moment() - - assert len(dip) == 3 - - -def test_calculator_energy_dipole(fitting_configs, trained_energy_dipole_model): - at = fitting_configs[2].copy() - at.calc = trained_energy_dipole_model - - grads = gradient_test(at) - dip = at.get_dipole_moment() - - assert np.allclose(grads[0], grads[1]) - assert len(dip) == 3 - - -def test_calculator_descriptor(fitting_configs, trained_equivariant_model): - at = fitting_configs[2].copy() - at_rotated = fitting_configs[2].copy() - at_rotated.rotate(90, "x") - calc = trained_equivariant_model - - desc_invariant = calc.get_descriptors(at, invariants_only=True) - desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) - desc_invariant_single_layer = calc.get_descriptors( - at, invariants_only=True, num_layers=1 - ) - desc_invariant_single_layer_rotated = calc.get_descriptors( - at_rotated, invariants_only=True, num_layers=1 - ) - desc = calc.get_descriptors(at, invariants_only=False) - desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) - desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) - desc_rotated_single_layer = calc.get_descriptors( - at_rotated, invariants_only=False, num_layers=1 - ) - - assert desc_invariant.shape[0] == 3 - assert desc_invariant.shape[1] == 32 - assert desc_invariant_single_layer.shape[0] == 3 - assert desc_invariant_single_layer.shape[1] == 16 - assert desc.shape[0] == 3 - assert desc.shape[1] == 80 - assert desc_single_layer.shape[0] == 3 - assert desc_single_layer.shape[1] == 16 * 4 - assert desc_rotated_single_layer.shape[0] == 3 - assert desc_rotated_single_layer.shape[1] == 16 * 4 - - np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) - np.testing.assert_allclose( - desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 - ) - np.testing.assert_allclose( - desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 - ) - np.testing.assert_allclose( - desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 - ) - assert not np.allclose( - desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 - ) - assert not np.allclose(desc, desc_rotated, atol=1e-6) - - -@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") -def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model_cueq): - at = fitting_configs[2].copy() - at_rotated = fitting_configs[2].copy() - at_rotated.rotate(90, "x") - calc = trained_equivariant_model_cueq - - desc_invariant = calc.get_descriptors(at, invariants_only=True) - desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) - desc_invariant_single_layer = calc.get_descriptors( - at, invariants_only=True, num_layers=1 - ) - desc_invariant_single_layer_rotated = calc.get_descriptors( - at_rotated, invariants_only=True, num_layers=1 - ) - desc = calc.get_descriptors(at, invariants_only=False) - desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) - desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) - desc_rotated_single_layer = calc.get_descriptors( - at_rotated, invariants_only=False, num_layers=1 - ) - - assert desc_invariant.shape[0] == 3 - assert desc_invariant.shape[1] == 32 - assert desc_invariant_single_layer.shape[0] == 3 - assert desc_invariant_single_layer.shape[1] == 16 - assert desc.shape[0] == 3 - assert desc.shape[1] == 80 - assert desc_single_layer.shape[0] == 3 - assert desc_single_layer.shape[1] == 16 * 4 - assert desc_rotated_single_layer.shape[0] == 3 - assert desc_rotated_single_layer.shape[1] == 16 * 4 - - np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) - np.testing.assert_allclose( - desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 - ) - np.testing.assert_allclose( - desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 - ) - np.testing.assert_allclose( - desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 - ) - assert not np.allclose( - desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 - ) - assert not np.allclose(desc, desc_rotated, atol=1e-6) - - -def test_mace_mp(capsys: pytest.CaptureFixture): - mp_mace = mace_mp() - assert isinstance(mp_mace, MACECalculator) - assert mp_mace.model_type == "MACE" - assert len(mp_mace.models) == 1 - assert isinstance(mp_mace.models[0], ScaleShiftMACE) - - _, stderr = capsys.readouterr() - assert stderr == "" - - -def test_mace_off(): - mace_off_model = mace_off(model="small", device="cpu") - assert isinstance(mace_off_model, MACECalculator) - assert mace_off_model.model_type == "MACE" - assert len(mace_off_model.models) == 1 - assert isinstance(mace_off_model.models[0], ScaleShiftMACE) - - atoms = build.molecule("H2O") - atoms.calc = mace_off_model - - E = atoms.get_potential_energy() - - assert np.allclose(E, -2081.116128586803, atol=1e-9) - - -@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") -def test_mace_off_cueq(model="medium", device="cpu"): - mace_off_model = mace_off(model=model, device=device, enable_cueq=True) - assert isinstance(mace_off_model, MACECalculator) - assert mace_off_model.model_type == "MACE" - assert len(mace_off_model.models) == 1 - assert isinstance(mace_off_model.models[0], ScaleShiftMACE) - - atoms = build.molecule("H2O") - atoms.calc = mace_off_model - - E = atoms.get_potential_energy() - - assert np.allclose(E, -2081.116128586803, atol=1e-9) - - -def test_mace_mp_stresses(model="medium", device="cpu"): - atoms = build.bulk("Al", "fcc", a=4.05, cubic=True) - atoms = atoms.repeat((2, 2, 2)) - mace_mp_model = mace_mp(model=model, device=device, compute_atomic_stresses=True) - atoms.set_calculator(mace_mp_model) - stress = atoms.get_stress() - stresses = atoms.get_stresses() - assert stress.shape == (6,) - assert stresses.shape == (32, 6) - assert np.allclose(stress, stresses.sum(axis=0), atol=1e-6) +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +import torch +from ase import build +from ase.atoms import Atoms +from ase.calculators.test import gradient_test +from ase.constraints import ExpCellFilter + +from mace.calculators import mace_mp, mace_off +from mace.calculators.mace import MACECalculator +from mace.modules.models import ScaleShiftMACE + +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +pytest_mace_dir = Path(__file__).parent.parent +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + +@pytest.fixture(scope="module", name="fitting_configs") +def fitting_configs_fixture(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + fit_configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] + fit_configs[0].info["REF_energy"] = 1.0 + fit_configs[0].info["config_type"] = "IsolatedAtom" + fit_configs[1].info["REF_energy"] = -0.5 + fit_configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(20): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + c.info["REF_dipole"] = np.random.normal(0.1, size=3) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.new_array("Qs", np.random.normal(0.1, size=c.positions.shape[0])) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + fit_configs.append(c) + + return fit_configs + + +@pytest.fixture(scope="module", name="trained_model") +def trained_model_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + +@pytest.fixture(scope="module", name="trained_equivariant_model") +def trained_model_equivariant_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "16x0e+16x1o", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + +@pytest.fixture(scope="module", name="trained_equivariant_model_cueq") +def trained_model_equivariant_fixture_cueq(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "16x0e+16x1o", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", enable_cueq=True + ) + + +@pytest.fixture(scope="module", name="trained_dipole_model") +def trained_dipole_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "AtomicDipolesMACE", + "num_channels": 8, + "max_L": 2, + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "dipole", + "energy_key": "", + "forces_key": "", + "stress_key": "", + "dipole_key": "REF_dipole", + "error_table": "DipoleRMSE", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", model_type="DipoleMACE" + ) + + +@pytest.fixture(scope="module", name="trained_energy_dipole_model") +def trained_energy_dipole_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "EnergyDipolesMACE", + "num_channels": 32, + "max_L": 1, + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "energy_forces_dipole", + "energy_key": "REF_energy", + "forces_key": "", + "stress_key": "", + "dipole_key": "REF_dipole", + "error_table": "EnergyDipoleRMSE", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", model_type="EnergyDipoleMACE" + ) + + +@pytest.fixture(scope="module", name="trained_committee") +def trained_committee_fixture(tmp_path_factory, fitting_configs): + _seeds = [5, 6, 7] + _model_paths = [] + for seed in _seeds: + _mace_params = { + "name": f"MACE{seed}", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "16x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": seed, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp(f"run{seed}_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + _model_paths.append(tmp_path / f"MACE{seed}.model") + + return MACECalculator(model_paths=_model_paths, device="cpu") + + +def test_calculator_node_energy(fitting_configs, trained_model): + for at in fitting_configs: + trained_model.calculate(at) + node_energies = trained_model.results["node_energy"] + batch = trained_model._atoms_to_batch(at) # pylint: disable=protected-access + node_heads = batch["head"][batch["batch"]] + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_e0 = ( + trained_model.models[0].atomic_energies_fn(batch["node_attrs"]).detach() + ) + node_e0 = node_e0[num_atoms_arange, node_heads].cpu().numpy() + energy_via_nodes = np.sum(node_energies + node_e0) + energy = trained_model.results["energy"] + np.testing.assert_allclose(energy, energy_via_nodes, atol=1e-6) + + +def test_calculator_forces(fitting_configs, trained_model): + at = fitting_configs[2].copy() + at.calc = trained_model + + # test just forces + grads = gradient_test(at) + + assert np.allclose(grads[0], grads[1]) + + +def test_calculator_stress(fitting_configs, trained_model): + at = fitting_configs[2].copy() + at.calc = trained_model + + # test forces and stress + at_wrapped = ExpCellFilter(at) + grads = gradient_test(at_wrapped) + + assert np.allclose(grads[0], grads[1]) + + +def test_calculator_committee(fitting_configs, trained_committee): + at = fitting_configs[2].copy() + at.calc = trained_committee + + # test just forces + grads = gradient_test(at) + + assert np.allclose(grads[0], grads[1]) + + E = at.get_potential_energy() + energies = at.calc.results["energies"] + energies_var = at.calc.results["energy_var"] + forces_var = np.var(at.calc.results["forces_comm"], axis=0) + assert np.allclose(E, np.mean(energies)) + assert np.allclose(energies_var, np.var(energies)) + assert forces_var.shape == at.calc.results["forces"].shape + + +def test_calculator_from_model(fitting_configs, trained_committee): + # test single model + test_calculator_forces( + fitting_configs, + trained_model=MACECalculator(models=trained_committee.models[0], device="cpu"), + ) + + # test committee model + test_calculator_committee( + fitting_configs, + trained_committee=MACECalculator(models=trained_committee.models, device="cpu"), + ) + + +def test_calculator_dipole(fitting_configs, trained_dipole_model): + at = fitting_configs[2].copy() + at.calc = trained_dipole_model + + dip = at.get_dipole_moment() + + assert len(dip) == 3 + + +def test_calculator_energy_dipole(fitting_configs, trained_energy_dipole_model): + at = fitting_configs[2].copy() + at.calc = trained_energy_dipole_model + + grads = gradient_test(at) + dip = at.get_dipole_moment() + + assert np.allclose(grads[0], grads[1]) + assert len(dip) == 3 + + +def test_calculator_descriptor(fitting_configs, trained_equivariant_model): + at = fitting_configs[2].copy() + at_rotated = fitting_configs[2].copy() + at_rotated.rotate(90, "x") + calc = trained_equivariant_model + + desc_invariant = calc.get_descriptors(at, invariants_only=True) + desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) + desc_invariant_single_layer = calc.get_descriptors( + at, invariants_only=True, num_layers=1 + ) + desc_invariant_single_layer_rotated = calc.get_descriptors( + at_rotated, invariants_only=True, num_layers=1 + ) + desc = calc.get_descriptors(at, invariants_only=False) + desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) + desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) + desc_rotated_single_layer = calc.get_descriptors( + at_rotated, invariants_only=False, num_layers=1 + ) + + assert desc_invariant.shape[0] == 3 + assert desc_invariant.shape[1] == 32 + assert desc_invariant_single_layer.shape[0] == 3 + assert desc_invariant_single_layer.shape[1] == 16 + assert desc.shape[0] == 3 + assert desc.shape[1] == 80 + assert desc_single_layer.shape[0] == 3 + assert desc_single_layer.shape[1] == 16 * 4 + assert desc_rotated_single_layer.shape[0] == 3 + assert desc_rotated_single_layer.shape[1] == 16 * 4 + + np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) + np.testing.assert_allclose( + desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 + ) + assert not np.allclose( + desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 + ) + assert not np.allclose(desc, desc_rotated, atol=1e-6) + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model_cueq): + at = fitting_configs[2].copy() + at_rotated = fitting_configs[2].copy() + at_rotated.rotate(90, "x") + calc = trained_equivariant_model_cueq + + desc_invariant = calc.get_descriptors(at, invariants_only=True) + desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) + desc_invariant_single_layer = calc.get_descriptors( + at, invariants_only=True, num_layers=1 + ) + desc_invariant_single_layer_rotated = calc.get_descriptors( + at_rotated, invariants_only=True, num_layers=1 + ) + desc = calc.get_descriptors(at, invariants_only=False) + desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) + desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) + desc_rotated_single_layer = calc.get_descriptors( + at_rotated, invariants_only=False, num_layers=1 + ) + + assert desc_invariant.shape[0] == 3 + assert desc_invariant.shape[1] == 32 + assert desc_invariant_single_layer.shape[0] == 3 + assert desc_invariant_single_layer.shape[1] == 16 + assert desc.shape[0] == 3 + assert desc.shape[1] == 80 + assert desc_single_layer.shape[0] == 3 + assert desc_single_layer.shape[1] == 16 * 4 + assert desc_rotated_single_layer.shape[0] == 3 + assert desc_rotated_single_layer.shape[1] == 16 * 4 + + np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) + np.testing.assert_allclose( + desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 + ) + assert not np.allclose( + desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 + ) + assert not np.allclose(desc, desc_rotated, atol=1e-6) + + +def test_mace_mp(capsys: pytest.CaptureFixture): + mp_mace = mace_mp() + assert isinstance(mp_mace, MACECalculator) + assert mp_mace.model_type == "MACE" + assert len(mp_mace.models) == 1 + assert isinstance(mp_mace.models[0], ScaleShiftMACE) + + _, stderr = capsys.readouterr() + assert stderr == "" + + +def test_mace_off(): + mace_off_model = mace_off(model="small", device="cpu") + assert isinstance(mace_off_model, MACECalculator) + assert mace_off_model.model_type == "MACE" + assert len(mace_off_model.models) == 1 + assert isinstance(mace_off_model.models[0], ScaleShiftMACE) + + atoms = build.molecule("H2O") + atoms.calc = mace_off_model + + E = atoms.get_potential_energy() + + assert np.allclose(E, -2081.116128586803, atol=1e-9) + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +def test_mace_off_cueq(model="medium", device="cpu"): + mace_off_model = mace_off(model=model, device=device, enable_cueq=True) + assert isinstance(mace_off_model, MACECalculator) + assert mace_off_model.model_type == "MACE" + assert len(mace_off_model.models) == 1 + assert isinstance(mace_off_model.models[0], ScaleShiftMACE) + + atoms = build.molecule("H2O") + atoms.calc = mace_off_model + + E = atoms.get_potential_energy() + + assert np.allclose(E, -2081.116128586803, atol=1e-9) + + +def test_mace_mp_stresses(model="medium", device="cpu"): + atoms = build.bulk("Al", "fcc", a=4.05, cubic=True) + atoms = atoms.repeat((2, 2, 2)) + mace_mp_model = mace_mp(model=model, device=device, compute_atomic_stresses=True) + atoms.set_calculator(mace_mp_model) + stress = atoms.get_stress() + stresses = atoms.get_stresses() + assert stress.shape == (6,) + assert stresses.shape == (32, 6) + assert np.allclose(stress, stresses.sum(axis=0), atol=1e-6) diff --git a/mace-bench/3rdparty/mace/tests/test_cg.py b/mace-bench/3rdparty/mace/tests/test_cg.py index c6465fc1691307c2274a22b3f529be7ae2df830a..36b119b9b7c84621f3521a9978bebfa081262db1 100644 --- a/mace-bench/3rdparty/mace/tests/test_cg.py +++ b/mace-bench/3rdparty/mace/tests/test_cg.py @@ -1,12 +1,12 @@ -from e3nn import o3 - -from mace.tools import cg - - -def test_U_matrix(): - irreps_in = o3.Irreps("1x0e + 1x1o + 1x2e") - irreps_out = o3.Irreps("1x0e + 1x1o") - u_matrix = cg.U_matrix_real( - irreps_in=irreps_in, irreps_out=irreps_out, correlation=3 - )[-1] - assert u_matrix.shape == (3, 9, 9, 9, 21) +from e3nn import o3 + +from mace.tools import cg + + +def test_U_matrix(): + irreps_in = o3.Irreps("1x0e + 1x1o + 1x2e") + irreps_out = o3.Irreps("1x0e + 1x1o") + u_matrix = cg.U_matrix_real( + irreps_in=irreps_in, irreps_out=irreps_out, correlation=3 + )[-1] + assert u_matrix.shape == (3, 9, 9, 9, 21) diff --git a/mace-bench/3rdparty/mace/tests/test_compile.py b/mace-bench/3rdparty/mace/tests/test_compile.py index 986944127fe0805d80e4ac9ebbb41d1d9d956985..d7d585e832c918345a29ea97b1f890fc208ba589 100644 --- a/mace-bench/3rdparty/mace/tests/test_compile.py +++ b/mace-bench/3rdparty/mace/tests/test_compile.py @@ -1,154 +1,154 @@ -import os -from functools import wraps -from typing import Callable - -import numpy as np -import pytest -import torch -import torch.nn.functional as F -from e3nn import o3 -from torch.testing import assert_close - -from mace import data, modules, tools -from mace.tools import compile as mace_compile -from mace.tools import torch_geometric - -table = tools.AtomicNumberTable([6]) -atomic_energies = np.array([1.0], dtype=float) -cutoff = 5.0 - - -def create_mace(device: str, seed: int = 1702): - torch_geometric.seed_everything(seed) - - model_config = { - "r_max": cutoff, - "num_bessel": 8, - "num_polynomial_cutoff": 6, - "max_ell": 3, - "interaction_cls": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "interaction_cls_first": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "num_interactions": 2, - "num_elements": 1, - "hidden_irreps": o3.Irreps("128x0e + 128x1o"), - "MLP_irreps": o3.Irreps("16x0e"), - "gate": F.silu, - "atomic_energies": atomic_energies, - "avg_num_neighbors": 8, - "atomic_numbers": table.zs, - "correlation": 3, - "radial_type": "bessel", - "atomic_inter_scale": 1.0, - "atomic_inter_shift": 0.0, - } - model = modules.ScaleShiftMACE(**model_config) - return model.to(device) - - -def create_batch(device: str): - from ase import build - - size = 2 - atoms = build.bulk("C", "diamond", a=3.567, cubic=True) - atoms_list = [atoms.repeat((size, size, size))] - print("Number of atoms", len(atoms_list[0])) - - configs = [data.config_from_atoms(atoms) for atoms in atoms_list] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) - for config in configs - ], - batch_size=1, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - batch = batch.to(device) - batch = batch.to_dict() - return batch - - -def time_func(func: Callable): - @wraps(func) - def wrapper(*args, **kwargs): - torch._inductor.cudagraph_mark_step_begin() # pylint: disable=W0212 - outputs = func(*args, **kwargs) - torch.cuda.synchronize() - return outputs - - return wrapper - - -@pytest.fixture(params=[torch.float32, torch.float64], ids=["fp32", "fp64"]) -def default_dtype(request): - with tools.torch_tools.default_dtype(request.param): - yield torch.get_default_dtype() - - -# skip if on windows -@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_mace(device, default_dtype): # pylint: disable=W0621 - print(f"using default dtype = {default_dtype}") - if device == "cuda" and not torch.cuda.is_available(): - pytest.skip(reason="cuda is not available") - - model_defaults = create_mace(device) - tmp_model = mace_compile.prepare(create_mace)(device) - model_compiled = torch.compile(tmp_model, mode="default") - - batch = create_batch(device) - output1 = model_defaults(batch, training=True) - output2 = model_compiled(batch, training=True) - assert_close(output1["energy"], output2["energy"]) - assert_close(output1["forces"], output2["forces"]) - - -@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -def test_eager_benchmark(benchmark, default_dtype): # pylint: disable=W0621 - print(f"using default dtype = {default_dtype}") - batch = create_batch("cuda") - model = create_mace("cuda") - model = time_func(model) - benchmark(model, batch, training=True) - - -@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -@pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"]) -@pytest.mark.parametrize("enable_amp", [False, True], ids=["fp32", "mixed"]) -def test_compile_benchmark(benchmark, compile_mode, enable_amp): - if enable_amp: - pytest.skip(reason="autocast compiler assertion aten.slice_scatter.default") - - with tools.torch_tools.default_dtype(torch.float32): - batch = create_batch("cuda") - torch.compiler.reset() - model = mace_compile.prepare(create_mace)("cuda") - model = torch.compile(model, mode=compile_mode) - model = time_func(model) - - with torch.autocast("cuda", enabled=enable_amp): - benchmark(model, batch, training=True) - - -@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -def test_graph_breaks(): - import torch._dynamo as dynamo - - batch = create_batch("cuda") - model = mace_compile.prepare(create_mace)("cuda") - explanation = dynamo.explain(model)(batch, training=False) - - # these clutter the output but might be useful for investigating graph breaks - explanation.ops_per_graph = None - explanation.out_guards = None - print(explanation) - assert explanation.graph_break_count == 0 +import os +from functools import wraps +from typing import Callable + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 +from torch.testing import assert_close + +from mace import data, modules, tools +from mace.tools import compile as mace_compile +from mace.tools import torch_geometric + +table = tools.AtomicNumberTable([6]) +atomic_energies = np.array([1.0], dtype=float) +cutoff = 5.0 + + +def create_mace(device: str, seed: int = 1702): + torch_geometric.seed_everything(seed) + + model_config = { + "r_max": cutoff, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": o3.Irreps("128x0e + 128x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": atomic_energies, + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel", + "atomic_inter_scale": 1.0, + "atomic_inter_shift": 0.0, + } + model = modules.ScaleShiftMACE(**model_config) + return model.to(device) + + +def create_batch(device: str): + from ase import build + + size = 2 + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms_list = [atoms.repeat((size, size, size))] + print("Number of atoms", len(atoms_list[0])) + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) + for config in configs + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch = batch.to(device) + batch = batch.to_dict() + return batch + + +def time_func(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + torch._inductor.cudagraph_mark_step_begin() # pylint: disable=W0212 + outputs = func(*args, **kwargs) + torch.cuda.synchronize() + return outputs + + return wrapper + + +@pytest.fixture(params=[torch.float32, torch.float64], ids=["fp32", "fp64"]) +def default_dtype(request): + with tools.torch_tools.default_dtype(request.param): + yield torch.get_default_dtype() + + +# skip if on windows +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_mace(device, default_dtype): # pylint: disable=W0621 + print(f"using default dtype = {default_dtype}") + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip(reason="cuda is not available") + + model_defaults = create_mace(device) + tmp_model = mace_compile.prepare(create_mace)(device) + model_compiled = torch.compile(tmp_model, mode="default") + + batch = create_batch(device) + output1 = model_defaults(batch, training=True) + output2 = model_compiled(batch, training=True) + assert_close(output1["energy"], output2["energy"]) + assert_close(output1["forces"], output2["forces"]) + + +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_eager_benchmark(benchmark, default_dtype): # pylint: disable=W0621 + print(f"using default dtype = {default_dtype}") + batch = create_batch("cuda") + model = create_mace("cuda") + model = time_func(model) + benchmark(model, batch, training=True) + + +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +@pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"]) +@pytest.mark.parametrize("enable_amp", [False, True], ids=["fp32", "mixed"]) +def test_compile_benchmark(benchmark, compile_mode, enable_amp): + if enable_amp: + pytest.skip(reason="autocast compiler assertion aten.slice_scatter.default") + + with tools.torch_tools.default_dtype(torch.float32): + batch = create_batch("cuda") + torch.compiler.reset() + model = mace_compile.prepare(create_mace)("cuda") + model = torch.compile(model, mode=compile_mode) + model = time_func(model) + + with torch.autocast("cuda", enabled=enable_amp): + benchmark(model, batch, training=True) + + +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_graph_breaks(): + import torch._dynamo as dynamo + + batch = create_batch("cuda") + model = mace_compile.prepare(create_mace)("cuda") + explanation = dynamo.explain(model)(batch, training=False) + + # these clutter the output but might be useful for investigating graph breaks + explanation.ops_per_graph = None + explanation.out_guards = None + print(explanation) + assert explanation.graph_break_count == 0 diff --git a/mace-bench/3rdparty/mace/tests/test_cueq.py b/mace-bench/3rdparty/mace/tests/test_cueq.py index d76b25f4e29ed934943e7236baa83ffdafb1983b..480a8170e3c5935fffbe781f600cf197b50a60ef 100644 --- a/mace-bench/3rdparty/mace/tests/test_cueq.py +++ b/mace-bench/3rdparty/mace/tests/test_cueq.py @@ -1,181 +1,181 @@ -# pylint: disable=wrong-import-position -import os -from copy import deepcopy -from typing import Any, Dict - -os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" - -import pytest -import torch -import torch.nn.functional as F -from e3nn import o3 - -from mace import data, modules, tools -from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn -from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq -from mace.tools import torch_geometric - -try: - import cuequivariance as cue # pylint: disable=unused-import - - CUET_AVAILABLE = True -except ImportError: - CUET_AVAILABLE = False - -CUDA_AVAILABLE = torch.cuda.is_available() - - -@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") -class TestCueq: - @pytest.fixture - def model_config(self, interaction_cls_first, hidden_irreps) -> Dict[str, Any]: - table = tools.AtomicNumberTable([6]) - return { - "r_max": 5.0, - "num_bessel": 8, - "num_polynomial_cutoff": 6, - "max_ell": 3, - "interaction_cls": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "interaction_cls_first": interaction_cls_first, - "num_interactions": 2, - "num_elements": 1, - "hidden_irreps": hidden_irreps, - "MLP_irreps": o3.Irreps("16x0e"), - "gate": F.silu, - "atomic_energies": torch.tensor([1.0]), - "avg_num_neighbors": 8, - "atomic_numbers": table.zs, - "correlation": 3, - "radial_type": "bessel", - "atomic_inter_scale": 1.0, - "atomic_inter_shift": 0.0, - } - - @pytest.fixture - def batch(self, device: str, default_dtype: torch.dtype) -> Dict[str, torch.Tensor]: - from ase import build - - torch.set_default_dtype(default_dtype) - - table = tools.AtomicNumberTable([6]) - - atoms = build.bulk("C", "diamond", a=3.567, cubic=True) - import numpy as np - - displacement = np.random.uniform(-0.1, 0.1, size=atoms.positions.shape) - atoms.positions += displacement - atoms_list = [atoms.repeat((2, 2, 2))] - - configs = [data.config_from_atoms(atoms) for atoms in atoms_list] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config(config, z_table=table, cutoff=5.0) - for config in configs - ], - batch_size=1, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - return batch.to(device).to_dict() - - @pytest.mark.parametrize( - "device", - ["cpu"] + (["cuda"] if CUDA_AVAILABLE else []), - ) - @pytest.mark.parametrize( - "interaction_cls_first", - [ - modules.interaction_classes["RealAgnosticResidualInteractionBlock"], - modules.interaction_classes["RealAgnosticInteractionBlock"], - modules.interaction_classes["RealAgnosticDensityInteractionBlock"], - ], - ) - @pytest.mark.parametrize( - "hidden_irreps", - [ - o3.Irreps("32x0e + 32x1o"), - o3.Irreps("32x0e + 32x1o + 32x2e"), - o3.Irreps("32x0e"), - ], - ) - @pytest.mark.parametrize("default_dtype", [torch.float32, torch.float64]) - def test_bidirectional_conversion( - self, - model_config: Dict[str, Any], - batch: Dict[str, torch.Tensor], - device: str, - default_dtype: torch.dtype, - ): - if device == "cuda" and not CUDA_AVAILABLE: - pytest.skip("CUDA not available") - torch.manual_seed(42) - - # Create original E3nn model - model_e3nn = modules.ScaleShiftMACE(**model_config).to(device) - - # Convert E3nn to CuEq - model_cueq = run_e3nn_to_cueq(model_e3nn).to(device) - - # Convert CuEq back to E3nn - model_e3nn_back = run_cueq_to_e3nn(model_cueq).to(device) - - # Test forward pass equivalence - out_e3nn = model_e3nn(deepcopy(batch), training=True, compute_stress=True) - out_cueq = model_cueq(deepcopy(batch), training=True, compute_stress=True) - out_e3nn_back = model_e3nn_back( - deepcopy(batch), training=True, compute_stress=True - ) - - # Check outputs match for both conversions - torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) - torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) - torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) - torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) - torch.testing.assert_close(out_e3nn["stress"], out_cueq["stress"]) - torch.testing.assert_close(out_cueq["stress"], out_e3nn_back["stress"]) - - # Test backward pass equivalence - loss_e3nn = out_e3nn["energy"].sum() - loss_cueq = out_cueq["energy"].sum() - loss_e3nn_back = out_e3nn_back["energy"].sum() - - loss_e3nn.backward() - loss_cueq.backward() - loss_e3nn_back.backward() - - # Compare gradients for all conversions - tol = 1e-4 if default_dtype == torch.float32 else 1e-7 - - def print_gradient_diff(name1, p1, name2, p2, conv_type): - if p1.grad is not None and p1.grad.shape == p2.grad.shape: - if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: - error = torch.abs(p1.grad - p2.grad) - print( - f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}" - ) - torch.testing.assert_close(p1.grad, p2.grad, atol=tol, rtol=tol) - - # E3nn to CuEq gradients - for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip( - model_e3nn.named_parameters(), model_cueq.named_parameters() - ): - print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq") - - # CuEq to E3nn gradients - for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( - model_cueq.named_parameters(), model_e3nn_back.named_parameters() - ): - print_gradient_diff( - name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn" - ) - - # Full circle comparison (E3nn -> E3nn) - for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( - model_e3nn.named_parameters(), model_e3nn_back.named_parameters() - ): - print_gradient_diff( - name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle" - ) +# pylint: disable=wrong-import-position +import os +from copy import deepcopy +from typing import Any, Dict + +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" + +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 + +from mace import data, modules, tools +from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq +from mace.tools import torch_geometric + +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +CUDA_AVAILABLE = torch.cuda.is_available() + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +class TestCueq: + @pytest.fixture + def model_config(self, interaction_cls_first, hidden_irreps) -> Dict[str, Any]: + table = tools.AtomicNumberTable([6]) + return { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": interaction_cls_first, + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": hidden_irreps, + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": torch.tensor([1.0]), + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel", + "atomic_inter_scale": 1.0, + "atomic_inter_shift": 0.0, + } + + @pytest.fixture + def batch(self, device: str, default_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + from ase import build + + torch.set_default_dtype(default_dtype) + + table = tools.AtomicNumberTable([6]) + + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + import numpy as np + + displacement = np.random.uniform(-0.1, 0.1, size=atoms.positions.shape) + atoms.positions += displacement + atoms_list = [atoms.repeat((2, 2, 2))] + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=5.0) + for config in configs + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + return batch.to(device).to_dict() + + @pytest.mark.parametrize( + "device", + ["cpu"] + (["cuda"] if CUDA_AVAILABLE else []), + ) + @pytest.mark.parametrize( + "interaction_cls_first", + [ + modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + modules.interaction_classes["RealAgnosticInteractionBlock"], + modules.interaction_classes["RealAgnosticDensityInteractionBlock"], + ], + ) + @pytest.mark.parametrize( + "hidden_irreps", + [ + o3.Irreps("32x0e + 32x1o"), + o3.Irreps("32x0e + 32x1o + 32x2e"), + o3.Irreps("32x0e"), + ], + ) + @pytest.mark.parametrize("default_dtype", [torch.float32, torch.float64]) + def test_bidirectional_conversion( + self, + model_config: Dict[str, Any], + batch: Dict[str, torch.Tensor], + device: str, + default_dtype: torch.dtype, + ): + if device == "cuda" and not CUDA_AVAILABLE: + pytest.skip("CUDA not available") + torch.manual_seed(42) + + # Create original E3nn model + model_e3nn = modules.ScaleShiftMACE(**model_config).to(device) + + # Convert E3nn to CuEq + model_cueq = run_e3nn_to_cueq(model_e3nn).to(device) + + # Convert CuEq back to E3nn + model_e3nn_back = run_cueq_to_e3nn(model_cueq).to(device) + + # Test forward pass equivalence + out_e3nn = model_e3nn(deepcopy(batch), training=True, compute_stress=True) + out_cueq = model_cueq(deepcopy(batch), training=True, compute_stress=True) + out_e3nn_back = model_e3nn_back( + deepcopy(batch), training=True, compute_stress=True + ) + + # Check outputs match for both conversions + torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) + torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) + torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) + torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) + torch.testing.assert_close(out_e3nn["stress"], out_cueq["stress"]) + torch.testing.assert_close(out_cueq["stress"], out_e3nn_back["stress"]) + + # Test backward pass equivalence + loss_e3nn = out_e3nn["energy"].sum() + loss_cueq = out_cueq["energy"].sum() + loss_e3nn_back = out_e3nn_back["energy"].sum() + + loss_e3nn.backward() + loss_cueq.backward() + loss_e3nn_back.backward() + + # Compare gradients for all conversions + tol = 1e-4 if default_dtype == torch.float32 else 1e-7 + + def print_gradient_diff(name1, p1, name2, p2, conv_type): + if p1.grad is not None and p1.grad.shape == p2.grad.shape: + if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: + error = torch.abs(p1.grad - p2.grad) + print( + f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}" + ) + torch.testing.assert_close(p1.grad, p2.grad, atol=tol, rtol=tol) + + # E3nn to CuEq gradients + for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip( + model_e3nn.named_parameters(), model_cueq.named_parameters() + ): + print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq") + + # CuEq to E3nn gradients + for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( + model_cueq.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff( + name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn" + ) + + # Full circle comparison (E3nn -> E3nn) + for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( + model_e3nn.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff( + name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle" + ) diff --git a/mace-bench/3rdparty/mace/tests/test_data.py b/mace-bench/3rdparty/mace/tests/test_data.py index 41180e8607eba8476e0ecd8630694e94c1ea93b5..6710ecddb4d042326b079bde96a888f32d3df429 100644 --- a/mace-bench/3rdparty/mace/tests/test_data.py +++ b/mace-bench/3rdparty/mace/tests/test_data.py @@ -1,213 +1,213 @@ -from copy import deepcopy -from pathlib import Path - -import ase.build -import h5py -import numpy as np -import torch - -from mace.data import ( - AtomicData, - Configuration, - HDF5Dataset, - config_from_atoms, - get_neighborhood, - save_configurations_as_HDF5, -) -from mace.tools import AtomicNumberTable, torch_geometric - -mace_path = Path(__file__).parent.parent - - -class TestAtomicData: - config = Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=np.array( - [ - [0.0, -2.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - ] - ), - properties={ - "forces": np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - "energy": -1.5, - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - }, - ) - config_2 = deepcopy(config) - config_2.positions = config.positions + 0.01 - - table = AtomicNumberTable([1, 8]) - - def test_atomic_data(self): - data = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) - - assert data.edge_index.shape == (2, 4) - assert data.forces.shape == (3, 3) - assert data.node_attrs.shape == (3, 2) - - def test_data_loader(self): - data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) - data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data1, data2], - batch_size=2, - shuffle=True, - drop_last=False, - ) - - for batch in data_loader: - assert batch.batch.shape == (6,) - assert batch.edge_index.shape == (2, 8) - assert batch.shifts.shape == (8, 3) - assert batch.positions.shape == (6, 3) - assert batch.node_attrs.shape == (6, 2) - assert batch.energy.shape == (2,) - assert batch.forces.shape == (6, 3) - - def test_to_atomic_data_dict(self): - data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) - data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data1, data2], - batch_size=2, - shuffle=True, - drop_last=False, - ) - for batch in data_loader: - batch_dict = batch.to_dict() - assert batch_dict["batch"].shape == (6,) - assert batch_dict["edge_index"].shape == (2, 8) - assert batch_dict["shifts"].shape == (8, 3) - assert batch_dict["positions"].shape == (6, 3) - assert batch_dict["node_attrs"].shape == (6, 2) - assert batch_dict["energy"].shape == (2,) - assert batch_dict["forces"].shape == (6, 3) - - def test_hdf5_dataloader(self): - datasets = [self.config, self.config_2] * 5 - # get path of the mace package - with h5py.File(str(mace_path) + "test.h5", "w") as f: - save_configurations_as_HDF5(datasets, 0, f) - train_dataset = HDF5Dataset( - str(mace_path) + "test.h5", z_table=self.table, r_max=3.0 - ) - train_loader = torch_geometric.dataloader.DataLoader( - dataset=train_dataset, - batch_size=2, - shuffle=False, - drop_last=False, - ) - batch_count = 0 - for batch in train_loader: - batch_count += 1 - assert batch.batch.shape == (6,) - assert batch.edge_index.shape == (2, 8) - assert batch.shifts.shape == (8, 3) - assert batch.positions.shape == (6, 3) - assert batch.node_attrs.shape == (6, 2) - assert batch.energy.shape == (2,) - assert batch.forces.shape == (6, 3) - print(batch_count, len(train_loader), len(train_dataset)) - assert batch_count == len(train_loader) == len(train_dataset) / 2 - train_loader_direct = torch_geometric.dataloader.DataLoader( - dataset=[ - AtomicData.from_config(config, z_table=self.table, cutoff=3.0) - for config in datasets - ], - batch_size=2, - shuffle=False, - drop_last=False, - ) - for batch_direct, batch in zip(train_loader_direct, train_loader): - assert torch.all(batch_direct.edge_index == batch.edge_index) - assert torch.all(batch_direct.shifts == batch.shifts) - assert torch.all(batch_direct.positions == batch.positions) - assert torch.all(batch_direct.node_attrs == batch.node_attrs) - assert torch.all(batch_direct.energy == batch.energy) - assert torch.all(batch_direct.forces == batch.forces) - - -class TestNeighborhood: - def test_basic(self): - positions = np.array( - [ - [-1.0, 0.0, 0.0], - [+0.0, 0.0, 0.0], - [+1.0, 0.0, 0.0], - ] - ) - - indices, shifts, unit_shifts, _ = get_neighborhood(positions, cutoff=1.5) - assert indices.shape == (2, 4) - assert shifts.shape == (4, 3) - assert unit_shifts.shape == (4, 3) - - def test_signs(self): - positions = np.array( - [ - [+0.5, 0.5, 0.0], - [+1.0, 1.0, 0.0], - ] - ) - - cell = np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) - edge_index, shifts, unit_shifts, _ = get_neighborhood( - positions, cutoff=3.5, pbc=(True, False, False), cell=cell - ) - num_edges = 10 - assert edge_index.shape == (2, num_edges) - assert shifts.shape == (num_edges, 3) - assert unit_shifts.shape == (num_edges, 3) - - -# Based on mir-group/nequip -def test_periodic_edge(): - atoms = ase.build.bulk("Cu", "fcc") - dist = np.linalg.norm(atoms.cell[0]).item() - config = config_from_atoms(atoms) - edge_index, shifts, _, _ = get_neighborhood( - config.positions, cutoff=1.05 * dist, pbc=(True, True, True), cell=config.cell - ) - sender, receiver = edge_index - vectors = ( - config.positions[receiver] - config.positions[sender] + shifts - ) # [n_edges, 3] - assert vectors.shape == (12, 3) # 12 neighbors in close-packed bulk - assert np.allclose( - np.linalg.norm(vectors, axis=-1), - dist, - ) - - -def test_half_periodic(): - atoms = ase.build.fcc111("Al", size=(3, 3, 1), vacuum=0.0) - assert all(atoms.pbc == (True, True, False)) - config = config_from_atoms(atoms) # first shell dist is 2.864A - edge_index, shifts, _, _ = get_neighborhood( - config.positions, cutoff=2.9, pbc=(True, True, False), cell=config.cell - ) - sender, receiver = edge_index - vectors = ( - config.positions[receiver] - config.positions[sender] + shifts - ) # [n_edges, 3] - # Check number of neighbors: - _, neighbor_count = np.unique(edge_index[0], return_counts=True) - assert (neighbor_count == 6).all() # 6 neighbors - # Check not periodic in z - assert np.allclose( - vectors[:, 2], - np.zeros(vectors.shape[0]), - ) +from copy import deepcopy +from pathlib import Path + +import ase.build +import h5py +import numpy as np +import torch + +from mace.data import ( + AtomicData, + Configuration, + HDF5Dataset, + config_from_atoms, + get_neighborhood, + save_configurations_as_HDF5, +) +from mace.tools import AtomicNumberTable, torch_geometric + +mace_path = Path(__file__).parent.parent + + +class TestAtomicData: + config = Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, + ) + config_2 = deepcopy(config) + config_2.positions = config.positions + 0.01 + + table = AtomicNumberTable([1, 8]) + + def test_atomic_data(self): + data = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + + assert data.edge_index.shape == (2, 4) + assert data.forces.shape == (3, 3) + assert data.node_attrs.shape == (3, 2) + + def test_data_loader(self): + data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data1, data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + + for batch in data_loader: + assert batch.batch.shape == (6,) + assert batch.edge_index.shape == (2, 8) + assert batch.shifts.shape == (8, 3) + assert batch.positions.shape == (6, 3) + assert batch.node_attrs.shape == (6, 2) + assert batch.energy.shape == (2,) + assert batch.forces.shape == (6, 3) + + def test_to_atomic_data_dict(self): + data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data1, data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + for batch in data_loader: + batch_dict = batch.to_dict() + assert batch_dict["batch"].shape == (6,) + assert batch_dict["edge_index"].shape == (2, 8) + assert batch_dict["shifts"].shape == (8, 3) + assert batch_dict["positions"].shape == (6, 3) + assert batch_dict["node_attrs"].shape == (6, 2) + assert batch_dict["energy"].shape == (2,) + assert batch_dict["forces"].shape == (6, 3) + + def test_hdf5_dataloader(self): + datasets = [self.config, self.config_2] * 5 + # get path of the mace package + with h5py.File(str(mace_path) + "test.h5", "w") as f: + save_configurations_as_HDF5(datasets, 0, f) + train_dataset = HDF5Dataset( + str(mace_path) + "test.h5", z_table=self.table, r_max=3.0 + ) + train_loader = torch_geometric.dataloader.DataLoader( + dataset=train_dataset, + batch_size=2, + shuffle=False, + drop_last=False, + ) + batch_count = 0 + for batch in train_loader: + batch_count += 1 + assert batch.batch.shape == (6,) + assert batch.edge_index.shape == (2, 8) + assert batch.shifts.shape == (8, 3) + assert batch.positions.shape == (6, 3) + assert batch.node_attrs.shape == (6, 2) + assert batch.energy.shape == (2,) + assert batch.forces.shape == (6, 3) + print(batch_count, len(train_loader), len(train_dataset)) + assert batch_count == len(train_loader) == len(train_dataset) / 2 + train_loader_direct = torch_geometric.dataloader.DataLoader( + dataset=[ + AtomicData.from_config(config, z_table=self.table, cutoff=3.0) + for config in datasets + ], + batch_size=2, + shuffle=False, + drop_last=False, + ) + for batch_direct, batch in zip(train_loader_direct, train_loader): + assert torch.all(batch_direct.edge_index == batch.edge_index) + assert torch.all(batch_direct.shifts == batch.shifts) + assert torch.all(batch_direct.positions == batch.positions) + assert torch.all(batch_direct.node_attrs == batch.node_attrs) + assert torch.all(batch_direct.energy == batch.energy) + assert torch.all(batch_direct.forces == batch.forces) + + +class TestNeighborhood: + def test_basic(self): + positions = np.array( + [ + [-1.0, 0.0, 0.0], + [+0.0, 0.0, 0.0], + [+1.0, 0.0, 0.0], + ] + ) + + indices, shifts, unit_shifts, _ = get_neighborhood(positions, cutoff=1.5) + assert indices.shape == (2, 4) + assert shifts.shape == (4, 3) + assert unit_shifts.shape == (4, 3) + + def test_signs(self): + positions = np.array( + [ + [+0.5, 0.5, 0.0], + [+1.0, 1.0, 0.0], + ] + ) + + cell = np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + edge_index, shifts, unit_shifts, _ = get_neighborhood( + positions, cutoff=3.5, pbc=(True, False, False), cell=cell + ) + num_edges = 10 + assert edge_index.shape == (2, num_edges) + assert shifts.shape == (num_edges, 3) + assert unit_shifts.shape == (num_edges, 3) + + +# Based on mir-group/nequip +def test_periodic_edge(): + atoms = ase.build.bulk("Cu", "fcc") + dist = np.linalg.norm(atoms.cell[0]).item() + config = config_from_atoms(atoms) + edge_index, shifts, _, _ = get_neighborhood( + config.positions, cutoff=1.05 * dist, pbc=(True, True, True), cell=config.cell + ) + sender, receiver = edge_index + vectors = ( + config.positions[receiver] - config.positions[sender] + shifts + ) # [n_edges, 3] + assert vectors.shape == (12, 3) # 12 neighbors in close-packed bulk + assert np.allclose( + np.linalg.norm(vectors, axis=-1), + dist, + ) + + +def test_half_periodic(): + atoms = ase.build.fcc111("Al", size=(3, 3, 1), vacuum=0.0) + assert all(atoms.pbc == (True, True, False)) + config = config_from_atoms(atoms) # first shell dist is 2.864A + edge_index, shifts, _, _ = get_neighborhood( + config.positions, cutoff=2.9, pbc=(True, True, False), cell=config.cell + ) + sender, receiver = edge_index + vectors = ( + config.positions[receiver] - config.positions[sender] + shifts + ) # [n_edges, 3] + # Check number of neighbors: + _, neighbor_count = np.unique(edge_index[0], return_counts=True) + assert (neighbor_count == 6).all() # 6 neighbors + # Check not periodic in z + assert np.allclose( + vectors[:, 2], + np.zeros(vectors.shape[0]), + ) diff --git a/mace-bench/3rdparty/mace/tests/test_finetuning_select.py b/mace-bench/3rdparty/mace/tests/test_finetuning_select.py index d8d9701f4b81e3cdb2974092c7138aa943fdb634..a58c8b305307ed4510d6e72972255b1144d07bd4 100644 --- a/mace-bench/3rdparty/mace/tests/test_finetuning_select.py +++ b/mace-bench/3rdparty/mace/tests/test_finetuning_select.py @@ -1,164 +1,164 @@ -import ase.io as aio -import numpy as np -import pytest -from ase import Atoms -from ase.build import molecule - -from mace.cli.fine_tuning_select import ( - FilteringType, - SelectionSettings, - SubselectType, - _filter_pretraining_data, - _load_descriptors, - _maybe_save_descriptors, - filter_atoms, - select_samples, -) - - -@pytest.fixture(name="train_atoms_fixture") -def train_atoms(): - return [ - molecule("H2O"), - molecule("CH4"), - Atoms("Fe2O3"), - Atoms("C"), - Atoms("FeON"), - Atoms("Fe"), - ] - - -@pytest.fixture(name="train_atom_descriptors_fixture") -def train_atom_descriptors(train_atoms_fixture): - return [ - {x: np.zeros(5) + i for x in atoms.symbols} - for i, atoms in enumerate(train_atoms_fixture) - ] - - -@pytest.mark.parametrize( - "filtering_type, passes_filter, element_sublist", - [ - (FilteringType.NONE, [True] * 6, []), - (FilteringType.NONE, [True] * 6, ["C", "U", "Anything really"]), - ( - FilteringType.COMBINATIONS, - [False, False, True, False, False, True], - ["O", "Fe"], - ), - ( - FilteringType.INCLUSIVE, - [False, False, True, False, True, False], - ["O", "Fe"], - ), - ( - FilteringType.EXCLUSIVE, - [False, False, True, False, False, False], - ["O", "Fe"], - ), - ], -) -def test_filter_data( - train_atoms_fixture, filtering_type, passes_filter, element_sublist -): - filtered, _, passes = _filter_pretraining_data( - train_atoms_fixture, filtering_type, element_sublist - ) - assert passes == passes_filter - assert len(filtered) == sum(passes_filter) - - -@pytest.mark.parametrize( - "passes_filter", [[True] * 6, [False, True, False, True, False, True]] -) -def test_load_descriptors( - train_atoms_fixture, train_atom_descriptors_fixture, passes_filter, tmp_path -): - for i, atoms in enumerate(train_atoms_fixture): - atoms.info["mace_descriptors"] = train_atom_descriptors_fixture[i] - save_path = tmp_path / "test.xyz" - _maybe_save_descriptors(train_atoms_fixture, save_path.as_posix()) - assert all(not "mace_descriptors" in atoms.info for atoms in train_atoms_fixture) - filtered_atoms = [ - x for x, passes in zip(train_atoms_fixture, passes_filter) if passes - ] - descriptors_path = save_path.as_posix().replace(".xyz", "_descriptors.npy") - - _load_descriptors( - filtered_atoms, - passes_filter, - descriptors_path=descriptors_path, - calc=None, - full_data_length=len(train_atoms_fixture), - ) - expected_descriptors = [ - train_atom_descriptors_fixture[i] - for i, passes in enumerate(passes_filter) - if passes - ] - for i, atoms in enumerate(filtered_atoms): - assert "mace_descriptors" in atoms.info - for key, value in expected_descriptors[i].items(): - assert np.allclose(atoms.info["mace_descriptors"][key], value) - - -def test_select_samples_random(train_atoms_fixture, tmp_path): - input_file_path = tmp_path / "input.xyz" - aio.write(input_file_path, train_atoms_fixture, format="extxyz") - output_file_path = tmp_path / "output.xyz" - - settings = SelectionSettings( - configs_pt=input_file_path.as_posix(), - output=output_file_path.as_posix(), - num_samples=2, - subselect=SubselectType.RANDOM, - filtering_type=FilteringType.NONE, - ) - select_samples(settings) - - # Check if output file is created - assert output_file_path.exists() - combined_output_file_path = tmp_path / "output_combined.xyz" - assert combined_output_file_path.exists() - - output_atoms = aio.read(output_file_path, index=":") - assert isinstance(output_atoms, list) - assert len(output_atoms) == 2 - - combined_output_atoms = aio.read(combined_output_file_path, index=":") - assert isinstance(combined_output_atoms, list) - assert ( - len(combined_output_atoms) == 2 - ) # combined same as output since no FT data provided - - -def test_select_samples_ft_provided(train_atoms_fixture, tmp_path): - input_file_path = tmp_path / "input.xyz" - aio.write(input_file_path, train_atoms_fixture, format="extxyz") - output_file_path = tmp_path / "output.xyz" - ft_file_path = tmp_path / "ft_data.xyz" - ft_data = [Atoms("FeO")] - aio.write(ft_file_path.as_posix(), ft_data, format="extxyz") - - settings = SelectionSettings( - configs_pt=input_file_path.as_posix(), - output=output_file_path.as_posix(), - num_samples=2, - subselect=SubselectType.RANDOM, - configs_ft=ft_file_path.as_posix(), - ) - select_samples(settings) - - # Check if output file is created - assert output_file_path.exists() - combined_output_file_path = tmp_path / "output_combined.xyz" - assert combined_output_file_path.exists() - - output_atoms = aio.read(output_file_path, index=":") - assert isinstance(output_atoms, list) - assert len(output_atoms) == 2 - assert all(filter_atoms(x, ["Fe", "O"]) for x in output_atoms) - - combined_atoms = aio.read(combined_output_file_path, index=":") - assert isinstance(combined_atoms, list) - assert len(combined_atoms) == len(output_atoms) + len(ft_data) +import ase.io as aio +import numpy as np +import pytest +from ase import Atoms +from ase.build import molecule + +from mace.cli.fine_tuning_select import ( + FilteringType, + SelectionSettings, + SubselectType, + _filter_pretraining_data, + _load_descriptors, + _maybe_save_descriptors, + filter_atoms, + select_samples, +) + + +@pytest.fixture(name="train_atoms_fixture") +def train_atoms(): + return [ + molecule("H2O"), + molecule("CH4"), + Atoms("Fe2O3"), + Atoms("C"), + Atoms("FeON"), + Atoms("Fe"), + ] + + +@pytest.fixture(name="train_atom_descriptors_fixture") +def train_atom_descriptors(train_atoms_fixture): + return [ + {x: np.zeros(5) + i for x in atoms.symbols} + for i, atoms in enumerate(train_atoms_fixture) + ] + + +@pytest.mark.parametrize( + "filtering_type, passes_filter, element_sublist", + [ + (FilteringType.NONE, [True] * 6, []), + (FilteringType.NONE, [True] * 6, ["C", "U", "Anything really"]), + ( + FilteringType.COMBINATIONS, + [False, False, True, False, False, True], + ["O", "Fe"], + ), + ( + FilteringType.INCLUSIVE, + [False, False, True, False, True, False], + ["O", "Fe"], + ), + ( + FilteringType.EXCLUSIVE, + [False, False, True, False, False, False], + ["O", "Fe"], + ), + ], +) +def test_filter_data( + train_atoms_fixture, filtering_type, passes_filter, element_sublist +): + filtered, _, passes = _filter_pretraining_data( + train_atoms_fixture, filtering_type, element_sublist + ) + assert passes == passes_filter + assert len(filtered) == sum(passes_filter) + + +@pytest.mark.parametrize( + "passes_filter", [[True] * 6, [False, True, False, True, False, True]] +) +def test_load_descriptors( + train_atoms_fixture, train_atom_descriptors_fixture, passes_filter, tmp_path +): + for i, atoms in enumerate(train_atoms_fixture): + atoms.info["mace_descriptors"] = train_atom_descriptors_fixture[i] + save_path = tmp_path / "test.xyz" + _maybe_save_descriptors(train_atoms_fixture, save_path.as_posix()) + assert all(not "mace_descriptors" in atoms.info for atoms in train_atoms_fixture) + filtered_atoms = [ + x for x, passes in zip(train_atoms_fixture, passes_filter) if passes + ] + descriptors_path = save_path.as_posix().replace(".xyz", "_descriptors.npy") + + _load_descriptors( + filtered_atoms, + passes_filter, + descriptors_path=descriptors_path, + calc=None, + full_data_length=len(train_atoms_fixture), + ) + expected_descriptors = [ + train_atom_descriptors_fixture[i] + for i, passes in enumerate(passes_filter) + if passes + ] + for i, atoms in enumerate(filtered_atoms): + assert "mace_descriptors" in atoms.info + for key, value in expected_descriptors[i].items(): + assert np.allclose(atoms.info["mace_descriptors"][key], value) + + +def test_select_samples_random(train_atoms_fixture, tmp_path): + input_file_path = tmp_path / "input.xyz" + aio.write(input_file_path, train_atoms_fixture, format="extxyz") + output_file_path = tmp_path / "output.xyz" + + settings = SelectionSettings( + configs_pt=input_file_path.as_posix(), + output=output_file_path.as_posix(), + num_samples=2, + subselect=SubselectType.RANDOM, + filtering_type=FilteringType.NONE, + ) + select_samples(settings) + + # Check if output file is created + assert output_file_path.exists() + combined_output_file_path = tmp_path / "output_combined.xyz" + assert combined_output_file_path.exists() + + output_atoms = aio.read(output_file_path, index=":") + assert isinstance(output_atoms, list) + assert len(output_atoms) == 2 + + combined_output_atoms = aio.read(combined_output_file_path, index=":") + assert isinstance(combined_output_atoms, list) + assert ( + len(combined_output_atoms) == 2 + ) # combined same as output since no FT data provided + + +def test_select_samples_ft_provided(train_atoms_fixture, tmp_path): + input_file_path = tmp_path / "input.xyz" + aio.write(input_file_path, train_atoms_fixture, format="extxyz") + output_file_path = tmp_path / "output.xyz" + ft_file_path = tmp_path / "ft_data.xyz" + ft_data = [Atoms("FeO")] + aio.write(ft_file_path.as_posix(), ft_data, format="extxyz") + + settings = SelectionSettings( + configs_pt=input_file_path.as_posix(), + output=output_file_path.as_posix(), + num_samples=2, + subselect=SubselectType.RANDOM, + configs_ft=ft_file_path.as_posix(), + ) + select_samples(settings) + + # Check if output file is created + assert output_file_path.exists() + combined_output_file_path = tmp_path / "output_combined.xyz" + assert combined_output_file_path.exists() + + output_atoms = aio.read(output_file_path, index=":") + assert isinstance(output_atoms, list) + assert len(output_atoms) == 2 + assert all(filter_atoms(x, ["Fe", "O"]) for x in output_atoms) + + combined_atoms = aio.read(combined_output_file_path, index=":") + assert isinstance(combined_atoms, list) + assert len(combined_atoms) == len(output_atoms) + len(ft_data) diff --git a/mace-bench/3rdparty/mace/tests/test_foundations.py b/mace-bench/3rdparty/mace/tests/test_foundations.py index cb19a3e6a583168682a90fba1c5c671b0544996f..c8641839d2746ab4c0932dd52b1cf1d0f0e0bdac 100644 --- a/mace-bench/3rdparty/mace/tests/test_foundations.py +++ b/mace-bench/3rdparty/mace/tests/test_foundations.py @@ -1,512 +1,512 @@ -from pathlib import Path - -import numpy as np -import pytest -import torch -import torch.nn.functional -from ase.build import molecule -from e3nn import o3 -from e3nn.util import jit -from scipy.spatial.transform import Rotation as R - -from mace import data, modules, tools -from mace.calculators import mace_mp, mace_off -from mace.tools import torch_geometric -from mace.tools.finetuning_utils import load_foundations_elements -from mace.tools.scripts_utils import extract_config_mace_model, remove_pt_head -from mace.tools.utils import AtomicNumberTable - -MODEL_PATH = ( - Path(__file__).parent.parent - / "mace" - / "calculators" - / "foundations_models" - / "2023-12-03-mace-mp.model" -) - -torch.set_default_dtype(torch.float64) - -@pytest.skip("Problem with the float type", allow_module_level=True) -def test_foundations(): - # Create MACE model - config = data.Configuration( - atomic_numbers=molecule("H2COH").numbers, - positions=molecule("H2COH").positions, - properties={ - "forces": molecule("H2COH").positions, - "energy": -1.5, - "charges": molecule("H2COH").numbers, - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, - ) - - # Created the rotated environment - rot = R.from_euler("z", 60, degrees=True).as_matrix() - positions_rotated = np.array(rot @ config.positions.T).T - config_rotated = data.Configuration( - atomic_numbers=molecule("H2COH").numbers, - positions=positions_rotated, - properties={ - "forces": molecule("H2COH").positions, - "energy": -1.5, - "charges": molecule("H2COH").numbers, - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, - ) - table = tools.AtomicNumberTable([1, 6, 8]) - atomic_energies = np.array([0.0, 0.0, 0.0], dtype=float) - model_config = dict( - r_max=6, - num_bessel=10, - num_polynomial_cutoff=5, - max_ell=3, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=3, - hidden_irreps=o3.Irreps("128x0e + 128x1o"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies, - avg_num_neighbors=3, - atomic_numbers=table.zs, - correlation=3, - radial_type="bessel", - atomic_inter_scale=0.1, - atomic_inter_shift=0.0, - ) - model = modules.ScaleShiftMACE(**model_config) - calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") - model_loaded = load_foundations_elements( - model, - calc_foundation.models[0], - table=table, - load_readout=True, - use_shift=False, - max_L=1, - ) - atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=6.0) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=table, cutoff=6.0 - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data2], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - forces_loaded = model_loaded(batch.to_dict())["forces"] - forces = model(batch.to_dict())["forces"] - assert torch.allclose(forces, forces_loaded) - - -def test_multi_reference(): - config_multi = data.Configuration( - atomic_numbers=molecule("H2COH").numbers, - positions=molecule("H2COH").positions, - properties={ - "forces": molecule("H2COH").positions, - "energy": -1.5, - "charges": molecule("H2COH").numbers, - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, - head="MP2", - ) - table_multi = tools.AtomicNumberTable([1, 6, 8]) - atomic_energies_multi = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) - table = tools.AtomicNumberTable([1, 6, 8]) - - - # Create MACE model - model_config = dict( - r_max=6, - num_bessel=10, - num_polynomial_cutoff=5, - max_ell=3, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=3, - hidden_irreps=o3.Irreps("128x0e + 128x1o"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies_multi, - avg_num_neighbors=61, - atomic_numbers=table.zs, - correlation=3, - radial_type="bessel", - atomic_inter_scale=[1.0, 1.0], - atomic_inter_shift=[0.0, 0.0], - heads=["MP2", "DFT"], - ) - model = modules.ScaleShiftMACE(**model_config) - calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") - model_loaded = load_foundations_elements( - model, - calc_foundation.models[0], - table=table, - load_readout=True, - use_shift=False, - max_L=1, - ) - atomic_data = data.AtomicData.from_config( - config_multi, z_table=table_multi, cutoff=6.0, heads=["MP2", "DFT"] - ) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - forces_loaded = model_loaded(batch.to_dict())["forces"] - calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") - atoms = molecule("H2COH") - atoms.info["head"] = "MP2" - atoms.calc = calc_foundation - forces = atoms.get_forces() - assert np.allclose( - forces, forces_loaded.detach().numpy()[:5, :], atol=1e-5, rtol=1e-5 - ) - - -@pytest.mark.parametrize( - "calc", - [ - mace_mp(device="cpu", default_dtype="float64"), - mace_mp(model="small", device="cpu", default_dtype="float64"), - mace_mp(model="medium", device="cpu", default_dtype="float64"), - mace_mp(model="large", device="cpu", default_dtype="float64"), - mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64"), - mace_off(model="small", device="cpu", default_dtype="float64"), - mace_off(model="medium", device="cpu", default_dtype="float64"), - mace_off(model="large", device="cpu", default_dtype="float64"), - mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64"), - ], -) -def test_compile_foundation(calc): - model = calc.models[0] - atoms = molecule("CH4") - atoms.positions += np.random.randn(*atoms.positions.shape) * 0.1 - batch = calc._atoms_to_batch(atoms) # pylint: disable=protected-access - output_1 = model(batch.to_dict()) - model_compiled = jit.compile(model) - output = model_compiled(batch.to_dict()) - for key in output_1.keys(): - if isinstance(output_1[key], torch.Tensor): - assert torch.allclose(output_1[key], output[key], atol=1e-5) - - -@pytest.mark.parametrize( - "model", - [ - mace_mp(model="small", device="cpu", default_dtype="float64").models[0], - mace_mp(model="medium", device="cpu", default_dtype="float64").models[0], - mace_mp(model="large", device="cpu", default_dtype="float64").models[0], - mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], - mace_off(model="small", device="cpu", default_dtype="float64").models[0], - mace_off(model="medium", device="cpu", default_dtype="float64").models[0], - mace_off(model="large", device="cpu", default_dtype="float64").models[0], - mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], - ], -) -def test_extract_config(model): - assert isinstance(model, modules.ScaleShiftMACE) - config = data.Configuration( - atomic_numbers=molecule("H2COH").numbers, - positions=molecule("H2COH").positions, - properties={ - "forces": molecule("H2COH").positions, - "energy": -1.5, - "charges": molecule("H2COH").numbers, - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, - ) - model_copy = modules.ScaleShiftMACE(**extract_config_mace_model(model)) - model_copy.load_state_dict(model.state_dict()) - z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) - atomic_data = data.AtomicData.from_config(config, z_table=z_table, cutoff=6.0) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - output = model(batch.to_dict()) - output_copy = model_copy(batch.to_dict()) - # assert all items of the output dicts are equal - for key in output.keys(): - if isinstance(output[key], torch.Tensor): - assert torch.allclose(output[key], output_copy[key], atol=1e-5) - - -def test_remove_pt_head(): - # Set up test data - torch.manual_seed(42) - atomic_energies_pt_head = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float) - z_table = AtomicNumberTable([1, 8]) # H and O - - # Create multihead model - model_config = { - "r_max": 5.0, - "num_bessel": 8, - "num_polynomial_cutoff": 5, - "max_ell": 2, - "interaction_cls": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "interaction_cls_first": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "num_interactions": 2, - "num_elements": len(z_table), - "hidden_irreps": o3.Irreps("32x0e + 32x1o"), - "MLP_irreps": o3.Irreps("16x0e"), - "gate": torch.nn.functional.silu, - "atomic_energies": atomic_energies_pt_head, - "avg_num_neighbors": 8, - "atomic_numbers": z_table.zs, - "correlation": 3, - "heads": ["pt_head", "DFT"], - "atomic_inter_scale": [1.0, 1.0], - "atomic_inter_shift": [0.0, 0.1], - } - - model = modules.ScaleShiftMACE(**model_config) - - # Create test molecule - mol = molecule("H2O") - config_pt_head = data.Configuration( - atomic_numbers=mol.numbers, - positions=mol.positions, - properties={"energy": 1.0, "forces": np.random.randn(len(mol), 3)}, - property_weights={"forces": 1.0, "energy": 1.0}, - head="DFT", - ) - atomic_data = data.AtomicData.from_config( - config_pt_head, z_table=z_table, cutoff=5.0, heads=["pt_head", "DFT"] - ) - dataloader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data], batch_size=1, shuffle=False - ) - batch = next(iter(dataloader)) - # Test original mode - output_orig = model(batch.to_dict()) - - # Convert to single head model - new_model = remove_pt_head(model, head_to_keep="DFT") - - # Basic structure tests - assert len(new_model.heads) == 1 - assert new_model.heads[0] == "DFT" - assert new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 - assert len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 - assert len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 - - # Test output consistency - atomic_data = data.AtomicData.from_config( - config_pt_head, z_table=z_table, cutoff=5.0, heads=["DFT"] - ) - dataloader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data], batch_size=1, shuffle=False - ) - batch = next(iter(dataloader)) - output_new = new_model(batch.to_dict()) - torch.testing.assert_close( - output_orig["energy"], output_new["energy"], rtol=1e-5, atol=1e-5 - ) - torch.testing.assert_close( - output_orig["forces"], output_new["forces"], rtol=1e-5, atol=1e-5 - ) - - -def test_remove_pt_head_multihead(): - # Set up test data - torch.manual_seed(42) - atomic_energies_pt_head = np.array( - [ - [1.0, 2.0], # H energies for each head - [3.0, 4.0], # O energies for each head - ] - * 2 - ) - z_table = AtomicNumberTable([1, 8]) # H and O - - # Create multihead model - model_config = { - "r_max": 5.0, - "num_bessel": 8, - "num_polynomial_cutoff": 5, - "max_ell": 2, - "interaction_cls": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "interaction_cls_first": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "num_interactions": 2, - "num_elements": len(z_table), - "hidden_irreps": o3.Irreps("32x0e + 32x1o"), - "MLP_irreps": o3.Irreps("16x0e"), - "gate": torch.nn.functional.silu, - "atomic_energies": atomic_energies_pt_head, - "avg_num_neighbors": 8, - "atomic_numbers": z_table.zs, - "correlation": 3, - "heads": ["pt_head", "DFT", "MP2", "CCSD"], - "atomic_inter_scale": [1.0, 1.0, 1.0, 1.0], - "atomic_inter_shift": [0.0, 0.1, 0.2, 0.3], - } - - model = modules.ScaleShiftMACE(**model_config) - - # Create test configurations for each head - mol = molecule("H2O") - configs = {} - atomic_datas = {} - dataloaders = {} - original_outputs = {} - - # First get outputs from original model for each head - for head in model.heads: - config_pt_head = data.Configuration( - atomic_numbers=mol.numbers, - positions=mol.positions, - properties={"energy": 1.0, "forces": np.random.randn(len(mol), 3)}, - property_weights={"forces": 1.0, "energy": 1.0}, - head=head, - ) - configs[head] = config_pt_head - - atomic_data = data.AtomicData.from_config( - config_pt_head, z_table=z_table, cutoff=5.0, heads=model.heads - ) - atomic_datas[head] = atomic_data - - dataloader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data], batch_size=1, shuffle=False - ) - dataloaders[head] = dataloader - - batch = next(iter(dataloader)) - output = model(batch.to_dict()) - original_outputs[head] = output - - # Now test each head separately - for i, head in enumerate(model.heads): - # Convert to single head model - new_model = remove_pt_head(model, head_to_keep=head) - - # Basic structure tests - assert len(new_model.heads) == 1, f"Failed for head {head}" - assert new_model.heads[0] == head, f"Failed for head {head}" - assert ( - new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 - ), f"Failed for head {head}" - assert ( - len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 - ), f"Failed for head {head}" - assert ( - len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 - ), f"Failed for head {head}" - - # Verify scale and shift values - assert torch.allclose( - new_model.scale_shift.scale, model.scale_shift.scale[i : i + 1] - ), f"Failed for head {head}" - assert torch.allclose( - new_model.scale_shift.shift, model.scale_shift.shift[i : i + 1] - ), f"Failed for head {head}" - - # Test output consistency - single_head_data = data.AtomicData.from_config( - configs[head], z_table=z_table, cutoff=5.0, heads=[head] - ) - single_head_loader = torch_geometric.dataloader.DataLoader( - dataset=[single_head_data], batch_size=1, shuffle=False - ) - batch = next(iter(single_head_loader)) - new_output = new_model(batch.to_dict()) - - # Compare outputs - print( - original_outputs[head]["energy"], - new_output["energy"], - ) - torch.testing.assert_close( - original_outputs[head]["energy"], - new_output["energy"], - rtol=1e-5, - atol=1e-5, - msg=f"Energy mismatch for head {head}", - ) - torch.testing.assert_close( - original_outputs[head]["forces"], - new_output["forces"], - rtol=1e-5, - atol=1e-5, - msg=f"Forces mismatch for head {head}", - ) - - # Test error cases - with pytest.raises(ValueError, match="Head non_existent not found in model"): - remove_pt_head(model, head_to_keep="non_existent") - - # Test default behavior (first non-PT head) - default_model = remove_pt_head(model) - assert default_model.heads[0] == "DFT" - - # Additional test: check if each model's computation graph is independent - models = {head: remove_pt_head(model, head_to_keep=head) for head in model.heads} - results = {} - - for head, head_model in models.items(): - single_head_data = data.AtomicData.from_config( - configs[head], z_table=z_table, cutoff=5.0, heads=[head] - ) - single_head_loader = torch_geometric.dataloader.DataLoader( - dataset=[single_head_data], batch_size=1, shuffle=False - ) - batch = next(iter(single_head_loader)) - results[head] = head_model(batch.to_dict()) - - # Verify each model produces different outputs - energies = torch.stack([results[head]["energy"] for head in model.heads]) - assert not torch.allclose( - energies[0], energies[1], rtol=1e-3 - ), "Different heads should produce different outputs" +from pathlib import Path + +import numpy as np +import pytest +import torch +import torch.nn.functional +from ase.build import molecule +from e3nn import o3 +from e3nn.util import jit +from scipy.spatial.transform import Rotation as R + +from mace import data, modules, tools +from mace.calculators import mace_mp, mace_off +from mace.tools import torch_geometric +from mace.tools.finetuning_utils import load_foundations_elements +from mace.tools.scripts_utils import extract_config_mace_model, remove_pt_head +from mace.tools.utils import AtomicNumberTable + +MODEL_PATH = ( + Path(__file__).parent.parent + / "mace" + / "calculators" + / "foundations_models" + / "2023-12-03-mace-mp.model" +) + +torch.set_default_dtype(torch.float64) + +@pytest.skip("Problem with the float type", allow_module_level=True) +def test_foundations(): + # Create MACE model + config = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=molecule("H2COH").positions, + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, + ) + + # Created the rotated environment + rot = R.from_euler("z", 60, degrees=True).as_matrix() + positions_rotated = np.array(rot @ config.positions.T).T + config_rotated = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=positions_rotated, + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, + ) + table = tools.AtomicNumberTable([1, 6, 8]) + atomic_energies = np.array([0.0, 0.0, 0.0], dtype=float) + model_config = dict( + r_max=6, + num_bessel=10, + num_polynomial_cutoff=5, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=3, + hidden_irreps=o3.Irreps("128x0e + 128x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=3, + atomic_numbers=table.zs, + correlation=3, + radial_type="bessel", + atomic_inter_scale=0.1, + atomic_inter_shift=0.0, + ) + model = modules.ScaleShiftMACE(**model_config) + calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") + model_loaded = load_foundations_elements( + model, + calc_foundation.models[0], + table=table, + load_readout=True, + use_shift=False, + max_L=1, + ) + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=6.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=6.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + forces_loaded = model_loaded(batch.to_dict())["forces"] + forces = model(batch.to_dict())["forces"] + assert torch.allclose(forces, forces_loaded) + + +def test_multi_reference(): + config_multi = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=molecule("H2COH").positions, + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, + head="MP2", + ) + table_multi = tools.AtomicNumberTable([1, 6, 8]) + atomic_energies_multi = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) + table = tools.AtomicNumberTable([1, 6, 8]) + + + # Create MACE model + model_config = dict( + r_max=6, + num_bessel=10, + num_polynomial_cutoff=5, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=3, + hidden_irreps=o3.Irreps("128x0e + 128x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies_multi, + avg_num_neighbors=61, + atomic_numbers=table.zs, + correlation=3, + radial_type="bessel", + atomic_inter_scale=[1.0, 1.0], + atomic_inter_shift=[0.0, 0.0], + heads=["MP2", "DFT"], + ) + model = modules.ScaleShiftMACE(**model_config) + calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") + model_loaded = load_foundations_elements( + model, + calc_foundation.models[0], + table=table, + load_readout=True, + use_shift=False, + max_L=1, + ) + atomic_data = data.AtomicData.from_config( + config_multi, z_table=table_multi, cutoff=6.0, heads=["MP2", "DFT"] + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + forces_loaded = model_loaded(batch.to_dict())["forces"] + calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") + atoms = molecule("H2COH") + atoms.info["head"] = "MP2" + atoms.calc = calc_foundation + forces = atoms.get_forces() + assert np.allclose( + forces, forces_loaded.detach().numpy()[:5, :], atol=1e-5, rtol=1e-5 + ) + + +@pytest.mark.parametrize( + "calc", + [ + mace_mp(device="cpu", default_dtype="float64"), + mace_mp(model="small", device="cpu", default_dtype="float64"), + mace_mp(model="medium", device="cpu", default_dtype="float64"), + mace_mp(model="large", device="cpu", default_dtype="float64"), + mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64"), + mace_off(model="small", device="cpu", default_dtype="float64"), + mace_off(model="medium", device="cpu", default_dtype="float64"), + mace_off(model="large", device="cpu", default_dtype="float64"), + mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64"), + ], +) +def test_compile_foundation(calc): + model = calc.models[0] + atoms = molecule("CH4") + atoms.positions += np.random.randn(*atoms.positions.shape) * 0.1 + batch = calc._atoms_to_batch(atoms) # pylint: disable=protected-access + output_1 = model(batch.to_dict()) + model_compiled = jit.compile(model) + output = model_compiled(batch.to_dict()) + for key in output_1.keys(): + if isinstance(output_1[key], torch.Tensor): + assert torch.allclose(output_1[key], output[key], atol=1e-5) + + +@pytest.mark.parametrize( + "model", + [ + mace_mp(model="small", device="cpu", default_dtype="float64").models[0], + mace_mp(model="medium", device="cpu", default_dtype="float64").models[0], + mace_mp(model="large", device="cpu", default_dtype="float64").models[0], + mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], + mace_off(model="small", device="cpu", default_dtype="float64").models[0], + mace_off(model="medium", device="cpu", default_dtype="float64").models[0], + mace_off(model="large", device="cpu", default_dtype="float64").models[0], + mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], + ], +) +def test_extract_config(model): + assert isinstance(model, modules.ScaleShiftMACE) + config = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=molecule("H2COH").positions, + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, + ) + model_copy = modules.ScaleShiftMACE(**extract_config_mace_model(model)) + model_copy.load_state_dict(model.state_dict()) + z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) + atomic_data = data.AtomicData.from_config(config, z_table=z_table, cutoff=6.0) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + output = model(batch.to_dict()) + output_copy = model_copy(batch.to_dict()) + # assert all items of the output dicts are equal + for key in output.keys(): + if isinstance(output[key], torch.Tensor): + assert torch.allclose(output[key], output_copy[key], atol=1e-5) + + +def test_remove_pt_head(): + # Set up test data + torch.manual_seed(42) + atomic_energies_pt_head = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float) + z_table = AtomicNumberTable([1, 8]) # H and O + + # Create multihead model + model_config = { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 5, + "max_ell": 2, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": len(z_table), + "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": torch.nn.functional.silu, + "atomic_energies": atomic_energies_pt_head, + "avg_num_neighbors": 8, + "atomic_numbers": z_table.zs, + "correlation": 3, + "heads": ["pt_head", "DFT"], + "atomic_inter_scale": [1.0, 1.0], + "atomic_inter_shift": [0.0, 0.1], + } + + model = modules.ScaleShiftMACE(**model_config) + + # Create test molecule + mol = molecule("H2O") + config_pt_head = data.Configuration( + atomic_numbers=mol.numbers, + positions=mol.positions, + properties={"energy": 1.0, "forces": np.random.randn(len(mol), 3)}, + property_weights={"forces": 1.0, "energy": 1.0}, + head="DFT", + ) + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=["pt_head", "DFT"] + ) + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + batch = next(iter(dataloader)) + # Test original mode + output_orig = model(batch.to_dict()) + + # Convert to single head model + new_model = remove_pt_head(model, head_to_keep="DFT") + + # Basic structure tests + assert len(new_model.heads) == 1 + assert new_model.heads[0] == "DFT" + assert new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 + assert len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 + assert len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 + + # Test output consistency + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=["DFT"] + ) + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + batch = next(iter(dataloader)) + output_new = new_model(batch.to_dict()) + torch.testing.assert_close( + output_orig["energy"], output_new["energy"], rtol=1e-5, atol=1e-5 + ) + torch.testing.assert_close( + output_orig["forces"], output_new["forces"], rtol=1e-5, atol=1e-5 + ) + + +def test_remove_pt_head_multihead(): + # Set up test data + torch.manual_seed(42) + atomic_energies_pt_head = np.array( + [ + [1.0, 2.0], # H energies for each head + [3.0, 4.0], # O energies for each head + ] + * 2 + ) + z_table = AtomicNumberTable([1, 8]) # H and O + + # Create multihead model + model_config = { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 5, + "max_ell": 2, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": len(z_table), + "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": torch.nn.functional.silu, + "atomic_energies": atomic_energies_pt_head, + "avg_num_neighbors": 8, + "atomic_numbers": z_table.zs, + "correlation": 3, + "heads": ["pt_head", "DFT", "MP2", "CCSD"], + "atomic_inter_scale": [1.0, 1.0, 1.0, 1.0], + "atomic_inter_shift": [0.0, 0.1, 0.2, 0.3], + } + + model = modules.ScaleShiftMACE(**model_config) + + # Create test configurations for each head + mol = molecule("H2O") + configs = {} + atomic_datas = {} + dataloaders = {} + original_outputs = {} + + # First get outputs from original model for each head + for head in model.heads: + config_pt_head = data.Configuration( + atomic_numbers=mol.numbers, + positions=mol.positions, + properties={"energy": 1.0, "forces": np.random.randn(len(mol), 3)}, + property_weights={"forces": 1.0, "energy": 1.0}, + head=head, + ) + configs[head] = config_pt_head + + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=model.heads + ) + atomic_datas[head] = atomic_data + + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + dataloaders[head] = dataloader + + batch = next(iter(dataloader)) + output = model(batch.to_dict()) + original_outputs[head] = output + + # Now test each head separately + for i, head in enumerate(model.heads): + # Convert to single head model + new_model = remove_pt_head(model, head_to_keep=head) + + # Basic structure tests + assert len(new_model.heads) == 1, f"Failed for head {head}" + assert new_model.heads[0] == head, f"Failed for head {head}" + assert ( + new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 + ), f"Failed for head {head}" + assert ( + len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 + ), f"Failed for head {head}" + assert ( + len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 + ), f"Failed for head {head}" + + # Verify scale and shift values + assert torch.allclose( + new_model.scale_shift.scale, model.scale_shift.scale[i : i + 1] + ), f"Failed for head {head}" + assert torch.allclose( + new_model.scale_shift.shift, model.scale_shift.shift[i : i + 1] + ), f"Failed for head {head}" + + # Test output consistency + single_head_data = data.AtomicData.from_config( + configs[head], z_table=z_table, cutoff=5.0, heads=[head] + ) + single_head_loader = torch_geometric.dataloader.DataLoader( + dataset=[single_head_data], batch_size=1, shuffle=False + ) + batch = next(iter(single_head_loader)) + new_output = new_model(batch.to_dict()) + + # Compare outputs + print( + original_outputs[head]["energy"], + new_output["energy"], + ) + torch.testing.assert_close( + original_outputs[head]["energy"], + new_output["energy"], + rtol=1e-5, + atol=1e-5, + msg=f"Energy mismatch for head {head}", + ) + torch.testing.assert_close( + original_outputs[head]["forces"], + new_output["forces"], + rtol=1e-5, + atol=1e-5, + msg=f"Forces mismatch for head {head}", + ) + + # Test error cases + with pytest.raises(ValueError, match="Head non_existent not found in model"): + remove_pt_head(model, head_to_keep="non_existent") + + # Test default behavior (first non-PT head) + default_model = remove_pt_head(model) + assert default_model.heads[0] == "DFT" + + # Additional test: check if each model's computation graph is independent + models = {head: remove_pt_head(model, head_to_keep=head) for head in model.heads} + results = {} + + for head, head_model in models.items(): + single_head_data = data.AtomicData.from_config( + configs[head], z_table=z_table, cutoff=5.0, heads=[head] + ) + single_head_loader = torch_geometric.dataloader.DataLoader( + dataset=[single_head_data], batch_size=1, shuffle=False + ) + batch = next(iter(single_head_loader)) + results[head] = head_model(batch.to_dict()) + + # Verify each model produces different outputs + energies = torch.stack([results[head]["energy"] for head in model.heads]) + assert not torch.allclose( + energies[0], energies[1], rtol=1e-3 + ), "Different heads should produce different outputs" diff --git a/mace-bench/3rdparty/mace/tests/test_hessian.py b/mace-bench/3rdparty/mace/tests/test_hessian.py index 5e23e82a8d43aea32846109be347eb4c7c53a16b..53457335d5226b3da0a774d0e32ef330c8b36ba6 100644 --- a/mace-bench/3rdparty/mace/tests/test_hessian.py +++ b/mace-bench/3rdparty/mace/tests/test_hessian.py @@ -1,54 +1,54 @@ -import numpy as np -import pytest -from ase.build import fcc111 - -from mace.calculators import mace_mp - - -@pytest.fixture(name="setup_calculator_") -def setup_calculator(): - calc = mace_mp( - model="medium", dispersion=False, default_dtype="float64", device="cpu" - ) - return calc - - -@pytest.fixture(name="setup_structure_") -def setup_structure(setup_calculator_): - initial = fcc111("Pt", size=(4, 4, 1), vacuum=10.0, orthogonal=True) - initial.calc = setup_calculator_ - return initial - - -def test_potential_energy_and_hessian(setup_structure_): - initial = setup_structure_ - h_autograd = initial.calc.get_hessian(atoms=initial) - assert h_autograd.shape == (len(initial) * 3, len(initial), 3) - - -def test_finite_difference_hessian(setup_structure_): - initial = setup_structure_ - indicies = list(range(len(initial))) - delta, ndim = 1e-4, 3 - hessian = np.zeros((len(indicies) * ndim, len(indicies) * ndim)) - atoms_h = initial.copy() - for i, index in enumerate(indicies): - for j in range(ndim): - atoms_i = atoms_h.copy() - atoms_i.positions[index, j] += delta - atoms_i.calc = initial.calc - forces_i = atoms_i.get_forces() - - atoms_j = atoms_h.copy() - atoms_j.positions[index, j] -= delta - atoms_j.calc = initial.calc - forces_j = atoms_j.get_forces() - - hessian[:, i * ndim + j] = -(forces_i - forces_j)[indicies].flatten() / ( - 2 * delta - ) - - hessian = hessian.reshape((-1, len(initial), 3)) - h_autograd = initial.calc.get_hessian(atoms=initial) - is_close = np.allclose(h_autograd, hessian, atol=1e-6) - assert is_close +import numpy as np +import pytest +from ase.build import fcc111 + +from mace.calculators import mace_mp + + +@pytest.fixture(name="setup_calculator_") +def setup_calculator(): + calc = mace_mp( + model="medium", dispersion=False, default_dtype="float64", device="cpu" + ) + return calc + + +@pytest.fixture(name="setup_structure_") +def setup_structure(setup_calculator_): + initial = fcc111("Pt", size=(4, 4, 1), vacuum=10.0, orthogonal=True) + initial.calc = setup_calculator_ + return initial + + +def test_potential_energy_and_hessian(setup_structure_): + initial = setup_structure_ + h_autograd = initial.calc.get_hessian(atoms=initial) + assert h_autograd.shape == (len(initial) * 3, len(initial), 3) + + +def test_finite_difference_hessian(setup_structure_): + initial = setup_structure_ + indicies = list(range(len(initial))) + delta, ndim = 1e-4, 3 + hessian = np.zeros((len(indicies) * ndim, len(indicies) * ndim)) + atoms_h = initial.copy() + for i, index in enumerate(indicies): + for j in range(ndim): + atoms_i = atoms_h.copy() + atoms_i.positions[index, j] += delta + atoms_i.calc = initial.calc + forces_i = atoms_i.get_forces() + + atoms_j = atoms_h.copy() + atoms_j.positions[index, j] -= delta + atoms_j.calc = initial.calc + forces_j = atoms_j.get_forces() + + hessian[:, i * ndim + j] = -(forces_i - forces_j)[indicies].flatten() / ( + 2 * delta + ) + + hessian = hessian.reshape((-1, len(initial), 3)) + h_autograd = initial.calc.get_hessian(atoms=initial) + is_close = np.allclose(h_autograd, hessian, atol=1e-6) + assert is_close diff --git a/mace-bench/3rdparty/mace/tests/test_lmdb_database.py b/mace-bench/3rdparty/mace/tests/test_lmdb_database.py index 197661a3a6e7b203cc9f0ef8b31ae08c2db8d4da..0c7043a6d6f84d3576256a08ba3492e12118e542 100644 --- a/mace-bench/3rdparty/mace/tests/test_lmdb_database.py +++ b/mace-bench/3rdparty/mace/tests/test_lmdb_database.py @@ -1,134 +1,134 @@ -import os -import tempfile - -import numpy as np -import torch -from ase.build import molecule -from ase.calculators.singlepoint import SinglePointCalculator - -from mace.data.lmdb_dataset import LMDBDataset -from mace.tools import AtomicNumberTable, torch_geometric -from mace.tools.fairchem_dataset.lmdb_dataset_tools import LMDBDatabase - - -def test_lmdb_dataset(): - """Test the LMDBDataset by creating a fake database and verifying batch creation.""" - # Set default dtype to match typical MACE usage - torch.set_default_dtype(torch.float64) - - # Set random seed for reproducibility - np.random.seed(42) - - # Create temporary directories for the databases - with tempfile.TemporaryDirectory() as tmpdir: - # Create 3 folders for databases - db_paths = [] - for i in range(3): - folder_path = os.path.join(tmpdir, f"folder_{i}") - os.makedirs(folder_path, exist_ok=True) - - # Create LMDB database files in each folder (2 per folder) - for j in range(2): - db_path = os.path.join(folder_path, f"data_{j}.aselmdb") - db = LMDBDatabase(db_path, readonly=False) - - # Add 2 configurations to each database - for _ in range(2): - # Create a water molecule using ASE's build functionality - atoms = molecule("H2O") - - # Apply small random displacements to the positions - displacement = np.random.rand(*atoms.positions.shape) * 0.1 - atoms.positions += displacement - - # Set cell and PBC - atoms.set_cell(np.eye(3) * 5.0) - atoms.set_pbc(True) - - # Add random energy, forces, and stress - energy = np.random.uniform( - -15.0, -5.0 - ) # Random energy between -15 and -5 eV - forces = ( - np.random.randn(*atoms.positions.shape) * 0.5 - ) # Random forces - stress = np.random.randn(6) * 0.2 # Random stress in Voigt notation - - # Add calculator to atoms with results - calc = SinglePointCalculator( - atoms, energy=energy, forces=forces, stress=stress - ) - atoms.calc = calc - - # Store in database - db.write(atoms) - - db.close() - - # Add folder path to our list - db_paths.append(folder_path) - - # Create the dataset using paths joined with colons - paths_str = ":".join(db_paths) - z_table = AtomicNumberTable([1, 8]) # H and O - dataset = LMDBDataset(file_path=paths_str, r_max=5.0, z_table=z_table) - - # Check dataset size (3 folders * 2 files * 2 configs = 12 entries) - assert len(dataset) == 12 - - # Test retrieving a single item - item = dataset[0] - print(item) - assert item.positions.shape == (3, 3) # 3 atoms, 3 coordinates - assert hasattr(item, "energy") - assert hasattr(item, "forces") - assert hasattr(item, "stress") - - # Create a dataloader - dataloader = torch_geometric.dataloader.DataLoader( - dataset=dataset, batch_size=4, shuffle=False, drop_last=False - ) - - # Get a batch and validate it - batch = next(iter(dataloader)) - - # Verify batch properties - should have 12 atoms (4 configs * 3 atoms per water) - assert batch.positions.shape == (12, 3) # 12 atoms, 3 coordinates - assert batch.energy.shape[0] == 4 # 4 energies (one per config) - assert batch.forces.shape == (12, 3) # Forces for each atom - print(batch.stress.shape) - assert batch.stress.shape == (4, 3, 3) # Stress for each config - - # Check batch has required attributes for MACE model processing - assert hasattr(batch, "batch") # Batch indices - assert batch.batch.shape[0] == 12 # One index per atom - assert hasattr(batch, "ptr") # Pointer for batch processing - assert batch.ptr.shape[0] == 5 # One pointer per config + 1 - - # Check that batch indices are correctly assigned - # First 3 atoms should be from config 0, next 3 from config 1, etc. - expected_batch = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) - assert torch.all(batch.batch == expected_batch) - - # Check ptr correctly points to start of each configuration - assert batch.ptr.tolist() == [0, 3, 6, 9, 12] - - # Create a batch dictionary that can be passed to a MACE model - batch_dict = batch.to_dict() - assert "positions" in batch_dict - assert "energy" in batch_dict - assert "forces" in batch_dict - assert "stress" in batch_dict - assert "batch" in batch_dict - assert "ptr" in batch_dict - - # Verify additional properties required by MACE - assert hasattr(batch, "edge_index") # Connectivity information - assert hasattr(batch, "shifts") # For periodic boundary conditions - assert hasattr(batch, "cell") # Unit cell information - - # Test that a full batch can be processed (without errors) - all_batches = list(dataloader) - assert ( - len(all_batches) == 3 - ) # Should have 3 batches (12 configs with batch size 4) +import os +import tempfile + +import numpy as np +import torch +from ase.build import molecule +from ase.calculators.singlepoint import SinglePointCalculator + +from mace.data.lmdb_dataset import LMDBDataset +from mace.tools import AtomicNumberTable, torch_geometric +from mace.tools.fairchem_dataset.lmdb_dataset_tools import LMDBDatabase + + +def test_lmdb_dataset(): + """Test the LMDBDataset by creating a fake database and verifying batch creation.""" + # Set default dtype to match typical MACE usage + torch.set_default_dtype(torch.float64) + + # Set random seed for reproducibility + np.random.seed(42) + + # Create temporary directories for the databases + with tempfile.TemporaryDirectory() as tmpdir: + # Create 3 folders for databases + db_paths = [] + for i in range(3): + folder_path = os.path.join(tmpdir, f"folder_{i}") + os.makedirs(folder_path, exist_ok=True) + + # Create LMDB database files in each folder (2 per folder) + for j in range(2): + db_path = os.path.join(folder_path, f"data_{j}.aselmdb") + db = LMDBDatabase(db_path, readonly=False) + + # Add 2 configurations to each database + for _ in range(2): + # Create a water molecule using ASE's build functionality + atoms = molecule("H2O") + + # Apply small random displacements to the positions + displacement = np.random.rand(*atoms.positions.shape) * 0.1 + atoms.positions += displacement + + # Set cell and PBC + atoms.set_cell(np.eye(3) * 5.0) + atoms.set_pbc(True) + + # Add random energy, forces, and stress + energy = np.random.uniform( + -15.0, -5.0 + ) # Random energy between -15 and -5 eV + forces = ( + np.random.randn(*atoms.positions.shape) * 0.5 + ) # Random forces + stress = np.random.randn(6) * 0.2 # Random stress in Voigt notation + + # Add calculator to atoms with results + calc = SinglePointCalculator( + atoms, energy=energy, forces=forces, stress=stress + ) + atoms.calc = calc + + # Store in database + db.write(atoms) + + db.close() + + # Add folder path to our list + db_paths.append(folder_path) + + # Create the dataset using paths joined with colons + paths_str = ":".join(db_paths) + z_table = AtomicNumberTable([1, 8]) # H and O + dataset = LMDBDataset(file_path=paths_str, r_max=5.0, z_table=z_table) + + # Check dataset size (3 folders * 2 files * 2 configs = 12 entries) + assert len(dataset) == 12 + + # Test retrieving a single item + item = dataset[0] + print(item) + assert item.positions.shape == (3, 3) # 3 atoms, 3 coordinates + assert hasattr(item, "energy") + assert hasattr(item, "forces") + assert hasattr(item, "stress") + + # Create a dataloader + dataloader = torch_geometric.dataloader.DataLoader( + dataset=dataset, batch_size=4, shuffle=False, drop_last=False + ) + + # Get a batch and validate it + batch = next(iter(dataloader)) + + # Verify batch properties - should have 12 atoms (4 configs * 3 atoms per water) + assert batch.positions.shape == (12, 3) # 12 atoms, 3 coordinates + assert batch.energy.shape[0] == 4 # 4 energies (one per config) + assert batch.forces.shape == (12, 3) # Forces for each atom + print(batch.stress.shape) + assert batch.stress.shape == (4, 3, 3) # Stress for each config + + # Check batch has required attributes for MACE model processing + assert hasattr(batch, "batch") # Batch indices + assert batch.batch.shape[0] == 12 # One index per atom + assert hasattr(batch, "ptr") # Pointer for batch processing + assert batch.ptr.shape[0] == 5 # One pointer per config + 1 + + # Check that batch indices are correctly assigned + # First 3 atoms should be from config 0, next 3 from config 1, etc. + expected_batch = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) + assert torch.all(batch.batch == expected_batch) + + # Check ptr correctly points to start of each configuration + assert batch.ptr.tolist() == [0, 3, 6, 9, 12] + + # Create a batch dictionary that can be passed to a MACE model + batch_dict = batch.to_dict() + assert "positions" in batch_dict + assert "energy" in batch_dict + assert "forces" in batch_dict + assert "stress" in batch_dict + assert "batch" in batch_dict + assert "ptr" in batch_dict + + # Verify additional properties required by MACE + assert hasattr(batch, "edge_index") # Connectivity information + assert hasattr(batch, "shifts") # For periodic boundary conditions + assert hasattr(batch, "cell") # Unit cell information + + # Test that a full batch can be processed (without errors) + all_batches = list(dataloader) + assert ( + len(all_batches) == 3 + ) # Should have 3 batches (12 configs with batch size 4) diff --git a/mace-bench/3rdparty/mace/tests/test_models.py b/mace-bench/3rdparty/mace/tests/test_models.py index 9c1d2a006d865147c035fbc00a22f57a74f9dba0..40ff48c394af814c304e44d1d5a43d705428ed07 100644 --- a/mace-bench/3rdparty/mace/tests/test_models.py +++ b/mace-bench/3rdparty/mace/tests/test_models.py @@ -1,374 +1,374 @@ -import numpy as np -import torch -import torch.nn.functional -from ase import build -from e3nn import o3 -from e3nn.util import jit -from scipy.spatial.transform import Rotation as R - -from mace import data, modules, tools -from mace.tools import torch_geometric - -torch.set_default_dtype(torch.float64) -config = data.Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=np.array( - [ - [0.0, -2.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - ] - ), - properties={ - "forces": np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - "energy": -1.5, - "charges": np.array([-2.0, 1.0, 1.0]), - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, -) -# Created the rotated environment -rot = R.from_euler("z", 60, degrees=True).as_matrix() -positions_rotated = np.array(rot @ config.positions.T).T -config_rotated = data.Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=positions_rotated, - properties={ - "forces": np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - "energy": -1.5, - "charges": np.array([-2.0, 1.0, 1.0]), - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, -) -table = tools.AtomicNumberTable([1, 8]) -atomic_energies = np.array([1.0, 3.0], dtype=float) - - -def test_mace(): - # Create MACE model - model_config = dict( - r_max=5, - num_bessel=8, - num_polynomial_cutoff=6, - max_ell=2, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=5, - num_elements=2, - hidden_irreps=o3.Irreps("32x0e + 32x1o"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies, - avg_num_neighbors=8, - atomic_numbers=table.zs, - correlation=3, - radial_type="bessel", - ) - model = modules.MACE(**model_config) - model_compiled = jit.compile(model) - - atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=table, cutoff=3.0 - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data2], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - output1 = model(batch.to_dict(), training=True) - output2 = model_compiled(batch.to_dict(), training=True) - assert torch.allclose(output1["energy"][0], output2["energy"][0]) - assert torch.allclose(output2["energy"][0], output2["energy"][1]) - - -def test_dipole_mace(): - # create dipole MACE model - model_config = dict( - r_max=5, - num_bessel=8, - num_polynomial_cutoff=5, - max_ell=2, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=2, - hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=None, - avg_num_neighbors=3, - atomic_numbers=table.zs, - correlation=3, - radial_type="gaussian", - ) - model = modules.AtomicDipolesMACE(**model_config) - - atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=table, cutoff=3.0 - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data2], - batch_size=2, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - output = model( - batch, - training=True, - ) - # sanity check of dipoles being the right shape - assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape - # test equivariance of output dipoles - assert np.allclose( - np.array(rot @ output["dipole"][0].detach().numpy()), - output["dipole"][1].detach().numpy(), - ) - - -def test_energy_dipole_mace(): - # create dipole MACE model - model_config = dict( - r_max=5, - num_bessel=8, - num_polynomial_cutoff=5, - max_ell=2, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=2, - hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies, - avg_num_neighbors=3, - atomic_numbers=table.zs, - correlation=3, - ) - model = modules.EnergyDipolesMACE(**model_config) - - atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=table, cutoff=3.0 - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data2], - batch_size=2, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - output = model( - batch, - training=True, - ) - # sanity check of dipoles being the right shape - assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape - # test energy is invariant - assert torch.allclose(output["energy"][0], output["energy"][1]) - # test equivariance of output dipoles - assert np.allclose( - np.array(rot @ output["dipole"][0].detach().numpy()), - output["dipole"][1].detach().numpy(), - ) - - -def test_mace_multi_reference(): - atomic_energies_multi = np.array([[1.0, 3.0], [0.0, 0.0]], dtype=float) - model_config = dict( - r_max=5, - num_bessel=8, - num_polynomial_cutoff=6, - max_ell=3, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=2, - hidden_irreps=o3.Irreps("96x0e + 96x1o"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies_multi, - avg_num_neighbors=8, - atomic_numbers=table.zs, - distance_transform=True, - pair_repulsion=True, - correlation=3, - heads=["Default", "dft"], - # radial_type="chebyshev", - atomic_inter_scale=[1.0, 1.0], - atomic_inter_shift=[0.0, 0.1], - ) - model = modules.ScaleShiftMACE(**model_config) - model_compiled = jit.compile(model) - config.head = "Default" - config_rotated.head = "dft" - atomic_data = data.AtomicData.from_config( - config, z_table=table, cutoff=3.0, heads=["Default", "dft"] - ) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=table, cutoff=3.0, heads=["Default", "dft"] - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data2], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - output1 = model(batch.to_dict(), training=True) - output2 = model_compiled(batch.to_dict(), training=True) - assert torch.allclose(output1["energy"][0], output2["energy"][0]) - assert output2["energy"].shape[0] == 2 - - -def test_atomic_virials_stresses(): - """ - Test that atomic virials and stresses sum to the total virials and stress. - """ - # Set default dtype for reproducibility - torch.set_default_dtype(torch.float64) - - # Create a periodic cell with ASE - atoms = build.bulk("Si", "diamond", a=5.43) - # Apply strain to ensure non-zero stress - strain_tensor = np.eye(3) * 1.02 # 2% strain - atoms.set_cell(np.dot(atoms.get_cell(), strain_tensor), scale_atoms=True) - - # Add forces and energy for completeness - atoms.arrays["REF_forces"] = np.random.normal(0, 0.1, size=atoms.positions.shape) - atoms.info["REF_energy"] = np.random.normal(0, 1) - atoms.info["REF_stress"] = np.random.normal(0, 0.1, size=6) - - # Setup MACE model configuration - stress_z_table = tools.AtomicNumberTable([14]) # Silicon - stress_atomic_energies = np.array([0.0]) - - model_config = dict( - r_max=5.0, - num_bessel=8, - num_polynomial_cutoff=6, - max_ell=2, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=3, - num_elements=1, - hidden_irreps=o3.Irreps("32x0e + 32x1o"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=stress_atomic_energies, - avg_num_neighbors=4.0, - atomic_numbers=table.zs, - correlation=3, - atomic_inter_scale=1.0, - atomic_inter_shift=0.0, - ) - - # Create the model - model = modules.ScaleShiftMACE(**model_config) - - # Create atomic data - atomic_data = data.AtomicData.from_config( - data.config_from_atoms( - atoms, key_specification=data.KeySpecification.from_defaults() - ), - z_table=stress_z_table, - cutoff=5.0, - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - batch_dict = batch.to_dict() - - # Run the model with compute_atomic_stresses=True - output = model( - batch_dict, - compute_force=True, - compute_virials=True, - compute_stress=True, - compute_atomic_stresses=True, - ) - - # Get total virials/stress and atomic virials/stresses - total_virials = output["virials"] - atomic_virials = output["atomic_virials"] - total_stress = output["stress"] - atomic_stresses = output["atomic_stresses"] - - # Test that atomic values are not None - assert atomic_virials is not None, "Atomic virials were not computed" - assert atomic_stresses is not None, "Atomic stresses were not computed" - - # Test shape of atomic values - assert atomic_virials.shape[0] == len(atoms), "Wrong shape for atomic virials" - assert atomic_virials.shape[1:] == (3, 3), "Atomic virials should be 3x3 matrices" - assert atomic_stresses.shape[0] == len(atoms), "Wrong shape for atomic stresses" - assert atomic_stresses.shape[1:] == (3, 3), "Atomic stresses should be 3x3 matrices" - - # Compute sum of atomic values - summed_atomic_virials = torch.sum(atomic_virials, dim=0) - summed_atomic_stresses = torch.sum(atomic_stresses, dim=0) - - # Test that sums match total values - assert torch.allclose( - summed_atomic_virials, total_virials.squeeze(0), atol=1e-6 - ), f"Sum of atomic virials {summed_atomic_virials} does not match total virials {total_virials.squeeze(0)}" - - assert torch.allclose( - summed_atomic_stresses, total_stress.squeeze(0), atol=1e-6 - ), f"Sum of atomic stresses (normalized by volume) {summed_atomic_stresses} does not match total stress {total_stress.squeeze(0)}" +import numpy as np +import torch +import torch.nn.functional +from ase import build +from e3nn import o3 +from e3nn.util import jit +from scipy.spatial.transform import Rotation as R + +from mace import data, modules, tools +from mace.tools import torch_geometric + +torch.set_default_dtype(torch.float64) +config = data.Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "charges": np.array([-2.0, 1.0, 1.0]), + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, +) +# Created the rotated environment +rot = R.from_euler("z", 60, degrees=True).as_matrix() +positions_rotated = np.array(rot @ config.positions.T).T +config_rotated = data.Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=positions_rotated, + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "charges": np.array([-2.0, 1.0, 1.0]), + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, +) +table = tools.AtomicNumberTable([1, 8]) +atomic_energies = np.array([1.0, 3.0], dtype=float) + + +def test_mace(): + # Create MACE model + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=5, + num_elements=2, + hidden_irreps=o3.Irreps("32x0e + 32x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=8, + atomic_numbers=table.zs, + correlation=3, + radial_type="bessel", + ) + model = modules.MACE(**model_config) + model_compiled = jit.compile(model) + + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + output1 = model(batch.to_dict(), training=True) + output2 = model_compiled(batch.to_dict(), training=True) + assert torch.allclose(output1["energy"][0], output2["energy"][0]) + assert torch.allclose(output2["energy"][0], output2["energy"][1]) + + +def test_dipole_mace(): + # create dipole MACE model + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=5, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=None, + avg_num_neighbors=3, + atomic_numbers=table.zs, + correlation=3, + radial_type="gaussian", + ) + model = modules.AtomicDipolesMACE(**model_config) + + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + output = model( + batch, + training=True, + ) + # sanity check of dipoles being the right shape + assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape + # test equivariance of output dipoles + assert np.allclose( + np.array(rot @ output["dipole"][0].detach().numpy()), + output["dipole"][1].detach().numpy(), + ) + + +def test_energy_dipole_mace(): + # create dipole MACE model + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=5, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=3, + atomic_numbers=table.zs, + correlation=3, + ) + model = modules.EnergyDipolesMACE(**model_config) + + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + output = model( + batch, + training=True, + ) + # sanity check of dipoles being the right shape + assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape + # test energy is invariant + assert torch.allclose(output["energy"][0], output["energy"][1]) + # test equivariance of output dipoles + assert np.allclose( + np.array(rot @ output["dipole"][0].detach().numpy()), + output["dipole"][1].detach().numpy(), + ) + + +def test_mace_multi_reference(): + atomic_energies_multi = np.array([[1.0, 3.0], [0.0, 0.0]], dtype=float) + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("96x0e + 96x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies_multi, + avg_num_neighbors=8, + atomic_numbers=table.zs, + distance_transform=True, + pair_repulsion=True, + correlation=3, + heads=["Default", "dft"], + # radial_type="chebyshev", + atomic_inter_scale=[1.0, 1.0], + atomic_inter_shift=[0.0, 0.1], + ) + model = modules.ScaleShiftMACE(**model_config) + model_compiled = jit.compile(model) + config.head = "Default" + config_rotated.head = "dft" + atomic_data = data.AtomicData.from_config( + config, z_table=table, cutoff=3.0, heads=["Default", "dft"] + ) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0, heads=["Default", "dft"] + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + output1 = model(batch.to_dict(), training=True) + output2 = model_compiled(batch.to_dict(), training=True) + assert torch.allclose(output1["energy"][0], output2["energy"][0]) + assert output2["energy"].shape[0] == 2 + + +def test_atomic_virials_stresses(): + """ + Test that atomic virials and stresses sum to the total virials and stress. + """ + # Set default dtype for reproducibility + torch.set_default_dtype(torch.float64) + + # Create a periodic cell with ASE + atoms = build.bulk("Si", "diamond", a=5.43) + # Apply strain to ensure non-zero stress + strain_tensor = np.eye(3) * 1.02 # 2% strain + atoms.set_cell(np.dot(atoms.get_cell(), strain_tensor), scale_atoms=True) + + # Add forces and energy for completeness + atoms.arrays["REF_forces"] = np.random.normal(0, 0.1, size=atoms.positions.shape) + atoms.info["REF_energy"] = np.random.normal(0, 1) + atoms.info["REF_stress"] = np.random.normal(0, 0.1, size=6) + + # Setup MACE model configuration + stress_z_table = tools.AtomicNumberTable([14]) # Silicon + stress_atomic_energies = np.array([0.0]) + + model_config = dict( + r_max=5.0, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=3, + num_elements=1, + hidden_irreps=o3.Irreps("32x0e + 32x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=stress_atomic_energies, + avg_num_neighbors=4.0, + atomic_numbers=table.zs, + correlation=3, + atomic_inter_scale=1.0, + atomic_inter_shift=0.0, + ) + + # Create the model + model = modules.ScaleShiftMACE(**model_config) + + # Create atomic data + atomic_data = data.AtomicData.from_config( + data.config_from_atoms( + atoms, key_specification=data.KeySpecification.from_defaults() + ), + z_table=stress_z_table, + cutoff=5.0, + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch_dict = batch.to_dict() + + # Run the model with compute_atomic_stresses=True + output = model( + batch_dict, + compute_force=True, + compute_virials=True, + compute_stress=True, + compute_atomic_stresses=True, + ) + + # Get total virials/stress and atomic virials/stresses + total_virials = output["virials"] + atomic_virials = output["atomic_virials"] + total_stress = output["stress"] + atomic_stresses = output["atomic_stresses"] + + # Test that atomic values are not None + assert atomic_virials is not None, "Atomic virials were not computed" + assert atomic_stresses is not None, "Atomic stresses were not computed" + + # Test shape of atomic values + assert atomic_virials.shape[0] == len(atoms), "Wrong shape for atomic virials" + assert atomic_virials.shape[1:] == (3, 3), "Atomic virials should be 3x3 matrices" + assert atomic_stresses.shape[0] == len(atoms), "Wrong shape for atomic stresses" + assert atomic_stresses.shape[1:] == (3, 3), "Atomic stresses should be 3x3 matrices" + + # Compute sum of atomic values + summed_atomic_virials = torch.sum(atomic_virials, dim=0) + summed_atomic_stresses = torch.sum(atomic_stresses, dim=0) + + # Test that sums match total values + assert torch.allclose( + summed_atomic_virials, total_virials.squeeze(0), atol=1e-6 + ), f"Sum of atomic virials {summed_atomic_virials} does not match total virials {total_virials.squeeze(0)}" + + assert torch.allclose( + summed_atomic_stresses, total_stress.squeeze(0), atol=1e-6 + ), f"Sum of atomic stresses (normalized by volume) {summed_atomic_stresses} does not match total stress {total_stress.squeeze(0)}" diff --git a/mace-bench/3rdparty/mace/tests/test_modules.py b/mace-bench/3rdparty/mace/tests/test_modules.py index 57ddc328362374189e2c46cf63c268cf270fb459..6afcccfb44252ee80b7c969748dfb914673d0e3e 100644 --- a/mace-bench/3rdparty/mace/tests/test_modules.py +++ b/mace-bench/3rdparty/mace/tests/test_modules.py @@ -1,268 +1,268 @@ -import numpy as np -import pytest -import torch -import torch.nn.functional -from e3nn import o3 - -from mace.data import AtomicData, Configuration -from mace.modules import ( - AtomicEnergiesBlock, - BesselBasis, - PolynomialCutoff, - SymmetricContraction, - WeightedEnergyForcesLoss, - WeightedHuberEnergyForcesStressLoss, - compute_mean_rms_energy_forces, - compute_statistics, -) -from mace.tools import AtomicNumberTable, scatter, to_numpy, torch_geometric -from mace.tools.scripts_utils import dict_to_array - - -@pytest.fixture(name="config") -def _config(): - return Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=np.array( - [ - [0.0, -2.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - ] - ), - properties={ - "forces": np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - "energy": -1.5, - "stress": np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "stress": 1.0, - }, - ) - - -@pytest.fixture(name="table") -def _table(): - return AtomicNumberTable([1, 8]) - - -@pytest.fixture(name="config1") -def _config1(): - return Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=np.array( - [ - [0.0, -2.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - ] - ), - properties={ - "forces": np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - "energy": -1.5, - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - }, - head="DFT", - ) - - -@pytest.fixture(name="config2") -def _config2(): - return Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=np.array( - [ - [0.1, -1.9, 0.1], - [1.1, 0.1, 0.1], - [0.1, 1.1, 0.1], - ] - ), - properties={ - "forces": np.array( - [ - [0.1, -1.2, 0.1], - [1.1, 0.3, 0.1], - [0.1, 1.2, 0.4], - ] - ), - "energy": -1.4, - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - }, - head="MP2", - ) - - -@pytest.fixture(name="atomic_data") -def _atomic_data(config1, config2, table): - atomic_data1 = AtomicData.from_config( - config1, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] - ) - atomic_data2 = AtomicData.from_config( - config2, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] - ) - return [atomic_data1, atomic_data2] - - -@pytest.fixture(name="data_loader") -def _data_loader(atomic_data): - return torch_geometric.dataloader.DataLoader( - dataset=atomic_data, - batch_size=2, - shuffle=False, - drop_last=False, - ) - - -@pytest.fixture(name="atomic_energies") -def _atomic_energies(): - atomic_energies_dict = { - "DFT": np.array([0.0, 0.0]), - "MP2": np.array([0.1, 0.1]), - } - return dict_to_array(atomic_energies_dict, ["DFT", "MP2"]) - - -@pytest.fixture(autouse=True) -def _set_torch_default_dtype(): - torch.set_default_dtype(torch.float64) - - -def test_weighted_loss(config, table): - loss1 = WeightedEnergyForcesLoss(energy_weight=1, forces_weight=10) - loss2 = WeightedHuberEnergyForcesStressLoss(energy_weight=1, forces_weight=10) - data = AtomicData.from_config(config, z_table=table, cutoff=3.0) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data, data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - pred = { - "energy": batch.energy, - "forces": batch.forces, - "stress": batch.stress, - } - out1 = loss1(batch, pred) - assert out1 == 0.0 - out2 = loss2(batch, pred) - assert out2 == 0.0 - - -def test_symmetric_contraction(): - operation = SymmetricContraction( - irreps_in=o3.Irreps("16x0e + 16x1o + 16x2e"), - irreps_out=o3.Irreps("16x0e + 16x1o"), - correlation=3, - num_elements=2, - ) - torch.manual_seed(123) - features = torch.randn(30, 16, 9) - one_hots = torch.nn.functional.one_hot(torch.arange(0, 30) % 2).to( - torch.get_default_dtype() - ) - out = operation(features, one_hots) - assert out.shape == (30, 64) - assert operation.contractions[0].weights_max.shape == (2, 11, 16) - - -def test_bessel_basis(): - d = torch.linspace(start=0.5, end=5.5, steps=10) - bessel_basis = BesselBasis(r_max=6.0, num_basis=5) - output = bessel_basis(d.unsqueeze(-1)) - assert output.shape == (10, 5) - - -def test_polynomial_cutoff(): - d = torch.linspace(start=0.5, end=5.5, steps=10) - cutoff_fn = PolynomialCutoff(r_max=5.0) - output = cutoff_fn(d) - assert output.shape == (10,) - - -def test_atomic_energies(config, table): - energies_block = AtomicEnergiesBlock(atomic_energies=np.array([1.0, 3.0])) - data = AtomicData.from_config(config, z_table=table, cutoff=3.0) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data, data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - energies = energies_block(batch.node_attrs).squeeze(-1) - out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") - out = to_numpy(out) - assert np.allclose(out, np.array([5.0, 5.0])) - - -def test_atomic_energies_multireference(config, table): - energies_block = AtomicEnergiesBlock( - atomic_energies=np.array([[1.0, 3.0], [2.0, 4.0]]) - ) - config.head = "MP2" - data = AtomicData.from_config( - config, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] - ) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data, data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - num_atoms_arange = torch.arange(batch["positions"].shape[0]) - node_heads = ( - batch["head"][batch["batch"]] - if "head" in batch - else torch.zeros_like(batch["batch"]) - ) - energies = energies_block(batch.node_attrs).squeeze(-1) - energies = energies[num_atoms_arange, node_heads] - out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") - out = to_numpy(out) - assert np.allclose(out, np.array([8.0, 8.0])) - - -def test_compute_mean_rms_energy_forces_multi_head(data_loader, atomic_energies): - mean, rms = compute_mean_rms_energy_forces(data_loader, atomic_energies) - assert isinstance(mean, np.ndarray) - assert isinstance(rms, np.ndarray) - assert mean.shape == (2,) - assert rms.shape == (2,) - assert np.all(rms >= 0) - assert rms[0] != rms[1] - - -def test_compute_statistics(data_loader, atomic_energies): - avg_num_neighbors, mean, std = compute_statistics(data_loader, atomic_energies) - assert isinstance(avg_num_neighbors, float) - assert isinstance(mean, np.ndarray) - assert isinstance(std, np.ndarray) - assert mean.shape == (2,) - assert std.shape == (2,) - assert avg_num_neighbors > 0 - assert np.all(mean != 0) - assert np.all(std > 0) - assert mean[0] != mean[1] - assert std[0] != std[1] +import numpy as np +import pytest +import torch +import torch.nn.functional +from e3nn import o3 + +from mace.data import AtomicData, Configuration +from mace.modules import ( + AtomicEnergiesBlock, + BesselBasis, + PolynomialCutoff, + SymmetricContraction, + WeightedEnergyForcesLoss, + WeightedHuberEnergyForcesStressLoss, + compute_mean_rms_energy_forces, + compute_statistics, +) +from mace.tools import AtomicNumberTable, scatter, to_numpy, torch_geometric +from mace.tools.scripts_utils import dict_to_array + + +@pytest.fixture(name="config") +def _config(): + return Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "stress": np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "stress": 1.0, + }, + ) + + +@pytest.fixture(name="table") +def _table(): + return AtomicNumberTable([1, 8]) + + +@pytest.fixture(name="config1") +def _config1(): + return Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, + head="DFT", + ) + + +@pytest.fixture(name="config2") +def _config2(): + return Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.1, -1.9, 0.1], + [1.1, 0.1, 0.1], + [0.1, 1.1, 0.1], + ] + ), + properties={ + "forces": np.array( + [ + [0.1, -1.2, 0.1], + [1.1, 0.3, 0.1], + [0.1, 1.2, 0.4], + ] + ), + "energy": -1.4, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, + head="MP2", + ) + + +@pytest.fixture(name="atomic_data") +def _atomic_data(config1, config2, table): + atomic_data1 = AtomicData.from_config( + config1, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] + ) + atomic_data2 = AtomicData.from_config( + config2, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] + ) + return [atomic_data1, atomic_data2] + + +@pytest.fixture(name="data_loader") +def _data_loader(atomic_data): + return torch_geometric.dataloader.DataLoader( + dataset=atomic_data, + batch_size=2, + shuffle=False, + drop_last=False, + ) + + +@pytest.fixture(name="atomic_energies") +def _atomic_energies(): + atomic_energies_dict = { + "DFT": np.array([0.0, 0.0]), + "MP2": np.array([0.1, 0.1]), + } + return dict_to_array(atomic_energies_dict, ["DFT", "MP2"]) + + +@pytest.fixture(autouse=True) +def _set_torch_default_dtype(): + torch.set_default_dtype(torch.float64) + + +def test_weighted_loss(config, table): + loss1 = WeightedEnergyForcesLoss(energy_weight=1, forces_weight=10) + loss2 = WeightedHuberEnergyForcesStressLoss(energy_weight=1, forces_weight=10) + data = AtomicData.from_config(config, z_table=table, cutoff=3.0) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data, data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + pred = { + "energy": batch.energy, + "forces": batch.forces, + "stress": batch.stress, + } + out1 = loss1(batch, pred) + assert out1 == 0.0 + out2 = loss2(batch, pred) + assert out2 == 0.0 + + +def test_symmetric_contraction(): + operation = SymmetricContraction( + irreps_in=o3.Irreps("16x0e + 16x1o + 16x2e"), + irreps_out=o3.Irreps("16x0e + 16x1o"), + correlation=3, + num_elements=2, + ) + torch.manual_seed(123) + features = torch.randn(30, 16, 9) + one_hots = torch.nn.functional.one_hot(torch.arange(0, 30) % 2).to( + torch.get_default_dtype() + ) + out = operation(features, one_hots) + assert out.shape == (30, 64) + assert operation.contractions[0].weights_max.shape == (2, 11, 16) + + +def test_bessel_basis(): + d = torch.linspace(start=0.5, end=5.5, steps=10) + bessel_basis = BesselBasis(r_max=6.0, num_basis=5) + output = bessel_basis(d.unsqueeze(-1)) + assert output.shape == (10, 5) + + +def test_polynomial_cutoff(): + d = torch.linspace(start=0.5, end=5.5, steps=10) + cutoff_fn = PolynomialCutoff(r_max=5.0) + output = cutoff_fn(d) + assert output.shape == (10,) + + +def test_atomic_energies(config, table): + energies_block = AtomicEnergiesBlock(atomic_energies=np.array([1.0, 3.0])) + data = AtomicData.from_config(config, z_table=table, cutoff=3.0) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data, data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + energies = energies_block(batch.node_attrs).squeeze(-1) + out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") + out = to_numpy(out) + assert np.allclose(out, np.array([5.0, 5.0])) + + +def test_atomic_energies_multireference(config, table): + energies_block = AtomicEnergiesBlock( + atomic_energies=np.array([[1.0, 3.0], [2.0, 4.0]]) + ) + config.head = "MP2" + data = AtomicData.from_config( + config, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data, data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_heads = ( + batch["head"][batch["batch"]] + if "head" in batch + else torch.zeros_like(batch["batch"]) + ) + energies = energies_block(batch.node_attrs).squeeze(-1) + energies = energies[num_atoms_arange, node_heads] + out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") + out = to_numpy(out) + assert np.allclose(out, np.array([8.0, 8.0])) + + +def test_compute_mean_rms_energy_forces_multi_head(data_loader, atomic_energies): + mean, rms = compute_mean_rms_energy_forces(data_loader, atomic_energies) + assert isinstance(mean, np.ndarray) + assert isinstance(rms, np.ndarray) + assert mean.shape == (2,) + assert rms.shape == (2,) + assert np.all(rms >= 0) + assert rms[0] != rms[1] + + +def test_compute_statistics(data_loader, atomic_energies): + avg_num_neighbors, mean, std = compute_statistics(data_loader, atomic_energies) + assert isinstance(avg_num_neighbors, float) + assert isinstance(mean, np.ndarray) + assert isinstance(std, np.ndarray) + assert mean.shape == (2,) + assert std.shape == (2,) + assert avg_num_neighbors > 0 + assert np.all(mean != 0) + assert np.all(std > 0) + assert mean[0] != mean[1] + assert std[0] != std[1] diff --git a/mace-bench/3rdparty/mace/tests/test_multifiles.py b/mace-bench/3rdparty/mace/tests/test_multifiles.py index fb995189f7b69b23d620cf8f3d934728a977ce60..16eacc2f24dfce9ba28defbaed3a3cf4026e5b7d 100644 --- a/mace-bench/3rdparty/mace/tests/test_multifiles.py +++ b/mace-bench/3rdparty/mace/tests/test_multifiles.py @@ -1,1029 +1,1029 @@ -import json -import os -import shutil -import subprocess -import sys -import tempfile -import zlib -from pathlib import Path - -import lmdb -import numpy as np -import orjson -import pytest -import torch -import yaml -from ase.atoms import Atoms -from ase.calculators.singlepoint import SinglePointCalculator - -from mace.calculators import MACECalculator - - -def create_test_atoms(num_atoms=5, seed=42): - """Create random atoms for testing purposes with energy, forces, and stress.""" - # Set random seed for reproducibility - rng = np.random.RandomState(seed) - - # Create random positions - positions = rng.rand(num_atoms, 3) * 5.0 - - # Create random atomic numbers (H, C, N, O) - atomic_numbers = rng.choice([1, 6, 7, 8], size=num_atoms) - - # Create atoms object - atoms = Atoms( - numbers=atomic_numbers, - positions=positions, - cell=np.eye(3) * 10.0, # 10 Å periodic box - pbc=True, - ) - - # Add random energy, forces and stress - energy = float(rng.uniform(-15.0, -5.0)) - forces = rng.rand(num_atoms, 3) * 0.5 - 0.25 # Forces between -0.25 and 0.25 eV/Å - stress = rng.rand(6) * 0.2 - 0.1 # Stress tensor in Voigt notation - - # Add calculator to atoms with results - calc = SinglePointCalculator(atoms, energy=energy, forces=forces, stress=stress) - atoms.calc = calc - - # Mark isolated atoms with config_type - if num_atoms == 1: - atoms.info["config_type"] = "IsolatedAtom" - - return atoms - - -def create_xyz_file(atoms_list, filename): - """Write a list of atoms to an xyz file.""" - from ase.io import write - - write(filename, atoms_list, format="extxyz") - return filename - - -def create_e0s_file(e0s_dict, filename): - """Create an E0s JSON file with isolated atom energies.""" - # Convert keys to integers since MACE expects atomic numbers as integers - e0s_dict_int_keys = {int(k): v for k, v in e0s_dict.items()} - - with open(filename, "w", encoding="utf-8") as f: - json.dump(e0s_dict_int_keys, f) - return filename - - -def create_h5_dataset(xyz_file, output_dir, e0s_file=None, r_max=5.0, seed=42): - """ - Run MACE's preprocess_data.py script to convert an xyz file to h5 format. - - Args: - xyz_file: Path to the input xyz file - output_dir: Directory to store the preprocessed h5 files - e0s_file: Path to the E0s file with isolated atom energies - r_max: Cutoff radius - seed: Random seed - - Returns: - The output directory containing the h5 files - """ - # Make sure output directory exists - os.makedirs(output_dir, exist_ok=True) - - # Find the path to the preprocess_data.py script - preprocess_script = ( - Path(__file__).parent.parent / "mace" / "cli" / "preprocess_data.py" - ) - - # Set up command to run preprocess_data.py - cmd = [ - sys.executable, - str(preprocess_script), - f"--train_file={xyz_file}", - f"--r_max={r_max}", - f"--h5_prefix={output_dir}/", - f"--seed={seed}", - "--compute_statistics", # Generate statistics file - "--num_process=2", # Create 2 files for testing sharded loading - ] - - # Add E0s file if provided - if e0s_file: - cmd.append(f"--E0s={e0s_file}") - - # Set up environment - env = os.environ.copy() - env["PYTHONPATH"] = ( - str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") - ) - - # Run the script - print(f"Running preprocess command: {' '.join(cmd)}") - try: - process = subprocess.run( - cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True - ) - # Print output for debugging - print("Preprocess stdout:", process.stdout.decode()) - print("Preprocess stderr:", process.stderr.decode()) - except subprocess.CalledProcessError as e: - print("Preprocess failed with error:", e) - print("Stdout:", e.stdout.decode() if e.stdout else "") - print("Stderr:", e.stderr.decode() if e.stderr else "") - raise - - return output_dir - - -def create_lmdb_dataset(atoms_list, folder_path, head_name="Default"): - """Create an LMDB dataset from a list of atoms objects that MACE can read.""" - # Create the folder if it doesn't exist - os.makedirs(folder_path, exist_ok=True) - - # Create the LMDB database file - db_path = os.path.join(folder_path, "data.aselmdb") - - # Initialize LMDB environment - env = lmdb.open( - db_path, - map_size=1099511627776, # 1TB - subdir=False, - meminit=False, - map_async=True, - ) - - # Open a transaction - with env.begin(write=True) as txn: - # Store metadata - metadata = {"format_version": 1} - txn.put( - "metadata".encode("ascii"), - zlib.compress(orjson.dumps(metadata, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - - # Store nextid - nextid = len(atoms_list) + 1 - txn.put( - "nextid".encode("ascii"), - zlib.compress(orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - - # Store deleted_ids (empty) - txn.put( - "deleted_ids".encode("ascii"), - zlib.compress(orjson.dumps([], option=orjson.OPT_SERIALIZE_NUMPY)), - ) - - # Store each atom - for i, atoms in enumerate(atoms_list): - id_num = i + 1 # Start from 1 - - # Convert atoms to dictionary - positions = atoms.get_positions() - cell = atoms.get_cell() - - # Create a dictionary with all necessary fields - dct = { - "numbers": atoms.get_atomic_numbers().tolist(), - "positions": positions.tolist(), - "cell": cell.tolist(), - "pbc": atoms.get_pbc().tolist(), - "ctime": 0.0, # Creation time - "mtime": 0.0, # Modification time - "user": "test", - "energy": atoms.calc.results["energy"], - "forces": atoms.calc.results["forces"].tolist(), - "stress": atoms.calc.results["stress"].tolist(), - "key_value_pairs": { - "config_type": atoms.info.get("config_type", "Default"), - "head": head_name, - }, - } - - # Store the atom in LMDB - txn.put( - f"{id_num}".encode("ascii"), - zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - - # Close the environment - env.close() - - return folder_path - - -@pytest.mark.slow -def test_multifile_training(): - """Test training with multiple file formats per head""" - # Create temporary directory - temp_dir = tempfile.mkdtemp() - try: - # Set up file paths - xyz_file1 = os.path.join(temp_dir, "data1.xyz") - xyz_file2 = os.path.join(temp_dir, "data2.xyz") - iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") - h5_folder = os.path.join(temp_dir, "h5_data") - lmdb_folder1 = os.path.join( - temp_dir, "lmdb_data1_lmdb" - ) # Add _lmdb suffix for LMDB recognition - lmdb_folder2 = os.path.join( - temp_dir, "lmdb_data2_lmdb" - ) # Add _lmdb suffix for LMDB recognition - - config_path = os.path.join(temp_dir, "config.yaml") - results_dir = os.path.join(temp_dir, "results") - checkpoints_dir = os.path.join(temp_dir, "checkpoints") - model_dir = os.path.join(temp_dir, "models") - e0s_file = os.path.join(temp_dir, "e0s.json") - - # Create directories - os.makedirs(results_dir, exist_ok=True) - os.makedirs(checkpoints_dir, exist_ok=True) - os.makedirs(model_dir, exist_ok=True) - - # Set atomic numbers for z_table - z_table_elements = [1, 6, 7, 8] # H, C, N, O - - # Create test data for each format - rng = np.random.RandomState(42) - seeds = rng.randint(0, 10000, size=5) - - # Create isolated atoms for E0s (one of each element) - isolated_atoms = [] - e0s_dict = {} - for z in z_table_elements: - # Create isolated atom - atom = Atoms( - numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True - ) - energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy - forces = np.zeros((1, 3)) - stress = np.zeros(6) - calc = SinglePointCalculator( - atom, energy=energy, forces=forces, stress=stress - ) - atom.calc = calc - atom.info["config_type"] = "IsolatedAtom" - atom.info["REF_energy"] = energy # Make sure energy is in the right place - isolated_atoms.append(atom) - e0s_dict[str(z)] = energy # Store energy for E0s file - - # Create E0s file - create_e0s_file(e0s_dict, e0s_file) - - # Create isolated atoms xyz file - create_xyz_file(isolated_atoms, iso_atoms_file) - - # Create 10 atoms for each dataset - xyz_atoms1 = [ - create_test_atoms(num_atoms=5, seed=seeds[0] + i) for i in range(10) - ] - xyz_atoms2 = [ - create_test_atoms(num_atoms=5, seed=seeds[1] + i) for i in range(10) - ] - - # Create h5 data directly - first convert the xyz file to a format with REF_ keys - for atom in xyz_atoms1: - atom.info["REF_energy"] = atom.calc.results["energy"] - atom.arrays["REF_forces"] = atom.calc.results["forces"] - atom.info["REF_stress"] = atom.calc.results["stress"] - - for atom in xyz_atoms2: - atom.info["REF_energy"] = atom.calc.results["energy"] - atom.arrays["REF_forces"] = atom.calc.results["forces"] - atom.info["REF_stress"] = atom.calc.results["stress"] - - # Save isolated atoms to xyz files first, then create the h5 datasets - create_xyz_file(xyz_atoms1, xyz_file1) - create_xyz_file(xyz_atoms2, xyz_file2) - - # Create h5 data from xyz file, using both isolated atoms and real data - all_atoms_for_h5 = isolated_atoms + xyz_atoms2 - all_atoms_xyz = os.path.join(temp_dir, "all_atoms_for_h5.xyz") - create_xyz_file(all_atoms_for_h5, all_atoms_xyz) - create_h5_dataset(all_atoms_xyz, h5_folder) - - # Create LMDB datasets - lmdb_atoms1 = [ - create_test_atoms(num_atoms=5, seed=seeds[3] + i) for i in range(10) - ] - lmdb_atoms2 = [ - create_test_atoms(num_atoms=5, seed=seeds[4] + i) for i in range(10) - ] - create_lmdb_dataset(lmdb_atoms1, lmdb_folder1, head_name="head1") - create_lmdb_dataset(lmdb_atoms2, lmdb_folder2, head_name="head2") - - # Create config.yaml for training with proper format specification - config = { - "name": "multifile_test", - "seed": 42, - "model": "MACE", - "hidden_irreps": "32x0e", - "r_max": 5.0, - "batch_size": 5, - "max_num_epochs": 2, - "patience": 5, - "device": "cpu", - "energy_weight": 1.0, - "forces_weight": 10.0, - "loss": "weighted", - "optimizer": "adam", - "default_dtype": "float64", - "lr": 0.01, - "swa": False, - "work_dir": temp_dir, - "results_dir": results_dir, - "checkpoints_dir": checkpoints_dir, - "model_dir": model_dir, - "E0s": e0s_file, - "atomic_numbers": str(z_table_elements), - "heads": { - "head1": { - "train_file": [lmdb_folder1, xyz_file1], - "valid_file": xyz_file1, - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - }, - "head2": { - "train_file": [h5_folder + "/train", xyz_file2], - "valid_file": xyz_file2, - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - }, - }, - } - - # Write config file - with open(config_path, "w", encoding="utf-8") as f: - yaml.dump(config, f) - - # Import the modified run_train from our local module - run_train_script = ( - Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - ) - - # Run training with subprocess - cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] - - # Set environment to add the current path to PYTHONPATH - env = os.environ.copy() - env["PYTHONPATH"] = ( - str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") - ) - - # Run the process - process = subprocess.run( - cmd, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=False, # Don't raise exception on non-zero exit, we'll check manually - ) - - # Print output for debugging - print("\n" + "=" * 40 + " STDOUT " + "=" * 40) - print(process.stdout.decode()) - print("\n" + "=" * 40 + " STDERR " + "=" * 40) - print(process.stderr.decode()) - - # Check that process completed successfully - assert ( - process.returncode == 0 - ), f"Training failed with error: {process.stderr.decode()}" - - # Check that model was created - model_path = os.path.join(model_dir, "multifile_test.model") - assert os.path.exists(model_path), f"Model was not created at {model_path}" - - # Try to load and run the model - model = torch.load(model_path, map_location="cpu") - assert model is not None, "Failed to load model" - - # Create a calculator - calc = MACECalculator(model_paths=model_path, device="cpu", head="head1") - - # Run prediction on a test atom - test_atom = create_test_atoms(num_atoms=5, seed=99999) - test_atom.calc = calc - energy = test_atom.get_potential_energy() - forces = test_atom.get_forces() - - # Assert we got sensible outputs - assert np.isfinite(energy), "Model produced non-finite energy" - assert np.all(np.isfinite(forces)), "Model produced non-finite forces" - - finally: - # Clean up - shutil.rmtree(temp_dir) - - -@pytest.mark.slow -def test_multiple_xyz_per_head(): - """Test training with multiple XYZ files per head for train, valid and test sets""" - # Create temporary directory - temp_dir = tempfile.mkdtemp() - try: - # Set up file paths - create multiple xyz files for each dataset - train_xyz_files = [ - os.path.join(temp_dir, f"train_data{i}.xyz") for i in range(1, 4) - ] # 3 train files - valid_xyz_files = [ - os.path.join(temp_dir, f"valid_data{i}.xyz") for i in range(1, 3) - ] # 2 valid files - test_xyz_files = [ - os.path.join(temp_dir, f"test_data{i}.xyz") for i in range(1, 3) - ] # 2 test files - - iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") - - config_path = os.path.join(temp_dir, "config.yaml") - results_dir = os.path.join(temp_dir, "results") - checkpoints_dir = os.path.join(temp_dir, "checkpoints") - model_dir = os.path.join(temp_dir, "models") - e0s_file = os.path.join(temp_dir, "e0s.json") - - # Create directories - os.makedirs(results_dir, exist_ok=True) - os.makedirs(checkpoints_dir, exist_ok=True) - os.makedirs(model_dir, exist_ok=True) - - # Set atomic numbers for z_table - z_table_elements = [1, 6, 7, 8] # H, C, N, O - - # Create test data for each format - rng = np.random.RandomState(42) - seeds = rng.randint(0, 10000, size=10) # More seeds for multiple files - - # Create isolated atoms for E0s (one of each element) - isolated_atoms = [] - e0s_dict = {} - for z in z_table_elements: - # Create isolated atom - atom = Atoms( - numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True - ) - energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy - forces = np.zeros((1, 3)) - stress = np.zeros(6) - calc = SinglePointCalculator( - atom, energy=energy, forces=forces, stress=stress - ) - atom.calc = calc - atom.info["config_type"] = "IsolatedAtom" - isolated_atoms.append(atom) - e0s_dict[str(z)] = energy # Store energy for E0s file - - # Create E0s file - create_e0s_file(e0s_dict, e0s_file) - - # Create isolated atoms xyz file - create_xyz_file(isolated_atoms, iso_atoms_file) - - # Create atoms for each train dataset - use different seeds for variety - train_datasets = [] - for i, file in enumerate(train_xyz_files): - # Create atoms with different seeds - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i] + j) for j in range(5) - ] - create_xyz_file(atoms, file) - train_datasets.append(atoms) - - # Create atoms for validation datasets - valid_datasets = [] - for i, file in enumerate(valid_xyz_files): - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i + 3] + j) for j in range(3) - ] - create_xyz_file(atoms, file) - valid_datasets.append(atoms) - - # Create atoms for test datasets - test_datasets = [] - for i, file in enumerate(test_xyz_files): - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i + 5] + j) for j in range(3) - ] - create_xyz_file(atoms, file) - test_datasets.append(atoms) - - # Create config.yaml for training with multiple xyz files per dataset - config = { - "name": "multi_xyz_test", - "seed": 42, - "model": "MACE", - "hidden_irreps": "32x0e", - "r_max": 5.0, - "batch_size": 5, - "max_num_epochs": 2, - "patience": 5, - "device": "cpu", - "energy_weight": 1.0, - "forces_weight": 10.0, - "loss": "weighted", - "optimizer": "adam", - "default_dtype": "float64", - "lr": 0.01, - "swa": False, - "work_dir": temp_dir, - "results_dir": results_dir, - "checkpoints_dir": checkpoints_dir, - "model_dir": model_dir, - "E0s": e0s_file, - "atomic_numbers": str(z_table_elements), - "heads": { - "multi_xyz_head": { - # Using lists of multiple xyz files for each dataset - "train_file": train_xyz_files, - "valid_file": valid_xyz_files, - "test_file": test_xyz_files, - "energy_key": "energy", - "forces_key": "forces", - "stress_key": "stress", - }, - }, - } - - # Write config file - with open(config_path, "w", encoding="utf-8") as f: - yaml.dump(config, f) - - # Import the modified run_train from our local module - run_train_script = ( - Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - ) - - # Run training with subprocess - cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] - - # Set environment to add the current path to PYTHONPATH - env = os.environ.copy() - env["PYTHONPATH"] = ( - str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") - ) - - # Run the process - process = subprocess.run( - cmd, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=False, - ) - - # Print output for debugging - print("\n" + "=" * 40 + " STDOUT " + "=" * 40) - print(process.stdout.decode()) - print("\n" + "=" * 40 + " STDERR " + "=" * 40) - print(process.stderr.decode()) - - # Check that process completed successfully - assert ( - process.returncode == 0 - ), f"Training failed with error: {process.stderr.decode()}" - - # Check that model was created - model_path = os.path.join(model_dir, "multi_xyz_test.model") - assert os.path.exists(model_path), f"Model was not created at {model_path}" - - # Try to load and run the model - model = torch.load(model_path, map_location="cpu") - assert model is not None, "Failed to load model" - - # Create a calculator - calc = MACECalculator( - model_paths=model_path, device="cpu", head="multi_xyz_head" - ) - - # Run prediction on a test atom - test_atom = create_test_atoms(num_atoms=5, seed=99999) - test_atom.calc = calc - energy = test_atom.get_potential_energy() - forces = test_atom.get_forces() - - # Assert we got sensible outputs - assert np.isfinite(energy), "Model produced non-finite energy" - assert np.all(np.isfinite(forces)), "Model produced non-finite forces" - - finally: - # Clean up - shutil.rmtree(temp_dir) - - -@pytest.mark.slow -def test_single_xyz_per_head(): - """Test training with multiple XYZ files per head for train, valid and test sets""" - # Create temporary directory - temp_dir = tempfile.mkdtemp() - try: - # Set up file paths - create multiple xyz files for each dataset - train_xyz_files = [ - os.path.join(temp_dir, f"train_data{i}.xyz") for i in range(1, 2) - ] # 3 train files - valid_xyz_files = [ - os.path.join(temp_dir, f"valid_data{i}.xyz") for i in range(1, 2) - ] # 2 valid files - test_xyz_files = [ - os.path.join(temp_dir, f"test_data{i}.xyz") for i in range(1, 2) - ] # 2 test files - - iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") - - config_path = os.path.join(temp_dir, "config.yaml") - results_dir = os.path.join(temp_dir, "results") - checkpoints_dir = os.path.join(temp_dir, "checkpoints") - model_dir = os.path.join(temp_dir, "models") - e0s_file = os.path.join(temp_dir, "e0s.json") - - # Create directories - os.makedirs(results_dir, exist_ok=True) - os.makedirs(checkpoints_dir, exist_ok=True) - os.makedirs(model_dir, exist_ok=True) - - # Set atomic numbers for z_table - z_table_elements = [1, 6, 7, 8] # H, C, N, O - - # Create test data for each format - rng = np.random.RandomState(42) - seeds = rng.randint(0, 10000, size=10) # More seeds for multiple files - - # Create isolated atoms for E0s (one of each element) - isolated_atoms = [] - e0s_dict = {} - for z in z_table_elements: - # Create isolated atom - atom = Atoms( - numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True - ) - energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy - forces = np.zeros((1, 3)) - stress = np.zeros(6) - calc = SinglePointCalculator( - atom, energy=energy, forces=forces, stress=stress - ) - atom.calc = calc - atom.info["config_type"] = "IsolatedAtom" - isolated_atoms.append(atom) - e0s_dict[str(z)] = energy # Store energy for E0s file - - # Create E0s file - create_e0s_file(e0s_dict, e0s_file) - - # Create isolated atoms xyz file - create_xyz_file(isolated_atoms, iso_atoms_file) - - # Create atoms for each train dataset - use different seeds for variety - train_datasets = [] - for i, file in enumerate(train_xyz_files): - # Create atoms with different seeds - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i] + j) for j in range(5) - ] - create_xyz_file(atoms, file) - train_datasets.append(atoms) - - # Create atoms for validation datasets - valid_datasets = [] - for i, file in enumerate(valid_xyz_files): - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i + 3] + j) for j in range(3) - ] - create_xyz_file(atoms, file) - valid_datasets.append(atoms) - - # Create atoms for test datasets - test_datasets = [] - for i, file in enumerate(test_xyz_files): - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i + 5] + j) for j in range(3) - ] - create_xyz_file(atoms, file) - test_datasets.append(atoms) - - # Create config.yaml for training with multiple xyz files per dataset - config = { - "name": "multi_xyz_test", - "seed": 42, - "model": "MACE", - "hidden_irreps": "32x0e", - "r_max": 5.0, - "batch_size": 5, - "max_num_epochs": 2, - "patience": 5, - "device": "cpu", - "energy_weight": 1.0, - "forces_weight": 10.0, - "loss": "weighted", - "optimizer": "adam", - "default_dtype": "float64", - "lr": 0.01, - "swa": False, - "work_dir": temp_dir, - "results_dir": results_dir, - "checkpoints_dir": checkpoints_dir, - "model_dir": model_dir, - "E0s": e0s_file, - "atomic_numbers": str(z_table_elements), - "heads": { - "multi_xyz_head": { - # Using lists of multiple xyz files for each dataset - "train_file": train_xyz_files, - "valid_file": valid_xyz_files, - "test_file": test_xyz_files, - "energy_key": "energy", - "forces_key": "forces", - "stress_key": "stress", - }, - }, - } - - # Write config file - with open(config_path, "w", encoding="utf-8") as f: - yaml.dump(config, f) - - # Import the modified run_train from our local module - run_train_script = ( - Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - ) - - # Run training with subprocess - cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] - - # Set environment to add the current path to PYTHONPATH - env = os.environ.copy() - env["PYTHONPATH"] = ( - str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") - ) - - # Run the process - process = subprocess.run( - cmd, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=False, - ) - - # Print output for debugging - print("\n" + "=" * 40 + " STDOUT " + "=" * 40) - print(process.stdout.decode()) - print("\n" + "=" * 40 + " STDERR " + "=" * 40) - print(process.stderr.decode()) - - # Check that process completed successfully - assert ( - process.returncode == 0 - ), f"Training failed with error: {process.stderr.decode()}" - - # Check that model was created - model_path = os.path.join(model_dir, "multi_xyz_test.model") - assert os.path.exists(model_path), f"Model was not created at {model_path}" - - # Try to load and run the model - model = torch.load(model_path, map_location="cpu") - assert model is not None, "Failed to load model" - - # Create a calculator - calc = MACECalculator( - model_paths=model_path, device="cpu", head="multi_xyz_head" - ) - - # Run prediction on a test atom - test_atom = create_test_atoms(num_atoms=5, seed=99999) - test_atom.calc = calc - energy = test_atom.get_potential_energy() - forces = test_atom.get_forces() - - # Assert we got sensible outputs - assert np.isfinite(energy), "Model produced non-finite energy" - assert np.all(np.isfinite(forces)), "Model produced non-finite forces" - - finally: - # Clean up - shutil.rmtree(temp_dir) - - -@pytest.mark.slow -def test_multihead_finetuning_different_formats(): - """Test multihead finetuning with different file formats for each head.""" - # Create temporary directory - temp_dir = tempfile.mkdtemp() - try: - # Set up file paths - xyz_file = os.path.join(temp_dir, "finetuning_xyz.xyz") - h5_folder = os.path.join(temp_dir, "h5_data") - iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") - - config_path = os.path.join(temp_dir, "config.yaml") - results_dir = os.path.join(temp_dir, "results") - checkpoints_dir = os.path.join(temp_dir, "checkpoints") - model_dir = os.path.join(temp_dir, "models") - e0s_file = os.path.join(temp_dir, "e0s.json") - - # Create directories - os.makedirs(results_dir, exist_ok=True) - os.makedirs(checkpoints_dir, exist_ok=True) - os.makedirs(model_dir, exist_ok=True) - - # Set atomic numbers for z_table - z_table_elements = [1, 6, 7, 8] # H, C, N, O - - # Create test data with different seeds - rng = np.random.RandomState(42) - seeds = rng.randint(0, 10000, size=3) - - # Create isolated atoms for E0s (one of each element) - isolated_atoms = [] - e0s_dict = {} - for z in z_table_elements: - atom = Atoms( - numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True - ) - energy = float(rng.uniform(-5.0, -1.0)) - forces = np.zeros((1, 3)) - stress = np.zeros(6) - calc = SinglePointCalculator( - atom, energy=energy, forces=forces, stress=stress - ) - atom.calc = calc - atom.info["config_type"] = "IsolatedAtom" - atom.info["REF_energy"] = energy # Make sure energy is in the right place - atom.arrays["REF_forces"] = forces - atom.info["REF_stress"] = stress - isolated_atoms.append(atom) - e0s_dict[str(z)] = energy - - # Create E0s file - create_e0s_file(e0s_dict, e0s_file) - - # Create isolated atoms xyz file - create_xyz_file(isolated_atoms, iso_atoms_file) - - # Create XYZ data for xyz_head - xyz_atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[0] + i) for i in range(30) - ] - # Add REF_ properties - for atom in xyz_atoms: - atom.info["REF_energy"] = atom.calc.results["energy"] - atom.arrays["REF_forces"] = atom.calc.results["forces"] - atom.info["REF_stress"] = atom.calc.results["stress"] - atom.info["head"] = "xyz_head" # Assign head - create_xyz_file(xyz_atoms, xyz_file) - - # Create H5 data for h5_head - h5_atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[1] + i) for i in range(30) - ] - # Add REF_ properties - for atom in h5_atoms: - atom.info["REF_energy"] = atom.calc.results["energy"] - atom.arrays["REF_forces"] = atom.calc.results["forces"] - atom.info["REF_stress"] = atom.calc.results["stress"] - atom.info["head"] = "h5_head" # Assign head - - h5_atoms_xyz = os.path.join(temp_dir, "h5_atoms.xyz") - create_xyz_file(h5_atoms, h5_atoms_xyz) - # Include isolated atoms for E0s in the h5 dataset - all_atoms_for_h5 = h5_atoms + isolated_atoms - all_atoms_h5_xyz = os.path.join(temp_dir, "all_atoms_for_h5.xyz") - create_xyz_file(all_atoms_for_h5, all_atoms_h5_xyz) - create_h5_dataset(all_atoms_h5_xyz, h5_folder) - - # Create config.yaml for multihead finetuning - heads = { - "xyz_head": { - "train_file": xyz_file, - "valid_fraction": 0.2, - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "E0s": e0s_file, - }, - "h5_head": { - "train_file": os.path.join(h5_folder, "train"), - "valid_file": os.path.join(h5_folder, "val"), - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "E0s": e0s_file, - }, - } - - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - - with open(config_path, "w", encoding="utf-8") as f: - f.write(yaml_str) - - # Now perform multihead finetuning - finetuning_params = { - "name": "multihead_finetuned", - "config": config_path, - "foundation_model": "small", # Use the small foundation model - "energy_weight": 1.0, - "forces_weight": 10.0, - "model": "MACE", - "hidden_irreps": "128x0e", # Match foundation model - "r_max": 5.0, - "batch_size": 2, - "max_num_epochs": 2, # Just do a quick finetuning for test - "device": "cpu", - "seed": 42, - "loss": "weighted", - "default_dtype": "float64", - "checkpoints_dir": checkpoints_dir, - "model_dir": model_dir, - "results_dir": results_dir, - "atomic_numbers": "[" + ",".join(map(str, z_table_elements)) + "]", - "multiheads_finetuning": True, - "filter_type_pt": "combinations", - "subselect_pt": "random", - "num_samples_pt": 10, # Small number for testing - "force_mh_ft_lr": True, # Force using specified learning rate - } - - # Run finetuning - run_train_script = ( - Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - ) - env = os.environ.copy() - env["PYTHONPATH"] = ( - str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") - ) - - cmd = [sys.executable, str(run_train_script)] - for k, v in finetuning_params.items(): - if v is None: - cmd.append(f"--{k}") - else: - cmd.append(f"--{k}={v}") - - # Run the process - process = subprocess.run( - cmd, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=False, - ) - - # Print output for debugging - print("\n" + "=" * 40 + " STDOUT " + "=" * 40) - print(process.stdout.decode()) - print("\n" + "=" * 40 + " STDERR " + "=" * 40) - print(process.stderr.decode()) - - # Check that process completed successfully - assert ( - process.returncode == 0 - ), f"Finetuning failed with error: {process.stderr.decode()}" - - # Check that model was created - model_path = os.path.join(model_dir, "multihead_finetuned.model") - assert os.path.exists(model_path), f"Model was not created at {model_path}" - - # Load model and verify it has the expected heads - model = torch.load(model_path, map_location="cpu") - assert hasattr(model, "heads"), "Model does not have heads attribute" - assert set(["xyz_head", "h5_head", "pt_head"]).issubset( - set(model.heads) - ), "Expected heads not found in model" - - # Try to run the model with both heads - # For xyz_head - calc_xyz = MACECalculator( - model_paths=model_path, - device="cpu", - head="xyz_head", - default_dtype="float64", - ) - test_atom = create_test_atoms(num_atoms=5, seed=99999) - test_atom.calc = calc_xyz - energy_xyz = test_atom.get_potential_energy() - forces_xyz = test_atom.get_forces() - - # For h5_head - calc_h5 = MACECalculator( - model_paths=model_path, - device="cpu", - head="h5_head", - default_dtype="float64", - ) - test_atom.calc = calc_h5 - energy_h5 = test_atom.get_potential_energy() - forces_h5 = test_atom.get_forces() - - # Verify results - assert np.isfinite(energy_xyz), "xyz_head produced non-finite energy" - assert np.all(np.isfinite(forces_xyz)), "xyz_head produced non-finite forces" - assert np.isfinite(energy_h5), "h5_head produced non-finite energy" - assert np.all(np.isfinite(forces_h5)), "h5_head produced non-finite forces" - - finally: - # Clean up - shutil.rmtree(temp_dir) +import json +import os +import shutil +import subprocess +import sys +import tempfile +import zlib +from pathlib import Path + +import lmdb +import numpy as np +import orjson +import pytest +import torch +import yaml +from ase.atoms import Atoms +from ase.calculators.singlepoint import SinglePointCalculator + +from mace.calculators import MACECalculator + + +def create_test_atoms(num_atoms=5, seed=42): + """Create random atoms for testing purposes with energy, forces, and stress.""" + # Set random seed for reproducibility + rng = np.random.RandomState(seed) + + # Create random positions + positions = rng.rand(num_atoms, 3) * 5.0 + + # Create random atomic numbers (H, C, N, O) + atomic_numbers = rng.choice([1, 6, 7, 8], size=num_atoms) + + # Create atoms object + atoms = Atoms( + numbers=atomic_numbers, + positions=positions, + cell=np.eye(3) * 10.0, # 10 Å periodic box + pbc=True, + ) + + # Add random energy, forces and stress + energy = float(rng.uniform(-15.0, -5.0)) + forces = rng.rand(num_atoms, 3) * 0.5 - 0.25 # Forces between -0.25 and 0.25 eV/Å + stress = rng.rand(6) * 0.2 - 0.1 # Stress tensor in Voigt notation + + # Add calculator to atoms with results + calc = SinglePointCalculator(atoms, energy=energy, forces=forces, stress=stress) + atoms.calc = calc + + # Mark isolated atoms with config_type + if num_atoms == 1: + atoms.info["config_type"] = "IsolatedAtom" + + return atoms + + +def create_xyz_file(atoms_list, filename): + """Write a list of atoms to an xyz file.""" + from ase.io import write + + write(filename, atoms_list, format="extxyz") + return filename + + +def create_e0s_file(e0s_dict, filename): + """Create an E0s JSON file with isolated atom energies.""" + # Convert keys to integers since MACE expects atomic numbers as integers + e0s_dict_int_keys = {int(k): v for k, v in e0s_dict.items()} + + with open(filename, "w", encoding="utf-8") as f: + json.dump(e0s_dict_int_keys, f) + return filename + + +def create_h5_dataset(xyz_file, output_dir, e0s_file=None, r_max=5.0, seed=42): + """ + Run MACE's preprocess_data.py script to convert an xyz file to h5 format. + + Args: + xyz_file: Path to the input xyz file + output_dir: Directory to store the preprocessed h5 files + e0s_file: Path to the E0s file with isolated atom energies + r_max: Cutoff radius + seed: Random seed + + Returns: + The output directory containing the h5 files + """ + # Make sure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Find the path to the preprocess_data.py script + preprocess_script = ( + Path(__file__).parent.parent / "mace" / "cli" / "preprocess_data.py" + ) + + # Set up command to run preprocess_data.py + cmd = [ + sys.executable, + str(preprocess_script), + f"--train_file={xyz_file}", + f"--r_max={r_max}", + f"--h5_prefix={output_dir}/", + f"--seed={seed}", + "--compute_statistics", # Generate statistics file + "--num_process=2", # Create 2 files for testing sharded loading + ] + + # Add E0s file if provided + if e0s_file: + cmd.append(f"--E0s={e0s_file}") + + # Set up environment + env = os.environ.copy() + env["PYTHONPATH"] = ( + str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") + ) + + # Run the script + print(f"Running preprocess command: {' '.join(cmd)}") + try: + process = subprocess.run( + cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True + ) + # Print output for debugging + print("Preprocess stdout:", process.stdout.decode()) + print("Preprocess stderr:", process.stderr.decode()) + except subprocess.CalledProcessError as e: + print("Preprocess failed with error:", e) + print("Stdout:", e.stdout.decode() if e.stdout else "") + print("Stderr:", e.stderr.decode() if e.stderr else "") + raise + + return output_dir + + +def create_lmdb_dataset(atoms_list, folder_path, head_name="Default"): + """Create an LMDB dataset from a list of atoms objects that MACE can read.""" + # Create the folder if it doesn't exist + os.makedirs(folder_path, exist_ok=True) + + # Create the LMDB database file + db_path = os.path.join(folder_path, "data.aselmdb") + + # Initialize LMDB environment + env = lmdb.open( + db_path, + map_size=1099511627776, # 1TB + subdir=False, + meminit=False, + map_async=True, + ) + + # Open a transaction + with env.begin(write=True) as txn: + # Store metadata + metadata = {"format_version": 1} + txn.put( + "metadata".encode("ascii"), + zlib.compress(orjson.dumps(metadata, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + + # Store nextid + nextid = len(atoms_list) + 1 + txn.put( + "nextid".encode("ascii"), + zlib.compress(orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + + # Store deleted_ids (empty) + txn.put( + "deleted_ids".encode("ascii"), + zlib.compress(orjson.dumps([], option=orjson.OPT_SERIALIZE_NUMPY)), + ) + + # Store each atom + for i, atoms in enumerate(atoms_list): + id_num = i + 1 # Start from 1 + + # Convert atoms to dictionary + positions = atoms.get_positions() + cell = atoms.get_cell() + + # Create a dictionary with all necessary fields + dct = { + "numbers": atoms.get_atomic_numbers().tolist(), + "positions": positions.tolist(), + "cell": cell.tolist(), + "pbc": atoms.get_pbc().tolist(), + "ctime": 0.0, # Creation time + "mtime": 0.0, # Modification time + "user": "test", + "energy": atoms.calc.results["energy"], + "forces": atoms.calc.results["forces"].tolist(), + "stress": atoms.calc.results["stress"].tolist(), + "key_value_pairs": { + "config_type": atoms.info.get("config_type", "Default"), + "head": head_name, + }, + } + + # Store the atom in LMDB + txn.put( + f"{id_num}".encode("ascii"), + zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + + # Close the environment + env.close() + + return folder_path + + +@pytest.mark.slow +def test_multifile_training(): + """Test training with multiple file formats per head""" + # Create temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Set up file paths + xyz_file1 = os.path.join(temp_dir, "data1.xyz") + xyz_file2 = os.path.join(temp_dir, "data2.xyz") + iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") + h5_folder = os.path.join(temp_dir, "h5_data") + lmdb_folder1 = os.path.join( + temp_dir, "lmdb_data1_lmdb" + ) # Add _lmdb suffix for LMDB recognition + lmdb_folder2 = os.path.join( + temp_dir, "lmdb_data2_lmdb" + ) # Add _lmdb suffix for LMDB recognition + + config_path = os.path.join(temp_dir, "config.yaml") + results_dir = os.path.join(temp_dir, "results") + checkpoints_dir = os.path.join(temp_dir, "checkpoints") + model_dir = os.path.join(temp_dir, "models") + e0s_file = os.path.join(temp_dir, "e0s.json") + + # Create directories + os.makedirs(results_dir, exist_ok=True) + os.makedirs(checkpoints_dir, exist_ok=True) + os.makedirs(model_dir, exist_ok=True) + + # Set atomic numbers for z_table + z_table_elements = [1, 6, 7, 8] # H, C, N, O + + # Create test data for each format + rng = np.random.RandomState(42) + seeds = rng.randint(0, 10000, size=5) + + # Create isolated atoms for E0s (one of each element) + isolated_atoms = [] + e0s_dict = {} + for z in z_table_elements: + # Create isolated atom + atom = Atoms( + numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True + ) + energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy + forces = np.zeros((1, 3)) + stress = np.zeros(6) + calc = SinglePointCalculator( + atom, energy=energy, forces=forces, stress=stress + ) + atom.calc = calc + atom.info["config_type"] = "IsolatedAtom" + atom.info["REF_energy"] = energy # Make sure energy is in the right place + isolated_atoms.append(atom) + e0s_dict[str(z)] = energy # Store energy for E0s file + + # Create E0s file + create_e0s_file(e0s_dict, e0s_file) + + # Create isolated atoms xyz file + create_xyz_file(isolated_atoms, iso_atoms_file) + + # Create 10 atoms for each dataset + xyz_atoms1 = [ + create_test_atoms(num_atoms=5, seed=seeds[0] + i) for i in range(10) + ] + xyz_atoms2 = [ + create_test_atoms(num_atoms=5, seed=seeds[1] + i) for i in range(10) + ] + + # Create h5 data directly - first convert the xyz file to a format with REF_ keys + for atom in xyz_atoms1: + atom.info["REF_energy"] = atom.calc.results["energy"] + atom.arrays["REF_forces"] = atom.calc.results["forces"] + atom.info["REF_stress"] = atom.calc.results["stress"] + + for atom in xyz_atoms2: + atom.info["REF_energy"] = atom.calc.results["energy"] + atom.arrays["REF_forces"] = atom.calc.results["forces"] + atom.info["REF_stress"] = atom.calc.results["stress"] + + # Save isolated atoms to xyz files first, then create the h5 datasets + create_xyz_file(xyz_atoms1, xyz_file1) + create_xyz_file(xyz_atoms2, xyz_file2) + + # Create h5 data from xyz file, using both isolated atoms and real data + all_atoms_for_h5 = isolated_atoms + xyz_atoms2 + all_atoms_xyz = os.path.join(temp_dir, "all_atoms_for_h5.xyz") + create_xyz_file(all_atoms_for_h5, all_atoms_xyz) + create_h5_dataset(all_atoms_xyz, h5_folder) + + # Create LMDB datasets + lmdb_atoms1 = [ + create_test_atoms(num_atoms=5, seed=seeds[3] + i) for i in range(10) + ] + lmdb_atoms2 = [ + create_test_atoms(num_atoms=5, seed=seeds[4] + i) for i in range(10) + ] + create_lmdb_dataset(lmdb_atoms1, lmdb_folder1, head_name="head1") + create_lmdb_dataset(lmdb_atoms2, lmdb_folder2, head_name="head2") + + # Create config.yaml for training with proper format specification + config = { + "name": "multifile_test", + "seed": 42, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 5, + "max_num_epochs": 2, + "patience": 5, + "device": "cpu", + "energy_weight": 1.0, + "forces_weight": 10.0, + "loss": "weighted", + "optimizer": "adam", + "default_dtype": "float64", + "lr": 0.01, + "swa": False, + "work_dir": temp_dir, + "results_dir": results_dir, + "checkpoints_dir": checkpoints_dir, + "model_dir": model_dir, + "E0s": e0s_file, + "atomic_numbers": str(z_table_elements), + "heads": { + "head1": { + "train_file": [lmdb_folder1, xyz_file1], + "valid_file": xyz_file1, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + }, + "head2": { + "train_file": [h5_folder + "/train", xyz_file2], + "valid_file": xyz_file2, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + }, + }, + } + + # Write config file + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config, f) + + # Import the modified run_train from our local module + run_train_script = ( + Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + ) + + # Run training with subprocess + cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] + + # Set environment to add the current path to PYTHONPATH + env = os.environ.copy() + env["PYTHONPATH"] = ( + str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") + ) + + # Run the process + process = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, # Don't raise exception on non-zero exit, we'll check manually + ) + + # Print output for debugging + print("\n" + "=" * 40 + " STDOUT " + "=" * 40) + print(process.stdout.decode()) + print("\n" + "=" * 40 + " STDERR " + "=" * 40) + print(process.stderr.decode()) + + # Check that process completed successfully + assert ( + process.returncode == 0 + ), f"Training failed with error: {process.stderr.decode()}" + + # Check that model was created + model_path = os.path.join(model_dir, "multifile_test.model") + assert os.path.exists(model_path), f"Model was not created at {model_path}" + + # Try to load and run the model + model = torch.load(model_path, map_location="cpu") + assert model is not None, "Failed to load model" + + # Create a calculator + calc = MACECalculator(model_paths=model_path, device="cpu", head="head1") + + # Run prediction on a test atom + test_atom = create_test_atoms(num_atoms=5, seed=99999) + test_atom.calc = calc + energy = test_atom.get_potential_energy() + forces = test_atom.get_forces() + + # Assert we got sensible outputs + assert np.isfinite(energy), "Model produced non-finite energy" + assert np.all(np.isfinite(forces)), "Model produced non-finite forces" + + finally: + # Clean up + shutil.rmtree(temp_dir) + + +@pytest.mark.slow +def test_multiple_xyz_per_head(): + """Test training with multiple XYZ files per head for train, valid and test sets""" + # Create temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Set up file paths - create multiple xyz files for each dataset + train_xyz_files = [ + os.path.join(temp_dir, f"train_data{i}.xyz") for i in range(1, 4) + ] # 3 train files + valid_xyz_files = [ + os.path.join(temp_dir, f"valid_data{i}.xyz") for i in range(1, 3) + ] # 2 valid files + test_xyz_files = [ + os.path.join(temp_dir, f"test_data{i}.xyz") for i in range(1, 3) + ] # 2 test files + + iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") + + config_path = os.path.join(temp_dir, "config.yaml") + results_dir = os.path.join(temp_dir, "results") + checkpoints_dir = os.path.join(temp_dir, "checkpoints") + model_dir = os.path.join(temp_dir, "models") + e0s_file = os.path.join(temp_dir, "e0s.json") + + # Create directories + os.makedirs(results_dir, exist_ok=True) + os.makedirs(checkpoints_dir, exist_ok=True) + os.makedirs(model_dir, exist_ok=True) + + # Set atomic numbers for z_table + z_table_elements = [1, 6, 7, 8] # H, C, N, O + + # Create test data for each format + rng = np.random.RandomState(42) + seeds = rng.randint(0, 10000, size=10) # More seeds for multiple files + + # Create isolated atoms for E0s (one of each element) + isolated_atoms = [] + e0s_dict = {} + for z in z_table_elements: + # Create isolated atom + atom = Atoms( + numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True + ) + energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy + forces = np.zeros((1, 3)) + stress = np.zeros(6) + calc = SinglePointCalculator( + atom, energy=energy, forces=forces, stress=stress + ) + atom.calc = calc + atom.info["config_type"] = "IsolatedAtom" + isolated_atoms.append(atom) + e0s_dict[str(z)] = energy # Store energy for E0s file + + # Create E0s file + create_e0s_file(e0s_dict, e0s_file) + + # Create isolated atoms xyz file + create_xyz_file(isolated_atoms, iso_atoms_file) + + # Create atoms for each train dataset - use different seeds for variety + train_datasets = [] + for i, file in enumerate(train_xyz_files): + # Create atoms with different seeds + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i] + j) for j in range(5) + ] + create_xyz_file(atoms, file) + train_datasets.append(atoms) + + # Create atoms for validation datasets + valid_datasets = [] + for i, file in enumerate(valid_xyz_files): + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i + 3] + j) for j in range(3) + ] + create_xyz_file(atoms, file) + valid_datasets.append(atoms) + + # Create atoms for test datasets + test_datasets = [] + for i, file in enumerate(test_xyz_files): + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i + 5] + j) for j in range(3) + ] + create_xyz_file(atoms, file) + test_datasets.append(atoms) + + # Create config.yaml for training with multiple xyz files per dataset + config = { + "name": "multi_xyz_test", + "seed": 42, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 5, + "max_num_epochs": 2, + "patience": 5, + "device": "cpu", + "energy_weight": 1.0, + "forces_weight": 10.0, + "loss": "weighted", + "optimizer": "adam", + "default_dtype": "float64", + "lr": 0.01, + "swa": False, + "work_dir": temp_dir, + "results_dir": results_dir, + "checkpoints_dir": checkpoints_dir, + "model_dir": model_dir, + "E0s": e0s_file, + "atomic_numbers": str(z_table_elements), + "heads": { + "multi_xyz_head": { + # Using lists of multiple xyz files for each dataset + "train_file": train_xyz_files, + "valid_file": valid_xyz_files, + "test_file": test_xyz_files, + "energy_key": "energy", + "forces_key": "forces", + "stress_key": "stress", + }, + }, + } + + # Write config file + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config, f) + + # Import the modified run_train from our local module + run_train_script = ( + Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + ) + + # Run training with subprocess + cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] + + # Set environment to add the current path to PYTHONPATH + env = os.environ.copy() + env["PYTHONPATH"] = ( + str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") + ) + + # Run the process + process = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + + # Print output for debugging + print("\n" + "=" * 40 + " STDOUT " + "=" * 40) + print(process.stdout.decode()) + print("\n" + "=" * 40 + " STDERR " + "=" * 40) + print(process.stderr.decode()) + + # Check that process completed successfully + assert ( + process.returncode == 0 + ), f"Training failed with error: {process.stderr.decode()}" + + # Check that model was created + model_path = os.path.join(model_dir, "multi_xyz_test.model") + assert os.path.exists(model_path), f"Model was not created at {model_path}" + + # Try to load and run the model + model = torch.load(model_path, map_location="cpu") + assert model is not None, "Failed to load model" + + # Create a calculator + calc = MACECalculator( + model_paths=model_path, device="cpu", head="multi_xyz_head" + ) + + # Run prediction on a test atom + test_atom = create_test_atoms(num_atoms=5, seed=99999) + test_atom.calc = calc + energy = test_atom.get_potential_energy() + forces = test_atom.get_forces() + + # Assert we got sensible outputs + assert np.isfinite(energy), "Model produced non-finite energy" + assert np.all(np.isfinite(forces)), "Model produced non-finite forces" + + finally: + # Clean up + shutil.rmtree(temp_dir) + + +@pytest.mark.slow +def test_single_xyz_per_head(): + """Test training with multiple XYZ files per head for train, valid and test sets""" + # Create temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Set up file paths - create multiple xyz files for each dataset + train_xyz_files = [ + os.path.join(temp_dir, f"train_data{i}.xyz") for i in range(1, 2) + ] # 3 train files + valid_xyz_files = [ + os.path.join(temp_dir, f"valid_data{i}.xyz") for i in range(1, 2) + ] # 2 valid files + test_xyz_files = [ + os.path.join(temp_dir, f"test_data{i}.xyz") for i in range(1, 2) + ] # 2 test files + + iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") + + config_path = os.path.join(temp_dir, "config.yaml") + results_dir = os.path.join(temp_dir, "results") + checkpoints_dir = os.path.join(temp_dir, "checkpoints") + model_dir = os.path.join(temp_dir, "models") + e0s_file = os.path.join(temp_dir, "e0s.json") + + # Create directories + os.makedirs(results_dir, exist_ok=True) + os.makedirs(checkpoints_dir, exist_ok=True) + os.makedirs(model_dir, exist_ok=True) + + # Set atomic numbers for z_table + z_table_elements = [1, 6, 7, 8] # H, C, N, O + + # Create test data for each format + rng = np.random.RandomState(42) + seeds = rng.randint(0, 10000, size=10) # More seeds for multiple files + + # Create isolated atoms for E0s (one of each element) + isolated_atoms = [] + e0s_dict = {} + for z in z_table_elements: + # Create isolated atom + atom = Atoms( + numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True + ) + energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy + forces = np.zeros((1, 3)) + stress = np.zeros(6) + calc = SinglePointCalculator( + atom, energy=energy, forces=forces, stress=stress + ) + atom.calc = calc + atom.info["config_type"] = "IsolatedAtom" + isolated_atoms.append(atom) + e0s_dict[str(z)] = energy # Store energy for E0s file + + # Create E0s file + create_e0s_file(e0s_dict, e0s_file) + + # Create isolated atoms xyz file + create_xyz_file(isolated_atoms, iso_atoms_file) + + # Create atoms for each train dataset - use different seeds for variety + train_datasets = [] + for i, file in enumerate(train_xyz_files): + # Create atoms with different seeds + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i] + j) for j in range(5) + ] + create_xyz_file(atoms, file) + train_datasets.append(atoms) + + # Create atoms for validation datasets + valid_datasets = [] + for i, file in enumerate(valid_xyz_files): + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i + 3] + j) for j in range(3) + ] + create_xyz_file(atoms, file) + valid_datasets.append(atoms) + + # Create atoms for test datasets + test_datasets = [] + for i, file in enumerate(test_xyz_files): + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i + 5] + j) for j in range(3) + ] + create_xyz_file(atoms, file) + test_datasets.append(atoms) + + # Create config.yaml for training with multiple xyz files per dataset + config = { + "name": "multi_xyz_test", + "seed": 42, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 5, + "max_num_epochs": 2, + "patience": 5, + "device": "cpu", + "energy_weight": 1.0, + "forces_weight": 10.0, + "loss": "weighted", + "optimizer": "adam", + "default_dtype": "float64", + "lr": 0.01, + "swa": False, + "work_dir": temp_dir, + "results_dir": results_dir, + "checkpoints_dir": checkpoints_dir, + "model_dir": model_dir, + "E0s": e0s_file, + "atomic_numbers": str(z_table_elements), + "heads": { + "multi_xyz_head": { + # Using lists of multiple xyz files for each dataset + "train_file": train_xyz_files, + "valid_file": valid_xyz_files, + "test_file": test_xyz_files, + "energy_key": "energy", + "forces_key": "forces", + "stress_key": "stress", + }, + }, + } + + # Write config file + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config, f) + + # Import the modified run_train from our local module + run_train_script = ( + Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + ) + + # Run training with subprocess + cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] + + # Set environment to add the current path to PYTHONPATH + env = os.environ.copy() + env["PYTHONPATH"] = ( + str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") + ) + + # Run the process + process = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + + # Print output for debugging + print("\n" + "=" * 40 + " STDOUT " + "=" * 40) + print(process.stdout.decode()) + print("\n" + "=" * 40 + " STDERR " + "=" * 40) + print(process.stderr.decode()) + + # Check that process completed successfully + assert ( + process.returncode == 0 + ), f"Training failed with error: {process.stderr.decode()}" + + # Check that model was created + model_path = os.path.join(model_dir, "multi_xyz_test.model") + assert os.path.exists(model_path), f"Model was not created at {model_path}" + + # Try to load and run the model + model = torch.load(model_path, map_location="cpu") + assert model is not None, "Failed to load model" + + # Create a calculator + calc = MACECalculator( + model_paths=model_path, device="cpu", head="multi_xyz_head" + ) + + # Run prediction on a test atom + test_atom = create_test_atoms(num_atoms=5, seed=99999) + test_atom.calc = calc + energy = test_atom.get_potential_energy() + forces = test_atom.get_forces() + + # Assert we got sensible outputs + assert np.isfinite(energy), "Model produced non-finite energy" + assert np.all(np.isfinite(forces)), "Model produced non-finite forces" + + finally: + # Clean up + shutil.rmtree(temp_dir) + + +@pytest.mark.slow +def test_multihead_finetuning_different_formats(): + """Test multihead finetuning with different file formats for each head.""" + # Create temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Set up file paths + xyz_file = os.path.join(temp_dir, "finetuning_xyz.xyz") + h5_folder = os.path.join(temp_dir, "h5_data") + iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") + + config_path = os.path.join(temp_dir, "config.yaml") + results_dir = os.path.join(temp_dir, "results") + checkpoints_dir = os.path.join(temp_dir, "checkpoints") + model_dir = os.path.join(temp_dir, "models") + e0s_file = os.path.join(temp_dir, "e0s.json") + + # Create directories + os.makedirs(results_dir, exist_ok=True) + os.makedirs(checkpoints_dir, exist_ok=True) + os.makedirs(model_dir, exist_ok=True) + + # Set atomic numbers for z_table + z_table_elements = [1, 6, 7, 8] # H, C, N, O + + # Create test data with different seeds + rng = np.random.RandomState(42) + seeds = rng.randint(0, 10000, size=3) + + # Create isolated atoms for E0s (one of each element) + isolated_atoms = [] + e0s_dict = {} + for z in z_table_elements: + atom = Atoms( + numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True + ) + energy = float(rng.uniform(-5.0, -1.0)) + forces = np.zeros((1, 3)) + stress = np.zeros(6) + calc = SinglePointCalculator( + atom, energy=energy, forces=forces, stress=stress + ) + atom.calc = calc + atom.info["config_type"] = "IsolatedAtom" + atom.info["REF_energy"] = energy # Make sure energy is in the right place + atom.arrays["REF_forces"] = forces + atom.info["REF_stress"] = stress + isolated_atoms.append(atom) + e0s_dict[str(z)] = energy + + # Create E0s file + create_e0s_file(e0s_dict, e0s_file) + + # Create isolated atoms xyz file + create_xyz_file(isolated_atoms, iso_atoms_file) + + # Create XYZ data for xyz_head + xyz_atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[0] + i) for i in range(30) + ] + # Add REF_ properties + for atom in xyz_atoms: + atom.info["REF_energy"] = atom.calc.results["energy"] + atom.arrays["REF_forces"] = atom.calc.results["forces"] + atom.info["REF_stress"] = atom.calc.results["stress"] + atom.info["head"] = "xyz_head" # Assign head + create_xyz_file(xyz_atoms, xyz_file) + + # Create H5 data for h5_head + h5_atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[1] + i) for i in range(30) + ] + # Add REF_ properties + for atom in h5_atoms: + atom.info["REF_energy"] = atom.calc.results["energy"] + atom.arrays["REF_forces"] = atom.calc.results["forces"] + atom.info["REF_stress"] = atom.calc.results["stress"] + atom.info["head"] = "h5_head" # Assign head + + h5_atoms_xyz = os.path.join(temp_dir, "h5_atoms.xyz") + create_xyz_file(h5_atoms, h5_atoms_xyz) + # Include isolated atoms for E0s in the h5 dataset + all_atoms_for_h5 = h5_atoms + isolated_atoms + all_atoms_h5_xyz = os.path.join(temp_dir, "all_atoms_for_h5.xyz") + create_xyz_file(all_atoms_for_h5, all_atoms_h5_xyz) + create_h5_dataset(all_atoms_h5_xyz, h5_folder) + + # Create config.yaml for multihead finetuning + heads = { + "xyz_head": { + "train_file": xyz_file, + "valid_fraction": 0.2, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "E0s": e0s_file, + }, + "h5_head": { + "train_file": os.path.join(h5_folder, "train"), + "valid_file": os.path.join(h5_folder, "val"), + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "E0s": e0s_file, + }, + } + + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + + with open(config_path, "w", encoding="utf-8") as f: + f.write(yaml_str) + + # Now perform multihead finetuning + finetuning_params = { + "name": "multihead_finetuned", + "config": config_path, + "foundation_model": "small", # Use the small foundation model + "energy_weight": 1.0, + "forces_weight": 10.0, + "model": "MACE", + "hidden_irreps": "128x0e", # Match foundation model + "r_max": 5.0, + "batch_size": 2, + "max_num_epochs": 2, # Just do a quick finetuning for test + "device": "cpu", + "seed": 42, + "loss": "weighted", + "default_dtype": "float64", + "checkpoints_dir": checkpoints_dir, + "model_dir": model_dir, + "results_dir": results_dir, + "atomic_numbers": "[" + ",".join(map(str, z_table_elements)) + "]", + "multiheads_finetuning": True, + "filter_type_pt": "combinations", + "subselect_pt": "random", + "num_samples_pt": 10, # Small number for testing + "force_mh_ft_lr": True, # Force using specified learning rate + } + + # Run finetuning + run_train_script = ( + Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + ) + env = os.environ.copy() + env["PYTHONPATH"] = ( + str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") + ) + + cmd = [sys.executable, str(run_train_script)] + for k, v in finetuning_params.items(): + if v is None: + cmd.append(f"--{k}") + else: + cmd.append(f"--{k}={v}") + + # Run the process + process = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + + # Print output for debugging + print("\n" + "=" * 40 + " STDOUT " + "=" * 40) + print(process.stdout.decode()) + print("\n" + "=" * 40 + " STDERR " + "=" * 40) + print(process.stderr.decode()) + + # Check that process completed successfully + assert ( + process.returncode == 0 + ), f"Finetuning failed with error: {process.stderr.decode()}" + + # Check that model was created + model_path = os.path.join(model_dir, "multihead_finetuned.model") + assert os.path.exists(model_path), f"Model was not created at {model_path}" + + # Load model and verify it has the expected heads + model = torch.load(model_path, map_location="cpu") + assert hasattr(model, "heads"), "Model does not have heads attribute" + assert set(["xyz_head", "h5_head", "pt_head"]).issubset( + set(model.heads) + ), "Expected heads not found in model" + + # Try to run the model with both heads + # For xyz_head + calc_xyz = MACECalculator( + model_paths=model_path, + device="cpu", + head="xyz_head", + default_dtype="float64", + ) + test_atom = create_test_atoms(num_atoms=5, seed=99999) + test_atom.calc = calc_xyz + energy_xyz = test_atom.get_potential_energy() + forces_xyz = test_atom.get_forces() + + # For h5_head + calc_h5 = MACECalculator( + model_paths=model_path, + device="cpu", + head="h5_head", + default_dtype="float64", + ) + test_atom.calc = calc_h5 + energy_h5 = test_atom.get_potential_energy() + forces_h5 = test_atom.get_forces() + + # Verify results + assert np.isfinite(energy_xyz), "xyz_head produced non-finite energy" + assert np.all(np.isfinite(forces_xyz)), "xyz_head produced non-finite forces" + assert np.isfinite(energy_h5), "h5_head produced non-finite energy" + assert np.all(np.isfinite(forces_h5)), "h5_head produced non-finite forces" + + finally: + # Clean up + shutil.rmtree(temp_dir) diff --git a/mace-bench/3rdparty/mace/tests/test_preprocess.py b/mace-bench/3rdparty/mace/tests/test_preprocess.py index f976ff1c035e45ba803efbf988ee195fc535407d..5b070e5df37528e43e799ef11fe16666b221ad2a 100644 --- a/mace-bench/3rdparty/mace/tests/test_preprocess.py +++ b/mace-bench/3rdparty/mace/tests/test_preprocess.py @@ -1,206 +1,206 @@ -import os -import subprocess -import sys -from pathlib import Path - -import ase.io -import numpy as np -import pytest -import yaml -from ase.atoms import Atoms - -pytest_mace_dir = Path(__file__).parent.parent -preprocess_data = Path(__file__).parent.parent / "mace" / "cli" / "preprocess_data.py" - - -@pytest.fixture(name="sample_configs") -def fixture_sample_configs(): - water = Atoms( - numbers=[8, 1, 1], - positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], - cell=[4] * 3, - pbc=[True] * 3, - ) - configs = [ - Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), - Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), - ] - configs[0].info["REF_energy"] = 0.0 - configs[0].info["config_type"] = "IsolatedAtom" - configs[1].info["REF_energy"] = 0.0 - configs[1].info["config_type"] = "IsolatedAtom" - - np.random.seed(5) - for _ in range(10): - c = water.copy() - c.positions += np.random.normal(0.1, size=c.positions.shape) - c.info["REF_energy"] = np.random.normal(0.1) - c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) - c.info["REF_stress"] = np.random.normal(0.1, size=6) - configs.append(c) - - return configs - - -def test_preprocess_data(tmp_path, sample_configs): - ase.io.write(tmp_path / "sample.xyz", sample_configs) - - preprocess_params = { - "train_file": tmp_path / "sample.xyz", - "r_max": 5.0, - "config_type_weights": "{'Default':1.0}", - "num_process": 2, - "valid_fraction": 0.1, - "h5_prefix": tmp_path / "preprocessed_", - "compute_statistics": None, - "seed": 42, - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - } - - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(preprocess_data) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in preprocess_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - # Check if the output files are created - assert (tmp_path / "preprocessed_train").is_dir() - assert (tmp_path / "preprocessed_val").is_dir() - assert (tmp_path / "preprocessed_statistics.json").is_file() - - # Check if the correct number of files are created - train_files = list((tmp_path / "preprocessed_train").glob("*.h5")) - val_files = list((tmp_path / "preprocessed_val").glob("*.h5")) - assert len(train_files) == preprocess_params["num_process"] - assert len(val_files) == preprocess_params["num_process"] - - # Example of checking statistics file content: - import json - - with open(tmp_path / "preprocessed_statistics.json", "r", encoding="utf-8") as f: - statistics = json.load(f) - assert "atomic_energies" in statistics - assert "avg_num_neighbors" in statistics - assert "mean" in statistics - assert "std" in statistics - assert "atomic_numbers" in statistics - assert "r_max" in statistics - - # Example of checking H5 file content: - import h5py - - with h5py.File(train_files[0], "r") as f: - assert "config_batch_0" in f - config = f["config_batch_0"]["config_0"] - assert "atomic_numbers" in config - assert "positions" in config - assert "energy" in config["properties"] - assert "forces" in config["properties"] - - original_energies = [ - config.info["REF_energy"] - for config in sample_configs[2:] - if "REF_energy" in config.info - ] - original_forces = [ - config.arrays["REF_forces"] - for config in sample_configs[2:] - if "REF_forces" in config.arrays - ] - - h5_energies = [] - h5_forces = [] - - for train_file in train_files: - with h5py.File(train_file, "r") as f: - for _, batch in f.items(): - for config_key in batch.keys(): - config = batch[config_key] - assert "atomic_numbers" in config - assert "positions" in config - assert "energy" in config["properties"] - assert "forces" in config["properties"] - - h5_energies.append(config["properties"]["energy"][()]) - h5_forces.append(config["properties"]["forces"][()]) - - for val_file in val_files: - with h5py.File(val_file, "r") as f: - for _, batch in f.items(): - for config_key in batch.keys(): - config = batch[config_key] - h5_energies.append(config["properties"]["energy"][()]) - h5_forces.append(config["properties"]["forces"][()]) - - print("Original energies", original_energies) - print("H5 energies", h5_energies) - print("Original forces", original_forces) - print("H5 forces", h5_forces) - original_energies.sort() - h5_energies.sort() - original_forces = np.concatenate(original_forces).flatten() - h5_forces = np.concatenate(h5_forces).flatten() - original_forces.sort() - h5_forces.sort() - - # Compare energies and forces - np.testing.assert_allclose(original_energies, h5_energies, rtol=1e-5, atol=1e-8) - np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8) - - print("All checks passed successfully!") - - -def test_preprocess_config(tmp_path, sample_configs): - ase.io.write(tmp_path / "sample.xyz", sample_configs) - - preprocess_params = { - "train_file": str(tmp_path / "sample.xyz"), - "r_max": 5.0, - "config_type_weights": "{'Default':1.0}", - "num_process": 2, - "valid_fraction": 0.1, - "h5_prefix": str(tmp_path / "preprocessed_"), - "compute_statistics": None, - "seed": 42, - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - } - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - yaml.dump(preprocess_params, file) - - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(preprocess_data) - + " " - + "--config" - + " " - + str(filename) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +import yaml +from ase.atoms import Atoms + +pytest_mace_dir = Path(__file__).parent.parent +preprocess_data = Path(__file__).parent.parent / "mace" / "cli" / "preprocess_data.py" + + +@pytest.fixture(name="sample_configs") +def fixture_sample_configs(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] + configs[0].info["REF_energy"] = 0.0 + configs[0].info["config_type"] = "IsolatedAtom" + configs[1].info["REF_energy"] = 0.0 + configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(10): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + configs.append(c) + + return configs + + +def test_preprocess_data(tmp_path, sample_configs): + ase.io.write(tmp_path / "sample.xyz", sample_configs) + + preprocess_params = { + "train_file": tmp_path / "sample.xyz", + "r_max": 5.0, + "config_type_weights": "{'Default':1.0}", + "num_process": 2, + "valid_fraction": 0.1, + "h5_prefix": tmp_path / "preprocessed_", + "compute_statistics": None, + "seed": 42, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + } + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(preprocess_data) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in preprocess_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Check if the output files are created + assert (tmp_path / "preprocessed_train").is_dir() + assert (tmp_path / "preprocessed_val").is_dir() + assert (tmp_path / "preprocessed_statistics.json").is_file() + + # Check if the correct number of files are created + train_files = list((tmp_path / "preprocessed_train").glob("*.h5")) + val_files = list((tmp_path / "preprocessed_val").glob("*.h5")) + assert len(train_files) == preprocess_params["num_process"] + assert len(val_files) == preprocess_params["num_process"] + + # Example of checking statistics file content: + import json + + with open(tmp_path / "preprocessed_statistics.json", "r", encoding="utf-8") as f: + statistics = json.load(f) + assert "atomic_energies" in statistics + assert "avg_num_neighbors" in statistics + assert "mean" in statistics + assert "std" in statistics + assert "atomic_numbers" in statistics + assert "r_max" in statistics + + # Example of checking H5 file content: + import h5py + + with h5py.File(train_files[0], "r") as f: + assert "config_batch_0" in f + config = f["config_batch_0"]["config_0"] + assert "atomic_numbers" in config + assert "positions" in config + assert "energy" in config["properties"] + assert "forces" in config["properties"] + + original_energies = [ + config.info["REF_energy"] + for config in sample_configs[2:] + if "REF_energy" in config.info + ] + original_forces = [ + config.arrays["REF_forces"] + for config in sample_configs[2:] + if "REF_forces" in config.arrays + ] + + h5_energies = [] + h5_forces = [] + + for train_file in train_files: + with h5py.File(train_file, "r") as f: + for _, batch in f.items(): + for config_key in batch.keys(): + config = batch[config_key] + assert "atomic_numbers" in config + assert "positions" in config + assert "energy" in config["properties"] + assert "forces" in config["properties"] + + h5_energies.append(config["properties"]["energy"][()]) + h5_forces.append(config["properties"]["forces"][()]) + + for val_file in val_files: + with h5py.File(val_file, "r") as f: + for _, batch in f.items(): + for config_key in batch.keys(): + config = batch[config_key] + h5_energies.append(config["properties"]["energy"][()]) + h5_forces.append(config["properties"]["forces"][()]) + + print("Original energies", original_energies) + print("H5 energies", h5_energies) + print("Original forces", original_forces) + print("H5 forces", h5_forces) + original_energies.sort() + h5_energies.sort() + original_forces = np.concatenate(original_forces).flatten() + h5_forces = np.concatenate(h5_forces).flatten() + original_forces.sort() + h5_forces.sort() + + # Compare energies and forces + np.testing.assert_allclose(original_energies, h5_energies, rtol=1e-5, atol=1e-8) + np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8) + + print("All checks passed successfully!") + + +def test_preprocess_config(tmp_path, sample_configs): + ase.io.write(tmp_path / "sample.xyz", sample_configs) + + preprocess_params = { + "train_file": str(tmp_path / "sample.xyz"), + "r_max": 5.0, + "config_type_weights": "{'Default':1.0}", + "num_process": 2, + "valid_fraction": 0.1, + "h5_prefix": str(tmp_path / "preprocessed_"), + "compute_statistics": None, + "seed": 42, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + } + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + yaml.dump(preprocess_params, file) + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(preprocess_data) + + " " + + "--config" + + " " + + str(filename) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 diff --git a/mace-bench/3rdparty/mace/tests/test_run_train.py b/mace-bench/3rdparty/mace/tests/test_run_train.py index 921d12e094eb89b775f246ded36c733ad12ea6cf..ddb849686ba3bb0a1085c2a44583c58e978d079a 100644 --- a/mace-bench/3rdparty/mace/tests/test_run_train.py +++ b/mace-bench/3rdparty/mace/tests/test_run_train.py @@ -1,1458 +1,1458 @@ -import json -import os -import subprocess -import sys -from pathlib import Path - -import ase.io -import numpy as np -import pytest -import torch -from ase.atoms import Atoms - -from mace.calculators import MACECalculator, mace_mp - -try: - import cuequivariance as cue # pylint: disable=unused-import - - CUET_AVAILABLE = True -except ImportError: - CUET_AVAILABLE = False - -run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - - -@pytest.fixture(name="fitting_configs") -def fixture_fitting_configs(): - water = Atoms( - numbers=[8, 1, 1], - positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], - cell=[4] * 3, - pbc=[True] * 3, - ) - fit_configs = [ - Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), - Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), - ] - fit_configs[0].info["REF_energy"] = 0.0 - fit_configs[0].info["config_type"] = "IsolatedAtom" - fit_configs[1].info["REF_energy"] = 0.0 - fit_configs[1].info["config_type"] = "IsolatedAtom" - - np.random.seed(5) - for _ in range(20): - c = water.copy() - c.positions += np.random.normal(0.1, size=c.positions.shape) - c.info["REF_energy"] = np.random.normal(0.1) - print(c.info["REF_energy"]) - c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) - c.info["REF_stress"] = np.random.normal(0.1, size=6) - fit_configs.append(c) - - return fit_configs - - -@pytest.fixture(name="pretraining_configs") -def fixture_pretraining_configs(): - configs = [] - for _ in range(10): - atoms = Atoms( - numbers=[8, 1, 1], - positions=np.random.rand(3, 3) * 3, - cell=[5, 5, 5], - pbc=[True] * 3, - ) - atoms.info["REF_energy"] = np.random.normal(0, 1) - atoms.arrays["REF_forces"] = np.random.normal(0, 1, size=(3, 3)) - atoms.info["REF_stress"] = np.random.normal(0, 1, size=6) - configs.append(atoms) - configs.append( - Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3), - ) - configs.append( - Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3) - ) - configs[-2].info["REF_energy"] = -2.0 - configs[-2].info["config_type"] = "IsolatedAtom" - configs[-1].info["REF_energy"] = -4.0 - configs[-1].info["config_type"] = "IsolatedAtom" - return configs - - -_mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "128x0e", - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "stress", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "eval_interval": 2, -} - - -def test_run_train(tmp_path, fitting_configs): - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 - ref_Es = [ - 0.0, - 0.0, - -0.039181344585828524, - -0.0915223395136733, - -0.14953484236456582, - -0.06662480820063998, - -0.09983737353050133, - 0.12477442296789745, - -0.06486086271762856, - -0.1460607988519944, - 0.12886334908465508, - -0.14000990081920373, - -0.05319886578958313, - 0.07780520158391, - -0.08895480281886901, - -0.15474719614734422, - 0.007756765146527644, - -0.044879267197498685, - -0.036065736712447574, - -0.24413743841886623, - -0.0838104612106429, - -0.14751978636626545, - ] - - assert np.allclose(Es, ref_Es) - - -def test_run_train_missing_data(tmp_path, fitting_configs): - del fitting_configs[5].info["REF_energy"] - del fitting_configs[6].arrays["REF_forces"] - del fitting_configs[7].info["REF_stress"] - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 - ref_Es = [ - 0.0, - 0.0, - -0.05464025113696155, - -0.11272131295940478, - 0.039200919331076826, - -0.07517990972827505, - -0.13504202474582666, - 0.0292022872055344, - -0.06541099574579018, - -0.1497824717832886, - 0.19397709360828813, - -0.13587609467143014, - -0.05242956276828463, - -0.0504862057364953, - -0.07095795959430119, - -0.2463753796753703, - -0.002031543147676121, - -0.03864918790300681, - -0.13680153117705554, - -0.23418951968636786, - -0.11790833839379238, - -0.14930562311066484, - ] - assert np.allclose(Es, ref_Es) - - -def test_run_train_no_stress(tmp_path, fitting_configs): - del fitting_configs[5].info["REF_energy"] - del fitting_configs[6].arrays["REF_forces"] - del fitting_configs[7].info["REF_stress"] - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - mace_params["loss"] = "weighted" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3 - ref_Es = [ - 0.0, - 0.0, - -0.05450093218377135, - -0.11235475232750518, - 0.03914558031854152, - -0.07500839914816063, - -0.13469160624431492, - 0.029384214243251838, - -0.06521819204166135, - -0.14944896282001804, - 0.19413948083049481, - -0.13543541860473626, - -0.05235495076237124, - -0.049556206595684105, - -0.07080758913030646, - -0.24571898386301153, - -0.002070636306950905, - -0.03863113401320783, - -0.13620291339913712, - -0.23383074855679695, - -0.11776449630199368, - -0.1489441490225184, - ] - assert np.allclose(Es, ref_Es) - - -def test_run_train_multihead(tmp_path, fitting_configs): - fitting_configs_dft = [] - fitting_configs_mp2 = [] - fitting_configs_ccd = [] - for _, c in enumerate(fitting_configs): - c_dft = c.copy() - c_dft.info["head"] = "DFT" - fitting_configs_dft.append(c_dft) - - c_mp2 = c.copy() - c_mp2.info["head"] = "MP2" - fitting_configs_mp2.append(c_mp2) - - c_ccd = c.copy() - c_ccd.info["head"] = "CCD" - fitting_configs_ccd.append(c_ccd) - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) - ase.io.write(tmp_path / "fit_multihead_ccd.xyz", fitting_configs_ccd) - - heads = { - "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, - "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, - "CCD": {"train_file": f"{str(tmp_path)}/fit_multihead_ccd.xyz"}, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(yaml_str) - - mace_params = _mace_params.copy() - mace_params["valid_fraction"] = 0.1 - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["loss"] = "weighted" - mace_params["hidden_irreps"] = "128x0e" - mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float64" - mace_params["num_radial_basis"] = 10 - mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["config"] = tmp_path / "config.yaml" - mace_params["batch_size"] = 2 - mace_params["num_samples_pt"] = 50 - mace_params["subselect_pt"] = "random" - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", - device="cpu", - default_dtype="float64", - head="CCD", - ) - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 02/09/2024 on develop branch - ref_Es = [ - 0.0, - 0.0, - 0.10637113905361611, - -0.012499594026624754, - 0.08983077108171753, - 0.21071322543112597, - -0.028921849222784398, - -0.02423359575741567, - 0.022923252188079057, - -0.02048334610058991, - 0.4349711162741364, - -0.04455577015569887, - -0.09765806785570091, - 0.16013134616829822, - 0.0758442928017698, - -0.05931856557011721, - 0.33964473532953265, - 0.134338442158641, - 0.18024119757783053, - -0.18914740992058765, - -0.06503477155294624, - 0.03436649147415213, - ] - assert np.allclose(Es, ref_Es) - - -def test_run_train_foundation(tmp_path, fitting_configs): - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - mace_params["loss"] = "weighted" - mace_params["foundation_model"] = "small" - mace_params["hidden_irreps"] = "128x0e" - mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float64" - mace_params["num_radial_basis"] = 10 - mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["multiheads_finetuning"] = False - - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" - ) - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 28/03/2023 on repulsion a63434aaab70c84ee016e13e4aca8d57297a0f26 - ref_Es = [ - 1.6780993938446045, - 0.8916864395141602, - 0.7290308475494385, - 0.6194742918014526, - 0.6697757840156555, - 0.7025266289710999, - 0.5818213224411011, - 0.7897703647613525, - 0.6558921337127686, - 0.5071806907653809, - 3.581131935119629, - 0.691562294960022, - 0.6257331967353821, - 0.9560437202453613, - 0.7716934680938721, - 0.6730310916900635, - 0.8297463655471802, - 0.8053972721099854, - 0.8337507247924805, - 0.4107491970062256, - 0.6019601821899414, - 0.7301387786865234, - ] - assert np.allclose(Es, ref_Es) - - -def test_run_train_foundation_multihead(tmp_path, fitting_configs): - fitting_configs_dft = [] - fitting_configs_mp2 = [] - atomic_numbers = np.unique( - np.concatenate([at.numbers for at in fitting_configs]) - ).tolist() - for i, c in enumerate(fitting_configs): - if i in (0, 1): - c_dft = c.copy() - c_dft.info["head"] = "DFT" - fitting_configs_dft.append(c_dft) - fitting_configs_dft.append(c) - c_mp2 = c.copy() - c_mp2.info["head"] = "MP2" - fitting_configs_mp2.append(c_mp2) - elif i % 2 == 0: - c.info["head"] = "DFT" - fitting_configs_dft.append(c) - else: - c.info["head"] = "MP2" - fitting_configs_mp2.append(c) - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) - heads = { - "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, - "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(yaml_str) - mace_params = _mace_params.copy() - mace_params["valid_fraction"] = 0.1 - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["config"] = tmp_path / "config.yaml" - mace_params["loss"] = "weighted" - mace_params["foundation_model"] = "small" - mace_params["hidden_irreps"] = "128x0e" - mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float64" - mace_params["num_radial_basis"] = 10 - mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["batch_size"] = 2 - mace_params["valid_batch_size"] = 1 - mace_params["num_samples_pt"] = 50 - mace_params["subselect_pt"] = "random" - mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" - mace_params["filter_type_pt"] = "combinations" - mace_params["force_mh_ft_lr"] = True - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - try: - completed_process = subprocess.run( - cmd.split(), env=run_env, capture_output=True, text=True, check=True - ) - # Process executed successfully - print(completed_process.stdout) - except subprocess.CalledProcessError as e: - # Process failed with non-zero exit code - print(f"Command failed with exit code {e.returncode}") - print(f"STDOUT: {e.stdout}") - print(f"STDERR: {e.stderr}") - raise e - assert completed_process.returncode == 0 - - Es = [] - for at in fitting_configs: - config_head = at.info.get("head", "MP2") - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", - device="cpu", - default_dtype="float64", - head=config_head, - ) - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 20/08/2024 on commit - ref_Es = [ - 1.654685616493225, - 0.44693732261657715, - 0.8741313815116882, - 0.569085955619812, - 0.7161882519721985, - 0.8654778599739075, - 0.8722733855247498, - 0.49582308530807495, - 0.814422607421875, - 0.7027317881584167, - 0.7196993827819824, - 0.517953097820282, - 0.8631765246391296, - 0.4679797887802124, - 0.8163984417915344, - 0.4252359867095947, - 1.0861445665359497, - 0.6829671263694763, - 0.7136879563331604, - 0.5160345435142517, - 0.7002358436584473, - 0.5574042201042175, - ] - assert np.allclose(Es, ref_Es, atol=1e-1) - - -def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): - fitting_configs_dft = [] - fitting_configs_mp2 = [] - atomic_numbers = np.unique( - np.concatenate([at.numbers for at in fitting_configs]) - ).tolist() - for i, c in enumerate(fitting_configs): - - if i in (0, 1): - continue # skip isolated atoms, as energies specified by json files below - if i % 2 == 0: - c.info["head"] = "DFT" - fitting_configs_dft.append(c) - else: - c.info["head"] = "MP2" - fitting_configs_mp2.append(c) - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) - - # write E0s to json files - E0s = {1: 0.0, 8: 0.0} - with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: - json.dump(E0s, f) - with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: - json.dump(E0s, f) - - heads = { - "DFT": { - "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", - "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", - }, - "MP2": { - "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", - "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", - }, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(yaml_str) - mace_params = _mace_params.copy() - mace_params["valid_fraction"] = 0.1 - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["config"] = tmp_path / "config.yaml" - mace_params["loss"] = "weighted" - mace_params["foundation_model"] = "small" - mace_params["hidden_irreps"] = "128x0e" - mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float64" - mace_params["num_radial_basis"] = 10 - mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["batch_size"] = 2 - mace_params["valid_batch_size"] = 1 - mace_params["num_samples_pt"] = 50 - mace_params["subselect_pt"] = "random" - mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" - mace_params["filter_type_pt"] = "combinations" - mace_params["force_mh_ft_lr"] = True - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - try: - completed_process = subprocess.run( - cmd.split(), env=run_env, capture_output=True, text=True, check=True - ) - # Process executed successfully - print(completed_process.stdout) - except subprocess.CalledProcessError as e: - # Process failed with non-zero exit code - print(f"Command failed with exit code {e.returncode}") - print(f"STDOUT: {e.stdout}") - print(f"STDERR: {e.stderr}") - raise e - assert completed_process.returncode == 0 - - Es = [] - for at in fitting_configs: - config_head = at.info.get("head", "MP2") - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", - device="cpu", - default_dtype="float64", - head=config_head, - ) - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 20/08/2024 on commit - ref_Es = [ - 1.654685616493225, - 0.44693732261657715, - 0.8741313815116882, - 0.569085955619812, - 0.7161882519721985, - 0.8654778599739075, - 0.8722733855247498, - 0.49582308530807495, - 0.814422607421875, - 0.7027317881584167, - 0.7196993827819824, - 0.517953097820282, - 0.8631765246391296, - 0.4679797887802124, - 0.8163984417915344, - 0.4252359867095947, - 1.0861445665359497, - 0.6829671263694763, - 0.7136879563331604, - 0.5160345435142517, - 0.7002358436584473, - 0.5574042201042175, - ] - assert np.allclose(Es, ref_Es, atol=1e-1) - - -def test_run_train_multihead_replay_custum_finetuning( - tmp_path, fitting_configs, pretraining_configs -): - ase.io.write(tmp_path / "pretrain.xyz", pretraining_configs) - - foundation_params = { - "name": "foundation", - "train_file": os.path.join(tmp_path, "pretrain.xyz"), - "valid_fraction": 0.2, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "32x0e", - "r_max": 5.0, - "batch_size": 2, - "max_num_epochs": 5, - "swa": None, - "start_swa": 3, - "device": "cpu", - "seed": 42, - "loss": "weighted", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "default_dtype": "float64", - "checkpoints_dir": str(tmp_path), - "model_dir": str(tmp_path), - } - - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - - cmd = [sys.executable, str(run_train)] - for k, v in foundation_params.items(): - if v is None: - cmd.append(f"--{k}") - else: - cmd.append(f"--{k}={v}") - - p = subprocess.run(cmd, env=run_env, check=True) - assert p.returncode == 0 - - # Step 3: Create finetuning set - fitting_configs_dft = [] - fitting_configs_mp2 = [] - for i, c in enumerate(fitting_configs): - if i in (0, 1): - c_dft = c.copy() - c_dft.info["head"] = "DFT" - fitting_configs_dft.append(c_dft) - fitting_configs_dft.append(c) - c_mp2 = c.copy() - c_mp2.info["head"] = "MP2" - fitting_configs_mp2.append(c_mp2) - elif i % 2 == 0: - c.info["head"] = "DFT" - fitting_configs_dft.append(c) - else: - c.info["head"] = "MP2" - fitting_configs_mp2.append(c) - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) - - # Step 4: Finetune the pretrained model with multihead replay - heads = { - "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, - "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(yaml_str) - - finetuning_params = { - "name": "finetuned", - "valid_fraction": 0.1, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "32x0e", - "r_max": 5.0, - "batch_size": 2, - "max_num_epochs": 5, - "device": "cpu", - "seed": 42, - "loss": "weighted", - "default_dtype": "float64", - "checkpoints_dir": str(tmp_path), - "model_dir": str(tmp_path), - "foundation_model": os.path.join(tmp_path, "foundation.model"), - "config": os.path.join(tmp_path, "config.yaml"), - "pt_train_file": os.path.join(tmp_path, "pretrain.xyz"), - "num_samples_pt": 3, - "subselect_pt": "random", - "force_mh_ft_lr": True, - } - - cmd = [sys.executable, str(run_train)] - for k, v in finetuning_params.items(): - if v is None: - cmd.append(f"--{k}") - else: - cmd.append(f"--{k}={v}") - - p = subprocess.run(cmd, env=run_env, check=True) - assert p.returncode == 0 - - # Load and test the finetuned model - calc = MACECalculator( - model_paths=tmp_path / "finetuned.model", - device="cpu", - default_dtype="float64", - head="pt_head", - ) - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Energies:", Es) - - # Add some basic checks - assert len(Es) == len(fitting_configs) - assert all(isinstance(E, float) for E in Es) - assert len(set(Es)) > 1 # Ens - - -@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") -def test_run_train_cueq(tmp_path, fitting_configs): - torch.set_default_dtype(torch.float64) - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - mace_params["enable_cueq"] = True - mace_params["default_dtype"] = "float64" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - try: - completed_process = subprocess.run( - cmd.split(), env=run_env, capture_output=True, text=True, check=True - ) - # Process executed successfully - print(completed_process.stdout) - except subprocess.CalledProcessError as e: - # Process failed with non-zero exit code - print(f"Command failed with exit code {e.returncode}") - print(f"STDOUT: {e.stdout}") - print(f"STDERR: {e.stderr}") - raise e - assert completed_process.returncode == 0 - - calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cuda") - Es = [] - for at in fitting_configs[2:]: - at.calc = calc - Es.append(at.get_potential_energy()) - - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", enable_cueq=True - ) - Es_cueq = [] - for at in fitting_configs[2:]: - at.calc = calc - Es_cueq.append(at.get_potential_energy()) - - # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 - ref_Es = [ - -0.039181344585828524, - -0.0915223395136733, - -0.14953484236456582, - -0.06662480820063998, - -0.09983737353050133, - 0.12477442296789745, - -0.06486086271762856, - -0.1460607988519944, - 0.12886334908465508, - -0.14000990081920373, - -0.05319886578958313, - 0.07780520158391, - -0.08895480281886901, - -0.15474719614734422, - 0.007756765146527644, - -0.044879267197498685, - -0.036065736712447574, - -0.24413743841886623, - -0.0838104612106429, - -0.14751978636626545, - ] - - assert np.allclose(Es, ref_Es) - assert np.allclose(ref_Es, Es_cueq) - - -@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") -def test_run_train_foundation_multihead_json_cueq(tmp_path, fitting_configs): - fitting_configs_dft = [] - fitting_configs_mp2 = [] - atomic_numbers = np.unique( - np.concatenate([at.numbers for at in fitting_configs]) - ).tolist() - for i, c in enumerate(fitting_configs): - - if i in (0, 1): - continue # skip isolated atoms, as energies specified by json files below - if i % 2 == 0: - c.info["head"] = "DFT" - fitting_configs_dft.append(c) - else: - c.info["head"] = "MP2" - fitting_configs_mp2.append(c) - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) - - # write E0s to json files - E0s = {1: 0.0, 8: 0.0} - with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: - json.dump(E0s, f) - with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: - json.dump(E0s, f) - - heads = { - "DFT": { - "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", - "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", - }, - "MP2": { - "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", - "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", - }, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(yaml_str) - mace_params = _mace_params.copy() - mace_params["valid_fraction"] = 0.1 - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["config"] = tmp_path / "config.yaml" - mace_params["loss"] = "weighted" - mace_params["foundation_model"] = "small" - mace_params["hidden_irreps"] = "128x0e" - mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float64" - mace_params["num_radial_basis"] = 10 - mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["batch_size"] = 2 - mace_params["valid_batch_size"] = 1 - mace_params["num_samples_pt"] = 50 - mace_params["subselect_pt"] = "random" - mace_params["enable_cueq"] = True - mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" - mace_params["filter_type_pt"] = "combinations" - mace_params["device"] = "cuda" - mace_params["force_mh_ft_lr"] = True - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - try: - completed_process = subprocess.run( - cmd.split(), env=run_env, capture_output=True, text=True, check=True - ) - # Process executed successfully - print(completed_process.stdout) - except subprocess.CalledProcessError as e: - # Process failed with non-zero exit code - print(f"Command failed with exit code {e.returncode}") - print(f"STDOUT: {e.stdout}") - print(f"STDERR: {e.stderr}") - raise e - assert completed_process.returncode == 0 - - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", - device="cuda", - default_dtype="float64", - head="DFT", - ) - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 20/08/2024 on commit - ref_Es = [ - 1.654685616493225, - 0.44693732261657715, - 0.8741313815116882, - 0.569085955619812, - 0.7161882519721985, - 0.8654778599739075, - 0.8722733855247498, - 0.49582308530807495, - 0.814422607421875, - 0.7027317881584167, - 0.7196993827819824, - 0.517953097820282, - 0.8631765246391296, - 0.4679797887802124, - 0.8163984417915344, - 0.4252359867095947, - 1.0861445665359497, - 0.6829671263694763, - 0.7136879563331604, - 0.5160345435142517, - 0.7002358436584473, - 0.5574042201042175, - ] - assert np.allclose(Es, ref_Es, atol=1e-1) - - -def test_run_train_lbfgs(tmp_path, fitting_configs): - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - mace_params["lbfgs"] = None - mace_params["max_num_epochs"] = 2 - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 14/03/2025 - ref_Es = [ - 0.0, - 0.0, - -0.1874197850340979, - -0.25991775038059006, - 0.18263492399322268, - -0.15026829765490662, - -0.2403061362015996, - 0.1689257170630718, - -0.2095568077455055, - -0.2957758160829075, - -0.0035370913684985364, - -0.2195416610745775, - -0.25405549447739517, - -0.06201390990366806, - -0.13332219494388334, - -0.19633181702040337, - 0.013014932630445699, - -0.08808335967147174, - -0.06664444189210728, - -0.4230467426992034, - -0.2348250569553676, - -0.17593904833220647, - ] - assert np.allclose(Es, ref_Es, atol=1e-2) - - -def test_run_train_foundation_elements(tmp_path, fitting_configs): - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - base_params = { - "name": "MACE", - "checkpoints_dir": str(tmp_path), - "model_dir": str(tmp_path), - "train_file": tmp_path / "fit.xyz", - "loss": "weighted", - "foundation_model": "small", - "hidden_irreps": "128x0e", - "r_max": 6.0, - "default_dtype": "float64", - "max_num_epochs": 5, - "num_radial_basis": 10, - "interaction_first": "RealAgnosticResidualInteractionBlock", - "multiheads_finetuning": False, - } - - # Run environment setup - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - - # First run: without foundation_model_elements (default behavior) - mace_params = base_params.copy() - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - # Load model and check elements - model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu") - filtered_elements = set(int(z) for z in model_filtered.atomic_numbers) - assert filtered_elements == {1, 8} # Only H and O should be present - - # Second run: with foundation_model_elements - mace_params = base_params.copy() - mace_params["name"] = "MACE_all_elements" - mace_params["foundation_model_elements"] = True # Flag-only argument - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - # Load model and check elements - model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu") - all_elements = set(int(z) for z in model_all.atomic_numbers) - - # Get elements from foundation model for comparison - calc = mace_mp(model="small", device="cpu") - foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers) - - # Check that all foundation model elements are preserved - assert all_elements == foundation_elements - assert len(all_elements) > len(filtered_elements) - - # Check that both models can make predictions - at = fitting_configs[2].copy() - - # Test filtered model - calc_filtered = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" - ) - at.calc = calc_filtered - e1 = at.get_potential_energy() - - # Test all-elements model - calc_all = MACECalculator( - model_paths=tmp_path / "MACE_all_elements.model", - device="cpu", - default_dtype="float64", - ) - at.calc = calc_all - e2 = at.get_potential_energy() - - # Energies should be different since the models are trained differently, - # but both should give reasonable results - assert np.isfinite(e1) - assert np.isfinite(e2) - - -def test_run_train_foundation_elements_multihead(tmp_path, fitting_configs): - fitting_configs_dft = [] - fitting_configs_mp2 = [] - atomic_numbers = np.unique( - np.concatenate([at.numbers for at in fitting_configs]) - ).tolist() - for i, c in enumerate(fitting_configs): - if i in (0, 1): - c_dft = c.copy() - c_dft.info["head"] = "DFT" - fitting_configs_dft.append(c_dft) - c_mp2 = c.copy() - c_mp2.info["head"] = "MP2" - fitting_configs_mp2.append(c_mp2) - if i % 2 == 0: - c_copy = c.copy() - c_copy.info["head"] = "DFT" - fitting_configs_dft.append(c_copy) - else: - c_copy = c.copy() - c_copy.info["head"] = "MP2" - fitting_configs_mp2.append(c_copy) - - ase.io.write(tmp_path / "fit_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_mp2.xyz", fitting_configs_mp2) - - # Create multihead configuration - heads = { - "DFT": {"train_file": f"{str(tmp_path)}/fit_dft.xyz"}, - "MP2": {"train_file": f"{str(tmp_path)}/fit_mp2.xyz"}, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - config_file = tmp_path / "config.yaml" - with open(config_file, "w", encoding="utf-8") as file: - file.write(yaml_str) - - base_params = { - "name": "MACE", - "checkpoints_dir": str(tmp_path), - "model_dir": str(tmp_path), - "config": str(config_file), - "loss": "weighted", - "foundation_model": "small", - "hidden_irreps": "128x0e", - "r_max": 6.0, - "default_dtype": "float64", - "max_num_epochs": 5, - "num_radial_basis": 10, - "interaction_first": "RealAgnosticResidualInteractionBlock", - "force_mh_ft_lr": True, - "batch_size": 1, - "num_samples_pt": 50, - "subselect_pt": "random", - "atomic_numbers": "[" + ",".join(map(str, atomic_numbers)) + "]", - "filter_type_pt": "combinations", - "valid_fraction": 0.1, - "valid_batch_size": 1, - } - - # Run environment setup - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - - # First run: without foundation_model_elements (default behavior) - mace_params = base_params.copy() - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - try: - completed_process = subprocess.run( - cmd.split(), env=run_env, capture_output=True, text=True, check=True - ) - # Process executed successfully - print(completed_process.stdout) - except subprocess.CalledProcessError as e: - # Process failed with non-zero exit code - print(f"Command failed with exit code {e.returncode}") - print(f"STDOUT: {e.stdout}") - print(f"STDERR: {e.stderr}") - raise e - assert completed_process.returncode == 0 - - # Load model and check elements - model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu") - filtered_elements = set(int(z) for z in model_filtered.atomic_numbers) - assert filtered_elements == {1, 8} # Only H and O should be present - assert len(model_filtered.heads) == 3 # pt_head + DFT + MP2 - - # Second run: with foundation_model_elements - mace_params = base_params.copy() - mace_params["name"] = "MACE_all_elements" - mace_params["foundation_model_elements"] = True - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - # Load model and check elements - model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu") - all_elements = set(int(z) for z in model_all.atomic_numbers) - - # Get elements from foundation model for comparison - calc = mace_mp(model="small", device="cpu") - foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers) - - # Check that all foundation model elements are preserved - assert all_elements == foundation_elements - assert len(all_elements) > len(filtered_elements) - assert len(model_all.heads) == 3 # pt_head + DFT + MP2 - - # Check that both models can make predictions - at = fitting_configs_dft[2].copy() - - # Test filtered model - calc_filtered = MACECalculator( - model_paths=tmp_path / "MACE.model", - device="cpu", - default_dtype="float64", - head="DFT", - ) - at.calc = calc_filtered - e1 = at.get_potential_energy() - - # Test all-elements model - calc_all = MACECalculator( - model_paths=tmp_path / "MACE_all_elements.model", - device="cpu", - default_dtype="float64", - head="DFT", - ) - at.calc = calc_all - e2 = at.get_potential_energy() - - assert np.isfinite(e1) - assert np.isfinite(e2) +import json +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +import torch +from ase.atoms import Atoms + +from mace.calculators import MACECalculator, mace_mp + +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + +@pytest.fixture(name="fitting_configs") +def fixture_fitting_configs(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + fit_configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] + fit_configs[0].info["REF_energy"] = 0.0 + fit_configs[0].info["config_type"] = "IsolatedAtom" + fit_configs[1].info["REF_energy"] = 0.0 + fit_configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(20): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + print(c.info["REF_energy"]) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + fit_configs.append(c) + + return fit_configs + + +@pytest.fixture(name="pretraining_configs") +def fixture_pretraining_configs(): + configs = [] + for _ in range(10): + atoms = Atoms( + numbers=[8, 1, 1], + positions=np.random.rand(3, 3) * 3, + cell=[5, 5, 5], + pbc=[True] * 3, + ) + atoms.info["REF_energy"] = np.random.normal(0, 1) + atoms.arrays["REF_forces"] = np.random.normal(0, 1, size=(3, 3)) + atoms.info["REF_stress"] = np.random.normal(0, 1, size=6) + configs.append(atoms) + configs.append( + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3), + ) + configs.append( + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3) + ) + configs[-2].info["REF_energy"] = -2.0 + configs[-2].info["config_type"] = "IsolatedAtom" + configs[-1].info["REF_energy"] = -4.0 + configs[-1].info["config_type"] = "IsolatedAtom" + return configs + + +_mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, +} + + +def test_run_train(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 + ref_Es = [ + 0.0, + 0.0, + -0.039181344585828524, + -0.0915223395136733, + -0.14953484236456582, + -0.06662480820063998, + -0.09983737353050133, + 0.12477442296789745, + -0.06486086271762856, + -0.1460607988519944, + 0.12886334908465508, + -0.14000990081920373, + -0.05319886578958313, + 0.07780520158391, + -0.08895480281886901, + -0.15474719614734422, + 0.007756765146527644, + -0.044879267197498685, + -0.036065736712447574, + -0.24413743841886623, + -0.0838104612106429, + -0.14751978636626545, + ] + + assert np.allclose(Es, ref_Es) + + +def test_run_train_missing_data(tmp_path, fitting_configs): + del fitting_configs[5].info["REF_energy"] + del fitting_configs[6].arrays["REF_forces"] + del fitting_configs[7].info["REF_stress"] + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 + ref_Es = [ + 0.0, + 0.0, + -0.05464025113696155, + -0.11272131295940478, + 0.039200919331076826, + -0.07517990972827505, + -0.13504202474582666, + 0.0292022872055344, + -0.06541099574579018, + -0.1497824717832886, + 0.19397709360828813, + -0.13587609467143014, + -0.05242956276828463, + -0.0504862057364953, + -0.07095795959430119, + -0.2463753796753703, + -0.002031543147676121, + -0.03864918790300681, + -0.13680153117705554, + -0.23418951968636786, + -0.11790833839379238, + -0.14930562311066484, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_no_stress(tmp_path, fitting_configs): + del fitting_configs[5].info["REF_energy"] + del fitting_configs[6].arrays["REF_forces"] + del fitting_configs[7].info["REF_stress"] + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["loss"] = "weighted" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3 + ref_Es = [ + 0.0, + 0.0, + -0.05450093218377135, + -0.11235475232750518, + 0.03914558031854152, + -0.07500839914816063, + -0.13469160624431492, + 0.029384214243251838, + -0.06521819204166135, + -0.14944896282001804, + 0.19413948083049481, + -0.13543541860473626, + -0.05235495076237124, + -0.049556206595684105, + -0.07080758913030646, + -0.24571898386301153, + -0.002070636306950905, + -0.03863113401320783, + -0.13620291339913712, + -0.23383074855679695, + -0.11776449630199368, + -0.1489441490225184, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + fitting_configs_ccd = [] + for _, c in enumerate(fitting_configs): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + + c_ccd = c.copy() + c_ccd.info["head"] = "CCD" + fitting_configs_ccd.append(c_ccd) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + ase.io.write(tmp_path / "fit_multihead_ccd.xyz", fitting_configs_ccd) + + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + "CCD": {"train_file": f"{str(tmp_path)}/fit_multihead_ccd.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["loss"] = "weighted" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["config"] = tmp_path / "config.yaml" + mace_params["batch_size"] = 2 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head="CCD", + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 02/09/2024 on develop branch + ref_Es = [ + 0.0, + 0.0, + 0.10637113905361611, + -0.012499594026624754, + 0.08983077108171753, + 0.21071322543112597, + -0.028921849222784398, + -0.02423359575741567, + 0.022923252188079057, + -0.02048334610058991, + 0.4349711162741364, + -0.04455577015569887, + -0.09765806785570091, + 0.16013134616829822, + 0.0758442928017698, + -0.05931856557011721, + 0.33964473532953265, + 0.134338442158641, + 0.18024119757783053, + -0.18914740992058765, + -0.06503477155294624, + 0.03436649147415213, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_foundation(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["multiheads_finetuning"] = False + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 28/03/2023 on repulsion a63434aaab70c84ee016e13e4aca8d57297a0f26 + ref_Es = [ + 1.6780993938446045, + 0.8916864395141602, + 0.7290308475494385, + 0.6194742918014526, + 0.6697757840156555, + 0.7025266289710999, + 0.5818213224411011, + 0.7897703647613525, + 0.6558921337127686, + 0.5071806907653809, + 3.581131935119629, + 0.691562294960022, + 0.6257331967353821, + 0.9560437202453613, + 0.7716934680938721, + 0.6730310916900635, + 0.8297463655471802, + 0.8053972721099854, + 0.8337507247924805, + 0.4107491970062256, + 0.6019601821899414, + 0.7301387786865234, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_foundation_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + atomic_numbers = np.unique( + np.concatenate([at.numbers for at in fitting_configs]) + ).tolist() + for i, c in enumerate(fitting_configs): + if i in (0, 1): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + fitting_configs_dft.append(c) + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + elif i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" + mace_params["filter_type_pt"] = "combinations" + mace_params["force_mh_ft_lr"] = True + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + try: + completed_process = subprocess.run( + cmd.split(), env=run_env, capture_output=True, text=True, check=True + ) + # Process executed successfully + print(completed_process.stdout) + except subprocess.CalledProcessError as e: + # Process failed with non-zero exit code + print(f"Command failed with exit code {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + raise e + assert completed_process.returncode == 0 + + Es = [] + for at in fitting_configs: + config_head = at.info.get("head", "MP2") + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head=config_head, + ) + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1) + + +def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + atomic_numbers = np.unique( + np.concatenate([at.numbers for at in fitting_configs]) + ).tolist() + for i, c in enumerate(fitting_configs): + + if i in (0, 1): + continue # skip isolated atoms, as energies specified by json files below + if i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # write E0s to json files + E0s = {1: 0.0, 8: 0.0} + with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + + heads = { + "DFT": { + "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", + }, + "MP2": { + "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", + }, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" + mace_params["filter_type_pt"] = "combinations" + mace_params["force_mh_ft_lr"] = True + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + try: + completed_process = subprocess.run( + cmd.split(), env=run_env, capture_output=True, text=True, check=True + ) + # Process executed successfully + print(completed_process.stdout) + except subprocess.CalledProcessError as e: + # Process failed with non-zero exit code + print(f"Command failed with exit code {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + raise e + assert completed_process.returncode == 0 + + Es = [] + for at in fitting_configs: + config_head = at.info.get("head", "MP2") + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head=config_head, + ) + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1) + + +def test_run_train_multihead_replay_custum_finetuning( + tmp_path, fitting_configs, pretraining_configs +): + ase.io.write(tmp_path / "pretrain.xyz", pretraining_configs) + + foundation_params = { + "name": "foundation", + "train_file": os.path.join(tmp_path, "pretrain.xyz"), + "valid_fraction": 0.2, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 2, + "max_num_epochs": 5, + "swa": None, + "start_swa": 3, + "device": "cpu", + "seed": 42, + "loss": "weighted", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "default_dtype": "float64", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + } + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + cmd = [sys.executable, str(run_train)] + for k, v in foundation_params.items(): + if v is None: + cmd.append(f"--{k}") + else: + cmd.append(f"--{k}={v}") + + p = subprocess.run(cmd, env=run_env, check=True) + assert p.returncode == 0 + + # Step 3: Create finetuning set + fitting_configs_dft = [] + fitting_configs_mp2 = [] + for i, c in enumerate(fitting_configs): + if i in (0, 1): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + fitting_configs_dft.append(c) + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + elif i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # Step 4: Finetune the pretrained model with multihead replay + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + + finetuning_params = { + "name": "finetuned", + "valid_fraction": 0.1, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 2, + "max_num_epochs": 5, + "device": "cpu", + "seed": 42, + "loss": "weighted", + "default_dtype": "float64", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + "foundation_model": os.path.join(tmp_path, "foundation.model"), + "config": os.path.join(tmp_path, "config.yaml"), + "pt_train_file": os.path.join(tmp_path, "pretrain.xyz"), + "num_samples_pt": 3, + "subselect_pt": "random", + "force_mh_ft_lr": True, + } + + cmd = [sys.executable, str(run_train)] + for k, v in finetuning_params.items(): + if v is None: + cmd.append(f"--{k}") + else: + cmd.append(f"--{k}={v}") + + p = subprocess.run(cmd, env=run_env, check=True) + assert p.returncode == 0 + + # Load and test the finetuned model + calc = MACECalculator( + model_paths=tmp_path / "finetuned.model", + device="cpu", + default_dtype="float64", + head="pt_head", + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Energies:", Es) + + # Add some basic checks + assert len(Es) == len(fitting_configs) + assert all(isinstance(E, float) for E in Es) + assert len(set(Es)) > 1 # Ens + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +def test_run_train_cueq(tmp_path, fitting_configs): + torch.set_default_dtype(torch.float64) + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["enable_cueq"] = True + mace_params["default_dtype"] = "float64" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + try: + completed_process = subprocess.run( + cmd.split(), env=run_env, capture_output=True, text=True, check=True + ) + # Process executed successfully + print(completed_process.stdout) + except subprocess.CalledProcessError as e: + # Process failed with non-zero exit code + print(f"Command failed with exit code {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + raise e + assert completed_process.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cuda") + Es = [] + for at in fitting_configs[2:]: + at.calc = calc + Es.append(at.get_potential_energy()) + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", enable_cueq=True + ) + Es_cueq = [] + for at in fitting_configs[2:]: + at.calc = calc + Es_cueq.append(at.get_potential_energy()) + + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 + ref_Es = [ + -0.039181344585828524, + -0.0915223395136733, + -0.14953484236456582, + -0.06662480820063998, + -0.09983737353050133, + 0.12477442296789745, + -0.06486086271762856, + -0.1460607988519944, + 0.12886334908465508, + -0.14000990081920373, + -0.05319886578958313, + 0.07780520158391, + -0.08895480281886901, + -0.15474719614734422, + 0.007756765146527644, + -0.044879267197498685, + -0.036065736712447574, + -0.24413743841886623, + -0.0838104612106429, + -0.14751978636626545, + ] + + assert np.allclose(Es, ref_Es) + assert np.allclose(ref_Es, Es_cueq) + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +def test_run_train_foundation_multihead_json_cueq(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + atomic_numbers = np.unique( + np.concatenate([at.numbers for at in fitting_configs]) + ).tolist() + for i, c in enumerate(fitting_configs): + + if i in (0, 1): + continue # skip isolated atoms, as energies specified by json files below + if i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # write E0s to json files + E0s = {1: 0.0, 8: 0.0} + with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + + heads = { + "DFT": { + "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", + }, + "MP2": { + "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", + }, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + mace_params["enable_cueq"] = True + mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" + mace_params["filter_type_pt"] = "combinations" + mace_params["device"] = "cuda" + mace_params["force_mh_ft_lr"] = True + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + try: + completed_process = subprocess.run( + cmd.split(), env=run_env, capture_output=True, text=True, check=True + ) + # Process executed successfully + print(completed_process.stdout) + except subprocess.CalledProcessError as e: + # Process failed with non-zero exit code + print(f"Command failed with exit code {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + raise e + assert completed_process.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cuda", + default_dtype="float64", + head="DFT", + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1) + + +def test_run_train_lbfgs(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["lbfgs"] = None + mace_params["max_num_epochs"] = 2 + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 14/03/2025 + ref_Es = [ + 0.0, + 0.0, + -0.1874197850340979, + -0.25991775038059006, + 0.18263492399322268, + -0.15026829765490662, + -0.2403061362015996, + 0.1689257170630718, + -0.2095568077455055, + -0.2957758160829075, + -0.0035370913684985364, + -0.2195416610745775, + -0.25405549447739517, + -0.06201390990366806, + -0.13332219494388334, + -0.19633181702040337, + 0.013014932630445699, + -0.08808335967147174, + -0.06664444189210728, + -0.4230467426992034, + -0.2348250569553676, + -0.17593904833220647, + ] + assert np.allclose(Es, ref_Es, atol=1e-2) + + +def test_run_train_foundation_elements(tmp_path, fitting_configs): + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + base_params = { + "name": "MACE", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + "train_file": tmp_path / "fit.xyz", + "loss": "weighted", + "foundation_model": "small", + "hidden_irreps": "128x0e", + "r_max": 6.0, + "default_dtype": "float64", + "max_num_epochs": 5, + "num_radial_basis": 10, + "interaction_first": "RealAgnosticResidualInteractionBlock", + "multiheads_finetuning": False, + } + + # Run environment setup + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + # First run: without foundation_model_elements (default behavior) + mace_params = base_params.copy() + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Load model and check elements + model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu") + filtered_elements = set(int(z) for z in model_filtered.atomic_numbers) + assert filtered_elements == {1, 8} # Only H and O should be present + + # Second run: with foundation_model_elements + mace_params = base_params.copy() + mace_params["name"] = "MACE_all_elements" + mace_params["foundation_model_elements"] = True # Flag-only argument + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Load model and check elements + model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu") + all_elements = set(int(z) for z in model_all.atomic_numbers) + + # Get elements from foundation model for comparison + calc = mace_mp(model="small", device="cpu") + foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers) + + # Check that all foundation model elements are preserved + assert all_elements == foundation_elements + assert len(all_elements) > len(filtered_elements) + + # Check that both models can make predictions + at = fitting_configs[2].copy() + + # Test filtered model + calc_filtered = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + at.calc = calc_filtered + e1 = at.get_potential_energy() + + # Test all-elements model + calc_all = MACECalculator( + model_paths=tmp_path / "MACE_all_elements.model", + device="cpu", + default_dtype="float64", + ) + at.calc = calc_all + e2 = at.get_potential_energy() + + # Energies should be different since the models are trained differently, + # but both should give reasonable results + assert np.isfinite(e1) + assert np.isfinite(e2) + + +def test_run_train_foundation_elements_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + atomic_numbers = np.unique( + np.concatenate([at.numbers for at in fitting_configs]) + ).tolist() + for i, c in enumerate(fitting_configs): + if i in (0, 1): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + if i % 2 == 0: + c_copy = c.copy() + c_copy.info["head"] = "DFT" + fitting_configs_dft.append(c_copy) + else: + c_copy = c.copy() + c_copy.info["head"] = "MP2" + fitting_configs_mp2.append(c_copy) + + ase.io.write(tmp_path / "fit_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_mp2.xyz", fitting_configs_mp2) + + # Create multihead configuration + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_mp2.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + config_file = tmp_path / "config.yaml" + with open(config_file, "w", encoding="utf-8") as file: + file.write(yaml_str) + + base_params = { + "name": "MACE", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + "config": str(config_file), + "loss": "weighted", + "foundation_model": "small", + "hidden_irreps": "128x0e", + "r_max": 6.0, + "default_dtype": "float64", + "max_num_epochs": 5, + "num_radial_basis": 10, + "interaction_first": "RealAgnosticResidualInteractionBlock", + "force_mh_ft_lr": True, + "batch_size": 1, + "num_samples_pt": 50, + "subselect_pt": "random", + "atomic_numbers": "[" + ",".join(map(str, atomic_numbers)) + "]", + "filter_type_pt": "combinations", + "valid_fraction": 0.1, + "valid_batch_size": 1, + } + + # Run environment setup + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + # First run: without foundation_model_elements (default behavior) + mace_params = base_params.copy() + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + try: + completed_process = subprocess.run( + cmd.split(), env=run_env, capture_output=True, text=True, check=True + ) + # Process executed successfully + print(completed_process.stdout) + except subprocess.CalledProcessError as e: + # Process failed with non-zero exit code + print(f"Command failed with exit code {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + raise e + assert completed_process.returncode == 0 + + # Load model and check elements + model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu") + filtered_elements = set(int(z) for z in model_filtered.atomic_numbers) + assert filtered_elements == {1, 8} # Only H and O should be present + assert len(model_filtered.heads) == 3 # pt_head + DFT + MP2 + + # Second run: with foundation_model_elements + mace_params = base_params.copy() + mace_params["name"] = "MACE_all_elements" + mace_params["foundation_model_elements"] = True + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Load model and check elements + model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu") + all_elements = set(int(z) for z in model_all.atomic_numbers) + + # Get elements from foundation model for comparison + calc = mace_mp(model="small", device="cpu") + foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers) + + # Check that all foundation model elements are preserved + assert all_elements == foundation_elements + assert len(all_elements) > len(filtered_elements) + assert len(model_all.heads) == 3 # pt_head + DFT + MP2 + + # Check that both models can make predictions + at = fitting_configs_dft[2].copy() + + # Test filtered model + calc_filtered = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head="DFT", + ) + at.calc = calc_filtered + e1 = at.get_potential_energy() + + # Test all-elements model + calc_all = MACECalculator( + model_paths=tmp_path / "MACE_all_elements.model", + device="cpu", + default_dtype="float64", + head="DFT", + ) + at.calc = calc_all + e2 = at.get_potential_energy() + + assert np.isfinite(e1) + assert np.isfinite(e2) diff --git a/mace-bench/3rdparty/mace/tests/test_run_train_allkeys.py b/mace-bench/3rdparty/mace/tests/test_run_train_allkeys.py index 1c102173b91a08dd23d5dd2adbfbbcc8c1612a70..1d59190c28c7e63604476565db2bd4a5420dfc24 100644 --- a/mace-bench/3rdparty/mace/tests/test_run_train_allkeys.py +++ b/mace-bench/3rdparty/mace/tests/test_run_train_allkeys.py @@ -1,468 +1,468 @@ -import os -import subprocess -import sys -from copy import deepcopy -from pathlib import Path - -import ase.io -import numpy as np -import pytest -from ase.atoms import Atoms - -from mace.calculators.mace import MACECalculator -from mace.cli.run_train import run as run_mace_train -from mace.data.utils import KeySpecification -from mace.tools import build_default_arg_parser - -run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - - -_mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "128x0e", - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "device": "cpu", - "seed": 5, - "loss": "weighted", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "interaction_first": "RealAgnosticResidualInteractionBlock", - "batch_size": 1, - "valid_batch_size": 1, - "num_samples_pt": 50, - "subselect_pt": "random", - "eval_interval": 2, - "num_radial_basis": 10, - "r_max": 6.0, - "default_dtype": "float64", -} - - -def configs_numbered_keys(): - np.random.seed(0) - water = Atoms( - numbers=[8, 1, 1], - positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], - cell=[4] * 3, - pbc=[True] * 3, - ) - - energies = list(np.random.normal(0.1, size=15)) - forces = list(np.random.normal(0.1, size=(15, 3, 3))) - - trial_configs_lists = [] - # some keys present, some not - keys_to_use = ( - ["REF_energy"] - + ["2_energy"] * 2 - + ["3_energy"] * 3 - + ["4_energy"] * 4 - + ["5_energy"] * 5 - ) - - force_keys_to_use = ( - ["REF_forces"] - + ["2_forces"] * 2 - + ["3_forces"] * 3 - + ["4_forces"] * 4 - + ["5_forces"] * 5 - ) - - for ind in range(15): - c = deepcopy(water) - c.info[keys_to_use[ind]] = energies[ind] - c.arrays[force_keys_to_use[ind]] = forces[ind] - c.positions += np.random.normal(0.1, size=(3, 3)) - trial_configs_lists.append(c) - - return trial_configs_lists - - -def trial_yamls_and_and_expected(): - yamls = {} - command_line_kwargs = {"energy_key": "2_energy", "forces_key": "2_forces"} - - yamls["no_heads"] = {} - - yamls["one_head_no_dicts"] = { - "heads": { - "Default": { - "energy_key": "3_energy", - } - } - } - - yamls["one_head_with_dicts"] = { - "heads": { - "Default": { - "info_keys": { - "energy": "3_energy", - }, - "arrays_keys": { - "forces": "3_forces", - }, - } - } - } - - yamls["two_heads_no_dicts"] = { - "heads": { - "dft": { - "train_file": "fit_multihead_dft.xyz", - "energy_key": "3_energy", - }, - "mp2": { - "train_file": "fit_multihead_mp2.xyz", - "energy_key": "4_energy", - }, - } - } - - yamls["two_heads_mixed"] = { - "heads": { - "dft": { - "train_file": "fit_multihead_dft.xyz", - "info_keys": { - "energy": "3_energy", - }, - "arrays_keys": { - "forces": "3_forces", - }, - "forces_key": "4_forces", - }, - "mp2": { - "train_file": "fit_multihead_mp2.xyz", - "energy_key": "4_energy", - }, - } - } - all_arg_sets = { - "with_command_line": { - key: {**command_line_kwargs, **value} for key, value in yamls.items() - }, - "without_command_line": yamls, - } - - all_expected_outputs = { - "with_command_line": { - "no_heads": [ - 1.0037831178668188, - 1.0183291323603265, - 1.0120784084221528, - 0.9935695881012243, - 1.0021641561865526, - 0.9999135609205868, - 0.9809440616323108, - 1.0025784765050076, - 1.0017901145495376, - 1.0136913185404515, - 1.006798563238269, - 1.0187758397828384, - 1.0180201540775071, - 1.0132368725061702, - 0.9998734173248169, - ], - "one_head_no_dicts": [ - 1.0028437510688613, - 1.0514693378041775, - 1.059933403321331, - 1.034719940573569, - 1.0438040675561824, - 1.019719477728329, - 0.9841759692947915, - 1.0435266573857496, - 1.0339501989779065, - 1.0501795448530264, - 1.0402594216704781, - 1.0604998765679152, - 1.0633411200246015, - 1.0539071190201297, - 1.0393496428177804, - ], - "one_head_with_dicts": [ - 0.8638341551096959, - 1.0078341354784144, - 1.0149701178418595, - 0.9945723048460148, - 1.0184158011731292, - 0.9992135295205004, - 0.8943420783639198, - 1.0327920054084088, - 0.9905731198078909, - 0.9838325204450648, - 1.0018725575620482, - 1.007263052421034, - 1.0335213929231966, - 1.0033503312511205, - 1.0174433894759563, - ], - "two_heads_no_dicts": [ - 0.9836377578288774, - 1.0196844186291318, - 1.0151628222871238, - 0.957307281711648, - 0.985574141310865, - 0.9629670134047853, - 0.9242583185138095, - 0.9807770070311039, - 0.9973679440479541, - 1.0221127246963275, - 1.0031807967874216, - 1.0358701219543687, - 1.0434208761164758, - 1.0235606028124515, - 0.9797494630655053, - ], - "two_heads_mixed": [ - 0.8664108574741868, - 0.9907166576278023, - 1.0051969372365164, - 0.978702477000018, - 1.025500166764692, - 0.9940095566375018, - 0.9034029726954119, - 1.0391739502744488, - 0.9717327061183668, - 0.972292103670355, - 1.0012510461663253, - 0.9978051155885286, - 1.0378611651753475, - 1.0003207628186224, - 1.0209509292189651, - ], - }, - "without_command_line": { - "no_heads": [ - 0.9352605307451007, - 0.991084559389268, - 0.9940350095024881, - 0.9953849198103668, - 0.9954705498032904, - 0.9964815693808411, - 0.9663142667436776, - 0.9947223808739147, - 0.9897776682803257, - 0.989027769690667, - 0.9910280920241263, - 0.992067980667518, - 0.9917276132506404, - 0.9902848752169671, - 0.9928585982942544, - ], - "one_head_no_dicts": [ - 0.9425342207393083, - 1.0149788456087416, - 1.0249228965652788, - 1.0247924743285792, - 1.02732103964481, - 1.0168852937950326, - 0.9771283495170653, - 1.0261776335561517, - 1.0130461033368028, - 1.0162619153561783, - 1.019995179866916, - 1.0209512298344965, - 1.0219971755636952, - 1.0195791901659124, - 1.0234662527729408, - ], - "one_head_with_dicts": [ - 0.8638341551096959, - 1.0078341354784144, - 1.0149701178418595, - 0.9945723048460148, - 1.0184158011731292, - 0.9992135295205004, - 0.8943420783639198, - 1.0327920054084088, - 0.9905731198078909, - 0.9838325204450648, - 1.0018725575620482, - 1.007263052421034, - 1.0335213929231966, - 1.0033503312511205, - 1.0174433894759563, - ], - "two_heads_no_dicts": [ - 0.9933763730233168, - 0.9986480398559268, - 1.0042486164355315, - 1.0025568793877726, - 1.0032598081704625, - 0.9926714183717912, - 0.9920385249670881, - 1.0020278841030676, - 1.0012474150830537, - 1.0039289677261019, - 1.0022718878661814, - 1.003586385624809, - 1.003436450009097, - 1.003805673887942, - 1.001450261102316, - ], - "two_heads_mixed": [ - 0.8781767864616707, - 0.9843563603794138, - 1.0145197579049248, - 0.9835060778675391, - 1.0419060462994596, - 0.9917393978520056, - 0.9091521032773944, - 1.0605463095070453, - 0.9685381713826684, - 0.9866493058823766, - 1.00305061187164, - 1.0051273128414386, - 1.037964258398104, - 1.0106663924241408, - 1.0274351814133602, - ], - }, - } - - list_of_all = [] - for key, value in all_arg_sets.items(): - for key2, value2 in value.items(): - list_of_all.append( - (value2, (key, key2), np.asarray(all_expected_outputs[key][key2])) - ) - - return list_of_all - - -def dict_to_yaml_str(data, indent=0): - yaml_str = "" - for key, value in data.items(): - yaml_str += " " * indent + str(key) + ":" - if isinstance(value, dict): - yaml_str += "\n" + dict_to_yaml_str(value, indent + 2) - else: - yaml_str += " " + str(value) + "\n" - return yaml_str - - -_trial_yamls_and_and_expected = trial_yamls_and_and_expected() - - -@pytest.mark.parametrize( - "yaml_contents, name, expected_value", _trial_yamls_and_and_expected -) -def test_key_specification_methods(tmp_path, yaml_contents, name, expected_value): - fitting_configs = configs_numbered_keys() - - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs) - ase.io.write(tmp_path / "duplicated_fit_multihead_dft.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["valid_fraction"] = 0.1 - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = "fit_multihead_dft.xyz" - mace_params["E0s"] = "{1:0.0,8:1.0}" - mace_params["valid_file"] = "duplicated_fit_multihead_dft.xyz" - del mace_params["valid_fraction"] - mace_params["max_num_epochs"] = 1 # many tests to do - del mace_params["energy_key"] - del mace_params["forces_key"] - del mace_params["stress_key"] - - mace_params["name"] = "MACE_" - - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(dict_to_yaml_str(yaml_contents)) - if len(yaml_contents) > 0: - mace_params["config"] = str(tmp_path / "config.yaml") - - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, cwd=tmp_path, check=True) - assert p.returncode == 0 - - if "heads" in yaml_contents: - headname = list(yaml_contents["heads"].keys())[0] - else: - headname = "Default" - - calc = MACECalculator( - tmp_path / "MACE_.model", device="cpu", default_dtype="float64", head=headname - ) - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print(name) - print("Es", Es) - - assert np.allclose( - np.asarray(Es), expected_value, rtol=1e-8, atol=1e-8 - ), f"Expected {expected_value} but got {Es} with error {np.max(np.abs(Es - expected_value))}" - - -def test_multihead_finetuning_does_not_modify_default_keyspec(tmp_path): - fitting_configs = configs_numbered_keys() - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs) - - args = build_default_arg_parser().parse_args( - [ - "--name", - "_MACE_", - "--train_file", - str(tmp_path / "fit_multihead_dft.xyz"), - "--foundation_model", - "small", - "--device", - "cpu", - "--E0s", - "{1:0.0,8:1.0}", - "--energy_key", - "2_energy", - "--dry_run", - ] - ) - default_key_spec = KeySpecification.from_defaults() - default_key_spec.info_keys["energy"] = "2_energy" - run_mace_train(args) - assert args.key_specification == default_key_spec - -# for creating values -def make_output(): - outputs = {} - for yaml_contents, name, expected_value in _trial_yamls_and_and_expected: - if name[0] not in outputs: - outputs[name[0]] = {} - expected = test_key_specification_methods( - Path("."), yaml_contents, name, expected_value, debug_test=False - ) - outputs[name[0]][name[1]] = expected - print(outputs) +import os +import subprocess +import sys +from copy import deepcopy +from pathlib import Path + +import ase.io +import numpy as np +import pytest +from ase.atoms import Atoms + +from mace.calculators.mace import MACECalculator +from mace.cli.run_train import run as run_mace_train +from mace.data.utils import KeySpecification +from mace.tools import build_default_arg_parser + +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + +_mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "device": "cpu", + "seed": 5, + "loss": "weighted", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "interaction_first": "RealAgnosticResidualInteractionBlock", + "batch_size": 1, + "valid_batch_size": 1, + "num_samples_pt": 50, + "subselect_pt": "random", + "eval_interval": 2, + "num_radial_basis": 10, + "r_max": 6.0, + "default_dtype": "float64", +} + + +def configs_numbered_keys(): + np.random.seed(0) + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + + energies = list(np.random.normal(0.1, size=15)) + forces = list(np.random.normal(0.1, size=(15, 3, 3))) + + trial_configs_lists = [] + # some keys present, some not + keys_to_use = ( + ["REF_energy"] + + ["2_energy"] * 2 + + ["3_energy"] * 3 + + ["4_energy"] * 4 + + ["5_energy"] * 5 + ) + + force_keys_to_use = ( + ["REF_forces"] + + ["2_forces"] * 2 + + ["3_forces"] * 3 + + ["4_forces"] * 4 + + ["5_forces"] * 5 + ) + + for ind in range(15): + c = deepcopy(water) + c.info[keys_to_use[ind]] = energies[ind] + c.arrays[force_keys_to_use[ind]] = forces[ind] + c.positions += np.random.normal(0.1, size=(3, 3)) + trial_configs_lists.append(c) + + return trial_configs_lists + + +def trial_yamls_and_and_expected(): + yamls = {} + command_line_kwargs = {"energy_key": "2_energy", "forces_key": "2_forces"} + + yamls["no_heads"] = {} + + yamls["one_head_no_dicts"] = { + "heads": { + "Default": { + "energy_key": "3_energy", + } + } + } + + yamls["one_head_with_dicts"] = { + "heads": { + "Default": { + "info_keys": { + "energy": "3_energy", + }, + "arrays_keys": { + "forces": "3_forces", + }, + } + } + } + + yamls["two_heads_no_dicts"] = { + "heads": { + "dft": { + "train_file": "fit_multihead_dft.xyz", + "energy_key": "3_energy", + }, + "mp2": { + "train_file": "fit_multihead_mp2.xyz", + "energy_key": "4_energy", + }, + } + } + + yamls["two_heads_mixed"] = { + "heads": { + "dft": { + "train_file": "fit_multihead_dft.xyz", + "info_keys": { + "energy": "3_energy", + }, + "arrays_keys": { + "forces": "3_forces", + }, + "forces_key": "4_forces", + }, + "mp2": { + "train_file": "fit_multihead_mp2.xyz", + "energy_key": "4_energy", + }, + } + } + all_arg_sets = { + "with_command_line": { + key: {**command_line_kwargs, **value} for key, value in yamls.items() + }, + "without_command_line": yamls, + } + + all_expected_outputs = { + "with_command_line": { + "no_heads": [ + 1.0037831178668188, + 1.0183291323603265, + 1.0120784084221528, + 0.9935695881012243, + 1.0021641561865526, + 0.9999135609205868, + 0.9809440616323108, + 1.0025784765050076, + 1.0017901145495376, + 1.0136913185404515, + 1.006798563238269, + 1.0187758397828384, + 1.0180201540775071, + 1.0132368725061702, + 0.9998734173248169, + ], + "one_head_no_dicts": [ + 1.0028437510688613, + 1.0514693378041775, + 1.059933403321331, + 1.034719940573569, + 1.0438040675561824, + 1.019719477728329, + 0.9841759692947915, + 1.0435266573857496, + 1.0339501989779065, + 1.0501795448530264, + 1.0402594216704781, + 1.0604998765679152, + 1.0633411200246015, + 1.0539071190201297, + 1.0393496428177804, + ], + "one_head_with_dicts": [ + 0.8638341551096959, + 1.0078341354784144, + 1.0149701178418595, + 0.9945723048460148, + 1.0184158011731292, + 0.9992135295205004, + 0.8943420783639198, + 1.0327920054084088, + 0.9905731198078909, + 0.9838325204450648, + 1.0018725575620482, + 1.007263052421034, + 1.0335213929231966, + 1.0033503312511205, + 1.0174433894759563, + ], + "two_heads_no_dicts": [ + 0.9836377578288774, + 1.0196844186291318, + 1.0151628222871238, + 0.957307281711648, + 0.985574141310865, + 0.9629670134047853, + 0.9242583185138095, + 0.9807770070311039, + 0.9973679440479541, + 1.0221127246963275, + 1.0031807967874216, + 1.0358701219543687, + 1.0434208761164758, + 1.0235606028124515, + 0.9797494630655053, + ], + "two_heads_mixed": [ + 0.8664108574741868, + 0.9907166576278023, + 1.0051969372365164, + 0.978702477000018, + 1.025500166764692, + 0.9940095566375018, + 0.9034029726954119, + 1.0391739502744488, + 0.9717327061183668, + 0.972292103670355, + 1.0012510461663253, + 0.9978051155885286, + 1.0378611651753475, + 1.0003207628186224, + 1.0209509292189651, + ], + }, + "without_command_line": { + "no_heads": [ + 0.9352605307451007, + 0.991084559389268, + 0.9940350095024881, + 0.9953849198103668, + 0.9954705498032904, + 0.9964815693808411, + 0.9663142667436776, + 0.9947223808739147, + 0.9897776682803257, + 0.989027769690667, + 0.9910280920241263, + 0.992067980667518, + 0.9917276132506404, + 0.9902848752169671, + 0.9928585982942544, + ], + "one_head_no_dicts": [ + 0.9425342207393083, + 1.0149788456087416, + 1.0249228965652788, + 1.0247924743285792, + 1.02732103964481, + 1.0168852937950326, + 0.9771283495170653, + 1.0261776335561517, + 1.0130461033368028, + 1.0162619153561783, + 1.019995179866916, + 1.0209512298344965, + 1.0219971755636952, + 1.0195791901659124, + 1.0234662527729408, + ], + "one_head_with_dicts": [ + 0.8638341551096959, + 1.0078341354784144, + 1.0149701178418595, + 0.9945723048460148, + 1.0184158011731292, + 0.9992135295205004, + 0.8943420783639198, + 1.0327920054084088, + 0.9905731198078909, + 0.9838325204450648, + 1.0018725575620482, + 1.007263052421034, + 1.0335213929231966, + 1.0033503312511205, + 1.0174433894759563, + ], + "two_heads_no_dicts": [ + 0.9933763730233168, + 0.9986480398559268, + 1.0042486164355315, + 1.0025568793877726, + 1.0032598081704625, + 0.9926714183717912, + 0.9920385249670881, + 1.0020278841030676, + 1.0012474150830537, + 1.0039289677261019, + 1.0022718878661814, + 1.003586385624809, + 1.003436450009097, + 1.003805673887942, + 1.001450261102316, + ], + "two_heads_mixed": [ + 0.8781767864616707, + 0.9843563603794138, + 1.0145197579049248, + 0.9835060778675391, + 1.0419060462994596, + 0.9917393978520056, + 0.9091521032773944, + 1.0605463095070453, + 0.9685381713826684, + 0.9866493058823766, + 1.00305061187164, + 1.0051273128414386, + 1.037964258398104, + 1.0106663924241408, + 1.0274351814133602, + ], + }, + } + + list_of_all = [] + for key, value in all_arg_sets.items(): + for key2, value2 in value.items(): + list_of_all.append( + (value2, (key, key2), np.asarray(all_expected_outputs[key][key2])) + ) + + return list_of_all + + +def dict_to_yaml_str(data, indent=0): + yaml_str = "" + for key, value in data.items(): + yaml_str += " " * indent + str(key) + ":" + if isinstance(value, dict): + yaml_str += "\n" + dict_to_yaml_str(value, indent + 2) + else: + yaml_str += " " + str(value) + "\n" + return yaml_str + + +_trial_yamls_and_and_expected = trial_yamls_and_and_expected() + + +@pytest.mark.parametrize( + "yaml_contents, name, expected_value", _trial_yamls_and_and_expected +) +def test_key_specification_methods(tmp_path, yaml_contents, name, expected_value): + fitting_configs = configs_numbered_keys() + + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs) + ase.io.write(tmp_path / "duplicated_fit_multihead_dft.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = "fit_multihead_dft.xyz" + mace_params["E0s"] = "{1:0.0,8:1.0}" + mace_params["valid_file"] = "duplicated_fit_multihead_dft.xyz" + del mace_params["valid_fraction"] + mace_params["max_num_epochs"] = 1 # many tests to do + del mace_params["energy_key"] + del mace_params["forces_key"] + del mace_params["stress_key"] + + mace_params["name"] = "MACE_" + + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(dict_to_yaml_str(yaml_contents)) + if len(yaml_contents) > 0: + mace_params["config"] = str(tmp_path / "config.yaml") + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, cwd=tmp_path, check=True) + assert p.returncode == 0 + + if "heads" in yaml_contents: + headname = list(yaml_contents["heads"].keys())[0] + else: + headname = "Default" + + calc = MACECalculator( + tmp_path / "MACE_.model", device="cpu", default_dtype="float64", head=headname + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print(name) + print("Es", Es) + + assert np.allclose( + np.asarray(Es), expected_value, rtol=1e-8, atol=1e-8 + ), f"Expected {expected_value} but got {Es} with error {np.max(np.abs(Es - expected_value))}" + + +def test_multihead_finetuning_does_not_modify_default_keyspec(tmp_path): + fitting_configs = configs_numbered_keys() + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs) + + args = build_default_arg_parser().parse_args( + [ + "--name", + "_MACE_", + "--train_file", + str(tmp_path / "fit_multihead_dft.xyz"), + "--foundation_model", + "small", + "--device", + "cpu", + "--E0s", + "{1:0.0,8:1.0}", + "--energy_key", + "2_energy", + "--dry_run", + ] + ) + default_key_spec = KeySpecification.from_defaults() + default_key_spec.info_keys["energy"] = "2_energy" + run_mace_train(args) + assert args.key_specification == default_key_spec + +# for creating values +def make_output(): + outputs = {} + for yaml_contents, name, expected_value in _trial_yamls_and_and_expected: + if name[0] not in outputs: + outputs[name[0]] = {} + expected = test_key_specification_methods( + Path("."), yaml_contents, name, expected_value, debug_test=False + ) + outputs[name[0]][name[1]] = expected + print(outputs) diff --git a/mace-bench/3rdparty/mace/tests/test_schedulefree.py b/mace-bench/3rdparty/mace/tests/test_schedulefree.py index d84163c36f6c42d691004fe9e926d0172a5e0216..00b207507e3b8fc7acd0c93049f1deef9967c2a7 100644 --- a/mace-bench/3rdparty/mace/tests/test_schedulefree.py +++ b/mace-bench/3rdparty/mace/tests/test_schedulefree.py @@ -1,127 +1,127 @@ -import tempfile -from unittest.mock import MagicMock - -import numpy as np -import pytest -import torch -import torch.nn.functional as F -from e3nn import o3 - -from mace import data, modules, tools -from mace.tools import scripts_utils, torch_geometric - -try: - import schedulefree -except ImportError: - pytest.skip( - "Skipping schedulefree tests due to ImportError", allow_module_level=True - ) - -torch.set_default_dtype(torch.float64) - -table = tools.AtomicNumberTable([6]) -atomic_energies = np.array([1.0], dtype=float) -cutoff = 5.0 - - -def create_mace(device: str, seed: int = 1702): - torch_geometric.seed_everything(seed) - - model_config = { - "r_max": cutoff, - "num_bessel": 8, - "num_polynomial_cutoff": 6, - "max_ell": 3, - "interaction_cls": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "interaction_cls_first": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "num_interactions": 2, - "num_elements": 1, - "hidden_irreps": o3.Irreps("8x0e + 8x1o"), - "MLP_irreps": o3.Irreps("16x0e"), - "gate": F.silu, - "atomic_energies": atomic_energies, - "avg_num_neighbors": 8, - "atomic_numbers": table.zs, - "correlation": 3, - "radial_type": "bessel", - } - model = modules.MACE(**model_config) - return model.to(device) - - -def create_batch(device: str): - from ase import build - - size = 2 - atoms = build.bulk("C", "diamond", a=3.567, cubic=True) - atoms_list = [atoms.repeat((size, size, size))] - print("Number of atoms", len(atoms_list[0])) - - configs = [data.config_from_atoms(atoms) for atoms in atoms_list] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) - for config in configs - ], - batch_size=1, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - batch = batch.to(device) - batch = batch.to_dict() - return batch - - -def do_optimization_step( - model, - optimizer, - device, -): - batch = create_batch(device) - model.train() - optimizer.train() - optimizer.zero_grad() - output = model(batch, training=True, compute_force=False) - loss = output["energy"].mean() - loss.backward() - optimizer.step() - model.eval() - optimizer.eval() - - -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_can_load_checkpoint(device): - model = create_mace(device) - optimizer = schedulefree.adamw_schedulefree.AdamWScheduleFree(model.parameters()) - args = MagicMock() - args.optimizer = "schedulefree" - args.scheduler = "ExponentialLR" - args.lr_scheduler_gamma = 0.9 - lr_scheduler = scripts_utils.LRScheduler(optimizer, args) - with tempfile.TemporaryDirectory() as d: - checkpoint_handler = tools.CheckpointHandler( - directory=d, keep=False, tag="schedulefree" - ) - for _ in range(10): - do_optimization_step(model, optimizer, device) - batch = create_batch(device) - output = model(batch) - energy = output["energy"].detach().cpu().numpy() - - state = tools.CheckpointState( - model=model, optimizer=optimizer, lr_scheduler=lr_scheduler - ) - checkpoint_handler.save(state, epochs=0, keep_last=False) - checkpoint_handler.load_latest( - state=tools.CheckpointState(model, optimizer, lr_scheduler), - swa=False, - ) - batch = create_batch(device) - output = model(batch) - new_energy = output["energy"].detach().cpu().numpy() - assert np.allclose(energy, new_energy, atol=1e-9) +import tempfile +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 + +from mace import data, modules, tools +from mace.tools import scripts_utils, torch_geometric + +try: + import schedulefree +except ImportError: + pytest.skip( + "Skipping schedulefree tests due to ImportError", allow_module_level=True + ) + +torch.set_default_dtype(torch.float64) + +table = tools.AtomicNumberTable([6]) +atomic_energies = np.array([1.0], dtype=float) +cutoff = 5.0 + + +def create_mace(device: str, seed: int = 1702): + torch_geometric.seed_everything(seed) + + model_config = { + "r_max": cutoff, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": o3.Irreps("8x0e + 8x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": atomic_energies, + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel", + } + model = modules.MACE(**model_config) + return model.to(device) + + +def create_batch(device: str): + from ase import build + + size = 2 + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms_list = [atoms.repeat((size, size, size))] + print("Number of atoms", len(atoms_list[0])) + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) + for config in configs + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch = batch.to(device) + batch = batch.to_dict() + return batch + + +def do_optimization_step( + model, + optimizer, + device, +): + batch = create_batch(device) + model.train() + optimizer.train() + optimizer.zero_grad() + output = model(batch, training=True, compute_force=False) + loss = output["energy"].mean() + loss.backward() + optimizer.step() + model.eval() + optimizer.eval() + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_can_load_checkpoint(device): + model = create_mace(device) + optimizer = schedulefree.adamw_schedulefree.AdamWScheduleFree(model.parameters()) + args = MagicMock() + args.optimizer = "schedulefree" + args.scheduler = "ExponentialLR" + args.lr_scheduler_gamma = 0.9 + lr_scheduler = scripts_utils.LRScheduler(optimizer, args) + with tempfile.TemporaryDirectory() as d: + checkpoint_handler = tools.CheckpointHandler( + directory=d, keep=False, tag="schedulefree" + ) + for _ in range(10): + do_optimization_step(model, optimizer, device) + batch = create_batch(device) + output = model(batch) + energy = output["energy"].detach().cpu().numpy() + + state = tools.CheckpointState( + model=model, optimizer=optimizer, lr_scheduler=lr_scheduler + ) + checkpoint_handler.save(state, epochs=0, keep_last=False) + checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=False, + ) + batch = create_batch(device) + output = model(batch) + new_energy = output["energy"].detach().cpu().numpy() + assert np.allclose(energy, new_energy, atol=1e-9) diff --git a/mace-bench/3rdparty/mace/tests/test_tools.py b/mace-bench/3rdparty/mace/tests/test_tools.py index 50c1ee8987a8859453d929e2feca6b85916f8cfb..227a1bfc9341f3ce0c5d7480f183c89fdd23fc4f 100644 --- a/mace-bench/3rdparty/mace/tests/test_tools.py +++ b/mace-bench/3rdparty/mace/tests/test_tools.py @@ -1,48 +1,48 @@ -import tempfile - -import numpy as np -import torch -import torch.nn.functional -from torch import nn, optim - -from mace.tools import ( - AtomicNumberTable, - CheckpointHandler, - CheckpointState, - atomic_numbers_to_indices, -) - - -def test_atomic_number_table(): - table = AtomicNumberTable(zs=[1, 8]) - array = np.array([8, 8, 1]) - indices = atomic_numbers_to_indices(array, z_table=table) - expected = np.array([1, 1, 0], dtype=int) - assert np.allclose(expected, indices) - - -class MyModel(nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 4) - - def forward(self, x): - return torch.nn.functional.relu(self.linear(x)) - - -def test_save_load(): - model = MyModel() - initial_lr = 0.001 - optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9) - scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99) - - with tempfile.TemporaryDirectory() as directory: - handler = CheckpointHandler(directory=directory, tag="test", keep=True) - handler.save(state=CheckpointState(model, optimizer, scheduler), epochs=50) - - optimizer.step() - scheduler.step() - assert not np.isclose(optimizer.param_groups[0]["lr"], initial_lr) - - handler.load_latest(state=CheckpointState(model, optimizer, scheduler)) - assert np.isclose(optimizer.param_groups[0]["lr"], initial_lr) +import tempfile + +import numpy as np +import torch +import torch.nn.functional +from torch import nn, optim + +from mace.tools import ( + AtomicNumberTable, + CheckpointHandler, + CheckpointState, + atomic_numbers_to_indices, +) + + +def test_atomic_number_table(): + table = AtomicNumberTable(zs=[1, 8]) + array = np.array([8, 8, 1]) + indices = atomic_numbers_to_indices(array, z_table=table) + expected = np.array([1, 1, 0], dtype=int) + assert np.allclose(expected, indices) + + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 4) + + def forward(self, x): + return torch.nn.functional.relu(self.linear(x)) + + +def test_save_load(): + model = MyModel() + initial_lr = 0.001 + optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9) + scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99) + + with tempfile.TemporaryDirectory() as directory: + handler = CheckpointHandler(directory=directory, tag="test", keep=True) + handler.save(state=CheckpointState(model, optimizer, scheduler), epochs=50) + + optimizer.step() + scheduler.step() + assert not np.isclose(optimizer.param_groups[0]["lr"], initial_lr) + + handler.load_latest(state=CheckpointState(model, optimizer, scheduler)) + assert np.isclose(optimizer.param_groups[0]["lr"], initial_lr) diff --git a/mace-bench/reproduce/init_7net.sh b/mace-bench/reproduce/init_7net.sh index 70aa632ee19110f7d5814b201834f922214360b1..22f19f12ca633f726ddc11f0d010790ecd5cebba 100644 --- a/mace-bench/reproduce/init_7net.sh +++ b/mace-bench/reproduce/init_7net.sh @@ -1,11 +1,11 @@ -#!/bin/bash - -pip install torch_scatter==2.1.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install torch_sparse==0.6.18+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install torch_spline_conv==1.2.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install -r requirements.txt -pip install -e 3rdparty/SevenNet -pip install -e . -pip install ase==3.23.0 -pip install ninja + + +pip install torch_scatter==2.1.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install torch_sparse==0.6.18+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install torch_spline_conv==1.2.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install -r requirements.txt +pip install -e 3rdparty/SevenNet +pip install -e . +pip install ase==3.23.0 +pip install ninja pip install rdkit==2024.3.5 \ No newline at end of file diff --git a/mace-bench/reproduce/init_mace.sh b/mace-bench/reproduce/init_mace.sh index d7e9cbae198f313400d53ed6454f9b136a9c8321..b491a652e291f1c1afb12bfe8a9bf37588112533 100644 --- a/mace-bench/reproduce/init_mace.sh +++ b/mace-bench/reproduce/init_mace.sh @@ -1,14 +1,12 @@ -#!/bin/bash - -pip install torch_scatter==2.1.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install torch_sparse==0.6.18+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install torch_spline_conv==1.2.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install -r requirements.txt -pip install -e 3rdparty/mace -pip install -e . -pip install e3nn==0.4.4 -pip install ase==3.23.0 -pip install ninja - -# for python_CSP -pip install rdkit-pypi +pip install torch_scatter==2.1.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install torch_sparse==0.6.18+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install torch_spline_conv==1.2.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install -r requirements.txt +pip install -e 3rdparty/mace +pip install -e . +pip install e3nn==0.4.4 +pip install ase==3.23.0 +pip install ninja + +# for python_CSP +pip install rdkit-pypi diff --git a/mace-bench/reproduce/mace_opt_new.py b/mace-bench/reproduce/mace_opt_new.py index 5b17949321b72f1ec008c3d061751b678e7d3678..4a6d9d7faeaf22a1f67968dc362ae6f56e79b5e1 100644 --- a/mace-bench/reproduce/mace_opt_new.py +++ b/mace-bench/reproduce/mace_opt_new.py @@ -1,300 +1,300 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -import os -os.environ['OMP_NUM_THREADS'] = '1' -os.environ['MKL_NUM_THREADS'] = '1' -os.environ['OPENBLAS_NUM_THREADS'] = '1' -import sys -# sys.path.append('/home/jiangj1group/zcxzcx1/volatile/mace') -from mace.calculators import mace_off, mace_mp -from ase.io import read, write -from ase.optimize import BFGS,LBFGS,FIRE,GPMin,MDMin, QuasiNewton -from ase.filters import UnitCellFilter, ExpCellFilter, FrechetCellFilter -import re -import io -from contextlib import redirect_stdout -import os -import pandas as pd -from joblib import Parallel, delayed -import json -import torch -import numpy as np -import random -import argparse -import time -import pathlib -import logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True) -##################################################################### -os.environ['PYTHONHASHSEED'] = '1' -torch.manual_seed(1) -np.random.seed(1) -random.seed(1) -torch.cuda.manual_seed(1) -torch.cuda.manual_seed_all(1) -##################################################################### -# n_jobs=32 -# # n_jobs=2 -# path = './' -# molecule_single = 64 -# target_folder = "/data_raw/" -##################################################################### - -def calculate_density(crystal): - # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 - total_mass = sum(crystal.get_masses()) # 转换为克 - - # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 - # 1 Å^3 = 1e-24 cm^3 - volume = crystal.get_volume() # 转换为立方厘米 - - # 计算密度,质量除以体积 - density = total_mass / (volume*10**-24)/(6.022140857*10**23) # 单位是 g/cm^3 - return density - -def run_calculation_one(path,file,target_folder,molecule_single,idx): - # os.environ['OMP_NUM_THREADS'] = '1' - # os.environ['MKL_NUM_THREADS'] = '1' - # os.environ['OPENBLAS_NUM_THREADS'] = '1' - if reproduce: - print("Reproducing deterministic results.") - torch.use_deterministic_algorithms(True) - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - np.set_printoptions(precision=17, suppress=False) - torch.set_printoptions(precision=17, sci_mode=False, linewidth=200) - if multithread and (not reproduce): - print("Using OMP and MKL multithreads will introduce non-deterministic results.") - else: - os.environ['OMP_NUM_THREADS'] = '1' - os.environ['MKL_NUM_THREADS'] = '1' - os.environ['OPENBLAS_NUM_THREADS'] = '1' - os.environ["CUDA_VISIBLE_DEVICES"]=str((idx%n_gpus)+gpu_offset) - - with io.StringIO() as buf, redirect_stdout(buf): - crystal = read(path+target_folder+file) - if molecule_single < 0: - molecule_single = int(file.split('_')[-1].split('.')[0]) - molecule_count = len(crystal.get_atomic_numbers())/molecule_single - calc = mace_off(model=model_path,dispersion=True, device='cuda', enable_cueq=cueq) - crystal.calc = calc - if filter1 == "UnitCellFilter": - sf = UnitCellFilter(crystal,scalar_pressure=0.0006) - elif filter1 == "FrechetCellFilter": - sf = FrechetCellFilter(crystal,scalar_pressure=0.0006) - else: - raise ValueError(f"Unrecognized filter type '{filter1}'. " - "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") - if optimizer_type1 == "BFGS": - if use_cuda_eigh: - optimizer = BFGS(sf, use_cuda_eigh=True) - else: - optimizer = BFGS(sf) - elif optimizer_type1 == "LBFGS": - optimizer = LBFGS(sf) - elif optimizer_type1 == "QuasiNewton": - optimizer = QuasiNewton(sf) - else: - raise ValueError(f"Unrecognized optimizer type '{optimizer_type1}'. " - "Supported types are 'BFGS' and 'LBFGS'.") - - if use_nsys or use_torch_profiler : # warmup for profiling - optimizer.run(fmax=0.01,steps=100) - if use_torch_profiler: - profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA - ], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), - with_stack=True - ) - profiler.start() - - start_time1 = time.time() - optimizer.run(fmax=0.01,steps=max_steps) - end_time1 = time.time() - - if use_torch_profiler: - profiler.stop() - - crystal.write(path+'cif_result_press/'+file[:-4]+"_press.cif") - output_1 = buf.getvalue() - # step_used_1 = float(re.split("\\s+", output_1.split('\n')[-2])[1][:]) - step_used_1 = optimizer.nsteps - if use_nsys or use_torch_profiler : - step_used_1 = step_used_1 - 100 - total_time1 = end_time1 - start_time1 - avg_time1 = total_time1 / step_used_1 if step_used_1 != 0 else 0 - - crystal = read(path+'cif_result_press/'+file[:-4]+"_press.cif") - crystal.calc = calc - if filter2 == "UnitCellFilter": - sf = UnitCellFilter(crystal) - elif filter2 == "FrechetCellFilter": - sf = FrechetCellFilter(crystal) - else: - raise ValueError(f"Unrecognized filter type '{filter2}'. " - "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") - if optimizer_type2 == "BFGS": - if use_cuda_eigh: - optimizer = BFGS(sf, use_cuda_eigh=True) - else: - optimizer = BFGS(sf) - elif optimizer_type2 == "LBFGS": - optimizer = LBFGS(sf) - elif optimizer_type2 == "QuasiNewton": - optimizer = QuasiNewton(sf) - else: - raise ValueError(f"Unrecognized optimizer type '{optimizer_type2}'. " - "Supported types are 'BFGS' and 'LBFGS'.") - if use_torch_profiler: - profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA - ], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), - with_stack=True - ) - profiler.start() - - start_time2 = time.time() - optimizer.run(fmax=0.01,steps=max_steps) - end_time2 = time.time() - - if use_torch_profiler: - profiler.stop() - - density = calculate_density(crystal) - crystal.write(path+'cif_result_final/'+file[:-4]+"_opt.cif") - output_2 = buf.getvalue() - energy = float(re.split("\\s+", output_2.split('\n')[-2])[3][:]) - # step_used_2 = float(re.split("\\s+", output_2.split('\n')[-2])[1][:]) - step_used_2 = optimizer.nsteps - energy_per_mol = energy / molecule_count * 96.485 - total_time2 = end_time2 - start_time2 - avg_time2 = total_time2 / step_used_2 if step_used_2 != 0 else 0 - - new_row = { - 'name': file[:-4], 'density': density, 'energy_kj': energy_per_mol, - 'step_used_1': step_used_1, 'step_used_2': step_used_2, - 'total_time1_s': total_time1, 'avg_time1_s': avg_time1, - 'total_time2_s': total_time2, 'avg_time2_s': avg_time2 - } - - print(f'output_2: {output_2}') - with open(path+'json_result/'+file[:-4]+".json", 'w') as json_file: - json.dump(new_row, json_file, indent=4) - return new_row - - -def already_have_calculation_one(path, file, target_folder, molecule_single, idx): - logging.info(f"reading on structure {file}") - print(f"reading on structure {file}") - with open(path + 'json_result/' + file[:-4] + ".json", 'r') as file: - old_row = json.load(file) - return old_row - -def run(): - df = pd.DataFrame(columns=['name', 'density', 'energy_kj', 'step_used_1', 'step_used_2', 'total_time1_s', 'avg_time1_s', 'total_time2_s', 'avg_time2_s']) - for root, dirs, files in os.walk(path + target_folder): - old_row = Parallel(n_jobs=n_jobs)( - delayed(already_have_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in - enumerate(files) if os.path.exists(path + 'json_result/' + file[:-4] + ".json")) - - filtered_files = [file for file in files if not os.path.exists(path + 'json_result/' + file[:-4] + ".json")] - new_row = Parallel(n_jobs=n_jobs)( - delayed(run_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in - enumerate(filtered_files)) - # show the length of new_row - print(f'new_row length: {len(new_row)}') - print(f'root: {root}\ndirs: {dirs}\nfiles: {files}') - for row in new_row: - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) - for row in old_row: - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) - df.to_csv(path + '/result.csv') - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Run parallel calculations on molecular crystals.") - parser.add_argument("--n_jobs", type=int, default=32, help="Number of parallel jobs to run (default: 32)") - parser.add_argument("--target_folder", type=str, required=True, help="Path to the target folder containing input files") - parser.add_argument("--path", type=str, default='./', help="Base path for the project (default: './')") - parser.add_argument("--molecule_single", type=int, default=-1, help="Number of atoms per molecule (default: 64)") - parser.add_argument("--n_gpus", type=int, default=2, help="Number of GPUs to use (default: 2)") - parser.add_argument("--cueq", action='store_true', help="Whether to use cuEquivariance Library (default: False)") - parser.add_argument("--max_steps", type=int, default=3000, help="Number of max steps to run the optimization (default: 3000)") - parser.add_argument("--use_torch_profiler", action='store_true', help="Whether to use torch profiler (default: False)") - parser.add_argument("--use_nsys", action='store_true', help="Whether to use nsys profiler (default: False)") - parser.add_argument("--model", type=str, default="small", help="Model to use for the calculation (default: 'small')") - parser.add_argument("--optimizer", type=str, default="BFGS", help="Optimizer to use for the calculation (default: 'BFGS')") - parser.add_argument("--use_cuda_eigh", action='store_true', help="Whether to use CUDA for eigh (default: False)") - parser.add_argument("--gpu_offset", type=int, default=0, help="GPU offset to use for the calculation (default: 0)") - parser.add_argument("--multithread", action='store_true', help="Whether to use multithread (default: False)") - parser.add_argument("--reproduce", action='store_true', help="Whether to reproduce deterministic results (default: False)") - parser.add_argument("--filter1", type=str, default="UnitCellFilter", help="1st filter to use for the calculation (default: 'UnitCellFilter')") - parser.add_argument("--filter2", type=str, default="UnitCellFilter", help="2nd filter to use for the calculation (default: 'UnitCellFilter')") - parser.add_argument("--optimizer1", type=str, default="BFGS", help="1st optimizer to use for the calculation (default: 'BFGS')") - parser.add_argument("--optimizer2", type=str, default="BFGS", help="2nd optimizer to use for the calculation (default: 'BFGS')") - - args = parser.parse_args() - - n_jobs = args.n_jobs - target_folder = args.target_folder - path = args.path - molecule_single = args.molecule_single - n_gpus = args.n_gpus - cueq = args.cueq - max_steps = args.max_steps - use_torch_profiler = args.use_torch_profiler - use_nsys = args.use_nsys - model_path = args.model - optimizer_type = args.optimizer - use_cuda_eigh = args.use_cuda_eigh - gpu_offset = args.gpu_offset - multithread = args.multithread - reproduce = args.reproduce - filter1 = args.filter1 - filter2 = args.filter2 - optimizer_type1 = args.optimizer1 - optimizer_type2 = args.optimizer2 - - - try: - os.mkdir("./cif_result_press") - os.mkdir("./cif_result_final") - except: - pass - try: - os.mkdir("./json_result") - except: - pass - - start_time_all = time.time() - - - iter = 0 - while iter < 100: - iter += 1 - try: - run() - break - except Exception as e: - print(f"Error occurred: {e}") - print("Retrying...") - time.sleep(10) - - end_time_all = time.time() - total_time_all = end_time_all - start_time_all - print('dataset,total_time_all_s,attempts') - print(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}") - with open(path + 'timing.csv', 'w') as f: - f.write('dataset,total_time_all_s,attempts\n') +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +os.environ['OMP_NUM_THREADS'] = '1' +os.environ['MKL_NUM_THREADS'] = '1' +os.environ['OPENBLAS_NUM_THREADS'] = '1' +import sys +# sys.path.append('/home/jiangj1group/zcxzcx1/volatile/mace') +from mace.calculators import mace_off, mace_mp +from ase.io import read, write +from ase.optimize import BFGS,LBFGS,FIRE,GPMin,MDMin, QuasiNewton +from ase.filters import UnitCellFilter, ExpCellFilter, FrechetCellFilter +import re +import io +from contextlib import redirect_stdout +import os +import pandas as pd +from joblib import Parallel, delayed +import json +import torch +import numpy as np +import random +import argparse +import time +import pathlib +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True) +##################################################################### +os.environ['PYTHONHASHSEED'] = '1' +torch.manual_seed(1) +np.random.seed(1) +random.seed(1) +torch.cuda.manual_seed(1) +torch.cuda.manual_seed_all(1) +##################################################################### +# n_jobs=32 +# # n_jobs=2 +# path = './' +# molecule_single = 64 +# target_folder = "/data_raw/" +##################################################################### + +def calculate_density(crystal): + # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 + total_mass = sum(crystal.get_masses()) # 转换为克 + + # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 + # 1 Å^3 = 1e-24 cm^3 + volume = crystal.get_volume() # 转换为立方厘米 + + # 计算密度,质量除以体积 + density = total_mass / (volume*10**-24)/(6.022140857*10**23) # 单位是 g/cm^3 + return density + +def run_calculation_one(path,file,target_folder,molecule_single,idx): + # os.environ['OMP_NUM_THREADS'] = '1' + # os.environ['MKL_NUM_THREADS'] = '1' + # os.environ['OPENBLAS_NUM_THREADS'] = '1' + if reproduce: + print("Reproducing deterministic results.") + torch.use_deterministic_algorithms(True) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + np.set_printoptions(precision=17, suppress=False) + torch.set_printoptions(precision=17, sci_mode=False, linewidth=200) + if multithread and (not reproduce): + print("Using OMP and MKL multithreads will introduce non-deterministic results.") + else: + os.environ['OMP_NUM_THREADS'] = '1' + os.environ['MKL_NUM_THREADS'] = '1' + os.environ['OPENBLAS_NUM_THREADS'] = '1' + os.environ["CUDA_VISIBLE_DEVICES"]=str((idx%n_gpus)+gpu_offset) + + with io.StringIO() as buf, redirect_stdout(buf): + crystal = read(path+target_folder+file) + if molecule_single < 0: + molecule_single = int(file.split('_')[-1].split('.')[0]) + molecule_count = len(crystal.get_atomic_numbers())/molecule_single + calc = mace_off(model=model_path,dispersion=True, device='cuda', enable_cueq=cueq) + crystal.calc = calc + if filter1 == "UnitCellFilter": + sf = UnitCellFilter(crystal,scalar_pressure=0.0006) + elif filter1 == "FrechetCellFilter": + sf = FrechetCellFilter(crystal,scalar_pressure=0.0006) + else: + raise ValueError(f"Unrecognized filter type '{filter1}'. " + "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") + if optimizer_type1 == "BFGS": + if use_cuda_eigh: + optimizer = BFGS(sf, use_cuda_eigh=True) + else: + optimizer = BFGS(sf) + elif optimizer_type1 == "LBFGS": + optimizer = LBFGS(sf) + elif optimizer_type1 == "QuasiNewton": + optimizer = QuasiNewton(sf) + else: + raise ValueError(f"Unrecognized optimizer type '{optimizer_type1}'. " + "Supported types are 'BFGS' and 'LBFGS'.") + + if use_nsys or use_torch_profiler : # warmup for profiling + optimizer.run(fmax=0.01,steps=100) + if use_torch_profiler: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA + ], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), + with_stack=True + ) + profiler.start() + + start_time1 = time.time() + optimizer.run(fmax=0.01,steps=max_steps) + end_time1 = time.time() + + if use_torch_profiler: + profiler.stop() + + crystal.write(path+'cif_result_press/'+file[:-4]+"_press.cif") + output_1 = buf.getvalue() + # step_used_1 = float(re.split("\\s+", output_1.split('\n')[-2])[1][:]) + step_used_1 = optimizer.nsteps + if use_nsys or use_torch_profiler : + step_used_1 = step_used_1 - 100 + total_time1 = end_time1 - start_time1 + avg_time1 = total_time1 / step_used_1 if step_used_1 != 0 else 0 + + crystal = read(path+'cif_result_press/'+file[:-4]+"_press.cif") + crystal.calc = calc + if filter2 == "UnitCellFilter": + sf = UnitCellFilter(crystal) + elif filter2 == "FrechetCellFilter": + sf = FrechetCellFilter(crystal) + else: + raise ValueError(f"Unrecognized filter type '{filter2}'. " + "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") + if optimizer_type2 == "BFGS": + if use_cuda_eigh: + optimizer = BFGS(sf, use_cuda_eigh=True) + else: + optimizer = BFGS(sf) + elif optimizer_type2 == "LBFGS": + optimizer = LBFGS(sf) + elif optimizer_type2 == "QuasiNewton": + optimizer = QuasiNewton(sf) + else: + raise ValueError(f"Unrecognized optimizer type '{optimizer_type2}'. " + "Supported types are 'BFGS' and 'LBFGS'.") + if use_torch_profiler: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA + ], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), + with_stack=True + ) + profiler.start() + + start_time2 = time.time() + optimizer.run(fmax=0.01,steps=max_steps) + end_time2 = time.time() + + if use_torch_profiler: + profiler.stop() + + density = calculate_density(crystal) + crystal.write(path+'cif_result_final/'+file[:-4]+"_opt.cif") + output_2 = buf.getvalue() + energy = float(re.split("\\s+", output_2.split('\n')[-2])[3][:]) + # step_used_2 = float(re.split("\\s+", output_2.split('\n')[-2])[1][:]) + step_used_2 = optimizer.nsteps + energy_per_mol = energy / molecule_count * 96.485 + total_time2 = end_time2 - start_time2 + avg_time2 = total_time2 / step_used_2 if step_used_2 != 0 else 0 + + new_row = { + 'name': file[:-4], 'density': density, 'energy_kj': energy_per_mol, + 'step_used_1': step_used_1, 'step_used_2': step_used_2, + 'total_time1_s': total_time1, 'avg_time1_s': avg_time1, + 'total_time2_s': total_time2, 'avg_time2_s': avg_time2 + } + + print(f'output_2: {output_2}') + with open(path+'json_result/'+file[:-4]+".json", 'w') as json_file: + json.dump(new_row, json_file, indent=4) + return new_row + + +def already_have_calculation_one(path, file, target_folder, molecule_single, idx): + logging.info(f"reading on structure {file}") + print(f"reading on structure {file}") + with open(path + 'json_result/' + file[:-4] + ".json", 'r') as file: + old_row = json.load(file) + return old_row + +def run(): + df = pd.DataFrame(columns=['name', 'density', 'energy_kj', 'step_used_1', 'step_used_2', 'total_time1_s', 'avg_time1_s', 'total_time2_s', 'avg_time2_s']) + for root, dirs, files in os.walk(path + target_folder): + old_row = Parallel(n_jobs=n_jobs)( + delayed(already_have_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in + enumerate(files) if os.path.exists(path + 'json_result/' + file[:-4] + ".json")) + + filtered_files = [file for file in files if not os.path.exists(path + 'json_result/' + file[:-4] + ".json")] + new_row = Parallel(n_jobs=n_jobs)( + delayed(run_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in + enumerate(filtered_files)) + # show the length of new_row + print(f'new_row length: {len(new_row)}') + print(f'root: {root}\ndirs: {dirs}\nfiles: {files}') + for row in new_row: + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) + for row in old_row: + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) + df.to_csv(path + '/result.csv') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Run parallel calculations on molecular crystals.") + parser.add_argument("--n_jobs", type=int, default=32, help="Number of parallel jobs to run (default: 32)") + parser.add_argument("--target_folder", type=str, required=True, help="Path to the target folder containing input files") + parser.add_argument("--path", type=str, default='./', help="Base path for the project (default: './')") + parser.add_argument("--molecule_single", type=int, default=-1, help="Number of atoms per molecule (default: 64)") + parser.add_argument("--n_gpus", type=int, default=2, help="Number of GPUs to use (default: 2)") + parser.add_argument("--cueq", action='store_true', help="Whether to use cuEquivariance Library (default: False)") + parser.add_argument("--max_steps", type=int, default=3000, help="Number of max steps to run the optimization (default: 3000)") + parser.add_argument("--use_torch_profiler", action='store_true', help="Whether to use torch profiler (default: False)") + parser.add_argument("--use_nsys", action='store_true', help="Whether to use nsys profiler (default: False)") + parser.add_argument("--model", type=str, default="small", help="Model to use for the calculation (default: 'small')") + parser.add_argument("--optimizer", type=str, default="BFGS", help="Optimizer to use for the calculation (default: 'BFGS')") + parser.add_argument("--use_cuda_eigh", action='store_true', help="Whether to use CUDA for eigh (default: False)") + parser.add_argument("--gpu_offset", type=int, default=0, help="GPU offset to use for the calculation (default: 0)") + parser.add_argument("--multithread", action='store_true', help="Whether to use multithread (default: False)") + parser.add_argument("--reproduce", action='store_true', help="Whether to reproduce deterministic results (default: False)") + parser.add_argument("--filter1", type=str, default="UnitCellFilter", help="1st filter to use for the calculation (default: 'UnitCellFilter')") + parser.add_argument("--filter2", type=str, default="UnitCellFilter", help="2nd filter to use for the calculation (default: 'UnitCellFilter')") + parser.add_argument("--optimizer1", type=str, default="BFGS", help="1st optimizer to use for the calculation (default: 'BFGS')") + parser.add_argument("--optimizer2", type=str, default="BFGS", help="2nd optimizer to use for the calculation (default: 'BFGS')") + + args = parser.parse_args() + + n_jobs = args.n_jobs + target_folder = args.target_folder + path = args.path + molecule_single = args.molecule_single + n_gpus = args.n_gpus + cueq = args.cueq + max_steps = args.max_steps + use_torch_profiler = args.use_torch_profiler + use_nsys = args.use_nsys + model_path = args.model + optimizer_type = args.optimizer + use_cuda_eigh = args.use_cuda_eigh + gpu_offset = args.gpu_offset + multithread = args.multithread + reproduce = args.reproduce + filter1 = args.filter1 + filter2 = args.filter2 + optimizer_type1 = args.optimizer1 + optimizer_type2 = args.optimizer2 + + + try: + os.mkdir("./cif_result_press") + os.mkdir("./cif_result_final") + except: + pass + try: + os.mkdir("./json_result") + except: + pass + + start_time_all = time.time() + + + iter = 0 + while iter < 100: + iter += 1 + try: + run() + break + except Exception as e: + print(f"Error occurred: {e}") + print("Retrying...") + time.sleep(10) + + end_time_all = time.time() + total_time_all = end_time_all - start_time_all + print('dataset,total_time_all_s,attempts') + print(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}") + with open(path + 'timing.csv', 'w') as f: + f.write('dataset,total_time_all_s,attempts\n') f.write(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}\n") \ No newline at end of file diff --git a/mace-bench/reproduce/mace_opt_origin.py b/mace-bench/reproduce/mace_opt_origin.py index f6796e80e11f9dfacc3819216516a4080bdfb2f1..c059fb1d46c3d7d3d7901fd006fc8160cb3b47e3 100644 --- a/mace-bench/reproduce/mace_opt_origin.py +++ b/mace-bench/reproduce/mace_opt_origin.py @@ -1,297 +1,297 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -import os -import sys -# sys.path.append('/home/jiangj1group/zcxzcx1/volatile/mace') -from mace.calculators import mace_off, mace_mp -from ase.io import read, write -from ase.optimize import BFGS,LBFGS,FIRE,GPMin,MDMin, QuasiNewton -from ase.filters import UnitCellFilter, ExpCellFilter, FrechetCellFilter -import re -import io -from contextlib import redirect_stdout -import os -import pandas as pd -from joblib import Parallel, delayed -import json -import torch -import numpy as np -import random -import argparse -import time -import pathlib -import logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True) -##################################################################### -os.environ['PYTHONHASHSEED'] = '1' -torch.manual_seed(1) -np.random.seed(1) -random.seed(1) -torch.cuda.manual_seed(1) -torch.cuda.manual_seed_all(1) -##################################################################### -# n_jobs=32 -# # n_jobs=2 -# path = './' -# molecule_single = 64 -# target_folder = "/data_raw/" -##################################################################### - -def calculate_density(crystal): - # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 - total_mass = sum(crystal.get_masses()) # 转换为克 - - # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 - # 1 Å^3 = 1e-24 cm^3 - volume = crystal.get_volume() # 转换为立方厘米 - - # 计算密度,质量除以体积 - density = total_mass / (volume*10**-24)/(6.022140857*10**23) # 单位是 g/cm^3 - return density - -def run_calculation_one(path,file,target_folder,molecule_single,idx): - # os.environ['OMP_NUM_THREADS'] = '1' - # os.environ['MKL_NUM_THREADS'] = '1' - # os.environ['OPENBLAS_NUM_THREADS'] = '1' - if reproduce: - print("Reproducing deterministic results.") - torch.use_deterministic_algorithms(True) - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - np.set_printoptions(precision=17, suppress=False) - torch.set_printoptions(precision=17, sci_mode=False, linewidth=200) - if multithread and (not reproduce): - print("Using OMP and MKL multithreads will introduce non-deterministic results.") - else: - os.environ['OMP_NUM_THREADS'] = '1' - os.environ['MKL_NUM_THREADS'] = '1' - os.environ['OPENBLAS_NUM_THREADS'] = '1' - os.environ["CUDA_VISIBLE_DEVICES"]=str((idx%n_gpus)+gpu_offset) - - with io.StringIO() as buf, redirect_stdout(buf): - crystal = read(path+target_folder+file) - if molecule_single < 0: - molecule_single = int(file.split('_')[-1].split('.')[0]) - molecule_count = len(crystal.get_atomic_numbers())/molecule_single - calc = mace_off(model=model_path,dispersion=True, device='cuda', enable_cueq=cueq) - crystal.calc = calc - if filter1 == "UnitCellFilter": - sf = UnitCellFilter(crystal,scalar_pressure=0.0006) - elif filter1 == "FrechetCellFilter": - sf = FrechetCellFilter(crystal,scalar_pressure=0.0006) - else: - raise ValueError(f"Unrecognized filter type '{filter1}'. " - "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") - if optimizer_type1 == "BFGS": - if use_cuda_eigh: - optimizer = BFGS(sf, use_cuda_eigh=True) - else: - optimizer = BFGS(sf) - elif optimizer_type1 == "LBFGS": - optimizer = LBFGS(sf) - elif optimizer_type1 == "QuasiNewton": - optimizer = QuasiNewton(sf) - else: - raise ValueError(f"Unrecognized optimizer type '{optimizer_type1}'. " - "Supported types are 'BFGS' and 'LBFGS'.") - - if use_nsys or use_torch_profiler : # warmup for profiling - optimizer.run(fmax=0.01,steps=100) - if use_torch_profiler: - profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA - ], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), - with_stack=True - ) - profiler.start() - - start_time1 = time.time() - optimizer.run(fmax=0.01,steps=max_steps) - end_time1 = time.time() - - if use_torch_profiler: - profiler.stop() - - crystal.write(path+'cif_result_press/'+file[:-4]+"_press.cif") - output_1 = buf.getvalue() - # step_used_1 = float(re.split("\\s+", output_1.split('\n')[-2])[1][:]) - step_used_1 = optimizer.nsteps - if use_nsys or use_torch_profiler : - step_used_1 = step_used_1 - 100 - total_time1 = end_time1 - start_time1 - avg_time1 = total_time1 / step_used_1 if step_used_1 != 0 else 0 - - crystal = read(path+'cif_result_press/'+file[:-4]+"_press.cif") - crystal.calc = calc - if filter2 == "UnitCellFilter": - sf = UnitCellFilter(crystal) - elif filter2 == "FrechetCellFilter": - sf = FrechetCellFilter(crystal) - else: - raise ValueError(f"Unrecognized filter type '{filter2}'. " - "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") - if optimizer_type2 == "BFGS": - if use_cuda_eigh: - optimizer = BFGS(sf, use_cuda_eigh=True) - else: - optimizer = BFGS(sf) - elif optimizer_type2 == "LBFGS": - optimizer = LBFGS(sf) - elif optimizer_type2 == "QuasiNewton": - optimizer = QuasiNewton(sf) - else: - raise ValueError(f"Unrecognized optimizer type '{optimizer_type2}'. " - "Supported types are 'BFGS' and 'LBFGS'.") - if use_torch_profiler: - profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA - ], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), - with_stack=True - ) - profiler.start() - - start_time2 = time.time() - optimizer.run(fmax=0.01,steps=max_steps) - end_time2 = time.time() - - if use_torch_profiler: - profiler.stop() - - density = calculate_density(crystal) - crystal.write(path+'cif_result_final/'+file[:-4]+"_opt.cif") - output_2 = buf.getvalue() - energy = float(re.split("\\s+", output_2.split('\n')[-2])[3][:]) - # step_used_2 = float(re.split("\\s+", output_2.split('\n')[-2])[1][:]) - step_used_2 = optimizer.nsteps - energy_per_mol = energy / molecule_count * 96.485 - total_time2 = end_time2 - start_time2 - avg_time2 = total_time2 / step_used_2 if step_used_2 != 0 else 0 - - new_row = { - 'name': file[:-4], 'density': density, 'energy_kj': energy_per_mol, - 'step_used_1': step_used_1, 'step_used_2': step_used_2, - 'total_time1_s': total_time1, 'avg_time1_s': avg_time1, - 'total_time2_s': total_time2, 'avg_time2_s': avg_time2 - } - - print(f'output_2: {output_2}') - with open(path+'json_result/'+file[:-4]+".json", 'w') as json_file: - json.dump(new_row, json_file, indent=4) - return new_row - - -def already_have_calculation_one(path, file, target_folder, molecule_single, idx): - logging.info(f"reading on structure {file}") - print(f"reading on structure {file}") - with open(path + 'json_result/' + file[:-4] + ".json", 'r') as file: - old_row = json.load(file) - return old_row - -def run(): - df = pd.DataFrame(columns=['name', 'density', 'energy_kj', 'step_used_1', 'step_used_2', 'total_time1_s', 'avg_time1_s', 'total_time2_s', 'avg_time2_s']) - for root, dirs, files in os.walk(path + target_folder): - old_row = Parallel(n_jobs=n_jobs)( - delayed(already_have_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in - enumerate(files) if os.path.exists(path + 'json_result/' + file[:-4] + ".json")) - - filtered_files = [file for file in files if not os.path.exists(path + 'json_result/' + file[:-4] + ".json")] - new_row = Parallel(n_jobs=n_jobs)( - delayed(run_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in - enumerate(filtered_files)) - # show the length of new_row - print(f'new_row length: {len(new_row)}') - print(f'root: {root}\ndirs: {dirs}\nfiles: {files}') - for row in new_row: - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) - for row in old_row: - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) - df.to_csv(path + '/result.csv') - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Run parallel calculations on molecular crystals.") - parser.add_argument("--n_jobs", type=int, default=32, help="Number of parallel jobs to run (default: 32)") - parser.add_argument("--target_folder", type=str, required=True, help="Path to the target folder containing input files") - parser.add_argument("--path", type=str, default='./', help="Base path for the project (default: './')") - parser.add_argument("--molecule_single", type=int, default=-1, help="Number of atoms per molecule (default: 64)") - parser.add_argument("--n_gpus", type=int, default=2, help="Number of GPUs to use (default: 2)") - parser.add_argument("--cueq", action='store_true', help="Whether to use cuEquivariance Library (default: False)") - parser.add_argument("--max_steps", type=int, default=3000, help="Number of max steps to run the optimization (default: 3000)") - parser.add_argument("--use_torch_profiler", action='store_true', help="Whether to use torch profiler (default: False)") - parser.add_argument("--use_nsys", action='store_true', help="Whether to use nsys profiler (default: False)") - parser.add_argument("--model", type=str, default="small", help="Model to use for the calculation (default: 'small')") - parser.add_argument("--optimizer", type=str, default="BFGS", help="Optimizer to use for the calculation (default: 'BFGS')") - parser.add_argument("--use_cuda_eigh", action='store_true', help="Whether to use CUDA for eigh (default: False)") - parser.add_argument("--gpu_offset", type=int, default=0, help="GPU offset to use for the calculation (default: 0)") - parser.add_argument("--multithread", action='store_true', help="Whether to use multithread (default: False)") - parser.add_argument("--reproduce", action='store_true', help="Whether to reproduce deterministic results (default: False)") - parser.add_argument("--filter1", type=str, default="UnitCellFilter", help="1st filter to use for the calculation (default: 'UnitCellFilter')") - parser.add_argument("--filter2", type=str, default="UnitCellFilter", help="2nd filter to use for the calculation (default: 'UnitCellFilter')") - parser.add_argument("--optimizer1", type=str, default="BFGS", help="1st optimizer to use for the calculation (default: 'BFGS')") - parser.add_argument("--optimizer2", type=str, default="BFGS", help="2nd optimizer to use for the calculation (default: 'BFGS')") - - args = parser.parse_args() - - n_jobs = args.n_jobs - target_folder = args.target_folder - path = args.path - molecule_single = args.molecule_single - n_gpus = args.n_gpus - cueq = args.cueq - max_steps = args.max_steps - use_torch_profiler = args.use_torch_profiler - use_nsys = args.use_nsys - model_path = args.model - optimizer_type = args.optimizer - use_cuda_eigh = args.use_cuda_eigh - gpu_offset = args.gpu_offset - multithread = args.multithread - reproduce = args.reproduce - filter1 = args.filter1 - filter2 = args.filter2 - optimizer_type1 = args.optimizer1 - optimizer_type2 = args.optimizer2 - - - try: - os.mkdir("./cif_result_press") - os.mkdir("./cif_result_final") - except: - pass - try: - os.mkdir("./json_result") - except: - pass - - start_time_all = time.time() - - - iter = 0 - while iter < 100: - iter += 1 - try: - run() - break - except Exception as e: - print(f"Error occurred: {e}") - print("Retrying...") - time.sleep(10) - - end_time_all = time.time() - total_time_all = end_time_all - start_time_all - print('dataset,total_time_all_s,attempts') - print(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}") - with open(path + 'timing.csv', 'w') as f: - f.write('dataset,total_time_all_s,attempts\n') +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import sys +# sys.path.append('/home/jiangj1group/zcxzcx1/volatile/mace') +from mace.calculators import mace_off, mace_mp +from ase.io import read, write +from ase.optimize import BFGS,LBFGS,FIRE,GPMin,MDMin, QuasiNewton +from ase.filters import UnitCellFilter, ExpCellFilter, FrechetCellFilter +import re +import io +from contextlib import redirect_stdout +import os +import pandas as pd +from joblib import Parallel, delayed +import json +import torch +import numpy as np +import random +import argparse +import time +import pathlib +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True) +##################################################################### +os.environ['PYTHONHASHSEED'] = '1' +torch.manual_seed(1) +np.random.seed(1) +random.seed(1) +torch.cuda.manual_seed(1) +torch.cuda.manual_seed_all(1) +##################################################################### +# n_jobs=32 +# # n_jobs=2 +# path = './' +# molecule_single = 64 +# target_folder = "/data_raw/" +##################################################################### + +def calculate_density(crystal): + # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 + total_mass = sum(crystal.get_masses()) # 转换为克 + + # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 + # 1 Å^3 = 1e-24 cm^3 + volume = crystal.get_volume() # 转换为立方厘米 + + # 计算密度,质量除以体积 + density = total_mass / (volume*10**-24)/(6.022140857*10**23) # 单位是 g/cm^3 + return density + +def run_calculation_one(path,file,target_folder,molecule_single,idx): + # os.environ['OMP_NUM_THREADS'] = '1' + # os.environ['MKL_NUM_THREADS'] = '1' + # os.environ['OPENBLAS_NUM_THREADS'] = '1' + if reproduce: + print("Reproducing deterministic results.") + torch.use_deterministic_algorithms(True) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + np.set_printoptions(precision=17, suppress=False) + torch.set_printoptions(precision=17, sci_mode=False, linewidth=200) + if multithread and (not reproduce): + print("Using OMP and MKL multithreads will introduce non-deterministic results.") + else: + os.environ['OMP_NUM_THREADS'] = '1' + os.environ['MKL_NUM_THREADS'] = '1' + os.environ['OPENBLAS_NUM_THREADS'] = '1' + os.environ["CUDA_VISIBLE_DEVICES"]=str((idx%n_gpus)+gpu_offset) + + with io.StringIO() as buf, redirect_stdout(buf): + crystal = read(path+target_folder+file) + if molecule_single < 0: + molecule_single = int(file.split('_')[-1].split('.')[0]) + molecule_count = len(crystal.get_atomic_numbers())/molecule_single + calc = mace_off(model=model_path,dispersion=True, device='cuda', enable_cueq=cueq) + crystal.calc = calc + if filter1 == "UnitCellFilter": + sf = UnitCellFilter(crystal,scalar_pressure=0.0006) + elif filter1 == "FrechetCellFilter": + sf = FrechetCellFilter(crystal,scalar_pressure=0.0006) + else: + raise ValueError(f"Unrecognized filter type '{filter1}'. " + "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") + if optimizer_type1 == "BFGS": + if use_cuda_eigh: + optimizer = BFGS(sf, use_cuda_eigh=True) + else: + optimizer = BFGS(sf) + elif optimizer_type1 == "LBFGS": + optimizer = LBFGS(sf) + elif optimizer_type1 == "QuasiNewton": + optimizer = QuasiNewton(sf) + else: + raise ValueError(f"Unrecognized optimizer type '{optimizer_type1}'. " + "Supported types are 'BFGS' and 'LBFGS'.") + + if use_nsys or use_torch_profiler : # warmup for profiling + optimizer.run(fmax=0.01,steps=100) + if use_torch_profiler: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA + ], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), + with_stack=True + ) + profiler.start() + + start_time1 = time.time() + optimizer.run(fmax=0.01,steps=max_steps) + end_time1 = time.time() + + if use_torch_profiler: + profiler.stop() + + crystal.write(path+'cif_result_press/'+file[:-4]+"_press.cif") + output_1 = buf.getvalue() + # step_used_1 = float(re.split("\\s+", output_1.split('\n')[-2])[1][:]) + step_used_1 = optimizer.nsteps + if use_nsys or use_torch_profiler : + step_used_1 = step_used_1 - 100 + total_time1 = end_time1 - start_time1 + avg_time1 = total_time1 / step_used_1 if step_used_1 != 0 else 0 + + crystal = read(path+'cif_result_press/'+file[:-4]+"_press.cif") + crystal.calc = calc + if filter2 == "UnitCellFilter": + sf = UnitCellFilter(crystal) + elif filter2 == "FrechetCellFilter": + sf = FrechetCellFilter(crystal) + else: + raise ValueError(f"Unrecognized filter type '{filter2}'. " + "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") + if optimizer_type2 == "BFGS": + if use_cuda_eigh: + optimizer = BFGS(sf, use_cuda_eigh=True) + else: + optimizer = BFGS(sf) + elif optimizer_type2 == "LBFGS": + optimizer = LBFGS(sf) + elif optimizer_type2 == "QuasiNewton": + optimizer = QuasiNewton(sf) + else: + raise ValueError(f"Unrecognized optimizer type '{optimizer_type2}'. " + "Supported types are 'BFGS' and 'LBFGS'.") + if use_torch_profiler: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA + ], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), + with_stack=True + ) + profiler.start() + + start_time2 = time.time() + optimizer.run(fmax=0.01,steps=max_steps) + end_time2 = time.time() + + if use_torch_profiler: + profiler.stop() + + density = calculate_density(crystal) + crystal.write(path+'cif_result_final/'+file[:-4]+"_opt.cif") + output_2 = buf.getvalue() + energy = float(re.split("\\s+", output_2.split('\n')[-2])[3][:]) + # step_used_2 = float(re.split("\\s+", output_2.split('\n')[-2])[1][:]) + step_used_2 = optimizer.nsteps + energy_per_mol = energy / molecule_count * 96.485 + total_time2 = end_time2 - start_time2 + avg_time2 = total_time2 / step_used_2 if step_used_2 != 0 else 0 + + new_row = { + 'name': file[:-4], 'density': density, 'energy_kj': energy_per_mol, + 'step_used_1': step_used_1, 'step_used_2': step_used_2, + 'total_time1_s': total_time1, 'avg_time1_s': avg_time1, + 'total_time2_s': total_time2, 'avg_time2_s': avg_time2 + } + + print(f'output_2: {output_2}') + with open(path+'json_result/'+file[:-4]+".json", 'w') as json_file: + json.dump(new_row, json_file, indent=4) + return new_row + + +def already_have_calculation_one(path, file, target_folder, molecule_single, idx): + logging.info(f"reading on structure {file}") + print(f"reading on structure {file}") + with open(path + 'json_result/' + file[:-4] + ".json", 'r') as file: + old_row = json.load(file) + return old_row + +def run(): + df = pd.DataFrame(columns=['name', 'density', 'energy_kj', 'step_used_1', 'step_used_2', 'total_time1_s', 'avg_time1_s', 'total_time2_s', 'avg_time2_s']) + for root, dirs, files in os.walk(path + target_folder): + old_row = Parallel(n_jobs=n_jobs)( + delayed(already_have_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in + enumerate(files) if os.path.exists(path + 'json_result/' + file[:-4] + ".json")) + + filtered_files = [file for file in files if not os.path.exists(path + 'json_result/' + file[:-4] + ".json")] + new_row = Parallel(n_jobs=n_jobs)( + delayed(run_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in + enumerate(filtered_files)) + # show the length of new_row + print(f'new_row length: {len(new_row)}') + print(f'root: {root}\ndirs: {dirs}\nfiles: {files}') + for row in new_row: + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) + for row in old_row: + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) + df.to_csv(path + '/result.csv') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Run parallel calculations on molecular crystals.") + parser.add_argument("--n_jobs", type=int, default=32, help="Number of parallel jobs to run (default: 32)") + parser.add_argument("--target_folder", type=str, required=True, help="Path to the target folder containing input files") + parser.add_argument("--path", type=str, default='./', help="Base path for the project (default: './')") + parser.add_argument("--molecule_single", type=int, default=-1, help="Number of atoms per molecule (default: 64)") + parser.add_argument("--n_gpus", type=int, default=2, help="Number of GPUs to use (default: 2)") + parser.add_argument("--cueq", action='store_true', help="Whether to use cuEquivariance Library (default: False)") + parser.add_argument("--max_steps", type=int, default=3000, help="Number of max steps to run the optimization (default: 3000)") + parser.add_argument("--use_torch_profiler", action='store_true', help="Whether to use torch profiler (default: False)") + parser.add_argument("--use_nsys", action='store_true', help="Whether to use nsys profiler (default: False)") + parser.add_argument("--model", type=str, default="small", help="Model to use for the calculation (default: 'small')") + parser.add_argument("--optimizer", type=str, default="BFGS", help="Optimizer to use for the calculation (default: 'BFGS')") + parser.add_argument("--use_cuda_eigh", action='store_true', help="Whether to use CUDA for eigh (default: False)") + parser.add_argument("--gpu_offset", type=int, default=0, help="GPU offset to use for the calculation (default: 0)") + parser.add_argument("--multithread", action='store_true', help="Whether to use multithread (default: False)") + parser.add_argument("--reproduce", action='store_true', help="Whether to reproduce deterministic results (default: False)") + parser.add_argument("--filter1", type=str, default="UnitCellFilter", help="1st filter to use for the calculation (default: 'UnitCellFilter')") + parser.add_argument("--filter2", type=str, default="UnitCellFilter", help="2nd filter to use for the calculation (default: 'UnitCellFilter')") + parser.add_argument("--optimizer1", type=str, default="BFGS", help="1st optimizer to use for the calculation (default: 'BFGS')") + parser.add_argument("--optimizer2", type=str, default="BFGS", help="2nd optimizer to use for the calculation (default: 'BFGS')") + + args = parser.parse_args() + + n_jobs = args.n_jobs + target_folder = args.target_folder + path = args.path + molecule_single = args.molecule_single + n_gpus = args.n_gpus + cueq = args.cueq + max_steps = args.max_steps + use_torch_profiler = args.use_torch_profiler + use_nsys = args.use_nsys + model_path = args.model + optimizer_type = args.optimizer + use_cuda_eigh = args.use_cuda_eigh + gpu_offset = args.gpu_offset + multithread = args.multithread + reproduce = args.reproduce + filter1 = args.filter1 + filter2 = args.filter2 + optimizer_type1 = args.optimizer1 + optimizer_type2 = args.optimizer2 + + + try: + os.mkdir("./cif_result_press") + os.mkdir("./cif_result_final") + except: + pass + try: + os.mkdir("./json_result") + except: + pass + + start_time_all = time.time() + + + iter = 0 + while iter < 100: + iter += 1 + try: + run() + break + except Exception as e: + print(f"Error occurred: {e}") + print("Retrying...") + time.sleep(10) + + end_time_all = time.time() + total_time_all = end_time_all - start_time_all + print('dataset,total_time_all_s,attempts') + print(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}") + with open(path + 'timing.csv', 'w') as f: + f.write('dataset,total_time_all_s,attempts\n') f.write(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}\n") \ No newline at end of file diff --git a/mace-bench/reproduce/perf_v2_base/run_mace.sh b/mace-bench/reproduce/perf_v2_base/run_mace.sh index b886fe1e6502184b0228b8190ea694970dc4f129..c9e81c83441007a14a9234efb9d17cc1df8f254b 100644 --- a/mace-bench/reproduce/perf_v2_base/run_mace.sh +++ b/mace-bench/reproduce/perf_v2_base/run_mace.sh @@ -1,5 +1,4 @@ -#!/bin/bash - -python ../mace_opt_new.py --n_jobs 64 --molecule_single 46 \ - --target_folder ../../data/perf_v2/ --model small --n_gpus 4 --gpu_offset 0 \ + +python ../mace_opt_new.py --n_jobs 64 --molecule_single 46 \ + --target_folder ../../data/perf_v2/ --model small --n_gpus 4 --gpu_offset 0 \ --optimizer1 QuasiNewton --filter1 UnitCellFilter --filter2 UnitCellFilter \ No newline at end of file diff --git a/mace-bench/reproduce/perf_v2_batch/opt.sh b/mace-bench/reproduce/perf_v2_batch/opt.sh index 51e8f977e7246000c8e37da704be1fb7e5d933f0..9c8c172661524cbbf2d6ec47fcbbf06529f819bb 100644 --- a/mace-bench/reproduce/perf_v2_batch/opt.sh +++ b/mace-bench/reproduce/perf_v2_batch/opt.sh @@ -1,6 +1,5 @@ -#!/bin/bash - -rm -r *_result_* - -python ../../scripts/mace_opt_batch.py --target_folder "../../data/perf_v2" --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers 40 --batch_size 0 \ + +rm -r *_result_* + +python ../../scripts/mace_opt_batch.py --target_folder "../../data/perf_v2" --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers 40 --batch_size 0 \ --max_steps 6000 --filter1 UnitCellFilter --filter2 UnitCellFilter --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 --cueq true --use_ordered_files true \ No newline at end of file diff --git a/mace-bench/reproduce/subtest.sh b/mace-bench/reproduce/subtest.sh index f9c74c7e20c56ed0c2b4f7b3b376141fcb612d14..8aaa9302691b4978b8520f884a5a6089734fa4c9 100644 --- a/mace-bench/reproduce/subtest.sh +++ b/mace-bench/reproduce/subtest.sh @@ -1,25 +1,25 @@ -#!/bin/bash - -top_dir=$(pwd) - -natoms_nw_bs=( - "92 48 25" - "184 40 12" - "368 40 5" -) - -for config in "${natoms_nw_bs[@]}"; do - read natoms nw bs <<< "$config" - - dir="$top_dir/subtest_BATCH_${natoms}_g4_j${nw}_bs${bs}_cueq_cupbc" - mkdir -p "$dir" - cd "$dir" || continue - - pwd - python ../../scripts/mace_opt_batch.py \ - --target_folder "../../data/perf_v2_sorted/perf_v2_${natoms}" \ - --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers ${nw} --batch_size ${bs} \ - --max_steps 6000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ - --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 \ - --use_ordered_files true --cueq true > opt.log 2>&1 + + +top_dir=$(pwd) + +natoms_nw_bs=( + "92 48 25" + "184 40 12" + "368 40 5" +) + +for config in "${natoms_nw_bs[@]}"; do + read natoms nw bs <<< "$config" + + dir="$top_dir/subtest_BATCH_${natoms}_g4_j${nw}_bs${bs}_cueq_cupbc" + mkdir -p "$dir" + cd "$dir" || continue + + pwd + python ../../scripts/mace_opt_batch.py \ + --target_folder "../../data/perf_v2_sorted/perf_v2_${natoms}" \ + --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers ${nw} --batch_size ${bs} \ + --max_steps 6000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ + --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 \ + --use_ordered_files true --cueq true > opt.log 2>&1 done \ No newline at end of file diff --git a/mace-bench/reproduce/subtest_baseline.sh b/mace-bench/reproduce/subtest_baseline.sh index 64761d46d30ff000affdd86f2fdd708074925e45..08e4c4954e976acc0fecebddc18adb2bf7200e24 100644 --- a/mace-bench/reproduce/subtest_baseline.sh +++ b/mace-bench/reproduce/subtest_baseline.sh @@ -1,24 +1,24 @@ -#!/bin/bash - -top_dir=$(pwd) - -natoms_nw_bs=( - "92 64" - "184 64" - "368 64" -) - -for config in "${natoms_nw_bs[@]}"; do - read natoms nw <<< "$config" - - dir="$top_dir/subtest_BASE_${natoms}_g4_j${nw}" - mkdir -p "$dir" - cd "$dir" || continue - - pwd - - python ../mace_opt_new.py --n_jobs ${nw} --molecule_single 46 \ - --target_folder ../../data/perf_v2_sorted/perf_v2_${natoms}/ --model small --n_gpus 4 \ - --gpu_offset 0 --optimizer1 QuasiNewton --filter1 UnitCellFilter \ - --filter2 UnitCellFilter --max_steps 3000 > opt.log 2>&1 + + +top_dir=$(pwd) + +natoms_nw_bs=( + "92 64" + "184 64" + "368 64" +) + +for config in "${natoms_nw_bs[@]}"; do + read natoms nw <<< "$config" + + dir="$top_dir/subtest_BASE_${natoms}_g4_j${nw}" + mkdir -p "$dir" + cd "$dir" || continue + + pwd + + python ../mace_opt_new.py --n_jobs ${nw} --molecule_single 46 \ + --target_folder ../../data/perf_v2_sorted/perf_v2_${natoms}/ --model small --n_gpus 4 \ + --gpu_offset 0 --optimizer1 QuasiNewton --filter1 UnitCellFilter \ + --filter2 UnitCellFilter --max_steps 3000 > opt.log 2>&1 done \ No newline at end of file diff --git a/mace-bench/requirements.txt b/mace-bench/requirements.txt index aa57e9540bc3b26ca61ad99b9290ada94eb134b7..135b56b83c1b222dc2d286a3eb8ae480d095a430 100644 --- a/mace-bench/requirements.txt +++ b/mace-bench/requirements.txt @@ -1,137 +1,137 @@ ---extra-index-url https://download.pytorch.org/whl/cu121 -absl-py==2.1.0 -aiohappyeyeballs==2.4.4 -aiohttp==3.11.11 -aiosignal==1.3.2 -annotated-types==0.7.0 -antlr4-python3-runtime==4.9.3 -# -e git+https://gitlab.com/ase/ase.git@72c50c76bac2396c7d58385b231c65bd07458279#egg=ase&subdirectory=../../../3rdparty/ase -async-timeout==5.0.1 -attrs==24.3.0 -certifi==2024.8.30 -cfgv==3.4.0 -charset-normalizer==3.4.0 -click==8.1.8 -cloudpickle==3.1.0 -ConfigArgParse==1.7 -contourpy==1.3.1 -coverage==7.6.9 -cuequivariance==0.4.0 -cuequivariance-ops-torch-cu12==0.4.0 -cuequivariance-ops-cu12==0.4.0 -cuequivariance-torch==0.4.0 -cycler==0.12.1 -distlib==0.3.9 -docker-pycreds==0.4.0 -e3nn==0.4.4 -exceptiongroup==1.2.2 -# -e git+https://github.com/mazhaojia123/fairchem.git@f50db9d5b29debdfb265d9c3fad394f18e16cab8#egg=fairchem_core&subdirectory=../../../3rdparty/fairchem/packages/fairchem-core -filelock==3.13.1 -fonttools==4.55.1 -frozenlist==1.5.0 -fsspec==2024.2.0 -gitdb==4.0.11 -GitPython==3.1.43 -grpcio==1.68.1 -h5py==3.12.1 -hydra-core==1.3.2 -identify==2.6.3 -idna==3.10 -iniconfig==2.0.0 -Jinja2==3.1.3 -joblib==1.4.2 -kiwisolver==1.4.7 -latexcodec==3.0.0 -lightning-utilities==0.11.9 -llvmlite==0.43.0 -lmdb==1.5.1 -# -e git+https://github.com/mazhaojia123/mace.git@edd6b479f4974d0b8162712872ad2eed1aa2fb75#egg=mace_torch&subdirectory=../../../3rdparty/mace -Markdown==3.7 -MarkupSafe==2.1.5 -matplotlib==3.9.3 -matscipy==1.1.1 -monty==2024.10.21 -mpmath==1.3.0 -multidict==6.1.0 -networkx==3.2.1 -nodeenv==1.9.1 -numba==0.60.0 -numpy==1.26.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==9.1.0.70 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.20.5 -nvidia-nvjitlink-cu12==12.1.105 -nvidia-nvtx-cu12==12.1.105 -omegaconf==2.3.0 -opt-einsum-fx==0.1.4 -opt_einsum==3.4.0 -orjson==3.10.12 -packaging==24.2 -palettable==3.3.3 -pandas==2.2.3 -pillow==11.0.0 -platformdirs==4.3.6 -plotly==5.24.1 -pluggy==1.5.0 -pre_commit==4.0.1 -prettytable==3.12.0 -propcache==0.2.1 -protobuf==5.29.2 -psutil==6.1.1 -pybtex==0.24.0 -pydantic==2.10.4 -pydantic_core==2.27.2 -pymatgen==2024.11.13 -pyparsing==3.2.0 -pytest==8.3.4 -pytest-cov==6.0.0 -python-dateutil==2.9.0.post0 -python-hostlist==2.0.0 -pytz==2024.2 -PyYAML==6.0.2 -requests==2.32.3 -ruamel.yaml==0.18.6 -ruamel.yaml.clib==0.2.12 -ruff==0.5.1 -scipy==1.14.1 -sentry-sdk==2.19.2 -setproctitle==1.3.4 -six==1.16.0 -smmap==5.0.1 -spglib==2.5.0 -submitit==1.5.2 -sympy==1.13.1 -syrupy==4.8.0 -tabulate==0.9.0 -tenacity==9.0.0 -tensorboard==2.18.0 -tensorboard-data-server==0.7.2 -tomli==2.2.1 -torch==2.4.1+cu121 -# ./torch-2.4.1+cu121-cp310-cp310-linux_x86_64.whl -torch-dftd==0.5.1 -torch-ema==0.3 -torch-geometric==2.6.1 -# torch_scatter==2.1.2+pt24cu121 -# torch_sparse==0.6.18+pt24cu121 -# torch_spline_conv==1.2.2+pt24cu121 -torchmetrics==1.6.0 -tqdm==4.67.1 -triton==3.0.0 -typing_extensions==4.12.2 -tzdata==2024.2 -uncertainties==3.2.2 -urllib3==2.2.3 -virtualenv==20.28.0 -wandb==0.19.1 -wcwidth==0.2.13 -Werkzeug==3.1.3 -yarl==1.18.3 +--extra-index-url https://download.pytorch.org/whl/cu121 +absl-py==2.1.0 +aiohappyeyeballs==2.4.4 +aiohttp==3.11.11 +aiosignal==1.3.2 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +# -e git+https://gitlab.com/ase/ase.git@72c50c76bac2396c7d58385b231c65bd07458279#egg=ase&subdirectory=../../../3rdparty/ase +async-timeout==5.0.1 +attrs==24.3.0 +certifi==2024.8.30 +cfgv==3.4.0 +charset-normalizer==3.4.0 +click==8.1.8 +cloudpickle==3.1.0 +ConfigArgParse==1.7 +contourpy==1.3.1 +coverage==7.6.9 +cuequivariance==0.4.0 +cuequivariance-ops-torch-cu12==0.4.0 +cuequivariance-ops-cu12==0.4.0 +cuequivariance-torch==0.4.0 +cycler==0.12.1 +distlib==0.3.9 +docker-pycreds==0.4.0 +e3nn==0.4.4 +exceptiongroup==1.2.2 +# -e git+https://github.com/mazhaojia123/fairchem.git@f50db9d5b29debdfb265d9c3fad394f18e16cab8#egg=fairchem_core&subdirectory=../../../3rdparty/fairchem/packages/fairchem-core +filelock==3.13.1 +fonttools==4.55.1 +frozenlist==1.5.0 +fsspec==2024.2.0 +gitdb==4.0.11 +GitPython==3.1.43 +grpcio==1.68.1 +h5py==3.12.1 +hydra-core==1.3.2 +identify==2.6.3 +idna==3.10 +iniconfig==2.0.0 +Jinja2==3.1.3 +joblib==1.4.2 +kiwisolver==1.4.7 +latexcodec==3.0.0 +lightning-utilities==0.11.9 +llvmlite==0.43.0 +lmdb==1.5.1 +# -e git+https://github.com/mazhaojia123/mace.git@edd6b479f4974d0b8162712872ad2eed1aa2fb75#egg=mace_torch&subdirectory=../../../3rdparty/mace +Markdown==3.7 +MarkupSafe==2.1.5 +matplotlib==3.9.3 +matscipy==1.1.1 +monty==2024.10.21 +mpmath==1.3.0 +multidict==6.1.0 +networkx==3.2.1 +nodeenv==1.9.1 +numba==0.60.0 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.1.105 +nvidia-nvtx-cu12==12.1.105 +omegaconf==2.3.0 +opt-einsum-fx==0.1.4 +opt_einsum==3.4.0 +orjson==3.10.12 +packaging==24.2 +palettable==3.3.3 +pandas==2.2.3 +pillow==11.0.0 +platformdirs==4.3.6 +plotly==5.24.1 +pluggy==1.5.0 +pre_commit==4.0.1 +prettytable==3.12.0 +propcache==0.2.1 +protobuf==5.29.2 +psutil==6.1.1 +pybtex==0.24.0 +pydantic==2.10.4 +pydantic_core==2.27.2 +pymatgen==2024.11.13 +pyparsing==3.2.0 +pytest==8.3.4 +pytest-cov==6.0.0 +python-dateutil==2.9.0.post0 +python-hostlist==2.0.0 +pytz==2024.2 +PyYAML==6.0.2 +requests==2.32.3 +ruamel.yaml==0.18.6 +ruamel.yaml.clib==0.2.12 +ruff==0.5.1 +scipy==1.14.1 +sentry-sdk==2.19.2 +setproctitle==1.3.4 +six==1.16.0 +smmap==5.0.1 +spglib==2.5.0 +submitit==1.5.2 +sympy==1.13.1 +syrupy==4.8.0 +tabulate==0.9.0 +tenacity==9.0.0 +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +tomli==2.2.1 +torch==2.4.1+cu121 +# ./torch-2.4.1+cu121-cp310-cp310-linux_x86_64.whl +torch-dftd==0.5.1 +torch-ema==0.3 +torch-geometric==2.6.1 +# torch_scatter==2.1.2+pt24cu121 +# torch_sparse==0.6.18+pt24cu121 +# torch_spline_conv==1.2.2+pt24cu121 +torchmetrics==1.6.0 +tqdm==4.67.1 +triton==3.0.0 +typing_extensions==4.12.2 +tzdata==2024.2 +uncertainties==3.2.2 +urllib3==2.2.3 +virtualenv==20.28.0 +wandb==0.19.1 +wcwidth==0.2.13 +Werkzeug==3.1.3 +yarl==1.18.3 torch-tb-profiler==0.4.3 \ No newline at end of file diff --git a/mace-bench/scripts/mace_opt_batch.py b/mace-bench/scripts/mace_opt_batch.py index e5c53fafaad502db3f19307239576c3a32313249..6d6ab1071de0f15982c87c9a0104fa53e00fb33f 100644 --- a/mace-bench/scripts/mace_opt_batch.py +++ b/mace-bench/scripts/mace_opt_batch.py @@ -1,112 +1,112 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -import os -import argparse - -parser = argparse.ArgumentParser(description="Run batch optimization on molecular crystals.") -parser.add_argument("--target_folder", type=str, required=True, help="Target folder containing crystal files") -parser.add_argument("--num_workers", type=int, default=4, help="Number of workers to distribute the files to") -parser.add_argument("--n_gpus", type=int, default=1, help="Number of GPUs to use for the optimization") -parser.add_argument("--gpu_offset", type=int, default=0, help="Offset for GPU numbering") -parser.add_argument("--batch_size", type=int, default=4, help="Number of files to process in a single batch") -parser.add_argument("--run_baseline", type=bool, default=False, help="Run baseline optimization using LBFGS from ase.optimize") -parser.add_argument("--max_steps", type=int, default=100, help="Number of max steps to run the optimization (default: 100)") -parser.add_argument("--filter1", type=str, default="UnitCellFilter", - choices=[None, "UnitCellFilter"], - help="Type of cell filter to use in first optimization") -parser.add_argument("--filter2", type=str, default="UnitCellFilter", - choices=[None, "UnitCellFilter"], - help="Type of cell filter to use in second optimization") -parser.add_argument("--optimizer1", type=str, default="BFGS", - choices=["LBFGS", "QuasiNewton", "BFGS", "BFGSLineSearch", "BFGSFusedLS"], - help="First optimizer to use (default: BFGS)") -parser.add_argument("--optimizer2", type=str, default="BFGS", - choices=["LBFGS", "QuasiNewton", "BFGS", "BFGSLineSearch", "BFGSFusedLS"], - help="Second optimizer to use (default: LBFGS)") -parser.add_argument("--skip_second_stage", type=bool, default=False, help="Skip the second optimization stage") -parser.add_argument("--scalar_pressure", type=float, default=0.0006, - help="Scalar pressure for cell optimization (default: 0.0006)") -parser.add_argument("--compile_mode", type=str, default=None, - choices=[None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], - help="Compile mode for MACE calculator") -parser.add_argument("--profile", type=str, default="False", - help="Enable profiling. Set to 'True' for basic profiling or provide a JSON string with profiler config options for wait, warmup, active, and repeat") -parser.add_argument("--num_threads", type=int, default=16, help="Number of cpu threads per process to use while running the optimization") -parser.add_argument("--bind_cores", type=str, default=None, - help=("Specify a comma-separated list of core ranges (e.g., '0-15,16-31,...') for each worker. The number of ranges must equal --num_workers.")) -parser.add_argument("--cueq", type=bool, default=False, help="Whether to use cuEquivariance Library (default: False)") -parser.add_argument("--molecule_single", type=int, default=64, help="Number of atoms per molecule (default: 64)") -parser.add_argument("--output_path", type=str, default="./", help="Absolute path for output files") -parser.add_argument("--model", type=str, default="mace", choices=["mace", "chgnet", "sevennet"], help="Model to use for optimization") -parser.add_argument("--use_ordered_files", type=bool, default=False, - help="Whether to sort files by atomic number in descending order before optimization") -args = parser.parse_args() - -os.environ['OMP_NUM_THREADS'] = str(args.num_threads) -os.environ['MKL_NUM_THREADS'] = str(args.num_threads) - -import pathlib -import logging -from batchopt import Scheduler, ensure_directory, run_baseline, count_atoms_cif -logging.basicConfig( - level=logging.WARNING, - format='%(asctime)s - %(process)d - %(levelname)s - %(message)s', - datefmt='%H:%M:%S', - force=True -) - -if __name__ == '__main__': - target_folder = pathlib.Path(args.target_folder) - files = [str(file) for file in target_folder.glob("*.cif")] - devices = [f"cuda:{i}" for i in range(args.gpu_offset, args.gpu_offset + args.n_gpus)] - - logging.info("Starting batch optimization.") - logging.info(f"Use devices: {devices}") - logging.info(f"files: {files}") - - output_path = args.output_path - if not os.path.isabs(output_path): - output_path = os.path.abspath(output_path) - logging.info(f"Output path: {output_path}") - - for output_dir in ["cif_result_press", "cif_result_final", "json_result_press", "json_result_final", "worker_results", "log"]: - dir_path = os.path.join(output_path, output_dir) - ensure_directory(dir_path) - - import time - start_time = time.perf_counter() - - use_ordered_files = args.use_ordered_files - if use_ordered_files: - logging.info(f"Use ordered files.") - if files[0].endswith("cif"): - files = sorted(files, key=count_atoms_cif, reverse=True) - else: - logging.error(f"No support for the file type in {target_folder}.") - end_time = time.perf_counter() - logging.info(f"atomic sorting time: {end_time - start_time:.4f} seconds.") - - if args.run_baseline: - run_baseline(files, args.num_workers, devices, args.max_steps, - args.filter1, args.filter2, args.skip_second_stage, - args.scalar_pressure, args.optimizer1, args.optimizer2, - output_path=output_path) - else: - scheduler = Scheduler(files=files, num_workers=args.num_workers, devices=devices, - batch_size=args.batch_size, max_steps=args.max_steps, - filter1=args.filter1, filter2=args.filter2, - skip_second_stage=args.skip_second_stage, - scalar_pressure=args.scalar_pressure, optimizer1=args.optimizer1, optimizer2=args.optimizer2, - compile_mode=args.compile_mode, profile=args.profile, - num_threads=args.num_threads, bind_cores=args.bind_cores, - cueq=args.cueq, molecule_single=args.molecule_single, - output_path=output_path, model=args.model) - scheduler.run() - - logging.info("Batch optimization completed.") - +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import argparse + +parser = argparse.ArgumentParser(description="Run batch optimization on molecular crystals.") +parser.add_argument("--target_folder", type=str, required=True, help="Target folder containing crystal files") +parser.add_argument("--num_workers", type=int, default=4, help="Number of workers to distribute the files to") +parser.add_argument("--n_gpus", type=int, default=1, help="Number of GPUs to use for the optimization") +parser.add_argument("--gpu_offset", type=int, default=0, help="Offset for GPU numbering") +parser.add_argument("--batch_size", type=int, default=4, help="Number of files to process in a single batch") +parser.add_argument("--run_baseline", type=bool, default=False, help="Run baseline optimization using LBFGS from ase.optimize") +parser.add_argument("--max_steps", type=int, default=100, help="Number of max steps to run the optimization (default: 100)") +parser.add_argument("--filter1", type=str, default="UnitCellFilter", + choices=[None, "UnitCellFilter"], + help="Type of cell filter to use in first optimization") +parser.add_argument("--filter2", type=str, default="UnitCellFilter", + choices=[None, "UnitCellFilter"], + help="Type of cell filter to use in second optimization") +parser.add_argument("--optimizer1", type=str, default="BFGS", + choices=["LBFGS", "QuasiNewton", "BFGS", "BFGSLineSearch", "BFGSFusedLS"], + help="First optimizer to use (default: BFGS)") +parser.add_argument("--optimizer2", type=str, default="BFGS", + choices=["LBFGS", "QuasiNewton", "BFGS", "BFGSLineSearch", "BFGSFusedLS"], + help="Second optimizer to use (default: LBFGS)") +parser.add_argument("--skip_second_stage", type=bool, default=False, help="Skip the second optimization stage") +parser.add_argument("--scalar_pressure", type=float, default=0.0006, + help="Scalar pressure for cell optimization (default: 0.0006)") +parser.add_argument("--compile_mode", type=str, default=None, + choices=[None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], + help="Compile mode for MACE calculator") +parser.add_argument("--profile", type=str, default="False", + help="Enable profiling. Set to 'True' for basic profiling or provide a JSON string with profiler config options for wait, warmup, active, and repeat") +parser.add_argument("--num_threads", type=int, default=16, help="Number of cpu threads per process to use while running the optimization") +parser.add_argument("--bind_cores", type=str, default=None, + help=("Specify a comma-separated list of core ranges (e.g., '0-15,16-31,...') for each worker. The number of ranges must equal --num_workers.")) +parser.add_argument("--cueq", type=bool, default=False, help="Whether to use cuEquivariance Library (default: False)") +parser.add_argument("--molecule_single", type=int, default=64, help="Number of atoms per molecule (default: 64)") +parser.add_argument("--output_path", type=str, default="./", help="Absolute path for output files") +parser.add_argument("--model", type=str, default="mace", choices=["mace", "chgnet", "sevennet"], help="Model to use for optimization") +parser.add_argument("--use_ordered_files", type=bool, default=False, + help="Whether to sort files by atomic number in descending order before optimization") +args = parser.parse_args() + +os.environ['OMP_NUM_THREADS'] = str(args.num_threads) +os.environ['MKL_NUM_THREADS'] = str(args.num_threads) + +import pathlib +import logging +from batchopt import Scheduler, ensure_directory, run_baseline, count_atoms_cif +logging.basicConfig( + level=logging.WARNING, + format='%(asctime)s - %(process)d - %(levelname)s - %(message)s', + datefmt='%H:%M:%S', + force=True +) + +if __name__ == '__main__': + target_folder = pathlib.Path(args.target_folder) + files = [str(file) for file in target_folder.glob("*.cif")] + devices = [f"cuda:{i}" for i in range(args.gpu_offset, args.gpu_offset + args.n_gpus)] + + logging.info("Starting batch optimization.") + logging.info(f"Use devices: {devices}") + logging.info(f"files: {files}") + + output_path = args.output_path + if not os.path.isabs(output_path): + output_path = os.path.abspath(output_path) + logging.info(f"Output path: {output_path}") + + for output_dir in ["cif_result_press", "cif_result_final", "json_result_press", "json_result_final", "worker_results", "log"]: + dir_path = os.path.join(output_path, output_dir) + ensure_directory(dir_path) + + import time + start_time = time.perf_counter() + + use_ordered_files = args.use_ordered_files + if use_ordered_files: + logging.info(f"Use ordered files.") + if files[0].endswith("cif"): + files = sorted(files, key=count_atoms_cif, reverse=True) + else: + logging.error(f"No support for the file type in {target_folder}.") + end_time = time.perf_counter() + logging.info(f"atomic sorting time: {end_time - start_time:.4f} seconds.") + + if args.run_baseline: + run_baseline(files, args.num_workers, devices, args.max_steps, + args.filter1, args.filter2, args.skip_second_stage, + args.scalar_pressure, args.optimizer1, args.optimizer2, + output_path=output_path) + else: + scheduler = Scheduler(files=files, num_workers=args.num_workers, devices=devices, + batch_size=args.batch_size, max_steps=args.max_steps, + filter1=args.filter1, filter2=args.filter2, + skip_second_stage=args.skip_second_stage, + scalar_pressure=args.scalar_pressure, optimizer1=args.optimizer1, optimizer2=args.optimizer2, + compile_mode=args.compile_mode, profile=args.profile, + num_threads=args.num_threads, bind_cores=args.bind_cores, + cueq=args.cueq, molecule_single=args.molecule_single, + output_path=output_path, model=args.model) + scheduler.run() + + logging.info("Batch optimization completed.") + diff --git a/mace-bench/setup.py b/mace-bench/setup.py index b2de1fc2d62b9f4e09c1e825602ba4ed1b6060c1..f7e1680e3ef2f79d760f763c12a62ef898ee6697 100644 --- a/mace-bench/setup.py +++ b/mace-bench/setup.py @@ -1,23 +1,23 @@ -from setuptools import setup, find_packages - -setup( - name='BOMLIP-CSP', - version='0.1', - author='Chengxi Zhao, Zhaojia Ma, Dingrui Fan', - author_email='chengxi_zhao@ustc.edu.cn, zhaojia_ma@foxmail.com', - description='Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction', - url='https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP', - license='MIT', - classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.10', - 'Topic :: Scientific/Engineering :: Chemistry', - 'Topic :: Scientific/Engineering :: Physics', - ], - python_requires='>=3.10', - package_dir={'': 'src'}, - packages=find_packages('src'), +from setuptools import setup, find_packages + +setup( + name='BOMLIP-CSP', + version='0.1', + author='Chengxi Zhao, Zhaojia Ma, Dingrui Fan', + author_email='chengxi_zhao@ustc.edu.cn, zhaojia_ma@foxmail.com', + description='Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction', + url='https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP', + license='MIT', + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.10', + 'Topic :: Scientific/Engineering :: Chemistry', + 'Topic :: Scientific/Engineering :: Physics', + ], + python_requires='>=3.10', + package_dir={'': 'src'}, + packages=find_packages('src'), ) \ No newline at end of file diff --git a/mace-bench/src/BOMLIP_CSP.egg-info/PKG-INFO b/mace-bench/src/BOMLIP_CSP.egg-info/PKG-INFO index 7be7ce716b243ccfce95d61843b901b769a091ee..d9cc8e1fbbc24c439bda37edbf59fc256a02ba50 100644 --- a/mace-bench/src/BOMLIP_CSP.egg-info/PKG-INFO +++ b/mace-bench/src/BOMLIP_CSP.egg-info/PKG-INFO @@ -1,23 +1,23 @@ -Metadata-Version: 2.4 -Name: BOMLIP-CSP -Version: 0.1 -Summary: Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction -Home-page: https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP -Author: Chengxi Zhao, Zhaojia Ma, Dingrui Fan -Author-email: chengxi_zhao@ustc.edu.cn, zhaojia_ma@foxmail.com -License: MIT -Classifier: Development Status :: 3 - Alpha -Classifier: Intended Audience :: Science/Research -Classifier: License :: OSI Approved :: MIT License -Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3.10 -Classifier: Topic :: Scientific/Engineering :: Chemistry -Classifier: Topic :: Scientific/Engineering :: Physics -Requires-Python: >=3.10 -Dynamic: author -Dynamic: author-email -Dynamic: classifier -Dynamic: home-page -Dynamic: license -Dynamic: requires-python -Dynamic: summary +Metadata-Version: 2.4 +Name: BOMLIP-CSP +Version: 0.1 +Summary: Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction +Home-page: https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP +Author: Chengxi Zhao, Zhaojia Ma, Dingrui Fan +Author-email: chengxi_zhao@ustc.edu.cn, zhaojia_ma@foxmail.com +License: MIT +Classifier: Development Status :: 3 - Alpha +Classifier: Intended Audience :: Science/Research +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Topic :: Scientific/Engineering :: Chemistry +Classifier: Topic :: Scientific/Engineering :: Physics +Requires-Python: >=3.10 +Dynamic: author +Dynamic: author-email +Dynamic: classifier +Dynamic: home-page +Dynamic: license +Dynamic: requires-python +Dynamic: summary diff --git a/mace-bench/src/BOMLIP_CSP.egg-info/SOURCES.txt b/mace-bench/src/BOMLIP_CSP.egg-info/SOURCES.txt index 0335aed147f04d8087bf003024ee3adcaf4d8f09..11e6ddc16cd77adb6fd964112b650ddba37ec67b 100644 --- a/mace-bench/src/BOMLIP_CSP.egg-info/SOURCES.txt +++ b/mace-bench/src/BOMLIP_CSP.egg-info/SOURCES.txt @@ -1,20 +1,20 @@ -setup.py -src/BOMLIP_CSP.egg-info/PKG-INFO -src/BOMLIP_CSP.egg-info/SOURCES.txt -src/BOMLIP_CSP.egg-info/dependency_links.txt -src/BOMLIP_CSP.egg-info/top_level.txt -src/batchopt/__init__.py -src/batchopt/atoms_to_graphs.py -src/batchopt/baseline.py -src/batchopt/pbc_graph.py -src/batchopt/pbc_graph_legacy.py -src/batchopt/relaxengine.py -src/batchopt/utils.py -src/batchopt/extensions/__init__.py -src/batchopt/extensions/cuda_ops/__init__.py -src/batchopt/relaxation/__init__.py -src/batchopt/relaxation/ase_utils.py -src/batchopt/relaxation/optimizable.py -src/batchopt/relaxation/optimizers/__init__.py -src/batchopt/relaxation/optimizers/bfgs_torch.py +setup.py +src/BOMLIP_CSP.egg-info/PKG-INFO +src/BOMLIP_CSP.egg-info/SOURCES.txt +src/BOMLIP_CSP.egg-info/dependency_links.txt +src/BOMLIP_CSP.egg-info/top_level.txt +src/batchopt/__init__.py +src/batchopt/atoms_to_graphs.py +src/batchopt/baseline.py +src/batchopt/pbc_graph.py +src/batchopt/pbc_graph_legacy.py +src/batchopt/relaxengine.py +src/batchopt/utils.py +src/batchopt/extensions/__init__.py +src/batchopt/extensions/cuda_ops/__init__.py +src/batchopt/relaxation/__init__.py +src/batchopt/relaxation/ase_utils.py +src/batchopt/relaxation/optimizable.py +src/batchopt/relaxation/optimizers/__init__.py +src/batchopt/relaxation/optimizers/bfgs_torch.py src/batchopt/relaxation/optimizers/bfgsfusedls.py \ No newline at end of file diff --git a/mace-bench/src/BOMLIP_CSP.egg-info/dependency_links.txt b/mace-bench/src/BOMLIP_CSP.egg-info/dependency_links.txt index d3f5a12faa99758192ecc4ed3fc22c9249232e86..8b137891791fe96927ad78e64b0aad7bded08bdc 100644 --- a/mace-bench/src/BOMLIP_CSP.egg-info/dependency_links.txt +++ b/mace-bench/src/BOMLIP_CSP.egg-info/dependency_links.txt @@ -1 +1 @@ - + diff --git a/mace-bench/src/BOMLIP_CSP.egg-info/top_level.txt b/mace-bench/src/BOMLIP_CSP.egg-info/top_level.txt index 0961265a1eb9dc52f6679e2d1e79107170bc012d..fe9f2991a0dd177d98c3ea1fff66d89e772ea40d 100644 --- a/mace-bench/src/BOMLIP_CSP.egg-info/top_level.txt +++ b/mace-bench/src/BOMLIP_CSP.egg-info/top_level.txt @@ -1 +1 @@ -batchopt +batchopt diff --git a/mace-bench/src/batchopt/__init__.py b/mace-bench/src/batchopt/__init__.py index c51e983e452bca4f2c9a7fc2da802776fd309f49..5d823fab2db674a109982a519689de42d8634e4b 100644 --- a/mace-bench/src/batchopt/__init__.py +++ b/mace-bench/src/batchopt/__init__.py @@ -1,30 +1,30 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from .relaxengine import Scheduler, Worker -from .baseline import ensure_directory, run_baseline -from .utils import count_atoms_cif -from .pbc_graph import radius_graph_pbc_cuda - -try: - from . import extensions - _extensions_available = True -except ImportError as e: - import warnings - warnings.warn(f"Extensions not available: {e}. Falling back to PyTorch implementations.") - extensions = None - _extensions_available = False - -__all__ = [ - "Scheduler", - "ensure_directory", - "run_baseline", - "count_atoms_cif", - "Worker", - "extensions", - "radius_graph_pbc_cuda", +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from .relaxengine import Scheduler, Worker +from .baseline import ensure_directory, run_baseline +from .utils import count_atoms_cif +from .pbc_graph import radius_graph_pbc_cuda + +try: + from . import extensions + _extensions_available = True +except ImportError as e: + import warnings + warnings.warn(f"Extensions not available: {e}. Falling back to PyTorch implementations.") + extensions = None + _extensions_available = False + +__all__ = [ + "Scheduler", + "ensure_directory", + "run_baseline", + "count_atoms_cif", + "Worker", + "extensions", + "radius_graph_pbc_cuda", ] \ No newline at end of file diff --git a/mace-bench/src/batchopt/__pycache__/__init__.cpython-310.pyc b/mace-bench/src/batchopt/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index dbd2e9aa6a41d35935da53ac0620a105f60c740c..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/__pycache__/atoms_to_graphs.cpython-310.pyc b/mace-bench/src/batchopt/__pycache__/atoms_to_graphs.cpython-310.pyc deleted file mode 100644 index 45c13be10aa89b2fdb92573a28e27aab703058af..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/__pycache__/atoms_to_graphs.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/__pycache__/baseline.cpython-310.pyc b/mace-bench/src/batchopt/__pycache__/baseline.cpython-310.pyc deleted file mode 100644 index dedbff080993e6184f4ed5a42d01f7b58a0751cd..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/__pycache__/baseline.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/__pycache__/pbc_graph.cpython-310.pyc b/mace-bench/src/batchopt/__pycache__/pbc_graph.cpython-310.pyc deleted file mode 100644 index 6ee0f9e3b5dd7a8a2850e24d8f5baf0885a93374..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/__pycache__/pbc_graph.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/__pycache__/pbc_graph_legacy.cpython-310.pyc b/mace-bench/src/batchopt/__pycache__/pbc_graph_legacy.cpython-310.pyc deleted file mode 100644 index 4f1e5876e66842f343ccd4017b84ae5ee6cbfdf6..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/__pycache__/pbc_graph_legacy.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/__pycache__/relaxengine.cpython-310.pyc b/mace-bench/src/batchopt/__pycache__/relaxengine.cpython-310.pyc deleted file mode 100644 index 778b24d1991e2b241a7d3482846fa64926fd60d5..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/__pycache__/relaxengine.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/__pycache__/utils.cpython-310.pyc b/mace-bench/src/batchopt/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index caee9055718a6ce84c2eafe1816a83acd4ba4923..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/atoms_to_graphs.py b/mace-bench/src/batchopt/atoms_to_graphs.py index ace2a731714f5119e0801194e41a04d6309b80b9..d50fc768fffdb2121ecf057624aec3223a164b61 100644 --- a/mace-bench/src/batchopt/atoms_to_graphs.py +++ b/mace-bench/src/batchopt/atoms_to_graphs.py @@ -1,309 +1,309 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import ase.db.sqlite -import ase.io.trajectory -import numpy as np -import torch -from ase.geometry import wrap_positions -from torch_geometric.data import Data - -from batchopt.utils import collate - -if TYPE_CHECKING: - from collections.abc import Sequence - -try: - from pymatgen.io.ase import AseAtomsAdaptor -except ImportError: - AseAtomsAdaptor = None - -from tqdm import tqdm - - -class AtomsToGraphs: - """A class to help convert periodic atomic structures to graphs. - - The AtomsToGraphs class takes in periodic atomic structures in form of ASE atoms objects and converts - them into graph representations for use in PyTorch. The primary purpose of this class is to determine the - nearest neighbors within some radius around each individual atom, taking into account PBC, and set the - pair index and distance between atom pairs appropriately. Lastly, atomic properties and the graph information - are put into a PyTorch geometric data object for use with PyTorch. - - Args: - max_neigh (int): Maximum number of neighbors to consider. - radius (int or float): Cutoff radius in Angstroms to search for neighbors. - r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. - r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. - r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. - r_distances (bool): Return the distances with other properties. - Default is False, so the distances will not be returned. - r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. - r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms. - Default is True, so the fixed indices will be returned. - r_pbc (bool): Return the periodic boundary conditions with other properties. - Default is False, so the periodic boundary conditions will not be returned. - r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other - properties. Default is None, so no data will be returned as properties. - - Attributes: - max_neigh (int): Maximum number of neighbors to consider. - radius (int or float): Cutoff radius in Angstoms to search for neighbors. - r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. - r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. - r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. - r_distances (bool): Return the distances with other properties. - Default is False, so the distances will not be returned. - r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. - r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms. - Default is True, so the fixed indices will be returned. - r_pbc (bool): Return the periodic boundary conditions with other properties. - Default is False, so the periodic boundary conditions will not be returned. - r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other - properties. Default is None, so no data will be returned as properties. - """ - - def __init__( - self, - max_neigh: int = 200, - radius: int = 6, - r_energy: bool = False, - r_forces: bool = False, - r_distances: bool = False, - r_edges: bool = True, - r_fixed: bool = True, - r_pbc: bool = False, - r_stress: bool = False, - r_data_keys: Sequence[str] | None = None, - ) -> None: - self.max_neigh = max_neigh - self.radius = radius - self.r_energy = r_energy - self.r_forces = r_forces - self.r_stress = r_stress - self.r_distances = r_distances - self.r_fixed = r_fixed - self.r_edges = r_edges - self.r_pbc = r_pbc - self.r_data_keys = r_data_keys - - def _get_neighbors_pymatgen(self, atoms: ase.Atoms): - """Preforms nearest neighbor search and returns edge index, distances, - and cell offsets""" - if AseAtomsAdaptor is None: - raise RuntimeError( - "Unable to import pymatgen.io.ase.AseAtomsAdaptor. Make sure pymatgen is properly installed." - ) - - struct = AseAtomsAdaptor.get_structure(atoms) - _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list( - r=self.radius, numerical_tol=0, exclude_self=True - ) - _nonmax_idx = [] - for i in range(len(atoms)): - idx_i = (_c_index == i).nonzero()[0] - # sort neighbors by distance, remove edges larger than max_neighbors - idx_sorted = np.argsort(n_distance[idx_i])[: self.max_neigh] - _nonmax_idx.append(idx_i[idx_sorted]) - _nonmax_idx = np.concatenate(_nonmax_idx) - - _c_index = _c_index[_nonmax_idx] - _n_index = _n_index[_nonmax_idx] - n_distance = n_distance[_nonmax_idx] - _offsets = _offsets[_nonmax_idx] - - return _c_index, _n_index, n_distance, _offsets - - def _reshape_features(self, c_index, n_index, n_distance, offsets): - """Stack center and neighbor index and reshapes distances, - takes in np.arrays and returns torch tensors""" - edge_index = torch.LongTensor(np.vstack((n_index, c_index))) - edge_distances = torch.FloatTensor(n_distance) - cell_offsets = torch.LongTensor(offsets) - - # remove distances smaller than a tolerance ~ 0. The small tolerance is - # needed to correct for pymatgen's neighbor_list returning self atoms - # in a few edge cases. - nonzero = torch.where(edge_distances >= 1e-8)[0] - edge_index = edge_index[:, nonzero] - edge_distances = edge_distances[nonzero] - cell_offsets = cell_offsets[nonzero] - - return edge_index, edge_distances, cell_offsets - - def get_edge_distance_vec( - self, - pos, - edge_index, - cell, - cell_offsets, - ): - row, col = edge_index - distance_vectors = pos[row] - pos[col] - - # correct for pbc - cell = torch.repeat_interleave(cell, edge_index.shape[1], dim=0) - offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) - distance_vectors += offsets - - return distance_vectors - - def convert(self, atoms: ase.Atoms, sid=None): - """Convert a single atomic structure to a graph. - - Args: - atoms (ase.atoms.Atoms): An ASE atoms object. - - sid (uniquely identifying object): An identifier that can be used to track the structure in downstream - tasks. Common sids used in OCP datasets include unique strings or integers. - - Returns: - data (torch_geometric.data.Data): A torch geometic data object with positions, atomic_numbers, tags, - and optionally, energy, forces, distances, edges, and periodic boundary conditions. - Optional properties can included by setting r_property=True when constructing the class. - """ - - # set the atomic numbers, positions, and cell - positions = np.array(atoms.get_positions(), copy=True) - pbc = np.array(atoms.pbc, copy=True) - cell = np.array(atoms.get_cell(complete=True), copy=True) - # TODO: change this back &&& ^^^ - # positions = wrap_positions(positions, cell, pbc=pbc, eps=0) - - atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.uint8) - positions = torch.from_numpy(positions).float() - cell = torch.from_numpy(cell).view(1, 3, 3).float() - natoms = positions.shape[0] - - # initialized to torch.zeros(natoms) if tags missing. - # https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags - tags = torch.tensor(atoms.get_tags(), dtype=torch.int) - - # put the minimum data in torch geometric data object - data = Data( - cell=cell, - pos=positions, - atomic_numbers=atomic_numbers, - natoms=natoms, - tags=tags, - ) - - # Optionally add a systemid (sid) to the object - if sid is not None: - data.sid = sid - - # optionally include other properties - if self.r_edges: - # run internal functions to get padded indices and distances - atoms_copy = atoms.copy() - atoms_copy.set_positions(positions) - split_idx_dist = self._get_neighbors_pymatgen(atoms_copy) - edge_index, edge_distances, cell_offsets = self._reshape_features( - *split_idx_dist - ) - - data.edge_index = edge_index - data.cell_offsets = cell_offsets - data.edge_distance_vec = self.get_edge_distance_vec( - positions, edge_index, cell, cell_offsets - ) - - del atoms_copy - if self.r_energy: - energy = atoms.get_potential_energy(apply_constraint=False) - data.energy = energy - if self.r_forces: - forces = torch.tensor( - atoms.get_forces(apply_constraint=False), dtype=torch.float32 - ) - data.forces = forces - if self.r_stress: - stress = torch.tensor( - atoms.get_stress(apply_constraint=False, voigt=False), - dtype=torch.float32, - ) - data.stress = stress - if self.r_distances and self.r_edges: - data.distances = edge_distances - if self.r_fixed: - fixed_idx = torch.zeros(natoms, dtype=torch.int) - if hasattr(atoms, "constraints"): - from ase.constraints import FixAtoms - - for constraint in atoms.constraints: - if isinstance(constraint, FixAtoms): - fixed_idx[constraint.index] = 1 - data.fixed = fixed_idx - if self.r_pbc: - data.pbc = torch.tensor(atoms.pbc, dtype=torch.bool) - if self.r_data_keys is not None: - for data_key in self.r_data_keys: - data[data_key] = ( - atoms.info[data_key] - if isinstance(atoms.info[data_key], (int, float, str)) - else torch.tensor(atoms.info[data_key]) - ) - - return data - - def convert_all( - self, - atoms_collection, - processed_file_path: str | None = None, - collate_and_save=False, - disable_tqdm=False, - ): - """Convert all atoms objects in a list or in an ase.db to graphs. - - Args: - atoms_collection (list of ase.atoms.Atoms or ase.db.sqlite.SQLite3Database): - Either a list of ASE atoms objects or an ASE database. - processed_file_path (str): - A string of the path to where the processed file will be written. Default is None. - collate_and_save (bool): A boolean to collate and save or not. Default is False, so will not write a file. - - Returns: - data_list (list of torch_geometric.data.Data): - A list of torch geometric data objects containing molecular graph info and properties. - """ - - # list for all data - data_list = [] - if isinstance(atoms_collection, list): - atoms_iter = atoms_collection - elif isinstance(atoms_collection, ase.db.sqlite.SQLite3Database): - atoms_iter = atoms_collection.select() - elif isinstance( - atoms_collection, - (ase.io.trajectory.SlicedTrajectory, ase.io.trajectory.TrajectoryReader), - ): - atoms_iter = atoms_collection - else: - raise NotImplementedError - - for atoms in tqdm( - atoms_iter, - desc="converting ASE atoms collection to graphs", - total=len(atoms_collection), - unit=" systems", - disable=disable_tqdm, - ): - # check if atoms is an ASE Atoms object this for the ase.db case - data = self.convert( - atoms if isinstance(atoms, ase.atoms.Atoms) else atoms.toatoms() - ) - data_list.append(data) - - if collate_and_save: - data, slices = collate(data_list) - torch.save((data, slices), processed_file_path) - - return data_list +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import ase.db.sqlite +import ase.io.trajectory +import numpy as np +import torch +from ase.geometry import wrap_positions +from torch_geometric.data import Data + +from batchopt.utils import collate + +if TYPE_CHECKING: + from collections.abc import Sequence + +try: + from pymatgen.io.ase import AseAtomsAdaptor +except ImportError: + AseAtomsAdaptor = None + +from tqdm import tqdm + + +class AtomsToGraphs: + """A class to help convert periodic atomic structures to graphs. + + The AtomsToGraphs class takes in periodic atomic structures in form of ASE atoms objects and converts + them into graph representations for use in PyTorch. The primary purpose of this class is to determine the + nearest neighbors within some radius around each individual atom, taking into account PBC, and set the + pair index and distance between atom pairs appropriately. Lastly, atomic properties and the graph information + are put into a PyTorch geometric data object for use with PyTorch. + + Args: + max_neigh (int): Maximum number of neighbors to consider. + radius (int or float): Cutoff radius in Angstroms to search for neighbors. + r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. + r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. + r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. + r_distances (bool): Return the distances with other properties. + Default is False, so the distances will not be returned. + r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. + r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms. + Default is True, so the fixed indices will be returned. + r_pbc (bool): Return the periodic boundary conditions with other properties. + Default is False, so the periodic boundary conditions will not be returned. + r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other + properties. Default is None, so no data will be returned as properties. + + Attributes: + max_neigh (int): Maximum number of neighbors to consider. + radius (int or float): Cutoff radius in Angstoms to search for neighbors. + r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. + r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. + r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. + r_distances (bool): Return the distances with other properties. + Default is False, so the distances will not be returned. + r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. + r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms. + Default is True, so the fixed indices will be returned. + r_pbc (bool): Return the periodic boundary conditions with other properties. + Default is False, so the periodic boundary conditions will not be returned. + r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other + properties. Default is None, so no data will be returned as properties. + """ + + def __init__( + self, + max_neigh: int = 200, + radius: int = 6, + r_energy: bool = False, + r_forces: bool = False, + r_distances: bool = False, + r_edges: bool = True, + r_fixed: bool = True, + r_pbc: bool = False, + r_stress: bool = False, + r_data_keys: Sequence[str] | None = None, + ) -> None: + self.max_neigh = max_neigh + self.radius = radius + self.r_energy = r_energy + self.r_forces = r_forces + self.r_stress = r_stress + self.r_distances = r_distances + self.r_fixed = r_fixed + self.r_edges = r_edges + self.r_pbc = r_pbc + self.r_data_keys = r_data_keys + + def _get_neighbors_pymatgen(self, atoms: ase.Atoms): + """Preforms nearest neighbor search and returns edge index, distances, + and cell offsets""" + if AseAtomsAdaptor is None: + raise RuntimeError( + "Unable to import pymatgen.io.ase.AseAtomsAdaptor. Make sure pymatgen is properly installed." + ) + + struct = AseAtomsAdaptor.get_structure(atoms) + _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list( + r=self.radius, numerical_tol=0, exclude_self=True + ) + _nonmax_idx = [] + for i in range(len(atoms)): + idx_i = (_c_index == i).nonzero()[0] + # sort neighbors by distance, remove edges larger than max_neighbors + idx_sorted = np.argsort(n_distance[idx_i])[: self.max_neigh] + _nonmax_idx.append(idx_i[idx_sorted]) + _nonmax_idx = np.concatenate(_nonmax_idx) + + _c_index = _c_index[_nonmax_idx] + _n_index = _n_index[_nonmax_idx] + n_distance = n_distance[_nonmax_idx] + _offsets = _offsets[_nonmax_idx] + + return _c_index, _n_index, n_distance, _offsets + + def _reshape_features(self, c_index, n_index, n_distance, offsets): + """Stack center and neighbor index and reshapes distances, + takes in np.arrays and returns torch tensors""" + edge_index = torch.LongTensor(np.vstack((n_index, c_index))) + edge_distances = torch.FloatTensor(n_distance) + cell_offsets = torch.LongTensor(offsets) + + # remove distances smaller than a tolerance ~ 0. The small tolerance is + # needed to correct for pymatgen's neighbor_list returning self atoms + # in a few edge cases. + nonzero = torch.where(edge_distances >= 1e-8)[0] + edge_index = edge_index[:, nonzero] + edge_distances = edge_distances[nonzero] + cell_offsets = cell_offsets[nonzero] + + return edge_index, edge_distances, cell_offsets + + def get_edge_distance_vec( + self, + pos, + edge_index, + cell, + cell_offsets, + ): + row, col = edge_index + distance_vectors = pos[row] - pos[col] + + # correct for pbc + cell = torch.repeat_interleave(cell, edge_index.shape[1], dim=0) + offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) + distance_vectors += offsets + + return distance_vectors + + def convert(self, atoms: ase.Atoms, sid=None): + """Convert a single atomic structure to a graph. + + Args: + atoms (ase.atoms.Atoms): An ASE atoms object. + + sid (uniquely identifying object): An identifier that can be used to track the structure in downstream + tasks. Common sids used in OCP datasets include unique strings or integers. + + Returns: + data (torch_geometric.data.Data): A torch geometic data object with positions, atomic_numbers, tags, + and optionally, energy, forces, distances, edges, and periodic boundary conditions. + Optional properties can included by setting r_property=True when constructing the class. + """ + + # set the atomic numbers, positions, and cell + positions = np.array(atoms.get_positions(), copy=True) + pbc = np.array(atoms.pbc, copy=True) + cell = np.array(atoms.get_cell(complete=True), copy=True) + # TODO: change this back &&& ^^^ + # positions = wrap_positions(positions, cell, pbc=pbc, eps=0) + + atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.uint8) + positions = torch.from_numpy(positions).float() + cell = torch.from_numpy(cell).view(1, 3, 3).float() + natoms = positions.shape[0] + + # initialized to torch.zeros(natoms) if tags missing. + # https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags + tags = torch.tensor(atoms.get_tags(), dtype=torch.int) + + # put the minimum data in torch geometric data object + data = Data( + cell=cell, + pos=positions, + atomic_numbers=atomic_numbers, + natoms=natoms, + tags=tags, + ) + + # Optionally add a systemid (sid) to the object + if sid is not None: + data.sid = sid + + # optionally include other properties + if self.r_edges: + # run internal functions to get padded indices and distances + atoms_copy = atoms.copy() + atoms_copy.set_positions(positions) + split_idx_dist = self._get_neighbors_pymatgen(atoms_copy) + edge_index, edge_distances, cell_offsets = self._reshape_features( + *split_idx_dist + ) + + data.edge_index = edge_index + data.cell_offsets = cell_offsets + data.edge_distance_vec = self.get_edge_distance_vec( + positions, edge_index, cell, cell_offsets + ) + + del atoms_copy + if self.r_energy: + energy = atoms.get_potential_energy(apply_constraint=False) + data.energy = energy + if self.r_forces: + forces = torch.tensor( + atoms.get_forces(apply_constraint=False), dtype=torch.float32 + ) + data.forces = forces + if self.r_stress: + stress = torch.tensor( + atoms.get_stress(apply_constraint=False, voigt=False), + dtype=torch.float32, + ) + data.stress = stress + if self.r_distances and self.r_edges: + data.distances = edge_distances + if self.r_fixed: + fixed_idx = torch.zeros(natoms, dtype=torch.int) + if hasattr(atoms, "constraints"): + from ase.constraints import FixAtoms + + for constraint in atoms.constraints: + if isinstance(constraint, FixAtoms): + fixed_idx[constraint.index] = 1 + data.fixed = fixed_idx + if self.r_pbc: + data.pbc = torch.tensor(atoms.pbc, dtype=torch.bool) + if self.r_data_keys is not None: + for data_key in self.r_data_keys: + data[data_key] = ( + atoms.info[data_key] + if isinstance(atoms.info[data_key], (int, float, str)) + else torch.tensor(atoms.info[data_key]) + ) + + return data + + def convert_all( + self, + atoms_collection, + processed_file_path: str | None = None, + collate_and_save=False, + disable_tqdm=False, + ): + """Convert all atoms objects in a list or in an ase.db to graphs. + + Args: + atoms_collection (list of ase.atoms.Atoms or ase.db.sqlite.SQLite3Database): + Either a list of ASE atoms objects or an ASE database. + processed_file_path (str): + A string of the path to where the processed file will be written. Default is None. + collate_and_save (bool): A boolean to collate and save or not. Default is False, so will not write a file. + + Returns: + data_list (list of torch_geometric.data.Data): + A list of torch geometric data objects containing molecular graph info and properties. + """ + + # list for all data + data_list = [] + if isinstance(atoms_collection, list): + atoms_iter = atoms_collection + elif isinstance(atoms_collection, ase.db.sqlite.SQLite3Database): + atoms_iter = atoms_collection.select() + elif isinstance( + atoms_collection, + (ase.io.trajectory.SlicedTrajectory, ase.io.trajectory.TrajectoryReader), + ): + atoms_iter = atoms_collection + else: + raise NotImplementedError + + for atoms in tqdm( + atoms_iter, + desc="converting ASE atoms collection to graphs", + total=len(atoms_collection), + unit=" systems", + disable=disable_tqdm, + ): + # check if atoms is an ASE Atoms object this for the ase.db case + data = self.convert( + atoms if isinstance(atoms, ase.atoms.Atoms) else atoms.toatoms() + ) + data_list.append(data) + + if collate_and_save: + data, slices = collate(data_list) + torch.save((data, slices), processed_file_path) + + return data_list diff --git a/mace-bench/src/batchopt/baseline.py b/mace-bench/src/batchopt/baseline.py index 552660b406d5753eaf07de1c58b7df30f0d9d8ba..d23595dfd4cb6103ced8b9f9a6cea8ed6c80673c 100644 --- a/mace-bench/src/batchopt/baseline.py +++ b/mace-bench/src/batchopt/baseline.py @@ -1,171 +1,171 @@ -""" -Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan} - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from ase.io import read -import logging -from joblib import Parallel, delayed -from ase.optimize import LBFGS as ASE_LBFGS -from ase.optimize import QuasiNewton as ASE_QuasiNewton -from ase.optimize import BFGS as ASE_BFGS -import time -import csv -import os -try: - from mace.calculators import mace_off -except ImportError: - logging.warning("Failed to import MACE modules") - -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - - -def ensure_directory(directory): - """Create directory if it doesn't exist.""" - if not os.path.exists(directory): - os.makedirs(directory) - logging.info(f"Created directory: {directory}") - -def baseline_task(file, device, max_steps, filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006, first_optimizer="LBFGS", second_optimizer="LBFGS"): - """ - Runs the baseline optimization using LBFGS from ase.optimize. - """ - os.environ["CUDA_VISIBLE_DEVICES"] = device.split(":")[-1] - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - logging.info(f"Starting baseline optimization for file {file} on device {device}.") - - - start_time = time.perf_counter() - - crystal = read(file) - # calc = mace_off(model="small", device=device) - calc = mace_off(model="small", device="cuda") - crystal.calc = calc - - first_optimizer_class ={ - "LBFGS": ASE_LBFGS, - "QuasiNewton": ASE_QuasiNewton, - "BFGS": ASE_BFGS - }.get(first_optimizer, ASE_LBFGS) - - # First optimization stage - if filter1 == "UnitCellFilter": - from ase.filters import UnitCellFilter - atoms_with_filter = UnitCellFilter(crystal, scalar_pressure=scalar_pressure) - first_optimizer_instance = first_optimizer_class(atoms_with_filter) - elif filter1 == "FrechetCellFilter": - from ase.filters import FrechetCellFilter - atoms_with_filter = FrechetCellFilter(crystal, scalar_pressure=scalar_pressure) - first_optimizer_instance = first_optimizer_class(atoms_with_filter) - else: - first_optimizer_instance = first_optimizer_class(crystal) - - start_time1 = time.perf_counter() - first_optimizer_instance.run(fmax=0.01, steps=max_steps) - end_time1 = time.perf_counter() - - # Save intermediate result - output_dir_press = "./cif_result_press" - output_file_press = os.path.join(output_dir_press, os.path.basename(file).replace(".cif", "_press.cif")) - crystal.write(output_file_press) - - elapsed_time1 = end_time1 - start_time1 - steps1 = first_optimizer_instance.nsteps - - if skip_second_stage: - - ret_result = { - "file": file, - "stage1_time": elapsed_time1, - "stage1_steps": steps1, - "stage2_time": 0.0, - "stage2_steps": 0, - "total_time": elapsed_time1, - "total_steps": steps1 - } - else: - # Second optimization stage - crystal = read(output_file_press) - crystal.calc = calc - - second_optimizer_class = { - "LBFGS": ASE_LBFGS, - "QuasiNewton": ASE_QuasiNewton, - "BFGS": ASE_BFGS - }.get(second_optimizer, ASE_LBFGS) - - if filter2 == "UnitCellFilter": - from ase.filters import UnitCellFilter - atoms_with_filter2 = UnitCellFilter(crystal) - second_optimizer_instance = second_optimizer_class(atoms_with_filter2) - elif filter2 == "FrechetCellFilter": - from ase.filters import FrechetCellFilter - atoms_with_filter2 = FrechetCellFilter(crystal) - second_optimizer_instance = second_optimizer_class(atoms_with_filter2) - else: - second_optimizer_instance = second_optimizer_class(crystal) - - start_time2 = time.perf_counter() - second_optimizer_instance.run(fmax=0.01, steps=max_steps) - end_time2 = time.perf_counter() - - # Save final result - output_dir_final = "./cif_result_final" - output_file_final = os.path.join(output_dir_final, os.path.basename(file).replace(".cif", "_opt.cif")) - crystal.write(output_file_final) - - # Collect metrics - elapsed_time2 = end_time2 - start_time2 - total_time = elapsed_time1 + elapsed_time2 - steps2 = second_optimizer_instance.nsteps - - ret_result = { - "file": file, - "stage1_time": elapsed_time1, - "stage1_steps": steps1, - "stage2_time": elapsed_time2, - "stage2_steps": steps2, - "total_time": total_time, - "total_steps": steps1 + steps2 - } - - logging.info(f"Baseline optimization completed for file {file}.") - return ret_result - -def run_baseline(files, num_workers, devices, max_steps, - filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006, - optimizer1=None, optimizer2=None): - """ - Runs the baseline optimization using LBFGS from ase.optimize. - """ - logging.info(f"Starting baseline optimization with {num_workers} workers.") - - start_time = time.perf_counter() - results = Parallel(n_jobs=num_workers)( - delayed(baseline_task)(file, devices[i % len(devices)], max_steps, filter1, filter2, skip_second_stage, scalar_pressure, optimizer1, optimizer2) - for i, file in enumerate(files) - ) - end_time = time.perf_counter() - - csv_file = "results_baseline.csv" - with open(csv_file, mode='w', newline='') as file: - writer = csv.DictWriter(file, fieldnames=["file", "stage1_time", "stage1_steps", "stage2_time", "stage2_steps", "total_time", "total_steps"]) - writer.writeheader() - for result in results: - writer.writerow(result) - - logging.info(f"Baseline optimization completed in {end_time - start_time:.2f} seconds.") - final_elapsed_time = end_time - start_time - summary_csv_file = "summary_baseline.csv" - with open(summary_csv_file, mode='w', newline='') as file: - writer = csv.DictWriter(file, fieldnames=["elapsed_time", "num_workers", "batch_size"]) - writer.writeheader() - writer.writerow({ - "elapsed_time": final_elapsed_time, - "num_workers": num_workers, - "batch_size": 1 - }) - +""" +Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan} + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from ase.io import read +import logging +from joblib import Parallel, delayed +from ase.optimize import LBFGS as ASE_LBFGS +from ase.optimize import QuasiNewton as ASE_QuasiNewton +from ase.optimize import BFGS as ASE_BFGS +import time +import csv +import os +try: + from mace.calculators import mace_off +except ImportError: + logging.warning("Failed to import MACE modules") + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + +def ensure_directory(directory): + """Create directory if it doesn't exist.""" + if not os.path.exists(directory): + os.makedirs(directory) + logging.info(f"Created directory: {directory}") + +def baseline_task(file, device, max_steps, filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006, first_optimizer="LBFGS", second_optimizer="LBFGS"): + """ + Runs the baseline optimization using LBFGS from ase.optimize. + """ + os.environ["CUDA_VISIBLE_DEVICES"] = device.split(":")[-1] + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + logging.info(f"Starting baseline optimization for file {file} on device {device}.") + + + start_time = time.perf_counter() + + crystal = read(file) + # calc = mace_off(model="small", device=device) + calc = mace_off(model="small", device="cuda") + crystal.calc = calc + + first_optimizer_class ={ + "LBFGS": ASE_LBFGS, + "QuasiNewton": ASE_QuasiNewton, + "BFGS": ASE_BFGS + }.get(first_optimizer, ASE_LBFGS) + + # First optimization stage + if filter1 == "UnitCellFilter": + from ase.filters import UnitCellFilter + atoms_with_filter = UnitCellFilter(crystal, scalar_pressure=scalar_pressure) + first_optimizer_instance = first_optimizer_class(atoms_with_filter) + elif filter1 == "FrechetCellFilter": + from ase.filters import FrechetCellFilter + atoms_with_filter = FrechetCellFilter(crystal, scalar_pressure=scalar_pressure) + first_optimizer_instance = first_optimizer_class(atoms_with_filter) + else: + first_optimizer_instance = first_optimizer_class(crystal) + + start_time1 = time.perf_counter() + first_optimizer_instance.run(fmax=0.01, steps=max_steps) + end_time1 = time.perf_counter() + + # Save intermediate result + output_dir_press = "./cif_result_press" + output_file_press = os.path.join(output_dir_press, os.path.basename(file).replace(".cif", "_press.cif")) + crystal.write(output_file_press) + + elapsed_time1 = end_time1 - start_time1 + steps1 = first_optimizer_instance.nsteps + + if skip_second_stage: + + ret_result = { + "file": file, + "stage1_time": elapsed_time1, + "stage1_steps": steps1, + "stage2_time": 0.0, + "stage2_steps": 0, + "total_time": elapsed_time1, + "total_steps": steps1 + } + else: + # Second optimization stage + crystal = read(output_file_press) + crystal.calc = calc + + second_optimizer_class = { + "LBFGS": ASE_LBFGS, + "QuasiNewton": ASE_QuasiNewton, + "BFGS": ASE_BFGS + }.get(second_optimizer, ASE_LBFGS) + + if filter2 == "UnitCellFilter": + from ase.filters import UnitCellFilter + atoms_with_filter2 = UnitCellFilter(crystal) + second_optimizer_instance = second_optimizer_class(atoms_with_filter2) + elif filter2 == "FrechetCellFilter": + from ase.filters import FrechetCellFilter + atoms_with_filter2 = FrechetCellFilter(crystal) + second_optimizer_instance = second_optimizer_class(atoms_with_filter2) + else: + second_optimizer_instance = second_optimizer_class(crystal) + + start_time2 = time.perf_counter() + second_optimizer_instance.run(fmax=0.01, steps=max_steps) + end_time2 = time.perf_counter() + + # Save final result + output_dir_final = "./cif_result_final" + output_file_final = os.path.join(output_dir_final, os.path.basename(file).replace(".cif", "_opt.cif")) + crystal.write(output_file_final) + + # Collect metrics + elapsed_time2 = end_time2 - start_time2 + total_time = elapsed_time1 + elapsed_time2 + steps2 = second_optimizer_instance.nsteps + + ret_result = { + "file": file, + "stage1_time": elapsed_time1, + "stage1_steps": steps1, + "stage2_time": elapsed_time2, + "stage2_steps": steps2, + "total_time": total_time, + "total_steps": steps1 + steps2 + } + + logging.info(f"Baseline optimization completed for file {file}.") + return ret_result + +def run_baseline(files, num_workers, devices, max_steps, + filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006, + optimizer1=None, optimizer2=None): + """ + Runs the baseline optimization using LBFGS from ase.optimize. + """ + logging.info(f"Starting baseline optimization with {num_workers} workers.") + + start_time = time.perf_counter() + results = Parallel(n_jobs=num_workers)( + delayed(baseline_task)(file, devices[i % len(devices)], max_steps, filter1, filter2, skip_second_stage, scalar_pressure, optimizer1, optimizer2) + for i, file in enumerate(files) + ) + end_time = time.perf_counter() + + csv_file = "results_baseline.csv" + with open(csv_file, mode='w', newline='') as file: + writer = csv.DictWriter(file, fieldnames=["file", "stage1_time", "stage1_steps", "stage2_time", "stage2_steps", "total_time", "total_steps"]) + writer.writeheader() + for result in results: + writer.writerow(result) + + logging.info(f"Baseline optimization completed in {end_time - start_time:.2f} seconds.") + final_elapsed_time = end_time - start_time + summary_csv_file = "summary_baseline.csv" + with open(summary_csv_file, mode='w', newline='') as file: + writer = csv.DictWriter(file, fieldnames=["elapsed_time", "num_workers", "batch_size"]) + writer.writeheader() + writer.writerow({ + "elapsed_time": final_elapsed_time, + "num_workers": num_workers, + "batch_size": 1 + }) + logging.info(f"Summary results written to {summary_csv_file}.") \ No newline at end of file diff --git a/mace-bench/src/batchopt/extensions/__init__.py b/mace-bench/src/batchopt/extensions/__init__.py index 70065418b41114ba89d51116fe1ad8531970e8b6..2e0a0f218186042a15fdccf7211b1155dcd74c16 100644 --- a/mace-bench/src/batchopt/extensions/__init__.py +++ b/mace-bench/src/batchopt/extensions/__init__.py @@ -1,12 +1,12 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. - -BatchOpt Extensions - C++ and CUDA implementations for performance-critical operations. - -This module provides optimized implementations of common operations using -torch.utils.cpp_extension for JIT compilation. -""" - +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +BatchOpt Extensions - C++ and CUDA implementations for performance-critical operations. + +This module provides optimized implementations of common operations using +torch.utils.cpp_extension for JIT compilation. +""" + diff --git a/mace-bench/src/batchopt/extensions/__pycache__/__init__.cpython-310.pyc b/mace-bench/src/batchopt/extensions/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 05baaef36837c42cefd96f5e646b6c6ffacbd2f4..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/extensions/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/extensions/cuda_ops/__init__.py b/mace-bench/src/batchopt/extensions/cuda_ops/__init__.py index 1e6bf7a569e5027c289693821b9370690b642bec..5ce1bd1ce4a9ffebe79277669a91032637a77b11 100644 --- a/mace-bench/src/batchopt/extensions/cuda_ops/__init__.py +++ b/mace-bench/src/batchopt/extensions/cuda_ops/__init__.py @@ -1,91 +1,91 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. - -CUDA Extension wrapper for vector addition and PBC graph operations. -""" -import torch -from torch.utils.cpp_extension import load -import os - -def load_cuda_extension(): - """Load the CUDA extension for vector addition.""" - # Check if CUDA is available - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available. Cannot load CUDA extension.") - - # Get the directory of this file - current_dir = os.path.dirname(os.path.abspath(__file__)) - - # Path to the CUDA source file - cuda_file = os.path.join(current_dir, "vector_add.cu") - - # Load the extension - return load( - name="vector_add_cuda", - sources=[cuda_file], - verbose=True, - extra_cflags=['-O3'], - extra_cuda_cflags=['-O3', '--use_fast_math'], - ) - -def load_pbc_graph_cuda_extension(): - """Load the CUDA extension for PBC graph operations.""" - # Check if CUDA is available - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available. Cannot load CUDA extension.") - - # Get the directory of this file - current_dir = os.path.dirname(os.path.abspath(__file__)) - - # Path to the CUDA source file - cuda_file = os.path.join(current_dir, "pbc_graph.cu") - - # Load the extension - return load( - name="pbc_graph_cuda", - sources=[cuda_file], - verbose=True, - extra_cflags=['-O3'], - extra_cuda_cflags=['-O3', '--use_fast_math'], - ) - -# Global variable to store loaded extension -_cuda_extension = None -_pbc_graph_cuda_extension = None - -def get_cuda_extension(): - """Get or load the CUDA extension.""" - global _cuda_extension - if _cuda_extension is None: - _cuda_extension = load_cuda_extension() - return _cuda_extension - -def get_pbc_graph_cuda_extension(): - """Get or load the PBC graph CUDA extension.""" - global _pbc_graph_cuda_extension - if _pbc_graph_cuda_extension is None: - _pbc_graph_cuda_extension = load_pbc_graph_cuda_extension() - return _pbc_graph_cuda_extension - -def vector_add_cuda(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """ - Perform vector addition using CUDA implementation. - - Args: - a: First input tensor (must be on CUDA device) - b: Second input tensor (must be on CUDA device) - - Returns: - Result tensor of element-wise addition - """ - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available.") - - if not (a.is_cuda and b.is_cuda): - raise ValueError("CUDA implementation requires CUDA tensors. Use .cuda() to move tensors to GPU.") - - extension = get_cuda_extension() - return extension.vector_add(a.float(), b.float()) +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +CUDA Extension wrapper for vector addition and PBC graph operations. +""" +import torch +from torch.utils.cpp_extension import load +import os + +def load_cuda_extension(): + """Load the CUDA extension for vector addition.""" + # Check if CUDA is available + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Cannot load CUDA extension.") + + # Get the directory of this file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Path to the CUDA source file + cuda_file = os.path.join(current_dir, "vector_add.cu") + + # Load the extension + return load( + name="vector_add_cuda", + sources=[cuda_file], + verbose=True, + extra_cflags=['-O3'], + extra_cuda_cflags=['-O3', '--use_fast_math'], + ) + +def load_pbc_graph_cuda_extension(): + """Load the CUDA extension for PBC graph operations.""" + # Check if CUDA is available + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Cannot load CUDA extension.") + + # Get the directory of this file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Path to the CUDA source file + cuda_file = os.path.join(current_dir, "pbc_graph.cu") + + # Load the extension + return load( + name="pbc_graph_cuda", + sources=[cuda_file], + verbose=True, + extra_cflags=['-O3'], + extra_cuda_cflags=['-O3', '--use_fast_math'], + ) + +# Global variable to store loaded extension +_cuda_extension = None +_pbc_graph_cuda_extension = None + +def get_cuda_extension(): + """Get or load the CUDA extension.""" + global _cuda_extension + if _cuda_extension is None: + _cuda_extension = load_cuda_extension() + return _cuda_extension + +def get_pbc_graph_cuda_extension(): + """Get or load the PBC graph CUDA extension.""" + global _pbc_graph_cuda_extension + if _pbc_graph_cuda_extension is None: + _pbc_graph_cuda_extension = load_pbc_graph_cuda_extension() + return _pbc_graph_cuda_extension + +def vector_add_cuda(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Perform vector addition using CUDA implementation. + + Args: + a: First input tensor (must be on CUDA device) + b: Second input tensor (must be on CUDA device) + + Returns: + Result tensor of element-wise addition + """ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available.") + + if not (a.is_cuda and b.is_cuda): + raise ValueError("CUDA implementation requires CUDA tensors. Use .cuda() to move tensors to GPU.") + + extension = get_cuda_extension() + return extension.vector_add(a.float(), b.float()) diff --git a/mace-bench/src/batchopt/extensions/cuda_ops/__pycache__/__init__.cpython-310.pyc b/mace-bench/src/batchopt/extensions/cuda_ops/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 358ed12c90314b7bbb2b36f457f9a66592fece28..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/extensions/cuda_ops/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/pbc_graph.py b/mace-bench/src/batchopt/pbc_graph.py index d9cb744e8692e2268686037e69bdb7ea7298d82a..998b0477a7a2a125aaf19f0afbaf6a9eb80677a6 100644 --- a/mace-bench/src/batchopt/pbc_graph.py +++ b/mace-bench/src/batchopt/pbc_graph.py @@ -1,158 +1,158 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. - -CUDA-accelerated PBC graph operations for atomic systems. -""" -import torch -from typing import Optional, List -from .pbc_graph_legacy import get_max_neighbors_mask -from .extensions.cuda_ops import get_pbc_graph_cuda_extension - -def radius_graph_pbc_cuda( - data, - radius, - max_num_neighbors_threshold, - enforce_max_neighbors_strictly: bool = False, - pbc=None, - dtype=torch.float64, -): - """ - Memory-efficient CUDA-accelerated version of radius_graph_pbc. - - This implementation follows the memory-efficient approach with triple loops - but accelerates the distance computation using CUDA kernels. - """ - if pbc is None: - pbc = [True, True, True] - - device = data.pos.device - batch_size = len(data.natoms) - - # Handle PBC settings - if hasattr(data, "pbc"): - data.pbc = torch.atleast_2d(data.pbc) - for i in range(3): - if not torch.any(data.pbc[:, i]).item(): - pbc[i] = False - elif torch.all(data.pbc[:, i]).item(): - pbc[i] = True - else: - raise RuntimeError( - "Different structures in the batch have different PBC configurations." - ) - - # position of the atoms - atom_pos = data.pos - - # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch - num_atoms_per_image = data.natoms - num_atoms_per_image_sqr = (num_atoms_per_image**2).long() - - # index offset between images - index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image - - index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) - num_atoms_per_image_expand = torch.repeat_interleave( - num_atoms_per_image, num_atoms_per_image_sqr - ) - - # Compute atom pair indices - num_atom_pairs = torch.sum(num_atoms_per_image_sqr) - index_sqr_offset = ( - torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr - ) - index_sqr_offset = torch.repeat_interleave( - index_sqr_offset, num_atoms_per_image_sqr - ) - atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset - - # Compute the indices for the pairs of atoms (using division and mod) - index1 = ( - torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") - ) + index_offset_expand - index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand - # Get the positions for each atom - pos1 = torch.index_select(atom_pos, 0, index1) - pos2 = torch.index_select(atom_pos, 0, index2) - - # Calculate required number of unit cells in each direction for PBC - cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) - cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) - - if pbc[0]: - inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) - rep_a1 = torch.ceil(radius * inv_min_dist_a1) - else: - rep_a1 = data.cell.new_zeros(1) - - if pbc[1]: - cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) - inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) - rep_a2 = torch.ceil(radius * inv_min_dist_a2) - else: - rep_a2 = data.cell.new_zeros(1) - - if pbc[2]: - cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) - inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) - rep_a3 = torch.ceil(radius * inv_min_dist_a3) - else: - rep_a3 = data.cell.new_zeros(1) - - # Take the max over all images for uniformity - max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())] - - # Pre-transpose data_cell for efficiency - data_cell = torch.transpose(data.cell, 1, 2) - - # Use CUDA kernel for the triple loop computation - # try: - pbc_graph_cuda = get_pbc_graph_cuda_extension() - - # Call the CUDA implementation - valid_pair_indices, unit_cell, atom_distance_sqr = pbc_graph_cuda.pbc_distance_cuda( - pos1, pos2, data_cell, - num_atoms_per_image_sqr, batch_size, max_rep, float(radius), device - ) - - # Map back to original index1 and index2 - if len(valid_pair_indices) > 0: - index1 = index1.index_select(0, valid_pair_indices.long()) - index2 = index2.index_select(0, valid_pair_indices.long()) - else: - index1 = torch.empty(0, dtype=torch.long, device=device) - index2 = torch.empty(0, dtype=torch.long, device=device) - unit_cell = torch.empty(0, 3, dtype=dtype, device=device) - atom_distance_sqr = torch.empty(0, dtype=dtype, device=device) - - # Sort index1 in ascending order and rearrange other arrays correspondingly - if len(index1) > 0: - sort_indices = torch.argsort(index1) - index1 = index1[sort_indices] - index2 = index2[sort_indices] - unit_cell = unit_cell[sort_indices] - atom_distance_sqr = atom_distance_sqr[sort_indices] - - mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( - natoms=data.natoms, - index=index1, - atom_distance=atom_distance_sqr, - max_num_neighbors_threshold=max_num_neighbors_threshold, - enforce_max_strictly=enforce_max_neighbors_strictly, - ) - - if not torch.all(mask_num_neighbors): - # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors - index1 = torch.masked_select(index1, mask_num_neighbors) - index2 = torch.masked_select(index2, mask_num_neighbors) - unit_cell = torch.masked_select( - unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) - ) - unit_cell = unit_cell.view(-1, 3) - - edge_index = torch.stack((index2, index1)) - +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +CUDA-accelerated PBC graph operations for atomic systems. +""" +import torch +from typing import Optional, List +from .pbc_graph_legacy import get_max_neighbors_mask +from .extensions.cuda_ops import get_pbc_graph_cuda_extension + +def radius_graph_pbc_cuda( + data, + radius, + max_num_neighbors_threshold, + enforce_max_neighbors_strictly: bool = False, + pbc=None, + dtype=torch.float64, +): + """ + Memory-efficient CUDA-accelerated version of radius_graph_pbc. + + This implementation follows the memory-efficient approach with triple loops + but accelerates the distance computation using CUDA kernels. + """ + if pbc is None: + pbc = [True, True, True] + + device = data.pos.device + batch_size = len(data.natoms) + + # Handle PBC settings + if hasattr(data, "pbc"): + data.pbc = torch.atleast_2d(data.pbc) + for i in range(3): + if not torch.any(data.pbc[:, i]).item(): + pbc[i] = False + elif torch.all(data.pbc[:, i]).item(): + pbc[i] = True + else: + raise RuntimeError( + "Different structures in the batch have different PBC configurations." + ) + + # position of the atoms + atom_pos = data.pos + + # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch + num_atoms_per_image = data.natoms + num_atoms_per_image_sqr = (num_atoms_per_image**2).long() + + # index offset between images + index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image + + index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) + num_atoms_per_image_expand = torch.repeat_interleave( + num_atoms_per_image, num_atoms_per_image_sqr + ) + + # Compute atom pair indices + num_atom_pairs = torch.sum(num_atoms_per_image_sqr) + index_sqr_offset = ( + torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr + ) + index_sqr_offset = torch.repeat_interleave( + index_sqr_offset, num_atoms_per_image_sqr + ) + atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset + + # Compute the indices for the pairs of atoms (using division and mod) + index1 = ( + torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") + ) + index_offset_expand + index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand + # Get the positions for each atom + pos1 = torch.index_select(atom_pos, 0, index1) + pos2 = torch.index_select(atom_pos, 0, index2) + + # Calculate required number of unit cells in each direction for PBC + cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) + cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) + + if pbc[0]: + inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) + rep_a1 = torch.ceil(radius * inv_min_dist_a1) + else: + rep_a1 = data.cell.new_zeros(1) + + if pbc[1]: + cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) + inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) + rep_a2 = torch.ceil(radius * inv_min_dist_a2) + else: + rep_a2 = data.cell.new_zeros(1) + + if pbc[2]: + cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) + inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) + rep_a3 = torch.ceil(radius * inv_min_dist_a3) + else: + rep_a3 = data.cell.new_zeros(1) + + # Take the max over all images for uniformity + max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())] + + # Pre-transpose data_cell for efficiency + data_cell = torch.transpose(data.cell, 1, 2) + + # Use CUDA kernel for the triple loop computation + # try: + pbc_graph_cuda = get_pbc_graph_cuda_extension() + + # Call the CUDA implementation + valid_pair_indices, unit_cell, atom_distance_sqr = pbc_graph_cuda.pbc_distance_cuda( + pos1, pos2, data_cell, + num_atoms_per_image_sqr, batch_size, max_rep, float(radius), device + ) + + # Map back to original index1 and index2 + if len(valid_pair_indices) > 0: + index1 = index1.index_select(0, valid_pair_indices.long()) + index2 = index2.index_select(0, valid_pair_indices.long()) + else: + index1 = torch.empty(0, dtype=torch.long, device=device) + index2 = torch.empty(0, dtype=torch.long, device=device) + unit_cell = torch.empty(0, 3, dtype=dtype, device=device) + atom_distance_sqr = torch.empty(0, dtype=dtype, device=device) + + # Sort index1 in ascending order and rearrange other arrays correspondingly + if len(index1) > 0: + sort_indices = torch.argsort(index1) + index1 = index1[sort_indices] + index2 = index2[sort_indices] + unit_cell = unit_cell[sort_indices] + atom_distance_sqr = atom_distance_sqr[sort_indices] + + mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( + natoms=data.natoms, + index=index1, + atom_distance=atom_distance_sqr, + max_num_neighbors_threshold=max_num_neighbors_threshold, + enforce_max_strictly=enforce_max_neighbors_strictly, + ) + + if not torch.all(mask_num_neighbors): + # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors + index1 = torch.masked_select(index1, mask_num_neighbors) + index2 = torch.masked_select(index2, mask_num_neighbors) + unit_cell = torch.masked_select( + unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) + ) + unit_cell = unit_cell.view(-1, 3) + + edge_index = torch.stack((index2, index1)) + return edge_index, unit_cell, num_neighbors_image \ No newline at end of file diff --git a/mace-bench/src/batchopt/pbc_graph_legacy.py b/mace-bench/src/batchopt/pbc_graph_legacy.py index 484ed0e3290a04ebb293eaf48b668b51a9b16d26..ca75c5368e84e36ca33a08cd1f0755948d36288f 100644 --- a/mace-bench/src/batchopt/pbc_graph_legacy.py +++ b/mace-bench/src/batchopt/pbc_graph_legacy.py @@ -1,563 +1,563 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import numpy as np -import torch -import torch.nn as nn -import torch_geometric -from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas -from matplotlib.figure import Figure -from torch_geometric.data import Data -from torch_geometric.utils import remove_self_loops -from torch_scatter import scatter, segment_coo, segment_csr - - -if TYPE_CHECKING: - from collections.abc import Mapping - - from torch.nn.modules.module import _IncompatibleKeys - - -DEFAULT_ENV_VARS = { - # Expandable segments is a new cuda feature that helps with memory fragmentation during frequent allocations (ie: in the case of variable batch sizes). - # see https://pytorch.org/docs/stable/notes/cuda.html. - "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", -} - - -def get_pbc_distances( - pos, - edge_index, - cell, - cell_offsets, - neighbors, - return_offsets: bool = False, - return_distance_vec: bool = False, -): - row, col = edge_index - - distance_vectors = pos[row] - pos[col] - - # correct for pbc - neighbors = neighbors.to(cell.device) - cell = torch.repeat_interleave(cell, neighbors, dim=0) - offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) - distance_vectors += offsets - - # compute distances - distances = distance_vectors.norm(dim=-1) - - # redundancy: remove zero distances - nonzero_idx = torch.arange(len(distances), device=distances.device)[distances != 0] - edge_index = edge_index[:, nonzero_idx] - distances = distances[nonzero_idx] - - out = { - "edge_index": edge_index, - "distances": distances, - } - - if return_distance_vec: - out["distance_vec"] = distance_vectors[nonzero_idx] - - if return_offsets: - out["offsets"] = offsets[nonzero_idx] - - return out - -def radius_graph_pbc_mem_effi( - data, - radius, - max_num_neighbors_threshold, - enforce_max_neighbors_strictly: bool = False, - pbc=None, - dtype=torch.float64, -): - if pbc is None: - pbc = [True, True, True] - device = data.pos.device - batch_size = len(data.natoms) - - if hasattr(data, "pbc"): - data.pbc = torch.atleast_2d(data.pbc) - for i in range(3): - if not torch.any(data.pbc[:, i]).item(): - pbc[i] = False - elif torch.all(data.pbc[:, i]).item(): - pbc[i] = True - else: - raise RuntimeError( - "Different structures in the batch have different PBC configurations. This is not currently supported." - ) - - # position of the atoms - atom_pos = data.pos - - # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch - num_atoms_per_image = data.natoms - num_atoms_per_image_sqr = (num_atoms_per_image**2).long() - - # index offset between images - index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image - - index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) - num_atoms_per_image_expand = torch.repeat_interleave( - num_atoms_per_image, num_atoms_per_image_sqr - ) - - # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image - # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement - # the following (but 10x faster since it removes the for loop) - # for batch_idx in range(batch_size): - # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0) - num_atom_pairs = torch.sum(num_atoms_per_image_sqr) - index_sqr_offset = ( - torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr - ) - index_sqr_offset = torch.repeat_interleave( - index_sqr_offset, num_atoms_per_image_sqr - ) - atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset - - # Compute the indices for the pairs of atoms (using division and mod) - # If the systems get too large this apporach could run into numerical precision issues - index1 = ( - torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") - ) + index_offset_expand - index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand - # Get the positions for each atom - pos1 = torch.index_select(atom_pos, 0, index1) - pos2 = torch.index_select(atom_pos, 0, index2) - - # Calculate required number of unit cells in each direction. - # Smallest distance between planes separated by a1 is - # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane. - # Note that the unit cell volume V = a1 * (a2 x a3) and that - # (a2 x a3) / V is also the reciprocal primitive vector - # (crystallographer's definition). - - cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) - cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) - - if pbc[0]: - inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) - rep_a1 = torch.ceil(radius * inv_min_dist_a1) - else: - rep_a1 = data.cell.new_zeros(1) - - if pbc[1]: - cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) - inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) - rep_a2 = torch.ceil(radius * inv_min_dist_a2) - else: - rep_a2 = data.cell.new_zeros(1) - - if pbc[2]: - cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) - inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) - rep_a3 = torch.ceil(radius * inv_min_dist_a3) - else: - rep_a3 = data.cell.new_zeros(1) - - # Take the max over all images for uniformity. This is essentially padding. - # Note that this can significantly increase the number of computed distances - # if the required repetitions are very different between images - # (which they usually are). Changing this to sparse (scatter) operations - # might be worth the effort if this function becomes a bottleneck. - max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())] - - # Memory-efficient implementation: iterate over unit cell offsets instead of expanding all at once - # This reduces memory usage by avoiding the creation of large tensor products - all_index1 = [] - all_index2 = [] - all_unit_cell = [] - all_atom_distance_sqr = [] - - # Pre-transpose data_cell for efficiency - data_cell = torch.transpose(data.cell, 1, 2) - - # Iterate over each unit cell offset combination - for i in range(-max_rep[0], max_rep[0] + 1): - for j in range(-max_rep[1], max_rep[1] + 1): - for k in range(-max_rep[2], max_rep[2] + 1): - # Create unit cell offset - unit_cell_offset = torch.tensor([i, j, k], device=device, dtype=dtype) - - # Compute the x, y, z positional offsets for this specific cell in each image - # unit_cell_offset_batch = unit_cell_offset.view(3, 1).expand(3, batch_size) - unit_cell_offset_batch = unit_cell_offset.view(1,3,1).expand(batch_size, -1, -1) - pbc_offsets = torch.bmm(data_cell, unit_cell_offset_batch).squeeze(-1) - pbc_offsets_per_atom = torch.repeat_interleave( - pbc_offsets, num_atoms_per_image_sqr, dim=0 - ) - - # Apply PBC offsets to the second atom positions - pos2_offset = pos2 + pbc_offsets_per_atom - - # Compute the squared distance between atoms - atom_distance_sqr = torch.sum((pos1 - pos2_offset) ** 2, dim=1) - - # Remove pairs that are too far apart - mask_within_radius = torch.le(atom_distance_sqr, radius * radius) - # Remove pairs with the same atoms (distance = 0.0) - mask_not_same = torch.gt(atom_distance_sqr, 0.0001) - mask = torch.logical_and(mask_within_radius, mask_not_same) - - # Only keep valid pairs for this unit cell offset - if torch.any(mask): - valid_index1 = torch.masked_select(index1, mask) - valid_index2 = torch.masked_select(index2, mask) - valid_distances = torch.masked_select(atom_distance_sqr, mask) - valid_unit_cell = unit_cell_offset.unsqueeze(0).repeat(valid_index1.shape[0], 1) - - all_index1.append(valid_index1) - all_index2.append(valid_index2) - all_unit_cell.append(valid_unit_cell) - all_atom_distance_sqr.append(valid_distances) - - # Concatenate all results - if len(all_index1) > 0: - index1 = torch.cat(all_index1) - index2 = torch.cat(all_index2) - unit_cell = torch.cat(all_unit_cell) - atom_distance_sqr = torch.cat(all_atom_distance_sqr) - - # Sort index1 in ascending order and rearrange other arrays correspondingly - sort_indices = torch.argsort(index1) - index1 = index1[sort_indices] - index2 = index2[sort_indices] - unit_cell = unit_cell[sort_indices] - atom_distance_sqr = atom_distance_sqr[sort_indices] - - else: - # No valid pairs found - index1 = torch.empty(0, dtype=torch.long, device=device) - index2 = torch.empty(0, dtype=torch.long, device=device) - unit_cell = torch.empty(0, 3, dtype=dtype, device=device) - atom_distance_sqr = torch.empty(0, dtype=dtype, device=device) - - mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( - natoms=data.natoms, - index=index1, - atom_distance=atom_distance_sqr, - max_num_neighbors_threshold=max_num_neighbors_threshold, - enforce_max_strictly=enforce_max_neighbors_strictly, - ) - - if not torch.all(mask_num_neighbors): - # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors - index1 = torch.masked_select(index1, mask_num_neighbors) - index2 = torch.masked_select(index2, mask_num_neighbors) - unit_cell = torch.masked_select( - unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) - ) - unit_cell = unit_cell.view(-1, 3) - - edge_index = torch.stack((index2, index1)) - - return edge_index, unit_cell, num_neighbors_image - - -def radius_graph_pbc( - data, - radius, - max_num_neighbors_threshold, - enforce_max_neighbors_strictly: bool = False, - pbc=None, - dtype=torch.float64, -): - if pbc is None: - pbc = [True, True, True] - device = data.pos.device - batch_size = len(data.natoms) - - if hasattr(data, "pbc"): - data.pbc = torch.atleast_2d(data.pbc) - for i in range(3): - if not torch.any(data.pbc[:, i]).item(): - pbc[i] = False - elif torch.all(data.pbc[:, i]).item(): - pbc[i] = True - else: - raise RuntimeError( - "Different structures in the batch have different PBC configurations. This is not currently supported." - ) - - # position of the atoms - atom_pos = data.pos - - # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch - num_atoms_per_image = data.natoms - num_atoms_per_image_sqr = (num_atoms_per_image**2).long() - - # index offset between images - index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image - - index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) - num_atoms_per_image_expand = torch.repeat_interleave( - num_atoms_per_image, num_atoms_per_image_sqr - ) - - # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image - # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement - # the following (but 10x faster since it removes the for loop) - # for batch_idx in range(batch_size): - # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0) - num_atom_pairs = torch.sum(num_atoms_per_image_sqr) - index_sqr_offset = ( - torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr - ) - index_sqr_offset = torch.repeat_interleave( - index_sqr_offset, num_atoms_per_image_sqr - ) - atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset - - # Compute the indices for the pairs of atoms (using division and mod) - # If the systems get too large this apporach could run into numerical precision issues - index1 = ( - torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") - ) + index_offset_expand - index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand - # Get the positions for each atom - pos1 = torch.index_select(atom_pos, 0, index1) - pos2 = torch.index_select(atom_pos, 0, index2) - - # Calculate required number of unit cells in each direction. - # Smallest distance between planes separated by a1 is - # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane. - # Note that the unit cell volume V = a1 * (a2 x a3) and that - # (a2 x a3) / V is also the reciprocal primitive vector - # (crystallographer's definition). - - cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) - cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) - - if pbc[0]: - inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) - rep_a1 = torch.ceil(radius * inv_min_dist_a1) - else: - rep_a1 = data.cell.new_zeros(1) - - if pbc[1]: - cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) - inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) - rep_a2 = torch.ceil(radius * inv_min_dist_a2) - else: - rep_a2 = data.cell.new_zeros(1) - - if pbc[2]: - cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) - inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) - rep_a3 = torch.ceil(radius * inv_min_dist_a3) - else: - rep_a3 = data.cell.new_zeros(1) - - # Take the max over all images for uniformity. This is essentially padding. - # Note that this can significantly increase the number of computed distances - # if the required repetitions are very different between images - # (which they usually are). Changing this to sparse (scatter) operations - # might be worth the effort if this function becomes a bottleneck. - max_rep = [2*rep_a1.max(), 2*rep_a2.max(), 2*rep_a3.max()] - # max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()] - # max_rep = [torch.tensor(1, device=device)] * 3 - # logging.info(f"&&& max_rep: {max_rep}") - - # Tensor of unit cells - cells_per_dim = [ - torch.arange(-rep.item(), rep.item() + 1, device=device, dtype=dtype) - for rep in max_rep - ] - unit_cell = torch.cartesian_prod(*cells_per_dim) - num_cells = len(unit_cell) - unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat(len(index2), 1, 1) - unit_cell = torch.transpose(unit_cell, 0, 1) - unit_cell_batch = unit_cell.view(1, 3, num_cells).expand(batch_size, -1, -1) - - # Compute the x, y, z positional offsets for each cell in each image - # data_cell = torch.transpose(data.cell, 1, 2) - data_cell = torch.transpose(data.cell, 1, 2) - pbc_offsets = torch.bmm(data_cell, unit_cell_batch) - pbc_offsets_per_atom = torch.repeat_interleave( - pbc_offsets, num_atoms_per_image_sqr, dim=0 - ) - - # Expand the positions and indices for the 9 cells - pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells) - pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells) - index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1) - index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1) - # Add the PBC offsets for the second atom - pos2 = pos2 + pbc_offsets_per_atom - - # Compute the squared distance between atoms - atom_distance_sqr = torch.sum((pos1 - pos2) ** 2, dim=1) - atom_distance_sqr = atom_distance_sqr.view(-1) - - # Remove pairs that are too far apart - mask_within_radius = torch.le(atom_distance_sqr, radius * radius) - # Remove pairs with the same atoms (distance = 0.0) - mask_not_same = torch.gt(atom_distance_sqr, 0.0001) - mask = torch.logical_and(mask_within_radius, mask_not_same) - index1 = torch.masked_select(index1, mask) - index2 = torch.masked_select(index2, mask) - unit_cell = torch.masked_select( - unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3) - ) - unit_cell = unit_cell.view(-1, 3) - atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask) - - mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( - natoms=data.natoms, - index=index1, - atom_distance=atom_distance_sqr, - max_num_neighbors_threshold=max_num_neighbors_threshold, - enforce_max_strictly=enforce_max_neighbors_strictly, - ) - - if not torch.all(mask_num_neighbors): - # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors - index1 = torch.masked_select(index1, mask_num_neighbors) - index2 = torch.masked_select(index2, mask_num_neighbors) - unit_cell = torch.masked_select( - unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) - ) - unit_cell = unit_cell.view(-1, 3) - - edge_index = torch.stack((index2, index1)) - - return edge_index, unit_cell, num_neighbors_image - - -@torch.compiler.disable -def get_max_neighbors_mask( - natoms, - index, - atom_distance, - max_num_neighbors_threshold, - degeneracy_tolerance: float = 0.01, - enforce_max_strictly: bool = False, -): - """ - Give a mask that filters out edges so that each atom has at most - `max_num_neighbors_threshold` neighbors. - Assumes that `index` is sorted. - - Enforcing the max strictly can force the arbitrary choice between - degenerate edges. This can lead to undesired behaviors; for - example, bulk formation energies which are not invariant to - unit cell choice. - - A degeneracy tolerance can help prevent sudden changes in edge - existence from small changes in atom position, for example, - rounding errors, slab relaxation, temperature, etc. - """ - - device = natoms.device - num_atoms = natoms.sum() - - # Get number of neighbors - # segment_coo assumes sorted index - ones = index.new_ones(1).expand_as(index) - num_neighbors = segment_coo(ones, index, dim_size=num_atoms) - max_num_neighbors = num_neighbors.max() - num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold) - - # Get number of (thresholded) neighbors per image - image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long) - image_indptr[1:] = torch.cumsum(natoms, dim=0) - num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) - - # If max_num_neighbors is below the threshold, return early - if ( - max_num_neighbors <= max_num_neighbors_threshold - or max_num_neighbors_threshold <= 0 - ): - mask_num_neighbors = torch.tensor([True], dtype=bool, device=device).expand_as( - index - ) - return mask_num_neighbors, num_neighbors_image - - # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors. - # Fill with infinity so we can easily remove unused distances later. - distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device) - - # Create an index map to map distances from atom_distance to distance_sort - # index_sort_map assumes index to be sorted - index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors - index_neighbor_offset_expand = torch.repeat_interleave( - index_neighbor_offset, num_neighbors - ) - index_sort_map = ( - index * max_num_neighbors - + torch.arange(len(index), device=device) - - index_neighbor_offset_expand - ) - distance_sort.index_copy_(0, index_sort_map, atom_distance) - distance_sort = distance_sort.view(num_atoms, max_num_neighbors) - - # Sort neighboring atoms based on distance - distance_sort, index_sort = torch.sort(distance_sort, dim=1) - - # Select the max_num_neighbors_threshold neighbors that are closest - if enforce_max_strictly: - distance_sort = distance_sort[:, :max_num_neighbors_threshold] - index_sort = index_sort[:, :max_num_neighbors_threshold] - max_num_included = max_num_neighbors_threshold - - else: - effective_cutoff = ( - distance_sort[:, max_num_neighbors_threshold] + degeneracy_tolerance - ) - is_included = torch.le(distance_sort.T, effective_cutoff) - - # Set all undesired edges to infinite length to be removed later - distance_sort[~is_included.T] = np.inf - - # Subselect tensors for efficiency - num_included_per_atom = torch.sum(is_included, dim=0) - max_num_included = torch.max(num_included_per_atom) - distance_sort = distance_sort[:, :max_num_included] - index_sort = index_sort[:, :max_num_included] - - # Recompute the number of neighbors - num_neighbors_thresholded = num_neighbors.clamp(max=num_included_per_atom) - - num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) - - # Offset index_sort so that it indexes into index - index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand( - -1, max_num_included - ) - # Remove "unused pairs" with infinite distances - mask_finite = torch.isfinite(distance_sort) - index_sort = torch.masked_select(index_sort, mask_finite) - - # At this point index_sort contains the index into index of the - # closest max_num_neighbors_threshold neighbors per atom - # Create a mask to remove all pairs not in index_sort - mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool) - mask_num_neighbors.index_fill_(0, index_sort, True) - - return mask_num_neighbors, num_neighbors_image - - -def get_pruned_edge_idx( - edge_index, num_atoms: int, max_neigh: float = 1e9 -) -> torch.Tensor: - assert num_atoms is not None # TODO: Shouldn't be necessary - - # removes neighbors > max_neigh - # assumes neighbors are sorted in increasing distance - _nonmax_idx_list = [] - for i in range(num_atoms): - idx_i = torch.arange(len(edge_index[1]))[(edge_index[1] == i)][:max_neigh] - _nonmax_idx_list.append(idx_i) - return torch.cat(_nonmax_idx_list) +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +import torch.nn as nn +import torch_geometric +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +from torch_geometric.data import Data +from torch_geometric.utils import remove_self_loops +from torch_scatter import scatter, segment_coo, segment_csr + + +if TYPE_CHECKING: + from collections.abc import Mapping + + from torch.nn.modules.module import _IncompatibleKeys + + +DEFAULT_ENV_VARS = { + # Expandable segments is a new cuda feature that helps with memory fragmentation during frequent allocations (ie: in the case of variable batch sizes). + # see https://pytorch.org/docs/stable/notes/cuda.html. + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", +} + + +def get_pbc_distances( + pos, + edge_index, + cell, + cell_offsets, + neighbors, + return_offsets: bool = False, + return_distance_vec: bool = False, +): + row, col = edge_index + + distance_vectors = pos[row] - pos[col] + + # correct for pbc + neighbors = neighbors.to(cell.device) + cell = torch.repeat_interleave(cell, neighbors, dim=0) + offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) + distance_vectors += offsets + + # compute distances + distances = distance_vectors.norm(dim=-1) + + # redundancy: remove zero distances + nonzero_idx = torch.arange(len(distances), device=distances.device)[distances != 0] + edge_index = edge_index[:, nonzero_idx] + distances = distances[nonzero_idx] + + out = { + "edge_index": edge_index, + "distances": distances, + } + + if return_distance_vec: + out["distance_vec"] = distance_vectors[nonzero_idx] + + if return_offsets: + out["offsets"] = offsets[nonzero_idx] + + return out + +def radius_graph_pbc_mem_effi( + data, + radius, + max_num_neighbors_threshold, + enforce_max_neighbors_strictly: bool = False, + pbc=None, + dtype=torch.float64, +): + if pbc is None: + pbc = [True, True, True] + device = data.pos.device + batch_size = len(data.natoms) + + if hasattr(data, "pbc"): + data.pbc = torch.atleast_2d(data.pbc) + for i in range(3): + if not torch.any(data.pbc[:, i]).item(): + pbc[i] = False + elif torch.all(data.pbc[:, i]).item(): + pbc[i] = True + else: + raise RuntimeError( + "Different structures in the batch have different PBC configurations. This is not currently supported." + ) + + # position of the atoms + atom_pos = data.pos + + # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch + num_atoms_per_image = data.natoms + num_atoms_per_image_sqr = (num_atoms_per_image**2).long() + + # index offset between images + index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image + + index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) + num_atoms_per_image_expand = torch.repeat_interleave( + num_atoms_per_image, num_atoms_per_image_sqr + ) + + # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image + # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement + # the following (but 10x faster since it removes the for loop) + # for batch_idx in range(batch_size): + # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0) + num_atom_pairs = torch.sum(num_atoms_per_image_sqr) + index_sqr_offset = ( + torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr + ) + index_sqr_offset = torch.repeat_interleave( + index_sqr_offset, num_atoms_per_image_sqr + ) + atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset + + # Compute the indices for the pairs of atoms (using division and mod) + # If the systems get too large this apporach could run into numerical precision issues + index1 = ( + torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") + ) + index_offset_expand + index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand + # Get the positions for each atom + pos1 = torch.index_select(atom_pos, 0, index1) + pos2 = torch.index_select(atom_pos, 0, index2) + + # Calculate required number of unit cells in each direction. + # Smallest distance between planes separated by a1 is + # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane. + # Note that the unit cell volume V = a1 * (a2 x a3) and that + # (a2 x a3) / V is also the reciprocal primitive vector + # (crystallographer's definition). + + cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) + cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) + + if pbc[0]: + inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) + rep_a1 = torch.ceil(radius * inv_min_dist_a1) + else: + rep_a1 = data.cell.new_zeros(1) + + if pbc[1]: + cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) + inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) + rep_a2 = torch.ceil(radius * inv_min_dist_a2) + else: + rep_a2 = data.cell.new_zeros(1) + + if pbc[2]: + cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) + inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) + rep_a3 = torch.ceil(radius * inv_min_dist_a3) + else: + rep_a3 = data.cell.new_zeros(1) + + # Take the max over all images for uniformity. This is essentially padding. + # Note that this can significantly increase the number of computed distances + # if the required repetitions are very different between images + # (which they usually are). Changing this to sparse (scatter) operations + # might be worth the effort if this function becomes a bottleneck. + max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())] + + # Memory-efficient implementation: iterate over unit cell offsets instead of expanding all at once + # This reduces memory usage by avoiding the creation of large tensor products + all_index1 = [] + all_index2 = [] + all_unit_cell = [] + all_atom_distance_sqr = [] + + # Pre-transpose data_cell for efficiency + data_cell = torch.transpose(data.cell, 1, 2) + + # Iterate over each unit cell offset combination + for i in range(-max_rep[0], max_rep[0] + 1): + for j in range(-max_rep[1], max_rep[1] + 1): + for k in range(-max_rep[2], max_rep[2] + 1): + # Create unit cell offset + unit_cell_offset = torch.tensor([i, j, k], device=device, dtype=dtype) + + # Compute the x, y, z positional offsets for this specific cell in each image + # unit_cell_offset_batch = unit_cell_offset.view(3, 1).expand(3, batch_size) + unit_cell_offset_batch = unit_cell_offset.view(1,3,1).expand(batch_size, -1, -1) + pbc_offsets = torch.bmm(data_cell, unit_cell_offset_batch).squeeze(-1) + pbc_offsets_per_atom = torch.repeat_interleave( + pbc_offsets, num_atoms_per_image_sqr, dim=0 + ) + + # Apply PBC offsets to the second atom positions + pos2_offset = pos2 + pbc_offsets_per_atom + + # Compute the squared distance between atoms + atom_distance_sqr = torch.sum((pos1 - pos2_offset) ** 2, dim=1) + + # Remove pairs that are too far apart + mask_within_radius = torch.le(atom_distance_sqr, radius * radius) + # Remove pairs with the same atoms (distance = 0.0) + mask_not_same = torch.gt(atom_distance_sqr, 0.0001) + mask = torch.logical_and(mask_within_radius, mask_not_same) + + # Only keep valid pairs for this unit cell offset + if torch.any(mask): + valid_index1 = torch.masked_select(index1, mask) + valid_index2 = torch.masked_select(index2, mask) + valid_distances = torch.masked_select(atom_distance_sqr, mask) + valid_unit_cell = unit_cell_offset.unsqueeze(0).repeat(valid_index1.shape[0], 1) + + all_index1.append(valid_index1) + all_index2.append(valid_index2) + all_unit_cell.append(valid_unit_cell) + all_atom_distance_sqr.append(valid_distances) + + # Concatenate all results + if len(all_index1) > 0: + index1 = torch.cat(all_index1) + index2 = torch.cat(all_index2) + unit_cell = torch.cat(all_unit_cell) + atom_distance_sqr = torch.cat(all_atom_distance_sqr) + + # Sort index1 in ascending order and rearrange other arrays correspondingly + sort_indices = torch.argsort(index1) + index1 = index1[sort_indices] + index2 = index2[sort_indices] + unit_cell = unit_cell[sort_indices] + atom_distance_sqr = atom_distance_sqr[sort_indices] + + else: + # No valid pairs found + index1 = torch.empty(0, dtype=torch.long, device=device) + index2 = torch.empty(0, dtype=torch.long, device=device) + unit_cell = torch.empty(0, 3, dtype=dtype, device=device) + atom_distance_sqr = torch.empty(0, dtype=dtype, device=device) + + mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( + natoms=data.natoms, + index=index1, + atom_distance=atom_distance_sqr, + max_num_neighbors_threshold=max_num_neighbors_threshold, + enforce_max_strictly=enforce_max_neighbors_strictly, + ) + + if not torch.all(mask_num_neighbors): + # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors + index1 = torch.masked_select(index1, mask_num_neighbors) + index2 = torch.masked_select(index2, mask_num_neighbors) + unit_cell = torch.masked_select( + unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) + ) + unit_cell = unit_cell.view(-1, 3) + + edge_index = torch.stack((index2, index1)) + + return edge_index, unit_cell, num_neighbors_image + + +def radius_graph_pbc( + data, + radius, + max_num_neighbors_threshold, + enforce_max_neighbors_strictly: bool = False, + pbc=None, + dtype=torch.float64, +): + if pbc is None: + pbc = [True, True, True] + device = data.pos.device + batch_size = len(data.natoms) + + if hasattr(data, "pbc"): + data.pbc = torch.atleast_2d(data.pbc) + for i in range(3): + if not torch.any(data.pbc[:, i]).item(): + pbc[i] = False + elif torch.all(data.pbc[:, i]).item(): + pbc[i] = True + else: + raise RuntimeError( + "Different structures in the batch have different PBC configurations. This is not currently supported." + ) + + # position of the atoms + atom_pos = data.pos + + # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch + num_atoms_per_image = data.natoms + num_atoms_per_image_sqr = (num_atoms_per_image**2).long() + + # index offset between images + index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image + + index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) + num_atoms_per_image_expand = torch.repeat_interleave( + num_atoms_per_image, num_atoms_per_image_sqr + ) + + # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image + # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement + # the following (but 10x faster since it removes the for loop) + # for batch_idx in range(batch_size): + # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0) + num_atom_pairs = torch.sum(num_atoms_per_image_sqr) + index_sqr_offset = ( + torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr + ) + index_sqr_offset = torch.repeat_interleave( + index_sqr_offset, num_atoms_per_image_sqr + ) + atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset + + # Compute the indices for the pairs of atoms (using division and mod) + # If the systems get too large this apporach could run into numerical precision issues + index1 = ( + torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") + ) + index_offset_expand + index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand + # Get the positions for each atom + pos1 = torch.index_select(atom_pos, 0, index1) + pos2 = torch.index_select(atom_pos, 0, index2) + + # Calculate required number of unit cells in each direction. + # Smallest distance between planes separated by a1 is + # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane. + # Note that the unit cell volume V = a1 * (a2 x a3) and that + # (a2 x a3) / V is also the reciprocal primitive vector + # (crystallographer's definition). + + cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) + cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) + + if pbc[0]: + inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) + rep_a1 = torch.ceil(radius * inv_min_dist_a1) + else: + rep_a1 = data.cell.new_zeros(1) + + if pbc[1]: + cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) + inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) + rep_a2 = torch.ceil(radius * inv_min_dist_a2) + else: + rep_a2 = data.cell.new_zeros(1) + + if pbc[2]: + cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) + inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) + rep_a3 = torch.ceil(radius * inv_min_dist_a3) + else: + rep_a3 = data.cell.new_zeros(1) + + # Take the max over all images for uniformity. This is essentially padding. + # Note that this can significantly increase the number of computed distances + # if the required repetitions are very different between images + # (which they usually are). Changing this to sparse (scatter) operations + # might be worth the effort if this function becomes a bottleneck. + max_rep = [2*rep_a1.max(), 2*rep_a2.max(), 2*rep_a3.max()] + # max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()] + # max_rep = [torch.tensor(1, device=device)] * 3 + # logging.info(f"&&& max_rep: {max_rep}") + + # Tensor of unit cells + cells_per_dim = [ + torch.arange(-rep.item(), rep.item() + 1, device=device, dtype=dtype) + for rep in max_rep + ] + unit_cell = torch.cartesian_prod(*cells_per_dim) + num_cells = len(unit_cell) + unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat(len(index2), 1, 1) + unit_cell = torch.transpose(unit_cell, 0, 1) + unit_cell_batch = unit_cell.view(1, 3, num_cells).expand(batch_size, -1, -1) + + # Compute the x, y, z positional offsets for each cell in each image + # data_cell = torch.transpose(data.cell, 1, 2) + data_cell = torch.transpose(data.cell, 1, 2) + pbc_offsets = torch.bmm(data_cell, unit_cell_batch) + pbc_offsets_per_atom = torch.repeat_interleave( + pbc_offsets, num_atoms_per_image_sqr, dim=0 + ) + + # Expand the positions and indices for the 9 cells + pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells) + pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells) + index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1) + index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1) + # Add the PBC offsets for the second atom + pos2 = pos2 + pbc_offsets_per_atom + + # Compute the squared distance between atoms + atom_distance_sqr = torch.sum((pos1 - pos2) ** 2, dim=1) + atom_distance_sqr = atom_distance_sqr.view(-1) + + # Remove pairs that are too far apart + mask_within_radius = torch.le(atom_distance_sqr, radius * radius) + # Remove pairs with the same atoms (distance = 0.0) + mask_not_same = torch.gt(atom_distance_sqr, 0.0001) + mask = torch.logical_and(mask_within_radius, mask_not_same) + index1 = torch.masked_select(index1, mask) + index2 = torch.masked_select(index2, mask) + unit_cell = torch.masked_select( + unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3) + ) + unit_cell = unit_cell.view(-1, 3) + atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask) + + mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( + natoms=data.natoms, + index=index1, + atom_distance=atom_distance_sqr, + max_num_neighbors_threshold=max_num_neighbors_threshold, + enforce_max_strictly=enforce_max_neighbors_strictly, + ) + + if not torch.all(mask_num_neighbors): + # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors + index1 = torch.masked_select(index1, mask_num_neighbors) + index2 = torch.masked_select(index2, mask_num_neighbors) + unit_cell = torch.masked_select( + unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) + ) + unit_cell = unit_cell.view(-1, 3) + + edge_index = torch.stack((index2, index1)) + + return edge_index, unit_cell, num_neighbors_image + + +@torch.compiler.disable +def get_max_neighbors_mask( + natoms, + index, + atom_distance, + max_num_neighbors_threshold, + degeneracy_tolerance: float = 0.01, + enforce_max_strictly: bool = False, +): + """ + Give a mask that filters out edges so that each atom has at most + `max_num_neighbors_threshold` neighbors. + Assumes that `index` is sorted. + + Enforcing the max strictly can force the arbitrary choice between + degenerate edges. This can lead to undesired behaviors; for + example, bulk formation energies which are not invariant to + unit cell choice. + + A degeneracy tolerance can help prevent sudden changes in edge + existence from small changes in atom position, for example, + rounding errors, slab relaxation, temperature, etc. + """ + + device = natoms.device + num_atoms = natoms.sum() + + # Get number of neighbors + # segment_coo assumes sorted index + ones = index.new_ones(1).expand_as(index) + num_neighbors = segment_coo(ones, index, dim_size=num_atoms) + max_num_neighbors = num_neighbors.max() + num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold) + + # Get number of (thresholded) neighbors per image + image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long) + image_indptr[1:] = torch.cumsum(natoms, dim=0) + num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) + + # If max_num_neighbors is below the threshold, return early + if ( + max_num_neighbors <= max_num_neighbors_threshold + or max_num_neighbors_threshold <= 0 + ): + mask_num_neighbors = torch.tensor([True], dtype=bool, device=device).expand_as( + index + ) + return mask_num_neighbors, num_neighbors_image + + # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors. + # Fill with infinity so we can easily remove unused distances later. + distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device) + + # Create an index map to map distances from atom_distance to distance_sort + # index_sort_map assumes index to be sorted + index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors + index_neighbor_offset_expand = torch.repeat_interleave( + index_neighbor_offset, num_neighbors + ) + index_sort_map = ( + index * max_num_neighbors + + torch.arange(len(index), device=device) + - index_neighbor_offset_expand + ) + distance_sort.index_copy_(0, index_sort_map, atom_distance) + distance_sort = distance_sort.view(num_atoms, max_num_neighbors) + + # Sort neighboring atoms based on distance + distance_sort, index_sort = torch.sort(distance_sort, dim=1) + + # Select the max_num_neighbors_threshold neighbors that are closest + if enforce_max_strictly: + distance_sort = distance_sort[:, :max_num_neighbors_threshold] + index_sort = index_sort[:, :max_num_neighbors_threshold] + max_num_included = max_num_neighbors_threshold + + else: + effective_cutoff = ( + distance_sort[:, max_num_neighbors_threshold] + degeneracy_tolerance + ) + is_included = torch.le(distance_sort.T, effective_cutoff) + + # Set all undesired edges to infinite length to be removed later + distance_sort[~is_included.T] = np.inf + + # Subselect tensors for efficiency + num_included_per_atom = torch.sum(is_included, dim=0) + max_num_included = torch.max(num_included_per_atom) + distance_sort = distance_sort[:, :max_num_included] + index_sort = index_sort[:, :max_num_included] + + # Recompute the number of neighbors + num_neighbors_thresholded = num_neighbors.clamp(max=num_included_per_atom) + + num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) + + # Offset index_sort so that it indexes into index + index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand( + -1, max_num_included + ) + # Remove "unused pairs" with infinite distances + mask_finite = torch.isfinite(distance_sort) + index_sort = torch.masked_select(index_sort, mask_finite) + + # At this point index_sort contains the index into index of the + # closest max_num_neighbors_threshold neighbors per atom + # Create a mask to remove all pairs not in index_sort + mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool) + mask_num_neighbors.index_fill_(0, index_sort, True) + + return mask_num_neighbors, num_neighbors_image + + +def get_pruned_edge_idx( + edge_index, num_atoms: int, max_neigh: float = 1e9 +) -> torch.Tensor: + assert num_atoms is not None # TODO: Shouldn't be necessary + + # removes neighbors > max_neigh + # assumes neighbors are sorted in increasing distance + _nonmax_idx_list = [] + for i in range(num_atoms): + idx_i = torch.arange(len(edge_index[1]))[(edge_index[1] == i)][:max_neigh] + _nonmax_idx_list.append(idx_i) + return torch.cat(_nonmax_idx_list) diff --git a/mace-bench/src/batchopt/relaxation/__init__.py b/mace-bench/src/batchopt/relaxation/__init__.py index 05c842ea2f299b14d22c4d5ea84d8c9bcfb601a3..b2facb46343dd01b9921119b3d3ad060612c2f45 100644 --- a/mace-bench/src/batchopt/relaxation/__init__.py +++ b/mace-bench/src/batchopt/relaxation/__init__.py @@ -1,11 +1,11 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations -from .optimizable import OptimizableBatch, OptimizableUnitCellBatch - -__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"] +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations +from .optimizable import OptimizableBatch, OptimizableUnitCellBatch + +__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"] diff --git a/mace-bench/src/batchopt/relaxation/__pycache__/__init__.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 97873291234c625f16591881a2a0d01c6b6f37cc..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/relaxation/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/relaxation/__pycache__/ase_utils.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/__pycache__/ase_utils.cpython-310.pyc deleted file mode 100644 index 00c49d5df97e8cdf23fbb1800cb92170912d122d..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/relaxation/__pycache__/ase_utils.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/relaxation/__pycache__/optimizable.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/__pycache__/optimizable.cpython-310.pyc deleted file mode 100644 index 824d09ee31525297386353934b8828c591fcb4e6..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/relaxation/__pycache__/optimizable.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/relaxation/ase_utils.py b/mace-bench/src/batchopt/relaxation/ase_utils.py index a827e2361d92823afb5aa11d9f3e4e8b912ce4db..b2bd4e967d442e970bf349991f2c721b419953f1 100644 --- a/mace-bench/src/batchopt/relaxation/ase_utils.py +++ b/mace-bench/src/batchopt/relaxation/ase_utils.py @@ -1,95 +1,95 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. - - - -Utilities to interface OCP models/trainers with the Atomic Simulation -Environment (ASE) -""" - -from __future__ import annotations - -from types import MappingProxyType -from typing import TYPE_CHECKING - -import torch -from ase import Atoms -from ase.calculators.singlepoint import SinglePointCalculator -from ase.constraints import FixAtoms - -if TYPE_CHECKING: - from torch_geometric.data import Batch - - -# system level model predictions have different shapes than expected by ASE -ASE_PROP_RESHAPE = MappingProxyType( - {"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)} -) - - -def batch_to_atoms( - batch: Batch, - results: dict[str, torch.Tensor] | None = None, - wrap_pos: bool = True, - eps: float = 1e-7, -) -> list[Atoms]: - """Convert a data batch to ase Atoms - - Args: - batch: data batch - results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results - are given no calculator will be added to the atoms objects. - wrap_pos: wrap positions back into the cell. - eps: Small number to prevent slightly negative coordinates from being wrapped. - - Returns: - list of Atoms - """ - n_systems = batch.natoms.shape[0] - natoms = batch.natoms.tolist() - numbers = torch.split(batch.atomic_numbers, natoms) - fixed = torch.split(batch.fixed.to(torch.bool), natoms) - if results is not None: - results = { - key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist() - if len(val) == len(batch) - else [v.cpu().detach().numpy() for v in torch.split(val, natoms)] - for key, val in results.items() - } - - positions = torch.split(batch.pos, natoms) - tags = torch.split(batch.tags, natoms) - cells = batch.cell - - atoms_objects = [] - for idx in range(n_systems): - pos = positions[idx].cpu().detach().numpy() - cell = cells[idx].cpu().detach().numpy() - - # TODO take pbc from data - # TODO: &&& ^^^ change this back !!! - # if wrap_pos: - # pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps) - - atoms = Atoms( - numbers=numbers[idx].tolist(), - cell=cell, - positions=pos, - tags=tags[idx].tolist(), - constraint=FixAtoms(mask=fixed[idx].tolist()), - pbc=[True, True, True], - ) - - if results is not None: - calc = SinglePointCalculator( - atoms=atoms, **{key: val[idx] for key, val in results.items()} - ) - atoms.set_calculator(calc) - - atoms_objects.append(atoms) - - return atoms_objects - +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + + + +Utilities to interface OCP models/trainers with the Atomic Simulation +Environment (ASE) +""" + +from __future__ import annotations + +from types import MappingProxyType +from typing import TYPE_CHECKING + +import torch +from ase import Atoms +from ase.calculators.singlepoint import SinglePointCalculator +from ase.constraints import FixAtoms + +if TYPE_CHECKING: + from torch_geometric.data import Batch + + +# system level model predictions have different shapes than expected by ASE +ASE_PROP_RESHAPE = MappingProxyType( + {"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)} +) + + +def batch_to_atoms( + batch: Batch, + results: dict[str, torch.Tensor] | None = None, + wrap_pos: bool = True, + eps: float = 1e-7, +) -> list[Atoms]: + """Convert a data batch to ase Atoms + + Args: + batch: data batch + results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results + are given no calculator will be added to the atoms objects. + wrap_pos: wrap positions back into the cell. + eps: Small number to prevent slightly negative coordinates from being wrapped. + + Returns: + list of Atoms + """ + n_systems = batch.natoms.shape[0] + natoms = batch.natoms.tolist() + numbers = torch.split(batch.atomic_numbers, natoms) + fixed = torch.split(batch.fixed.to(torch.bool), natoms) + if results is not None: + results = { + key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist() + if len(val) == len(batch) + else [v.cpu().detach().numpy() for v in torch.split(val, natoms)] + for key, val in results.items() + } + + positions = torch.split(batch.pos, natoms) + tags = torch.split(batch.tags, natoms) + cells = batch.cell + + atoms_objects = [] + for idx in range(n_systems): + pos = positions[idx].cpu().detach().numpy() + cell = cells[idx].cpu().detach().numpy() + + # TODO take pbc from data + # TODO: &&& ^^^ change this back !!! + # if wrap_pos: + # pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps) + + atoms = Atoms( + numbers=numbers[idx].tolist(), + cell=cell, + positions=pos, + tags=tags[idx].tolist(), + constraint=FixAtoms(mask=fixed[idx].tolist()), + pbc=[True, True, True], + ) + + if results is not None: + calc = SinglePointCalculator( + atoms=atoms, **{key: val[idx] for key, val in results.items()} + ) + atoms.set_calculator(calc) + + atoms_objects.append(atoms) + + return atoms_objects + diff --git a/mace-bench/src/batchopt/relaxation/optimizable.py b/mace-bench/src/batchopt/relaxation/optimizable.py index 31c8469e9aef00ae1e838677cb5dc8117255e8d2..3dc9623af4ad3a1f3a7e169703af10ee4c4a97a1 100644 --- a/mace-bench/src/batchopt/relaxation/optimizable.py +++ b/mace-bench/src/batchopt/relaxation/optimizable.py @@ -1,791 +1,791 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. - -Modified from original Meta implementation. -""" - -from __future__ import annotations - -from functools import cached_property -from types import SimpleNamespace -from typing import TYPE_CHECKING, ClassVar, Any, Generator -import numpy as np -import torch -import logging -from ase.calculators.calculator import PropertyNotImplementedError -from ase.stress import voigt_6_to_full_3x3_stress -from torch_scatter import scatter - -from batchopt.relaxation.ase_utils import batch_to_atoms - - -# Define dummy classes for when imports fail -class _DummyCalculator: - pass - - -try: - from mace.calculators import MACECalculator -except ImportError: - logging.warning("Unable to import MACECalculator.") - MACECalculator = _DummyCalculator - -try: - from chgnet.model.dynamics import CHGNetCalculator -except ImportError: - logging.warning("Unable to import CHGNetCalculator.") - CHGNetCalculator = _DummyCalculator - -try: - from sevenn.calculator import ( - SevenNetCalculator, - SevenNetD3Calculator, - D3Calculator, - ) -except ImportError: - logging.warning("Unable to import SevenNetCalculator.") - SevenNetCalculator = _DummyCalculator - SevenNetD3Calculator = _DummyCalculator - D3Calculator = _DummyCalculator - -try: - from fairchem.core import pretrained_mlip, FAIRChemCalculator -except ImportError: - logging.warning("Unable to import FAIRChemCalculator.") - FAIRChemCalculator = _DummyCalculator - - -# this can be removed after pinning ASE dependency >= 3.23 -try: - from ase.optimize.optimize import Optimizable -except ImportError: - - class Optimizable: - pass - - -if TYPE_CHECKING: - from collections.abc import Sequence - - from ase import Atoms - from numpy.typing import NDArray - from torch_geometric.data import Batch - - -ALL_CHANGES: set[str] = { - "pos", - "atomic_numbers", - "cell", - "pbc", -} - - -# @torch.compile -def compare_batches( - batch1: Batch | None, - batch2: Batch, - tol: float = 1e-6, - excluded_properties: set[str] | None = None, -) -> list[str]: - """Compare properties between two batches - - Args: - batch1: atoms batch - batch2: atoms batch - tol: tolerance used to compare equility of floating point properties - excluded_properties: list of properties to exclude from comparison - - Returns: - list of system changes, property names that are differente between batch1 and batch2 - """ - system_changes = [] - - if batch1 is None: - system_changes = ALL_CHANGES - else: - properties_to_check = set(ALL_CHANGES) - if excluded_properties: - properties_to_check -= set(excluded_properties) - - # Check properties that aren't - for prop in ALL_CHANGES: - if prop in properties_to_check: - properties_to_check.remove(prop) - if not torch.allclose( - getattr(batch1, prop), getattr(batch2, prop), atol=tol - ): - system_changes.append(prop) - - return system_changes - - -class OptimizableBatch(Optimizable): - """A Batch version of ase Optimizable Atoms - - This class can be used with ML relaxations in fairchem.core.relaxations.ml_relaxation - or in ase relaxations classes, i.e. ase.optimize.lbfgs - """ - - ignored_changes: ClassVar[set[str]] = set() - - def __init__( - self, - batch: Batch, - trainer: Any, # Any calculator type (MACECalculator | CHGNetCalculator | SevenNetCalculator | FAIRChemCalculator) - transform: torch.nn.Module | None = None, - mask_converged: bool = True, - numpy: bool = False, - masked_eps: float = 1e-8, - compute_stress: bool = False, - use_fast_predict: bool = True, - dtype: torch.dtype = torch.float64, - ): - """Initialize Optimizable Batch - - Args: - batch: A batch of atoms graph data - model: An instance of a BaseTrainer derived class - transform: graph transform - mask_converged: if true will mask systems in batch that are already converged - numpy: whether to cast results to numpy arrays - masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero - from zero differences in masked positions at future steps, we add a small number to prevent this. - compute_stress: whether to compute stress during prediction - use_fast_predict: use fast prediction method when available - dtype: data type for tensor operations (torch.float32 or torch.float64) - """ - self.batch = batch.to(trainer.device) - self.trainer = trainer - self.transform = transform - self.numpy = numpy - self.mask_converged = mask_converged - self._cached_batch = None - self._update_mask = None - self.torch_results = {} - self.results = {} - self._eps = masked_eps - self.dtype = dtype - - self.otf_graph = True # trainer._unwrapped_model.otf_graph - if not self.otf_graph and "edge_index" not in self.batch: - self.update_graph() - - self.batch.pos = self.batch.pos.to(dtype=self.dtype) - self.batch.cell = self.batch.cell.to(dtype=self.dtype) - self.compute_stress = compute_stress - self.use_fast_predict = use_fast_predict - - # Determine calculator type once during initialization for efficiency - self._calculator_type = self._determine_calculator_type() - logging.info( - f"OptimizableBatch initialized with calculator type: {self._calculator_type}" - ) - - def _determine_calculator_type(self) -> str: - """Determine the type of calculator to avoid repeated isinstance checks.""" - # Check against actual imported classes, not dummy classes - trainer_class_name = type(self.trainer).__name__ - trainer_module = type(self.trainer).__module__ - - if ( - "mace" in trainer_module.lower() - or trainer_class_name == "MACECalculator" - ): - return "mace" - elif ( - "chgnet" in trainer_module.lower() - or trainer_class_name == "CHGNetCalculator" - ): - return "chgnet" - elif "sevenn" in trainer_module.lower() or trainer_class_name in [ - "SevenNetCalculator", - "SevenNetD3Calculator", - "D3Calculator", - ]: - return "sevennet" - elif ( - "fairchem" in trainer_module.lower() - or trainer_class_name == "FAIRChemCalculator" - ): - return "fairchem" - else: - return "default" - - @property - def device(self): - return self.trainer.device - - @property - def batch_indices(self): - """Get the batch indices specifying which position/force corresponds to which batch.""" - return self.batch.batch - - @property - def converged_mask(self): - if self._update_mask is not None: - return torch.logical_not(self._update_mask) - return None - - @property - def update_mask(self): - if self._update_mask is None: - return torch.ones(len(self.batch), dtype=bool) - return self._update_mask - - @property - def converge_indices_list(self): - return torch.where(~self.update_mask)[0].tolist() - - @property - def elem_per_group(self): - # This return value actually represents the number of elements - # in a group within a batch. Each group corresponds to batch_indices. - # It will count the number of CELL elements in each group. - return torch.bincount(self.batch_indices) - - @property - def batch_size(self): - return len(torch.unique(self.batch_indices)) - - def check_state(self, batch: Batch, tol: float = 1e-12) -> bool: - """Check for any system changes since last calculation.""" - return compare_batches( - self._cached_batch, - batch, - tol=tol, - excluded_properties=set(self.ignored_changes), - ) - - def _predict(self) -> None: - """Run prediction if batch has any changes.""" - # TODO: Currently, the batch inference interfaces of various models are not unified and are poorly implemented. - system_changes = self.check_state(self.batch) - if len(system_changes) > 0: - if self._calculator_type == "mace": - # FIXME: &&& - # for key, val in self.batch.to_dict().items(): - # print(f'&&& key: {key}, val: {val}') - # self.torch_results = self.trainer.predict_debug(atoms_list, self.batch, compute_stress=self.compute_stress) - # self.torch_results = self.trainer.predict(self.config_batch) - if self.use_fast_predict: - self.torch_results = self.trainer.fast_predict( - self.batch, compute_stress=self.compute_stress - ) - self.batch.pos = self.batch.pos.to(self.dtype) - self.batch.cell = self.batch.cell.to(self.dtype) - else: - atoms_list = batch_to_atoms( - self.batch, results=None, wrap_pos=False, eps=1e-17 - ) - self.torch_results = self.trainer.predict( - atoms_list, compute_stress=self.compute_stress - ) - elif self._calculator_type == "fairchem": - # TODO: FAIRChemCalculator does not support batch prediction yet - atoms_list = batch_to_atoms( - self.batch, results=None, wrap_pos=False, eps=1e-17 - ) - self.torch_results = self.trainer.predict(atoms_list=atoms_list) - elif self._calculator_type == "chgnet": - atoms_list = batch_to_atoms( - self.batch, results=None, wrap_pos=False, eps=1e-17 - ) - model_prediction = self.trainer.predict( - atoms_list=atoms_list, task="efs" - ) - results = { - "energy": torch.tensor( - [pred["e"].item() for pred in model_prediction], - device=self.device, - dtype=self.dtype, - ), - "forces": torch.vstack( - [ - torch.from_numpy(pred["f"]).to( - device=self.device, dtype=self.dtype - ) - for pred in model_prediction - ] - ), - "stress": torch.vstack( - [ - torch.from_numpy(pred["s"]).to( - device=self.device, dtype=self.dtype - ) - for pred in model_prediction - ] - ).view(-1, 3, 3), - } - self.torch_results = results - elif self._calculator_type == "sevennet": - atoms_list = batch_to_atoms( - self.batch, results=None, wrap_pos=False, eps=1e-17 - ) - self.torch_results = self.trainer.predict(atoms_list=atoms_list) - else: # default case - self.torch_results = self.trainer.predict( - self.batch, per_image=False, disable_tqdm=True - ) - # save only subset of props in simple namespace instead of cloning the whole batch to save memory - changes = ALL_CHANGES - set(self.ignored_changes) - self._cached_batch = SimpleNamespace( - **{prop: self.batch[prop].clone() for prop in changes} - ) - - def get_property( - self, name, no_numpy: bool = False - ) -> torch.Tensor | NDArray: - """Get a predicted property by name.""" - self._predict() - if self.numpy: - self.results = { - key: pred.item() if pred.numel() == 1 else pred.cpu().numpy() - for key, pred in self.torch_results.items() - } - else: - self.results = self.torch_results - - if name not in self.results: - raise PropertyNotImplementedError( - f"{name} not present in this calculation" - ) - - return ( - self.results[name] - if no_numpy is False - else self.torch_results[name] - ) - - def get_positions(self) -> torch.Tensor | NDArray: - """Get the batch positions""" - pos = self.batch.pos.clone() - if self.numpy: - if self.mask_converged: - pos[~self.update_mask[self.batch.batch]] = self._eps - pos = pos.cpu().numpy() - - return pos - - def set_positions(self, positions: torch.Tensor | NDArray) -> None: - """Set the atom positions in the batch.""" - if isinstance(positions, np.ndarray): - positions = torch.tensor( - positions, dtype=self.dtype, device=self.device - ) - else: - positions = positions.to(dtype=self.dtype, device=self.device) - - if self.mask_converged and self._update_mask is not None: - mask = self.update_mask[self.batch.batch] - self.batch.pos[mask] = positions[mask] - else: - self.batch.pos = positions - - if not self.otf_graph: - self.update_graph() - - def get_forces( - self, apply_constraint: bool = False, no_numpy: bool = False - ) -> torch.Tensor | NDArray: - """Get predicted batch forces.""" - forces = self.get_property("forces", no_numpy=no_numpy) - if apply_constraint: - fixed_idx = torch.where(self.batch.fixed == 1)[0] - if isinstance(forces, np.ndarray): - fixed_idx = fixed_idx.tolist() - forces[fixed_idx] = 0.0 - return forces.view(-1, 3) - - def get_potential_energy(self, **kwargs) -> torch.Tensor | NDArray: - """Get predicted energy as the sum of all batch energies.""" - # ASE 3.22.1 expects a check for force_consistent calculations - if kwargs.get("force_consistent", False) is True: - raise PropertyNotImplementedError( - "force_consistent calculations are not implemented" - ) - if ( - len(self.batch) == 1 - ): # unfortunately batch size 1 returns a float, not a tensor - return self.get_property("energy") - return self.get_property("energy").sum() - - def get_potential_energies(self) -> torch.Tensor | NDArray: - """Get the predicted energy for each system in batch.""" - return self.get_property("energy") - - def get_cells(self) -> torch.Tensor: - """Get batch crystallographic cells.""" - return self.batch.cell - - def set_cells( - self, cells: torch.Tensor | NDArray, scale_atoms=False - ) -> None: - """Set batch cells.""" - assert self.batch.cell.shape == cells.shape, "Cell shape mismatch" - if isinstance(cells, np.ndarray): - cells = torch.tensor(cells, dtype=self.dtype, device=self.device) - cells = cells.to(dtype=self.dtype, device=self.device) - if scale_atoms: - from ase.geometry.cell import complete_cell - - # M = torch.linalg.solve( - # self.batch.cell.view(-1, 3, 3), - # cells.view(-1, 3, 3), - # ) - # TODO: need to implement a sparse version. - # tmp_pos = torch.matmul(self.batch.pos, M.reshape(-1,3)) - for i in range(self.batch_size): - if not self.update_mask[i]: - continue - M = np.linalg.solve( - complete_cell(self.batch.cell[i].cpu().detach().numpy()), - complete_cell(cells[i].cpu().detach().numpy()), - ) - pos_update_mask = self.batch.batch == i - self.batch.pos[pos_update_mask] = torch.matmul( - self.batch.pos[pos_update_mask], - torch.from_numpy(M).to(self.device).reshape(-1, 3), - ) - self.batch.cell[self.update_mask] = cells[self.update_mask] - - def get_volumes(self) -> torch.Tensor: - """Get a tensor of volumes for each cell in batch""" - cells = self.get_cells() - return torch.linalg.det(cells) - - def iterimages(self) -> Generator[Batch, None, None]: - # XXX document purpose of iterimages - this is just needed to work with ASE optimizers - yield self.batch - - def get_max_forces( - self, forces: torch.Tensor | None = None, apply_constraint: bool = False - ) -> torch.Tensor: - """Get the maximum forces per structure in batch""" - if forces is None: - forces = self.get_forces( - apply_constraint=apply_constraint, no_numpy=True - ) - return scatter( - (forces**2).sum(axis=1).sqrt(), self.batch_indices, reduce="max" - ) - - def converged( - self, - forces: torch.Tensor | NDArray | None, - fmax: float, - max_forces: torch.Tensor | None = None, - f_upper_limit: float = 1e20, - ) -> bool: - """Check if norm of all predicted forces are below fmax""" - if forces is not None: - if isinstance(forces, np.ndarray): - forces = torch.tensor( - forces, device=self.device, dtype=self.dtype - ) - max_forces = self.get_max_forces(forces) - elif max_forces is None: - max_forces = self.get_max_forces() - - # Update mask is True for forces that are greater than fmax AND less than f_upper_limit - update_mask = torch.logical_and( - max_forces.ge(fmax), max_forces.le(f_upper_limit) - ) - # update cached mask - if self.mask_converged: - if self._update_mask is None: - self._update_mask = update_mask - else: - # some models can have random noise in their predictions, so the mask is updated by - # keeping all previously converged structures masked even if new force predictions - # push it slightly above threshold - self._update_mask = torch.logical_and( - self._update_mask, update_mask - ) - update_mask = self._update_mask - - return not torch.any(update_mask).item() - - def get_atoms_list(self) -> list[Atoms]: - """Get ase Atoms objects corresponding to the batch""" - self._predict() # in case no predictions have been run - return batch_to_atoms(self.batch, results=self.torch_results) - - def update_graph(self): - """Update the graph if model does not use otf_graph.""" - graph = self.trainer._unwrapped_model.generate_graph(self.batch) - self.batch.edge_index = graph.edge_index - self.batch.cell_offsets = graph.cell_offsets - self.batch.neighbors = graph.neighbors - if self.transform is not None: - self.batch = self.transform(self.batch) - - def __len__(self) -> int: - # TODO: this might be changed in ASE to be 3 * len(self.atoms) - return len(self.batch.pos) - - -class OptimizableUnitCellBatch(OptimizableBatch): - """Modify the supercell and the atom positions in relaxations. - - Based on ase UnitCellFilter to work on data batches - """ - - def __init__( - self, - batch: Batch, - trainer: Any, # Any calculator type (MACECalculator | CHGNetCalculator | SevenNetD3Calculator | FAIRChemCalculator) - transform: torch.nn.Module | None = None, - numpy: bool = False, - mask_converged: bool = True, - mask: Sequence[bool] | None = None, - cell_factor: float | torch.Tensor | None = None, - hydrostatic_strain: bool = False, - constant_volume: bool = False, - scalar_pressure: float = 0.0, - masked_eps: float = 1e-8, - use_fast_predict: bool = True, - dtype: torch.dtype = torch.float64, - ): - """Create a filter that returns the forces and unit cell stresses together, for simultaneous optimization. - - For full details see: - E. B. Tadmor, G. S. Smith, N. Bernstein, and E. Kaxiras, - Phys. Rev. B 59, 235 (1999) - - Args: - batch: A batch of atoms graph data - model: An instance of a BaseTrainer derived class - transform: graph transform - numpy: whether to cast results to numpy arrays - mask_converged: if true will mask systems in batch that are already converged - mask: a boolean mask specifying which strain components are allowed to relax - cell_factor: - Factor by which deformation gradient is multiplied to put - it on the same scale as the positions when assembling - the combined position/cell vector. The stress contribution to - the forces is scaled down by the same factor. This can be thought - of as a very simple preconditioner. Default is number of atoms - which gives approximately the correct scaling. - hydrostatic_strain: - Constrain the cell by only allowing hydrostatic deformation. - The virial tensor is replaced by np.diag([np.trace(virial)]*3). - constant_volume: - Project out the diagonal elements of the virial tensor to allow - relaxations at constant volume, e.g. for mapping out an - energy-volume curve. Note: this only approximately conserves - the volume and breaks energy/force consistency so can only be - used with optimizers that do require a line minimisation - (e.g. FIRE). - scalar_pressure: - Applied pressure to use for enthalpy pV term. As above, this - breaks energy/force consistency. - masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero - from zero differences in masked positions at future steps, we add a small number to prevent this. - dtype: data type for tensor operations (torch.float32 or torch.float64) - """ - super().__init__( - batch=batch, - trainer=trainer, - transform=transform, - numpy=numpy, - mask_converged=mask_converged, - masked_eps=masked_eps, - compute_stress=True, - use_fast_predict=use_fast_predict, - dtype=dtype, - ) - - self.orig_cells = self.get_cells().clone() - self.stress = None - - if mask is None: - # mask = torch.eye(3, device=self.device) - mask = torch.ones(6, device=self.device) - - # TODO make sure mask is on GPU - if mask.shape == (6,): - self.mask = torch.tensor( - voigt_6_to_full_3x3_stress(mask.detach().cpu()), - device=self.device, - ) - elif mask.shape == (3, 3): - self.mask = mask - else: - raise ValueError("shape of mask should be (3,3) or (6,)") - - if isinstance(cell_factor, float): - cell_factor = cell_factor * torch.ones( - (3 * len(batch), 1), requires_grad=False - ) - if cell_factor is None: - cell_factor = self.batch.natoms.repeat_interleave(3).unsqueeze( - dim=1 - ) - - self.hydrostatic_strain = hydrostatic_strain - self.constant_volume = constant_volume - self.pressure = scalar_pressure * torch.eye(3, device=self.device) - self.cell_factor = cell_factor - self.stress = None - self._batch_trace = torch.vmap(torch.trace) - self._batch_diag = torch.vmap( - lambda x: x * torch.eye(3, device=x.device) - ) - - @cached_property - def batch_indices(self): - """Get the batch indices specifying which position/force corresponds to which batch. - - We augment this to specify the batch indices for augmented positions and forces. - """ - augmented_batch = torch.repeat_interleave( - torch.arange( - len(self.batch), - dtype=self.batch.batch.dtype, - device=self.device, - ), - 3, - ) - return torch.cat([self.batch.batch, augmented_batch]) - - def deform_grad(self): - """Get the cell deformation matrix""" - return torch.transpose( - torch.linalg.solve(self.orig_cells, self.get_cells()), 1, 2 - ) - - def get_positions(self): - """Get positions and cell deformation gradient.""" - cur_deform_grad = self.deform_grad() - natoms = self.batch.num_nodes - pos = torch.zeros( - (natoms + 3 * len(self.get_cells()), 3), - dtype=self.batch.pos.dtype, - device=self.device, - ) - - # Augmented positions are the self.atoms.positions but without the applied deformation gradient - pos[:natoms] = torch.linalg.solve( - cur_deform_grad[self.batch.batch, :, :], - self.batch.pos.view(-1, 3, 1), - ).view(-1, 3) - # cell DOFs are the deformation gradient times a scaling factor - pos[natoms:] = self.cell_factor * cur_deform_grad.view(-1, 3) - return pos.cpu().numpy() if self.numpy else pos - - def set_positions(self, positions: torch.Tensor | NDArray) -> None: - """Set positions and cell. - - positions has shape (natoms + ncells * 3, 3). - the first natoms rows are the positions of the atoms, the last nsystems * three rows are the deformation tensor - for each cell. - """ - if isinstance(positions, np.ndarray): - positions = torch.tensor( - positions, dtype=self.dtype, device=self.device - ) - else: - positions = positions.to(dtype=self.dtype, device=self.device) - - natoms = self.batch.num_nodes - new_atom_positions = positions[:natoms] - new_deform_grad = (positions[natoms:] / self.cell_factor).view(-1, 3, 3) - - # TODO check that in fact symmetry is preserved setting cells and positions - # Set the new cell from the original cell and the new deformation gradient. Both current and final structures - # should preserve symmetry. - new_cells = torch.bmm( - self.orig_cells, torch.transpose(new_deform_grad, 1, 2) - ) - self.set_cells(new_cells) - - # Set the positions from the ones passed in (which are without the deformation gradient applied) and the new - # deformation gradient. This should also preserve symmetry - new_atom_positions = torch.bmm( - new_atom_positions.view(-1, 1, 3), - torch.transpose( - new_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), 1, 2 - ), - ) - super().set_positions(new_atom_positions.view(-1, 3)) - - def get_potential_energy(self, **kwargs): - """ - returns potential energy including enthalpy PV term. - """ - atoms_energy = super().get_potential_energy(**kwargs) - return atoms_energy + self.pressure[0, 0] * self.get_volumes().sum() - - def get_forces( - self, apply_constraint: bool = False, no_numpy: bool = False - ) -> torch.Tensor | NDArray: - """Get forces and unit cell stress.""" - stress = self.get_property("stress", no_numpy=True).view(-1, 3, 3) - atom_forces = self.get_property("forces", no_numpy=True) - - if apply_constraint: - fixed_idx = torch.where(self.batch.fixed == 1)[0] - atom_forces[fixed_idx] = 0.0 - - volumes = self.get_volumes().view(-1, 1, 1) - # virial = -volumes * stress + self.pressure.view(-1, 3, 3) - virial = -volumes * (stress + self.pressure.view(-1, 3, 3)) - # print(f'&&& virial0: {virial}') - cur_deform_grad = self.deform_grad() - atom_forces = torch.bmm( - atom_forces.view(-1, 1, 3), - cur_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), - ) - virial = torch.linalg.solve( - cur_deform_grad, torch.transpose(virial, dim0=1, dim1=2) - ) - virial = torch.transpose(virial, dim0=1, dim1=2) - - # print(f'&&& virial1: {virial}') - - # TODO this does not work yet! maybe _batch_trace gives an issue - if self.hydrostatic_strain: - virial = self._batch_diag(self._batch_trace(virial) / 3.0) - - # Zero out components corresponding to fixed lattice elements - if (self.mask != 1.0).any(): - virial *= self.mask.view(-1, 3, 3) - - if self.constant_volume: - virial[:, range(3), range(3)] -= ( - self._batch_trace(virial).view(3, -1) / 3.0 - ) - - natoms = self.batch.num_nodes - augmented_forces = torch.zeros( - (natoms + 3 * len(self.get_cells()), 3), - device=self.device, - dtype=atom_forces.dtype, - ) - # print(f'&&& atom_forces: {atom_forces}') - # print(f'&&& virial2: {virial}') - augmented_forces[:natoms] = atom_forces.view(-1, 3) - augmented_forces[natoms:] = virial.view(-1, 3) / self.cell_factor - - self.stress = -virial.view(-1, 9) / volumes.view(-1, 1) - - if self.numpy and not no_numpy: - augmented_forces = augmented_forces.cpu().numpy() - - # print(f'&&& augmented_forces: {augmented_forces}') - - return augmented_forces - - def __len__(self): - return len(self.batch.pos) + 3 * len(self.batch) - - def get_potential_energies(self) -> torch.Tensor: - """Get the predicted energy for each system in batch.""" - return ( - self.get_property("energy").view(-1) - + self.pressure[0, 0] * self.get_volumes() - ) +""" +Copyright (c) Meta, Inc. and its affiliates. +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Modified from original Meta implementation. +""" + +from __future__ import annotations + +from functools import cached_property +from types import SimpleNamespace +from typing import TYPE_CHECKING, ClassVar, Any, Generator +import numpy as np +import torch +import logging +from ase.calculators.calculator import PropertyNotImplementedError +from ase.stress import voigt_6_to_full_3x3_stress +from torch_scatter import scatter + +from batchopt.relaxation.ase_utils import batch_to_atoms + + +# Define dummy classes for when imports fail +class _DummyCalculator: + pass + + +try: + from mace.calculators import MACECalculator +except ImportError: + logging.warning("Unable to import MACECalculator.") + MACECalculator = _DummyCalculator + +try: + from chgnet.model.dynamics import CHGNetCalculator +except ImportError: + logging.warning("Unable to import CHGNetCalculator.") + CHGNetCalculator = _DummyCalculator + +try: + from sevenn.calculator import ( + SevenNetCalculator, + SevenNetD3Calculator, + D3Calculator, + ) +except ImportError: + logging.warning("Unable to import SevenNetCalculator.") + SevenNetCalculator = _DummyCalculator + SevenNetD3Calculator = _DummyCalculator + D3Calculator = _DummyCalculator + +try: + from fairchem.core import pretrained_mlip, FAIRChemCalculator +except ImportError: + logging.warning("Unable to import FAIRChemCalculator.") + FAIRChemCalculator = _DummyCalculator + + +# this can be removed after pinning ASE dependency >= 3.23 +try: + from ase.optimize.optimize import Optimizable +except ImportError: + + class Optimizable: + pass + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ase import Atoms + from numpy.typing import NDArray + from torch_geometric.data import Batch + + +ALL_CHANGES: set[str] = { + "pos", + "atomic_numbers", + "cell", + "pbc", +} + + +# @torch.compile +def compare_batches( + batch1: Batch | None, + batch2: Batch, + tol: float = 1e-6, + excluded_properties: set[str] | None = None, +) -> list[str]: + """Compare properties between two batches + + Args: + batch1: atoms batch + batch2: atoms batch + tol: tolerance used to compare equility of floating point properties + excluded_properties: list of properties to exclude from comparison + + Returns: + list of system changes, property names that are differente between batch1 and batch2 + """ + system_changes = [] + + if batch1 is None: + system_changes = ALL_CHANGES + else: + properties_to_check = set(ALL_CHANGES) + if excluded_properties: + properties_to_check -= set(excluded_properties) + + # Check properties that aren't + for prop in ALL_CHANGES: + if prop in properties_to_check: + properties_to_check.remove(prop) + if not torch.allclose( + getattr(batch1, prop), getattr(batch2, prop), atol=tol + ): + system_changes.append(prop) + + return system_changes + + +class OptimizableBatch(Optimizable): + """A Batch version of ase Optimizable Atoms + + This class can be used with ML relaxations in fairchem.core.relaxations.ml_relaxation + or in ase relaxations classes, i.e. ase.optimize.lbfgs + """ + + ignored_changes: ClassVar[set[str]] = set() + + def __init__( + self, + batch: Batch, + trainer: Any, # Any calculator type (MACECalculator | CHGNetCalculator | SevenNetCalculator | FAIRChemCalculator) + transform: torch.nn.Module | None = None, + mask_converged: bool = True, + numpy: bool = False, + masked_eps: float = 1e-8, + compute_stress: bool = False, + use_fast_predict: bool = True, + dtype: torch.dtype = torch.float64, + ): + """Initialize Optimizable Batch + + Args: + batch: A batch of atoms graph data + model: An instance of a BaseTrainer derived class + transform: graph transform + mask_converged: if true will mask systems in batch that are already converged + numpy: whether to cast results to numpy arrays + masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero + from zero differences in masked positions at future steps, we add a small number to prevent this. + compute_stress: whether to compute stress during prediction + use_fast_predict: use fast prediction method when available + dtype: data type for tensor operations (torch.float32 or torch.float64) + """ + self.batch = batch.to(trainer.device) + self.trainer = trainer + self.transform = transform + self.numpy = numpy + self.mask_converged = mask_converged + self._cached_batch = None + self._update_mask = None + self.torch_results = {} + self.results = {} + self._eps = masked_eps + self.dtype = dtype + + self.otf_graph = True # trainer._unwrapped_model.otf_graph + if not self.otf_graph and "edge_index" not in self.batch: + self.update_graph() + + self.batch.pos = self.batch.pos.to(dtype=self.dtype) + self.batch.cell = self.batch.cell.to(dtype=self.dtype) + self.compute_stress = compute_stress + self.use_fast_predict = use_fast_predict + + # Determine calculator type once during initialization for efficiency + self._calculator_type = self._determine_calculator_type() + logging.info( + f"OptimizableBatch initialized with calculator type: {self._calculator_type}" + ) + + def _determine_calculator_type(self) -> str: + """Determine the type of calculator to avoid repeated isinstance checks.""" + # Check against actual imported classes, not dummy classes + trainer_class_name = type(self.trainer).__name__ + trainer_module = type(self.trainer).__module__ + + if ( + "mace" in trainer_module.lower() + or trainer_class_name == "MACECalculator" + ): + return "mace" + elif ( + "chgnet" in trainer_module.lower() + or trainer_class_name == "CHGNetCalculator" + ): + return "chgnet" + elif "sevenn" in trainer_module.lower() or trainer_class_name in [ + "SevenNetCalculator", + "SevenNetD3Calculator", + "D3Calculator", + ]: + return "sevennet" + elif ( + "fairchem" in trainer_module.lower() + or trainer_class_name == "FAIRChemCalculator" + ): + return "fairchem" + else: + return "default" + + @property + def device(self): + return self.trainer.device + + @property + def batch_indices(self): + """Get the batch indices specifying which position/force corresponds to which batch.""" + return self.batch.batch + + @property + def converged_mask(self): + if self._update_mask is not None: + return torch.logical_not(self._update_mask) + return None + + @property + def update_mask(self): + if self._update_mask is None: + return torch.ones(len(self.batch), dtype=bool) + return self._update_mask + + @property + def converge_indices_list(self): + return torch.where(~self.update_mask)[0].tolist() + + @property + def elem_per_group(self): + # This return value actually represents the number of elements + # in a group within a batch. Each group corresponds to batch_indices. + # It will count the number of CELL elements in each group. + return torch.bincount(self.batch_indices) + + @property + def batch_size(self): + return len(torch.unique(self.batch_indices)) + + def check_state(self, batch: Batch, tol: float = 1e-12) -> bool: + """Check for any system changes since last calculation.""" + return compare_batches( + self._cached_batch, + batch, + tol=tol, + excluded_properties=set(self.ignored_changes), + ) + + def _predict(self) -> None: + """Run prediction if batch has any changes.""" + # TODO: Currently, the batch inference interfaces of various models are not unified and are poorly implemented. + system_changes = self.check_state(self.batch) + if len(system_changes) > 0: + if self._calculator_type == "mace": + # FIXME: &&& + # for key, val in self.batch.to_dict().items(): + # print(f'&&& key: {key}, val: {val}') + # self.torch_results = self.trainer.predict_debug(atoms_list, self.batch, compute_stress=self.compute_stress) + # self.torch_results = self.trainer.predict(self.config_batch) + if self.use_fast_predict: + self.torch_results = self.trainer.fast_predict( + self.batch, compute_stress=self.compute_stress + ) + self.batch.pos = self.batch.pos.to(self.dtype) + self.batch.cell = self.batch.cell.to(self.dtype) + else: + atoms_list = batch_to_atoms( + self.batch, results=None, wrap_pos=False, eps=1e-17 + ) + self.torch_results = self.trainer.predict( + atoms_list, compute_stress=self.compute_stress + ) + elif self._calculator_type == "fairchem": + # TODO: FAIRChemCalculator does not support batch prediction yet + atoms_list = batch_to_atoms( + self.batch, results=None, wrap_pos=False, eps=1e-17 + ) + self.torch_results = self.trainer.predict(atoms_list=atoms_list) + elif self._calculator_type == "chgnet": + atoms_list = batch_to_atoms( + self.batch, results=None, wrap_pos=False, eps=1e-17 + ) + model_prediction = self.trainer.predict( + atoms_list=atoms_list, task="efs" + ) + results = { + "energy": torch.tensor( + [pred["e"].item() for pred in model_prediction], + device=self.device, + dtype=self.dtype, + ), + "forces": torch.vstack( + [ + torch.from_numpy(pred["f"]).to( + device=self.device, dtype=self.dtype + ) + for pred in model_prediction + ] + ), + "stress": torch.vstack( + [ + torch.from_numpy(pred["s"]).to( + device=self.device, dtype=self.dtype + ) + for pred in model_prediction + ] + ).view(-1, 3, 3), + } + self.torch_results = results + elif self._calculator_type == "sevennet": + atoms_list = batch_to_atoms( + self.batch, results=None, wrap_pos=False, eps=1e-17 + ) + self.torch_results = self.trainer.predict(atoms_list=atoms_list) + else: # default case + self.torch_results = self.trainer.predict( + self.batch, per_image=False, disable_tqdm=True + ) + # save only subset of props in simple namespace instead of cloning the whole batch to save memory + changes = ALL_CHANGES - set(self.ignored_changes) + self._cached_batch = SimpleNamespace( + **{prop: self.batch[prop].clone() for prop in changes} + ) + + def get_property( + self, name, no_numpy: bool = False + ) -> torch.Tensor | NDArray: + """Get a predicted property by name.""" + self._predict() + if self.numpy: + self.results = { + key: pred.item() if pred.numel() == 1 else pred.cpu().numpy() + for key, pred in self.torch_results.items() + } + else: + self.results = self.torch_results + + if name not in self.results: + raise PropertyNotImplementedError( + f"{name} not present in this calculation" + ) + + return ( + self.results[name] + if no_numpy is False + else self.torch_results[name] + ) + + def get_positions(self) -> torch.Tensor | NDArray: + """Get the batch positions""" + pos = self.batch.pos.clone() + if self.numpy: + if self.mask_converged: + pos[~self.update_mask[self.batch.batch]] = self._eps + pos = pos.cpu().numpy() + + return pos + + def set_positions(self, positions: torch.Tensor | NDArray) -> None: + """Set the atom positions in the batch.""" + if isinstance(positions, np.ndarray): + positions = torch.tensor( + positions, dtype=self.dtype, device=self.device + ) + else: + positions = positions.to(dtype=self.dtype, device=self.device) + + if self.mask_converged and self._update_mask is not None: + mask = self.update_mask[self.batch.batch] + self.batch.pos[mask] = positions[mask] + else: + self.batch.pos = positions + + if not self.otf_graph: + self.update_graph() + + def get_forces( + self, apply_constraint: bool = False, no_numpy: bool = False + ) -> torch.Tensor | NDArray: + """Get predicted batch forces.""" + forces = self.get_property("forces", no_numpy=no_numpy) + if apply_constraint: + fixed_idx = torch.where(self.batch.fixed == 1)[0] + if isinstance(forces, np.ndarray): + fixed_idx = fixed_idx.tolist() + forces[fixed_idx] = 0.0 + return forces.view(-1, 3) + + def get_potential_energy(self, **kwargs) -> torch.Tensor | NDArray: + """Get predicted energy as the sum of all batch energies.""" + # ASE 3.22.1 expects a check for force_consistent calculations + if kwargs.get("force_consistent", False) is True: + raise PropertyNotImplementedError( + "force_consistent calculations are not implemented" + ) + if ( + len(self.batch) == 1 + ): # unfortunately batch size 1 returns a float, not a tensor + return self.get_property("energy") + return self.get_property("energy").sum() + + def get_potential_energies(self) -> torch.Tensor | NDArray: + """Get the predicted energy for each system in batch.""" + return self.get_property("energy") + + def get_cells(self) -> torch.Tensor: + """Get batch crystallographic cells.""" + return self.batch.cell + + def set_cells( + self, cells: torch.Tensor | NDArray, scale_atoms=False + ) -> None: + """Set batch cells.""" + assert self.batch.cell.shape == cells.shape, "Cell shape mismatch" + if isinstance(cells, np.ndarray): + cells = torch.tensor(cells, dtype=self.dtype, device=self.device) + cells = cells.to(dtype=self.dtype, device=self.device) + if scale_atoms: + from ase.geometry.cell import complete_cell + + # M = torch.linalg.solve( + # self.batch.cell.view(-1, 3, 3), + # cells.view(-1, 3, 3), + # ) + # TODO: need to implement a sparse version. + # tmp_pos = torch.matmul(self.batch.pos, M.reshape(-1,3)) + for i in range(self.batch_size): + if not self.update_mask[i]: + continue + M = np.linalg.solve( + complete_cell(self.batch.cell[i].cpu().detach().numpy()), + complete_cell(cells[i].cpu().detach().numpy()), + ) + pos_update_mask = self.batch.batch == i + self.batch.pos[pos_update_mask] = torch.matmul( + self.batch.pos[pos_update_mask], + torch.from_numpy(M).to(self.device).reshape(-1, 3), + ) + self.batch.cell[self.update_mask] = cells[self.update_mask] + + def get_volumes(self) -> torch.Tensor: + """Get a tensor of volumes for each cell in batch""" + cells = self.get_cells() + return torch.linalg.det(cells) + + def iterimages(self) -> Generator[Batch, None, None]: + # XXX document purpose of iterimages - this is just needed to work with ASE optimizers + yield self.batch + + def get_max_forces( + self, forces: torch.Tensor | None = None, apply_constraint: bool = False + ) -> torch.Tensor: + """Get the maximum forces per structure in batch""" + if forces is None: + forces = self.get_forces( + apply_constraint=apply_constraint, no_numpy=True + ) + return scatter( + (forces**2).sum(axis=1).sqrt(), self.batch_indices, reduce="max" + ) + + def converged( + self, + forces: torch.Tensor | NDArray | None, + fmax: float, + max_forces: torch.Tensor | None = None, + f_upper_limit: float = 1e20, + ) -> bool: + """Check if norm of all predicted forces are below fmax""" + if forces is not None: + if isinstance(forces, np.ndarray): + forces = torch.tensor( + forces, device=self.device, dtype=self.dtype + ) + max_forces = self.get_max_forces(forces) + elif max_forces is None: + max_forces = self.get_max_forces() + + # Update mask is True for forces that are greater than fmax AND less than f_upper_limit + update_mask = torch.logical_and( + max_forces.ge(fmax), max_forces.le(f_upper_limit) + ) + # update cached mask + if self.mask_converged: + if self._update_mask is None: + self._update_mask = update_mask + else: + # some models can have random noise in their predictions, so the mask is updated by + # keeping all previously converged structures masked even if new force predictions + # push it slightly above threshold + self._update_mask = torch.logical_and( + self._update_mask, update_mask + ) + update_mask = self._update_mask + + return not torch.any(update_mask).item() + + def get_atoms_list(self) -> list[Atoms]: + """Get ase Atoms objects corresponding to the batch""" + self._predict() # in case no predictions have been run + return batch_to_atoms(self.batch, results=self.torch_results) + + def update_graph(self): + """Update the graph if model does not use otf_graph.""" + graph = self.trainer._unwrapped_model.generate_graph(self.batch) + self.batch.edge_index = graph.edge_index + self.batch.cell_offsets = graph.cell_offsets + self.batch.neighbors = graph.neighbors + if self.transform is not None: + self.batch = self.transform(self.batch) + + def __len__(self) -> int: + # TODO: this might be changed in ASE to be 3 * len(self.atoms) + return len(self.batch.pos) + + +class OptimizableUnitCellBatch(OptimizableBatch): + """Modify the supercell and the atom positions in relaxations. + + Based on ase UnitCellFilter to work on data batches + """ + + def __init__( + self, + batch: Batch, + trainer: Any, # Any calculator type (MACECalculator | CHGNetCalculator | SevenNetD3Calculator | FAIRChemCalculator) + transform: torch.nn.Module | None = None, + numpy: bool = False, + mask_converged: bool = True, + mask: Sequence[bool] | None = None, + cell_factor: float | torch.Tensor | None = None, + hydrostatic_strain: bool = False, + constant_volume: bool = False, + scalar_pressure: float = 0.0, + masked_eps: float = 1e-8, + use_fast_predict: bool = True, + dtype: torch.dtype = torch.float64, + ): + """Create a filter that returns the forces and unit cell stresses together, for simultaneous optimization. + + For full details see: + E. B. Tadmor, G. S. Smith, N. Bernstein, and E. Kaxiras, + Phys. Rev. B 59, 235 (1999) + + Args: + batch: A batch of atoms graph data + model: An instance of a BaseTrainer derived class + transform: graph transform + numpy: whether to cast results to numpy arrays + mask_converged: if true will mask systems in batch that are already converged + mask: a boolean mask specifying which strain components are allowed to relax + cell_factor: + Factor by which deformation gradient is multiplied to put + it on the same scale as the positions when assembling + the combined position/cell vector. The stress contribution to + the forces is scaled down by the same factor. This can be thought + of as a very simple preconditioner. Default is number of atoms + which gives approximately the correct scaling. + hydrostatic_strain: + Constrain the cell by only allowing hydrostatic deformation. + The virial tensor is replaced by np.diag([np.trace(virial)]*3). + constant_volume: + Project out the diagonal elements of the virial tensor to allow + relaxations at constant volume, e.g. for mapping out an + energy-volume curve. Note: this only approximately conserves + the volume and breaks energy/force consistency so can only be + used with optimizers that do require a line minimisation + (e.g. FIRE). + scalar_pressure: + Applied pressure to use for enthalpy pV term. As above, this + breaks energy/force consistency. + masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero + from zero differences in masked positions at future steps, we add a small number to prevent this. + dtype: data type for tensor operations (torch.float32 or torch.float64) + """ + super().__init__( + batch=batch, + trainer=trainer, + transform=transform, + numpy=numpy, + mask_converged=mask_converged, + masked_eps=masked_eps, + compute_stress=True, + use_fast_predict=use_fast_predict, + dtype=dtype, + ) + + self.orig_cells = self.get_cells().clone() + self.stress = None + + if mask is None: + # mask = torch.eye(3, device=self.device) + mask = torch.ones(6, device=self.device) + + # TODO make sure mask is on GPU + if mask.shape == (6,): + self.mask = torch.tensor( + voigt_6_to_full_3x3_stress(mask.detach().cpu()), + device=self.device, + ) + elif mask.shape == (3, 3): + self.mask = mask + else: + raise ValueError("shape of mask should be (3,3) or (6,)") + + if isinstance(cell_factor, float): + cell_factor = cell_factor * torch.ones( + (3 * len(batch), 1), requires_grad=False + ) + if cell_factor is None: + cell_factor = self.batch.natoms.repeat_interleave(3).unsqueeze( + dim=1 + ) + + self.hydrostatic_strain = hydrostatic_strain + self.constant_volume = constant_volume + self.pressure = scalar_pressure * torch.eye(3, device=self.device) + self.cell_factor = cell_factor + self.stress = None + self._batch_trace = torch.vmap(torch.trace) + self._batch_diag = torch.vmap( + lambda x: x * torch.eye(3, device=x.device) + ) + + @cached_property + def batch_indices(self): + """Get the batch indices specifying which position/force corresponds to which batch. + + We augment this to specify the batch indices for augmented positions and forces. + """ + augmented_batch = torch.repeat_interleave( + torch.arange( + len(self.batch), + dtype=self.batch.batch.dtype, + device=self.device, + ), + 3, + ) + return torch.cat([self.batch.batch, augmented_batch]) + + def deform_grad(self): + """Get the cell deformation matrix""" + return torch.transpose( + torch.linalg.solve(self.orig_cells, self.get_cells()), 1, 2 + ) + + def get_positions(self): + """Get positions and cell deformation gradient.""" + cur_deform_grad = self.deform_grad() + natoms = self.batch.num_nodes + pos = torch.zeros( + (natoms + 3 * len(self.get_cells()), 3), + dtype=self.batch.pos.dtype, + device=self.device, + ) + + # Augmented positions are the self.atoms.positions but without the applied deformation gradient + pos[:natoms] = torch.linalg.solve( + cur_deform_grad[self.batch.batch, :, :], + self.batch.pos.view(-1, 3, 1), + ).view(-1, 3) + # cell DOFs are the deformation gradient times a scaling factor + pos[natoms:] = self.cell_factor * cur_deform_grad.view(-1, 3) + return pos.cpu().numpy() if self.numpy else pos + + def set_positions(self, positions: torch.Tensor | NDArray) -> None: + """Set positions and cell. + + positions has shape (natoms + ncells * 3, 3). + the first natoms rows are the positions of the atoms, the last nsystems * three rows are the deformation tensor + for each cell. + """ + if isinstance(positions, np.ndarray): + positions = torch.tensor( + positions, dtype=self.dtype, device=self.device + ) + else: + positions = positions.to(dtype=self.dtype, device=self.device) + + natoms = self.batch.num_nodes + new_atom_positions = positions[:natoms] + new_deform_grad = (positions[natoms:] / self.cell_factor).view(-1, 3, 3) + + # TODO check that in fact symmetry is preserved setting cells and positions + # Set the new cell from the original cell and the new deformation gradient. Both current and final structures + # should preserve symmetry. + new_cells = torch.bmm( + self.orig_cells, torch.transpose(new_deform_grad, 1, 2) + ) + self.set_cells(new_cells) + + # Set the positions from the ones passed in (which are without the deformation gradient applied) and the new + # deformation gradient. This should also preserve symmetry + new_atom_positions = torch.bmm( + new_atom_positions.view(-1, 1, 3), + torch.transpose( + new_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), 1, 2 + ), + ) + super().set_positions(new_atom_positions.view(-1, 3)) + + def get_potential_energy(self, **kwargs): + """ + returns potential energy including enthalpy PV term. + """ + atoms_energy = super().get_potential_energy(**kwargs) + return atoms_energy + self.pressure[0, 0] * self.get_volumes().sum() + + def get_forces( + self, apply_constraint: bool = False, no_numpy: bool = False + ) -> torch.Tensor | NDArray: + """Get forces and unit cell stress.""" + stress = self.get_property("stress", no_numpy=True).view(-1, 3, 3) + atom_forces = self.get_property("forces", no_numpy=True) + + if apply_constraint: + fixed_idx = torch.where(self.batch.fixed == 1)[0] + atom_forces[fixed_idx] = 0.0 + + volumes = self.get_volumes().view(-1, 1, 1) + # virial = -volumes * stress + self.pressure.view(-1, 3, 3) + virial = -volumes * (stress + self.pressure.view(-1, 3, 3)) + # print(f'&&& virial0: {virial}') + cur_deform_grad = self.deform_grad() + atom_forces = torch.bmm( + atom_forces.view(-1, 1, 3), + cur_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), + ) + virial = torch.linalg.solve( + cur_deform_grad, torch.transpose(virial, dim0=1, dim1=2) + ) + virial = torch.transpose(virial, dim0=1, dim1=2) + + # print(f'&&& virial1: {virial}') + + # TODO this does not work yet! maybe _batch_trace gives an issue + if self.hydrostatic_strain: + virial = self._batch_diag(self._batch_trace(virial) / 3.0) + + # Zero out components corresponding to fixed lattice elements + if (self.mask != 1.0).any(): + virial *= self.mask.view(-1, 3, 3) + + if self.constant_volume: + virial[:, range(3), range(3)] -= ( + self._batch_trace(virial).view(3, -1) / 3.0 + ) + + natoms = self.batch.num_nodes + augmented_forces = torch.zeros( + (natoms + 3 * len(self.get_cells()), 3), + device=self.device, + dtype=atom_forces.dtype, + ) + # print(f'&&& atom_forces: {atom_forces}') + # print(f'&&& virial2: {virial}') + augmented_forces[:natoms] = atom_forces.view(-1, 3) + augmented_forces[natoms:] = virial.view(-1, 3) / self.cell_factor + + self.stress = -virial.view(-1, 9) / volumes.view(-1, 1) + + if self.numpy and not no_numpy: + augmented_forces = augmented_forces.cpu().numpy() + + # print(f'&&& augmented_forces: {augmented_forces}') + + return augmented_forces + + def __len__(self): + return len(self.batch.pos) + 3 * len(self.batch) + + def get_potential_energies(self) -> torch.Tensor: + """Get the predicted energy for each system in batch.""" + return ( + self.get_property("energy").view(-1) + + self.pressure[0, 0] * self.get_volumes() + ) diff --git a/mace-bench/src/batchopt/relaxation/optimizers/__init__.py b/mace-bench/src/batchopt/relaxation/optimizers/__init__.py index 47cba81c13cd74eb35eefeb06496243eb6bc8f16..71d1fff41092d06b798a16349d8ace7636b04a4e 100644 --- a/mace-bench/src/batchopt/relaxation/optimizers/__init__.py +++ b/mace-bench/src/batchopt/relaxation/optimizers/__init__.py @@ -1,13 +1,13 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -from .bfgs_torch import BFGS -from .bfgsfusedls import BFGSFusedLS - +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from .bfgs_torch import BFGS +from .bfgsfusedls import BFGSFusedLS + __all__ = ["BFGS", "BFGSFusedLS"] \ No newline at end of file diff --git a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/__init__.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 55f3f1d40668d6ad2146482b50fb484eccfadd16..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgs_torch.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgs_torch.cpython-310.pyc deleted file mode 100644 index 9026c7afa91c7bb0212f141077a20bb5de9dcc16..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgs_torch.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgsfusedls.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgsfusedls.cpython-310.pyc deleted file mode 100644 index 3a44638b72ddcc91a51e60fe752b77f824bf654f..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgsfusedls.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgslinesearch_torch.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgslinesearch_torch.cpython-310.pyc deleted file mode 100644 index 84b5ad212524202b4811ea5adbb0c79dbc7266fa..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgslinesearch_torch.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/lbfgs_torch.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/lbfgs_torch.cpython-310.pyc deleted file mode 100644 index 57128a45333701e99b6ae1b2f563b728722f0ad6..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/lbfgs_torch.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/linesearch_torch.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/linesearch_torch.cpython-310.pyc deleted file mode 100644 index 595a73fe601e540bfd2e4aee3e9349c7ef6b4ad7..0000000000000000000000000000000000000000 Binary files a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/linesearch_torch.cpython-310.pyc and /dev/null differ diff --git a/mace-bench/src/batchopt/relaxation/optimizers/bfgs_torch.py b/mace-bench/src/batchopt/relaxation/optimizers/bfgs_torch.py index 2f8fb89591e192ca3c1efd6ffb41ea2971fc52ce..62409eb23d0a460c716fd545665c3c070dbde4f2 100644 --- a/mace-bench/src/batchopt/relaxation/optimizers/bfgs_torch.py +++ b/mace-bench/src/batchopt/relaxation/optimizers/bfgs_torch.py @@ -1,286 +1,286 @@ - -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -import logging -import torch -from torch_scatter import scatter - -from ..optimizable import OptimizableBatch - -class BFGS: - def __init__( - self, - optimizable_batch: OptimizableBatch, - maxstep: float = 0.2, - alpha: float = 70.0, - early_stop = False, - ) -> None: - """ - Args: - """ - self.optimizable = optimizable_batch - self.maxstep = maxstep - self.alpha = alpha - # self.H0 = 1.0 / self.alpha - self.trajectories = None - self.device=self.optimizable.device - - self.fmax = None - self.steps = None - - self.initialize() - self.early_stop = early_stop - - - def initialize(self): - # initial hessian - self.H0 = [ - torch.eye(3 * size, device=self.optimizable.device, dtype=torch.float64) * self.alpha - for size in self.optimizable.elem_per_group - ] - - self.H = [None] * self.optimizable.batch_size - self.pos0 = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device, dtype=torch.float64) - self.forces0 = torch.zeros_like(self.pos0, device=self.device, dtype=torch.float64) - - def restart_from_earlystop(self, restart_indices, old_batch_indices): - H_new = [] - pos0_new = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device, dtype=torch.float64) - forces0_new = torch.zeros_like(pos0_new, device=self.device, dtype=torch.float64) - - # collect the preserved historical data by old_batch_indices - for i, idx in enumerate(restart_indices): - mask_old = (idx==old_batch_indices.repeat_interleave(3)) - mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) - H_new.append(self.H[idx]) - pos0_new[mask] = self.pos0[mask_old] - forces0_new[mask] = self.forces0[mask_old] - - # append new info for the new batch - for i in range(len(H_new), self.optimizable.batch_size): - H_new.append(None) - - self.H = H_new - self.pos0 = pos0_new - self.forces0 = forces0_new - - - def run(self, fmax, maxstep, is_restart_earlystop=False, restart_indices=None, old_batch_indices=None): - logging.info("Enter bfgs's main program.") - self.fmax = fmax - self.max_iter = maxstep - - if is_restart_earlystop: - self.restart_from_earlystop(restart_indices, old_batch_indices) - - iteration = 0 - max_forces = self.optimizable.get_max_forces(apply_constraint=True) - logging.info("Step Fmax(eV/A)") - - while iteration < self.max_iter and not self.optimizable.converged( - forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, - ): - if self.early_stop and iteration > 0: - converge_indices = self.optimizable.converge_indices_list - if len(converge_indices) > 0: - logging.info(f"Early stopping at iteration {iteration}") - break - - logging.info( - f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) - ) - - self.step() - max_forces = self.optimizable.get_max_forces(apply_constraint=True) - iteration += 1 - - logging.info( - f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) - ) - - # GPU memory usage as per nvidia-smi seems to gradually build up as - # batches are processed. This releases unoccupied cached memory. - torch.cuda.empty_cache() - - # set predicted values to batch - for name, value in self.optimizable.results.items(): - setattr(self.optimizable.batch, name, value) - - self.nsteps = iteration - - if self.early_stop: - converge_indices_list = self.optimizable.converge_indices_list - return converge_indices_list - else: - return self.optimizable.converged( - forces=None, fmax=self.fmax, max_forces=max_forces - ) - - - def step(self): - forces = self.optimizable.get_forces(apply_constraint=True).to( - dtype=torch.float64 - ) - pos = self.optimizable.get_positions().to(dtype=torch.float64) - dpos, steplengths = self.prepare_step(pos, forces) - dpos = self.determine_step(dpos, steplengths) - self.optimizable.set_positions(pos+dpos) - - - def prepare_step(self, pos, forces): - forces = forces.reshape(-1) - pos = pos.view(-1) - self.update(pos, forces, self.pos0, self.forces0) - - dpos_list = [] - cur_indices = self.optimizable.batch_indices.repeat_interleave(3) - # 预初始化结果列表 - dpos_list = [None] * len(self.H) - - # 分离计算任务:仅对需要计算的H矩阵创建流 - calc_indices = [i for i, need_update in enumerate(self.optimizable.update_mask) if need_update] - streams = [torch.cuda.Stream() for _ in calc_indices] - - # 并行执行实际计算 - for i, stream in zip(calc_indices, streams): - with torch.cuda.stream(stream): - omega, V = torch.linalg.eigh(self.H[i]) - dpos_list[i] = (V @ (forces[cur_indices==i].t() @ V / torch.abs(omega)).t()) - - # 同步所有计算流 - torch.cuda.current_stream().synchronize() - - # 在主线程处理零张量 - for i in range(len(self.H)): - if not self.optimizable.update_mask[i]: - dpos_list[i] = torch.zeros_like(forces[cur_indices==i]) - - # 同步所有流 - for stream in streams: - stream.synchronize() - - # dpos = torch.vstack(dpos_list) - dpos = torch.zeros_like(forces) - for i in torch.unique(cur_indices): - mask = (cur_indices == i) - dpos[mask] = dpos_list[i] - dpos = dpos.reshape(-1, 3) - - steplengths = (dpos ** 2).sum(dim=-1).sqrt() - self.pos0 = pos - self.forces0 = forces - - return dpos, steplengths - - - def determine_step(self, dpos, steplengths): - longest_steps = scatter( - steplengths, self.optimizable.batch_indices, reduce="max" - ) - longest_steps = longest_steps[self.optimizable.batch_indices] - maxstep = longest_steps.new_tensor(self.maxstep) - scale = (longest_steps).reciprocal() * torch.min(longest_steps, maxstep) - dpos *= scale.unsqueeze(1) - return dpos - - def update(self, pos, forces, pos0, forces0): - if self.H is None: - self.H = self.H0 - return - dpos = pos - pos0 - dforces = forces - forces0 - batch_indices_flatten = self.optimizable.batch_indices.repeat_interleave(3) - dg = torch.zeros_like(dforces) - all_size = self.optimizable.elem_per_group - - for i in range(self.optimizable.batch_size): - if self.H[i] is None: - continue - mask = (i==batch_indices_flatten) - if torch.abs(dpos[mask]).max() < 1e-7: - continue - - dg[mask] = self.H[i] @ dpos[mask] - - a = self._batched_dot_1d(dforces, dpos) - b = self._batched_dot_1d(dpos, dg) - - for i in range(self.optimizable.batch_size): - if self.H[i] is None: - self.H[i] = torch.eye(3*all_size[i], device=self.device, dtype=torch.float64) * self.alpha - continue - mask = (i==batch_indices_flatten) - if not self.optimizable.update_mask[i]: - continue - if torch.abs(dpos[mask]).max() < 1e-7: - continue - - outer_force = torch.outer(dforces[mask], dforces[mask]) - outer_dg = torch.outer(dg[mask], dg[mask]) - self.H[i] -= outer_force / a[i] + outer_dg / b[i] - - - - def update_parallel(self, pos, forces, pos0, forces0): - if self.H is None: - self.H = self.H0 - return - - dpos = pos - pos0 - - if torch.abs(dpos).max() < 1e-7: - return - - dforces = forces - forces0 - cur_indices = self.optimizable.batch_indices.repeat_interleave(3) - a = self._batched_dot_1d(dforces, dpos) - # DONE: There is a bug using hstack. - # dg = torch.hstack([self.H[i] @ dpos[cur_indices == i] for i in range(len(self.H))]) - # DONE: parallel this part - # dg_list = [self.H[i] @ dpos[cur_indices == i] for i in range(len(self.H))] - dg_list = [None] * len(self.H) - streams = [torch.cuda.Stream() for _ in dg_list] - for i, stream in zip(range(len(dg_list)), streams): - with torch.cuda.stream(stream): - dg_list[i] = self.H[i] @ dpos[cur_indices == i] - - torch.cuda.current_stream().synchronize() - for stream in streams: - stream.synchronize() - - dg = torch.zeros_like(dforces) - for i in torch.unique(cur_indices): - mask = (cur_indices == i) - dg[mask] = dg_list[i] - b = self._batched_dot_1d(dpos, dg) - - # DONE: parallel this part - for i, stream in zip(range(len(self.H)), streams): - if not self.optimizable.update_mask[i]: - continue - with torch.cuda.stream(stream): - outer_force = torch.outer(dforces[cur_indices==i], dforces[cur_indices==i]) - outer_dg = torch.outer(dg[cur_indices==i], dg[cur_indices==i]) - self.H[i] -= outer_force / a[i] + outer_dg / b[i] - - torch.cuda.current_stream().synchronize() - for stream in streams: - stream.synchronize() - - - def _batched_dot_2d(self, x: torch.Tensor, y: torch.Tensor): - return scatter( - (x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum" - ) - - def _batched_dot_1d(self, x: torch.Tensor, y: torch.Tensor): - return scatter( - (x * y), self.optimizable.batch_indices.repeat_interleave(3), reduce="sum" + +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +import torch +from torch_scatter import scatter + +from ..optimizable import OptimizableBatch + +class BFGS: + def __init__( + self, + optimizable_batch: OptimizableBatch, + maxstep: float = 0.2, + alpha: float = 70.0, + early_stop = False, + ) -> None: + """ + Args: + """ + self.optimizable = optimizable_batch + self.maxstep = maxstep + self.alpha = alpha + # self.H0 = 1.0 / self.alpha + self.trajectories = None + self.device=self.optimizable.device + + self.fmax = None + self.steps = None + + self.initialize() + self.early_stop = early_stop + + + def initialize(self): + # initial hessian + self.H0 = [ + torch.eye(3 * size, device=self.optimizable.device, dtype=torch.float64) * self.alpha + for size in self.optimizable.elem_per_group + ] + + self.H = [None] * self.optimizable.batch_size + self.pos0 = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device, dtype=torch.float64) + self.forces0 = torch.zeros_like(self.pos0, device=self.device, dtype=torch.float64) + + def restart_from_earlystop(self, restart_indices, old_batch_indices): + H_new = [] + pos0_new = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device, dtype=torch.float64) + forces0_new = torch.zeros_like(pos0_new, device=self.device, dtype=torch.float64) + + # collect the preserved historical data by old_batch_indices + for i, idx in enumerate(restart_indices): + mask_old = (idx==old_batch_indices.repeat_interleave(3)) + mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) + H_new.append(self.H[idx]) + pos0_new[mask] = self.pos0[mask_old] + forces0_new[mask] = self.forces0[mask_old] + + # append new info for the new batch + for i in range(len(H_new), self.optimizable.batch_size): + H_new.append(None) + + self.H = H_new + self.pos0 = pos0_new + self.forces0 = forces0_new + + + def run(self, fmax, maxstep, is_restart_earlystop=False, restart_indices=None, old_batch_indices=None): + logging.info("Enter bfgs's main program.") + self.fmax = fmax + self.max_iter = maxstep + + if is_restart_earlystop: + self.restart_from_earlystop(restart_indices, old_batch_indices) + + iteration = 0 + max_forces = self.optimizable.get_max_forces(apply_constraint=True) + logging.info("Step Fmax(eV/A)") + + while iteration < self.max_iter and not self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, + ): + if self.early_stop and iteration > 0: + converge_indices = self.optimizable.converge_indices_list + if len(converge_indices) > 0: + logging.info(f"Early stopping at iteration {iteration}") + break + + logging.info( + f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) + ) + + self.step() + max_forces = self.optimizable.get_max_forces(apply_constraint=True) + iteration += 1 + + logging.info( + f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) + ) + + # GPU memory usage as per nvidia-smi seems to gradually build up as + # batches are processed. This releases unoccupied cached memory. + torch.cuda.empty_cache() + + # set predicted values to batch + for name, value in self.optimizable.results.items(): + setattr(self.optimizable.batch, name, value) + + self.nsteps = iteration + + if self.early_stop: + converge_indices_list = self.optimizable.converge_indices_list + return converge_indices_list + else: + return self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces + ) + + + def step(self): + forces = self.optimizable.get_forces(apply_constraint=True).to( + dtype=torch.float64 + ) + pos = self.optimizable.get_positions().to(dtype=torch.float64) + dpos, steplengths = self.prepare_step(pos, forces) + dpos = self.determine_step(dpos, steplengths) + self.optimizable.set_positions(pos+dpos) + + + def prepare_step(self, pos, forces): + forces = forces.reshape(-1) + pos = pos.view(-1) + self.update(pos, forces, self.pos0, self.forces0) + + dpos_list = [] + cur_indices = self.optimizable.batch_indices.repeat_interleave(3) + # 预初始化结果列表 + dpos_list = [None] * len(self.H) + + # 分离计算任务:仅对需要计算的H矩阵创建流 + calc_indices = [i for i, need_update in enumerate(self.optimizable.update_mask) if need_update] + streams = [torch.cuda.Stream() for _ in calc_indices] + + # 并行执行实际计算 + for i, stream in zip(calc_indices, streams): + with torch.cuda.stream(stream): + omega, V = torch.linalg.eigh(self.H[i]) + dpos_list[i] = (V @ (forces[cur_indices==i].t() @ V / torch.abs(omega)).t()) + + # 同步所有计算流 + torch.cuda.current_stream().synchronize() + + # 在主线程处理零张量 + for i in range(len(self.H)): + if not self.optimizable.update_mask[i]: + dpos_list[i] = torch.zeros_like(forces[cur_indices==i]) + + # 同步所有流 + for stream in streams: + stream.synchronize() + + # dpos = torch.vstack(dpos_list) + dpos = torch.zeros_like(forces) + for i in torch.unique(cur_indices): + mask = (cur_indices == i) + dpos[mask] = dpos_list[i] + dpos = dpos.reshape(-1, 3) + + steplengths = (dpos ** 2).sum(dim=-1).sqrt() + self.pos0 = pos + self.forces0 = forces + + return dpos, steplengths + + + def determine_step(self, dpos, steplengths): + longest_steps = scatter( + steplengths, self.optimizable.batch_indices, reduce="max" + ) + longest_steps = longest_steps[self.optimizable.batch_indices] + maxstep = longest_steps.new_tensor(self.maxstep) + scale = (longest_steps).reciprocal() * torch.min(longest_steps, maxstep) + dpos *= scale.unsqueeze(1) + return dpos + + def update(self, pos, forces, pos0, forces0): + if self.H is None: + self.H = self.H0 + return + dpos = pos - pos0 + dforces = forces - forces0 + batch_indices_flatten = self.optimizable.batch_indices.repeat_interleave(3) + dg = torch.zeros_like(dforces) + all_size = self.optimizable.elem_per_group + + for i in range(self.optimizable.batch_size): + if self.H[i] is None: + continue + mask = (i==batch_indices_flatten) + if torch.abs(dpos[mask]).max() < 1e-7: + continue + + dg[mask] = self.H[i] @ dpos[mask] + + a = self._batched_dot_1d(dforces, dpos) + b = self._batched_dot_1d(dpos, dg) + + for i in range(self.optimizable.batch_size): + if self.H[i] is None: + self.H[i] = torch.eye(3*all_size[i], device=self.device, dtype=torch.float64) * self.alpha + continue + mask = (i==batch_indices_flatten) + if not self.optimizable.update_mask[i]: + continue + if torch.abs(dpos[mask]).max() < 1e-7: + continue + + outer_force = torch.outer(dforces[mask], dforces[mask]) + outer_dg = torch.outer(dg[mask], dg[mask]) + self.H[i] -= outer_force / a[i] + outer_dg / b[i] + + + + def update_parallel(self, pos, forces, pos0, forces0): + if self.H is None: + self.H = self.H0 + return + + dpos = pos - pos0 + + if torch.abs(dpos).max() < 1e-7: + return + + dforces = forces - forces0 + cur_indices = self.optimizable.batch_indices.repeat_interleave(3) + a = self._batched_dot_1d(dforces, dpos) + # DONE: There is a bug using hstack. + # dg = torch.hstack([self.H[i] @ dpos[cur_indices == i] for i in range(len(self.H))]) + # DONE: parallel this part + # dg_list = [self.H[i] @ dpos[cur_indices == i] for i in range(len(self.H))] + dg_list = [None] * len(self.H) + streams = [torch.cuda.Stream() for _ in dg_list] + for i, stream in zip(range(len(dg_list)), streams): + with torch.cuda.stream(stream): + dg_list[i] = self.H[i] @ dpos[cur_indices == i] + + torch.cuda.current_stream().synchronize() + for stream in streams: + stream.synchronize() + + dg = torch.zeros_like(dforces) + for i in torch.unique(cur_indices): + mask = (cur_indices == i) + dg[mask] = dg_list[i] + b = self._batched_dot_1d(dpos, dg) + + # DONE: parallel this part + for i, stream in zip(range(len(self.H)), streams): + if not self.optimizable.update_mask[i]: + continue + with torch.cuda.stream(stream): + outer_force = torch.outer(dforces[cur_indices==i], dforces[cur_indices==i]) + outer_dg = torch.outer(dg[cur_indices==i], dg[cur_indices==i]) + self.H[i] -= outer_force / a[i] + outer_dg / b[i] + + torch.cuda.current_stream().synchronize() + for stream in streams: + stream.synchronize() + + + def _batched_dot_2d(self, x: torch.Tensor, y: torch.Tensor): + return scatter( + (x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum" + ) + + def _batched_dot_1d(self, x: torch.Tensor, y: torch.Tensor): + return scatter( + (x * y), self.optimizable.batch_indices.repeat_interleave(3), reduce="sum" ) \ No newline at end of file diff --git a/mace-bench/src/batchopt/relaxation/optimizers/bfgsfusedls.py b/mace-bench/src/batchopt/relaxation/optimizers/bfgsfusedls.py index 522de6d9d182781d4776a13af45feb1d4db8dd90..9dc221b64021c1607a8a80b06592c3aa5683d737 100644 --- a/mace-bench/src/batchopt/relaxation/optimizers/bfgsfusedls.py +++ b/mace-bench/src/batchopt/relaxation/optimizers/bfgsfusedls.py @@ -1,993 +1,993 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations -import logging -import torch -from torch_scatter import scatter -# from .linesearch_torch import LineSearchBatch -from ..optimizable import OptimizableBatch -from torch.profiler import profile, record_function, ProfilerActivity, schedule, tensorboard_trace_handler -from datetime import datetime -import os -import math -import gc - -class BFGSFusedLS: - """ - Port of BFGSLineSearch from bfgslinesearch.py, adapted to PyTorch - and batched operations, mirroring lbfgs_torch.py structure. - """ - def __init__( - self, - optimizable_batch: OptimizableBatch, - maxstep: float = 0.2, - c1: float = 0.23, - c2: float = 0.46, - alpha: float = 10.0, - stpmax: float = 50.0, - device = 'cpu', - early_stop: bool = False, - use_profiler: bool = False, - profiler_log_dir: str = './log', - profiler_schedule_config: dict = None, - dtype: torch.dtype = torch.float64, - ): - self.optimizable = optimizable_batch - self.maxstep = maxstep - self.c1 = c1 - self.c2 = c2 - self.alpha = alpha - self.stpmax = stpmax - self.nsteps = 0 - self.device = device - self.force_calls = 0 - self.early_stop = early_stop - self.use_profiler = use_profiler - self.profiler_log_dir = profiler_log_dir - self.profiler_schedule_config = profiler_schedule_config or {"wait": 48, "warmup": 1, "active": 1, "repeat": 8} - self.dtype = dtype - - self.converge_indices_list = None - - # The information from the previous round is useful for the current round's calculations. - ## These variables need to be update accroding to new input when eary stop is triggered. - self.Hs = None - self.r0 = None - self.g0 = None - self.p_list = [None] * self.optimizable.batch_size - self.no_update_list = [False] * self.optimizable.batch_size - self.ls_completed = [True] * self.optimizable.batch_size - self.ls_batch = LineSearchBatch(self.optimizable.batch_indices, device="cpu", dtype=self.dtype) - ## need to be recalculate when early stop is triggered - self.forces = None - self.energies = None - - def restart_from_earlystop(self, restart_indices, old_batch_indices): - Hs_new = [] - r0_new = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device) - g0_new = torch.zeros_like(r0_new, device=self.device) - p_list_new = [] - no_update_list_new = [] - ls_completed_new = [] - - # collect the preserved historical info by old_indices - for i, idx in enumerate(restart_indices): - mask_old = (idx==old_batch_indices.repeat_interleave(3)) - mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) - Hs_new.append(self.Hs[idx]) - p_list_new.append(self.p_list[idx]) - no_update_list_new.append(self.no_update_list[idx]) - ls_completed_new.append(self.ls_completed[idx]) - r0_new[mask] = self.r0[mask_old] - g0_new[mask] = self.g0[mask_old] - - # append new info for new element in batch - for i in range(len(Hs_new), self.optimizable.batch_size): - # Hs_new.append(torch.eye(3 * self.optimizable.elem_per_group[i], device=self.device, dtype=torch.float64)) - Hs_new.append(None) - p_list_new.append(None) - no_update_list_new.append(False) - ls_completed_new.append(True) - - self.Hs = Hs_new - self.r0 = r0_new - self.g0 = g0_new - self.p_list = p_list_new - self.no_update_list = no_update_list_new - self.ls_completed = ls_completed_new - self.forces = None - self.energies = None - self.ls_batch.restart_from_earlystop(restart_indices=restart_indices, batch_indices_new=self.optimizable.batch_indices) - - def step(self): - optimizable = self.optimizable - if self.forces is None: - self.forces = optimizable.get_forces().to(self.device) - r = optimizable.get_positions().reshape(-1).to(self.device) - g = -self.forces.reshape(-1) / self.alpha - p0_list = self.p_list - self.update(r, g, self.r0, self.g0, p0_list) - if self.energies is None: - self.energies = self.func(r) - - for i in range(self.optimizable.batch_size): - if self.ls_completed[i]: - p = -torch.matmul(self.Hs[i], g[i==self.optimizable.batch_indices.repeat_interleave(3)]) - - # Implement scaling for numerical stability with simpler calculation - p_size = torch.sqrt((p**2).sum()) - min_size = torch.sqrt(self.optimizable.elem_per_group[i] * 1e-10) - if p_size <= min_size: - p = p * (min_size / p_size) - - self.p_list[i] = p - - # ls_batch = LineSearchBatch(self.optimizable.batch_indices, device="cpu") - continue_search = [not elem for elem in self.ls_completed] - self.alpha_k_list, self.e_list, self.e0_list, self.no_update_list, self.ls_completed = self.ls_batch._linesearch_batch( - self.func, self.fprime, r, self.p_list, g, self.energies, None, - maxstep=self.maxstep, c1=self.c1, c2=self.c2, stpmax=self.stpmax, continue_search=continue_search - ) - - # reset device for linesearch result - for i in range(self.optimizable.batch_size): - if self.ls_completed[i]: - self.alpha_k_list[i] = self.alpha_k_list[i].to(self.device) - self.p_list[i] = self.p_list[i].to(self.device) - - dr_tensor = torch.zeros_like(r) - - - for i in range(self.optimizable.batch_size): - # if check_cache: - # mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) - # dr_tensor_all[mask] = self.alpha_k_list[i].to(self.device) * self.p_list[i].to(self.device) - - if not self.ls_completed[i]: - continue - if self.alpha_k_list[i] is None: - raise RuntimeError("LineSearch failed!") - - mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) - dr_tensor[mask] = self.alpha_k_list[i] * self.p_list[i] - - # if check_cache: - # cached_pos = optimizable.get_positions().reshape(-1).to(self.device) - # update_pos = r + dr_tensor_all - # assert torch.allclose(update_pos, cached_pos), "dr_tensor_cached should be equal to dr_tensor" - - - # TODO: get_forces/get_potential_energies will trigger compare_batch which is time-consuming - forces_cache = optimizable.get_forces() - energies_cache = self.optimizable.get_potential_energies() / self.alpha - - # update self.forces - for i in range(self.optimizable.batch_size): - if not self.ls_completed[i]: - continue - mask = (i == self.optimizable.batch_indices) - self.forces[mask] = forces_cache[mask] - self.energies[i] = energies_cache[i] - - optimizable.set_positions((r + dr_tensor).reshape(-1, 3)) - - self.r0 = r - self.g0 = g - - # @torch.compile - def update(self, r, g, r0, g0, p0_list): - all_sizes = self.optimizable.elem_per_group - - if self.Hs is None: - self.Hs = [ - torch.eye(3 * sz, device=self.device, dtype=self.dtype) - for sz in all_sizes - ] - return - - dr = r - r0 - dg = g - g0 - - for i in range(self.optimizable.batch_size): - if self.Hs[i] is None: - self.Hs[i] = torch.eye(3 * all_sizes[i], device=self.optimizable.device, dtype=self.dtype) - continue - if not self.ls_completed[i]: - continue - if self.no_update_list[i] is True: - print('skip update') - continue - - cur_mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) - cur_g = g[cur_mask] - cur_p0 = p0_list[i] - cur_g0 = g0[cur_mask] - cur_dg = dg[cur_mask] - cur_dr = dr[cur_mask] - - if not (((self.alpha_k_list[i] or 0) > 0 and - abs(torch.dot(cur_g, cur_p0)) - abs(torch.dot(cur_g0, cur_p0)) < 0) or False): - continue - - try: - rhok = 1.0 / (torch.dot(cur_dg, cur_dr)) - except: - rhok = 1000.0 - print("Divide-by-zero encountered: rhok assumed large") - if torch.isinf(rhok): - rhok = 1000.0 - print("Divide-by-zero encountered: rhok assumed large") - I = torch.eye(all_sizes[i]*3, device=self.device, dtype=self.dtype) - A1 = I - cur_dr[:, None] * cur_dg[None, :] * rhok - A2 = I - cur_dg[:, None] * cur_dr[None, :] * rhok - self.Hs[i] = (torch.matmul(A1, torch.matmul(self.Hs[i], A2)) + - rhok * cur_dr[:, None] * cur_dr[None, :]) - - - # def update(self, r, g, r0, g0, p0_list): - # self.Is = [ - # torch.eye(sz * 3, dtype=torch.float64, device=self.device) - # for sz in self.optimizable.elem_per_group - # ] - - # # TODO: BFGS for loop 是不是在被打断之后需要重建这个 self.Hs? - # # TODO: 并且我们保存的上一次的r,g,r0,g0也被丢弃了 - # if self.Hs is None: - # self.Hs = [ - # torch.eye(3 * sz, device=self.optimizable.device, dtype=torch.float64) - # for sz in self.optimizable.elem_per_group - # ] - # return - # else: - # dr = r - r0 - # dg = g - g0 - - # for i in range(self.optimizable.batch_size): - # if not self.ls_completed[i]: - # continue - # cur_mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) - # cur_g = g[cur_mask] - # cur_p0 = p0_list[i] - # cur_g0 = g0[cur_mask] - # cur_dg = dg[cur_mask] - # cur_dr = dr[cur_mask] - - # if not (((self.alpha_k_list[i] or 0) > 0 and - # abs(torch.dot(cur_g, cur_p0)) - abs(torch.dot(cur_g0, cur_p0)) < 0) or False): - # break - - # if self.no_update_list[i] is True: - # print('skip update') - # break - - # try: - # rhok = 1.0 / (torch.dot(cur_dg, cur_dr)) - # except: - # rhok = 1000.0 - # print("Divide-by-zero encountered: rhok assumed large") - # if torch.isinf(rhok): - # rhok = 1000.0 - # print("Divide-by-zero encountered: rhok assumed large") - # A1 = self.Is[i] - cur_dr[:, None] * cur_dg[None, :] * rhok - # A2 = self.Is[i] - cur_dg[:, None] * cur_dr[None, :] * rhok - # self.Hs[i] = (torch.matmul(A1, torch.matmul(self.Hs[i], A2)) + - # rhok * cur_dr[:, None] * cur_dr[None, :]) - - - - def func(self, x): - self.optimizable.set_positions(x.reshape(-1, 3).to(self.device)) - return self.optimizable.get_potential_energies() / self.alpha - - def fprime(self, x): - self.optimizable.set_positions(x.reshape(-1, 3).to(self.device)) - - self.force_calls += 1 - forces = self.optimizable.get_forces().reshape(-1) - return - forces / self.alpha - - def run(self, fmax, maxstep, is_restart_earlystop=False, restart_indices=None, old_batch_indices=None): - logging.info("Enter bfgsfusedlinesearch's main program.") - self.fmax = fmax - self.max_iter = maxstep - - if is_restart_earlystop: - self.restart_from_earlystop(restart_indices, old_batch_indices) - - iteration = 0 - max_forces = self.optimizable.get_max_forces(apply_constraint=True) - logging.info("Step Fmax(eV/A)") - - # Run with profiler if enabled - if self.use_profiler: - activities = [ProfilerActivity.CPU] - if torch.cuda.is_available(): - activities.append(ProfilerActivity.CUDA) - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - pid = os.getpid() - with torch.profiler.profile( - activities=activities, - schedule=torch.profiler.schedule( - wait=self.profiler_schedule_config["wait"], - warmup=self.profiler_schedule_config["warmup"], - active=self.profiler_schedule_config["active"], - repeat=self.profiler_schedule_config["repeat"] - ), - on_trace_ready=tensorboard_trace_handler(self.profiler_log_dir, worker_name=f"BFGSLS_{pid}"), - with_stack=True, - profile_memory=True, - ) as prof: - # Main optimization loop with profiling - while iteration < self.max_iter and not self.optimizable.converged( - forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, - ): - if self.early_stop and iteration > 0: - self.converge_indices_list = self.optimizable.converge_indices_list - if len(self.converge_indices_list) > 0: - logging.info(f"Early stopping at iteration {iteration}") - break - - logging.info( - f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) - ) - - self.step() - max_forces = self.optimizable.get_max_forces(apply_constraint=True, forces=self.forces) - iteration += 1 - - # Step the profiler in each iteration - prof.step() - - else: - # Original optimization loop without profiling - while iteration < self.max_iter and not self.optimizable.converged( - forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, - ): - if self.early_stop and iteration > 0: - self.converge_indices_list = self.optimizable.converge_indices_list - if len(self.converge_indices_list) > 0: - logging.info(f"Early stopping at iteration {iteration}") - break - - logging.info( - f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) - ) - - self.step() - max_forces = self.optimizable.get_max_forces(apply_constraint=True, forces=self.forces) - iteration += 1 - - logging.info( - f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) - ) - - # GPU memory usage as per nvidia-smi seems to gradually build up as - # batches are processed. This releases unoccupied cached memory. - torch.cuda.empty_cache() - gc.collect() - - # set predicted values to batch - for name, value in self.optimizable.results.items(): - setattr(self.optimizable.batch, name, value) - - self.nsteps = iteration - - if self.early_stop: - self.converge_indices_list = self.optimizable.converge_indices_list - return self.converge_indices_list - else: - return self.optimizable.converged( - forces=None, fmax=self.fmax, max_forces=max_forces - ) - - def _batched_dot_2d(self, x: torch.Tensor, y: torch.Tensor): - return scatter( - (x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum" - ) - - def _batched_dot_1d(self, x: torch.Tensor, y: torch.Tensor): - return scatter( - (x * y), self.optimizable.batch_indices.repeat_interleave(3), reduce="sum" - ) - -# flake8: noqa -import math -import torch -import logging - -pymin = min -pymax = max - - -class LineSearch: - def __init__(self, xtol=1e-14, device='cpu', dtype=torch.float64): - self.xtol = xtol - self.task = 'START' - self.device = device - self.dtype = dtype - self.isave = torch.zeros(2, dtype=torch.int64, device=self.device) - self.dsave = torch.zeros(13, dtype=self.dtype, device=self.device) - self.fc = 0 - self.gc = 0 - self.case = 0 - self.old_stp = 0 - - def initialize(self, xk, pk, gfk, old_fval, old_old_fval, - maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., - stpmax=50., stpmin=1e-8): - # Scalar parameters can stay as Python scalars - self.stpmin = stpmin - self.stpmax = stpmax - self.xtrapl = xtrapl - self.xtrapu = xtrapu - self.maxstep = maxstep - - # Move tensors to the device - self.pk = pk.to(self.device) - xk = xk.to(self.device) - gfk = gfk.to(self.device) - - phi0 = old_fval - - - # This dot product needs tensors - derphi0 = torch.dot(gfk, self.pk).item() - - # Use Python math for scalar calculations - self.dim = len(pk) - self.gms = math.sqrt(self.dim) * maxstep - - alpha1 = 1.0 - self.no_update = False - self.gradient = True - - self.steps = [] - return alpha1, phi0, derphi0 - - def prologue(self, fval, gval, pk_tensor, alpha1): - phi0 = fval - derphi0 = torch.dot(gval, pk_tensor) - self.old_stp = alpha1 - # TODO: self.no_update == True: break is needed to reimplemented. - - return phi0, derphi0 - - def epilogue(self): - pass - - def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, - maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., - stpmax=50., stpmin=1e-8, args=()): - self.stpmin = stpmin - self.pk = pk.to(self.device) - self.stpmax = stpmax - self.xtrapl = xtrapl - self.xtrapu = xtrapu - self.maxstep = maxstep - - xk = xk.to(self.device) - - # Convert inputs to torch tensors if they're not already - if not isinstance(old_fval, torch.Tensor): - phi0 = torch.tensor(old_fval, dtype=self.dtype, device=self.device) - else: - phi0 = old_fval.to(self.device) - - # Ensure pk and gfk are torch tensors - pk_tensor = torch.tensor(pk, dtype=self.dtype, device=self.device) if not isinstance(pk, torch.Tensor) else pk.to(self.device) - gfk_tensor = torch.tensor(gfk, dtype=self.dtype, device=self.device) if not isinstance(gfk, torch.Tensor) else gfk.to(self.device) - - derphi0 = torch.dot(gfk_tensor, pk_tensor) - self.dim = len(pk) - self.gms = torch.sqrt(torch.tensor(self.dim, dtype=self.dtype, device=self.device)) * maxstep - alpha1 = 1. - self.no_update = False - - if isinstance(myfprime, tuple): - fprime = myfprime[0] - gradient = False - else: - fprime = myfprime - newargs = args - gradient = True - - fval = phi0 - gval = gfk_tensor - self.steps = [] - - while True: - stp = self.step(alpha1, phi0, derphi0, c1, c2, - self.xtol, - self.isave, self.dsave) - - if self.task[:2] == 'FG': - alpha1 = stp - - # Get function value and gradient - x_new = xk + stp * pk_tensor - fval = func(x_new).to(self.device) - self.fc += 1 - - gval = fprime(x_new).to(self.device) - if gradient: - self.gc += 1 - else: - self.fc += len(xk) + 1 - - phi0 = fval - derphi0 = torch.dot(gval, pk_tensor) - self.old_stp = alpha1 - - if self.no_update == True: - break - else: - break - - if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN': - stp = None # failed - - return stp, fval.item(), old_fval.item() if isinstance(old_fval, torch.Tensor) else old_fval, self.no_update - - def step(self, stp, f, g, c1, c2, xtol, isave, dsave): - if self.task[:5] == 'START': - # Check the input arguments for errors. - if stp < self.stpmin: - self.task = 'ERROR: STP .LT. minstep' - if stp > self.stpmax: - self.task = 'ERROR: STP .GT. maxstep' - if g >= 0: - self.task = 'ERROR: INITIAL G >= 0' - if c1 < 0: - self.task = 'ERROR: c1 .LT. 0' - if c2 < 0: - self.task = 'ERROR: c2 .LT. 0' - if xtol < 0: - self.task = 'ERROR: XTOL .LT. 0' - if self.stpmin < 0: - self.task = 'ERROR: minstep .LT. 0' - if self.stpmax < self.stpmin: - self.task = 'ERROR: maxstep .LT. minstep' - if self.task[:5] == 'ERROR': - return stp - - # Initialize local variables. - self.bracket = False - stage = 1 - finit = f - ginit = g - gtest = c1 * ginit - width = self.stpmax - self.stpmin - width1 = width / .5 - - # The variables stx, fx, gx contain the values of the step, - # function, and derivative at the best step. - # The variables sty, fy, gy contain the values of the step, - # function, and derivative at sty. - # The variables stp, f, g contain the values of the step, - # function, and derivative at stp. - stx = 0.0 - fx = finit - gx = ginit - sty = 0.0 - fy = finit - gy = ginit - stmin = 0.0 - stmax = stp + self.xtrapu * stp - self.task = 'FG' - self.save((stage, ginit, gtest, gx, - gy, finit, fx, fy, stx, sty, - stmin, stmax, width, width1)) - stp = self.determine_step(stp) - return stp - else: - if self.isave[0] == 1: - self.bracket = True - else: - self.bracket = False - stage = self.isave[1] - (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, - width, width1) = self.dsave - - # If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the - # algorithm enters the second stage. - ftest = finit + stp * gtest - if stage == 1 and f < ftest and g >= 0.: - stage = 2 - - # Test for warnings. - if self.bracket and (stp <= stmin or stp >= stmax): - self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS' - if self.bracket and stmax - stmin <= self.xtol * stmax: - self.task = 'WARNING: XTOL TEST SATISFIED' - if stp == self.stpmax and f <= ftest and g <= gtest: - self.task = 'WARNING: STP = maxstep' - if stp == self.stpmin and (f > ftest or g >= gtest): - self.task = 'WARNING: STP = minstep' - - # Test for convergence. - # if f <= ftest and abs(g) <= c2 * (- ginit): - # self.task = 'CONVERGENCE' - if (f < ftest or math.isclose(f, ftest, rel_tol=1e-6, abs_tol=1e-5)) and (abs(g) < c2 * (- ginit) or math.isclose(abs(g), c2 * (- ginit), rel_tol=1e-6, abs_tol=1e-5)): - self.task = 'CONVERGENCE' - - # Test for termination. - if self.task[:4] == 'WARN' or self.task[:4] == 'CONV': - self.save((stage, ginit, gtest, gx, - gy, finit, fx, fy, stx, sty, - stmin, stmax, width, width1)) - return stp - - stx, sty, stp, gx, fx, gy, fy = self.update(stx, fx, gx, sty, - fy, gy, stp, f, g, - stmin, stmax) - - # Decide if a bisection step is needed. - if self.bracket: - if abs(sty - stx) >= .66 * width1: - stp = stx + .5 * (sty - stx) - width1 = width - width = abs(sty - stx) - - # Set the minimum and maximum steps allowed for stp. - if self.bracket: - stmin = min(stx, sty) - stmax = max(stx, sty) - else: - stmin = stp + self.xtrapl * (stp - stx) - stmax = stp + self.xtrapu * (stp - stx) - - # Force the step to be within the bounds maxstep and minstep. - stp = max(stp, self.stpmin) - stp = min(stp, self.stpmax) - - if (stx == stp and stp == self.stpmax and stmin > self.stpmax): - self.no_update = True - - # If further progress is not possible, let stp be the best - # point obtained during the search. - if (self.bracket and stp < stmin or stp >= stmax) \ - or (self.bracket and stmax - stmin < self.xtol * stmax): - stp = stx - - # Obtain another function and derivative. - self.task = 'FG' - self.save((stage, ginit, gtest, gx, - gy, finit, fx, fy, stx, sty, - stmin, stmax, width, width1)) - return stp - - def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp, - stpmin, stpmax): - sign = gp * (gx / abs(gx)) - - # First case: A higher function value. The minimum is bracketed. - # If the cubic step is closer to stx than the quadratic step, the - # cubic step is taken, otherwise the average of the cubic and - # quadratic steps is taken. - if fp > fx: # case1 - self.case = 1 - theta = 3. * (fx - fp) / (stp - stx) + gx + gp - s = max(max(abs(theta), abs(gx)), abs(gp)) - gamma = s * math.sqrt((theta / s) ** 2. - (gx / s) * (gp / s)) - if stp < stx: - gamma = -gamma - p = (gamma - gx) + theta - q = ((gamma - gx) + gamma) + gp - r = p / q - stpc = stx + r * (stp - stx) - stpq = stx + ((gx / ((fx - fp) / (stp - stx) + gx)) / 2.) \ - * (stp - stx) - if (abs(stpc - stx) < abs(stpq - stx)): - stpf = stpc - else: - stpf = stpc + (stpq - stpc) / 2. - - self.bracket = True - - # Second case: A lower function value and derivatives of opposite - # sign. The minimum is bracketed. If the cubic step is farther from - # stp than the secant step, the cubic step is taken, otherwise the - # secant step is taken. - elif sign < 0: # case2 - self.case = 2 - theta = 3. * (fx - fp) / (stp - stx) + gx + gp - s = max(max(abs(theta), abs(gx)), abs(gp)) - gamma = s * math.sqrt((theta / s) ** 2 - (gx / s) * (gp / s)) - if stp > stx: - gamma = -gamma - p = (gamma - gp) + theta - q = ((gamma - gp) + gamma) + gx - r = p / q - stpc = stp + r * (stx - stp) - stpq = stp + (gp / (gp - gx)) * (stx - stp) - if (abs(stpc - stp) > abs(stpq - stp)): - stpf = stpc - else: - stpf = stpq - self.bracket = True - - # Third case: A lower function value, derivatives of the same sign, - # and the magnitude of the derivative decreases. - elif abs(gp) < abs(gx): # case3 - self.case = 3 - # The cubic step is computed only if the cubic tends to infinity - # in the direction of the step or if the minimum of the cubic - # is beyond stp. Otherwise the cubic step is defined to be the - # secant step. - theta = 3. * (fx - fp) / (stp - stx) + gx + gp - s = max(max(abs(theta), abs(gx)), abs(gp)) - - # The case gamma = 0 only arises if the cubic does not tend - # to infinity in the direction of the step. - gamma = s * math.sqrt(max(0., (theta / s) ** 2 - (gx / s) * (gp / s))) - if stp > stx: - gamma = -gamma - p = (gamma - gp) + theta - q = (gamma + (gx - gp)) + gamma - r = p / q - if r < 0. and gamma != 0: - stpc = stp + r * (stx - stp) - elif stp > stx: - stpc = stpmax - else: - stpc = stpmin - stpq = stp + (gp / (gp - gx)) * (stx - stp) - - if self.bracket: - # A minimizer has been bracketed. If the cubic step is - # closer to stp than the secant step, the cubic step is - # taken, otherwise the secant step is taken. - if abs(stpc - stp) < abs(stpq - stp): - stpf = stpc - else: - stpf = stpq - if stp > stx: - stpf = min(stp + .66 * (sty - stp), stpf) - else: - stpf = max(stp + .66 * (sty - stp), stpf) - else: - # A minimizer has not been bracketed. If the cubic step is - # farther from stp than the secant step, the cubic step is - # taken, otherwise the secant step is taken. - if abs(stpc - stp) > abs(stpq - stp): - stpf = stpc - else: - stpf = stpq - stpf = min(stpmax, stpf) - stpf = max(stpmin, stpf) - - # Fourth case: A lower function value, derivatives of the same sign, - # and the magnitude of the derivative does not decrease. If the - # minimum is not bracketed, the step is either minstep or maxstep, - # otherwise the cubic step is taken. - else: # case4 - self.case = 4 - if self.bracket: - theta = 3. * (fp - fy) / (sty - stp) + gy + gp - s = max(max(abs(theta), abs(gy)), abs(gp)) - gamma = s * math.sqrt((theta / s) ** 2 - (gy / s) * (gp / s)) - if stp > sty: - gamma = -gamma - p = (gamma - gp) + theta - q = ((gamma - gp) + gamma) + gy - r = p / q - stpc = stp + r * (sty - stp) - stpf = stpc - elif stp > stx: - stpf = stpmax - else: - stpf = stpmin - - # Update the interval which contains a minimizer. - if fp > fx: - sty = stp - fy = fp - gy = gp - else: - if sign < 0: - sty = stx - fy = fx - gy = gx - stx = stp - fx = fp - gx = gp - - # Compute the new step. - stp = self.determine_step(stpf) - - return stx, sty, stp, gx, fx, gy, fy - - def determine_step(self, stp): - dr = stp - self.old_stp - x = torch.reshape(self.pk.to(self.device), (-1, 3)) - steplengths = ((dr * x)**2).sum(1)**0.5 - maxsteplength = max(steplengths) - if maxsteplength >= self.maxstep: - dr *= self.maxstep / maxsteplength - stp = self.old_stp + dr - return stp - - def save(self, data): - if self.bracket: - self.isave[0] = 1 - else: - self.isave[0] = 0 - self.isave[1] = data[0] - self.dsave = data[1:] - -class LineSearchBatch: - - def __init__(self, batch_indices, device='cpu', dtype=torch.float64): - self.device = device - self.dtype = dtype - self.batch_indices = batch_indices.to(self.device) - self.batch_indices_flatten = self.batch_indices.repeat_interleave(3).to(self.device) - self.batch_size = len(torch.unique(batch_indices)) - self.linesearch_list = [LineSearch(device=self.device, dtype=self.dtype) for _ in range(self.batch_size)] - self.steps = [1.] * self.batch_size - self.phi0_values = [None] * self.batch_size - self.derphi0_values = [None] * self.batch_size - - def restart_from_earlystop(self, restart_indices, batch_indices_new): - self.batch_indices = batch_indices_new.to(self.device) - self.batch_indices_flatten = self.batch_indices.repeat_interleave(3).to(self.device) - self.batch_size = len(torch.unique(batch_indices_new)) - - linesearch_list_new = [] - steps_new = [] - phi0_values_new = [] - derphi0_values_new = [] - - for i, idx in enumerate(restart_indices): - linesearch_list_new.append(self.linesearch_list[idx]) - steps_new.append(self.steps[idx]) - phi0_values_new.append(self.phi0_values[idx]) - derphi0_values_new.append(self.derphi0_values[idx]) - - for i in range(len(restart_indices), self.batch_size): - linesearch_list_new.append(LineSearch(device=self.device)) - steps_new.append(1.) - phi0_values_new.append(None) - derphi0_values_new.append(None) - - self.linesearch_list = linesearch_list_new - self.steps = steps_new - self.phi0_values = phi0_values_new - self.derphi0_values = derphi0_values_new - - - - def _linesearch_batch(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, - maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., - stpmax=50., stpmin=1e-8, continue_search=None, max_iter=15): - if continue_search is None: - self.linesearch_list = [LineSearch(device=self.device) for _ in range(self.batch_size)] - else: - assert len(continue_search) == self.batch_size - for i in range(len(continue_search)): - if not continue_search[i]: - self.linesearch_list[i] = LineSearch(device=self.device) - - if isinstance(xk, torch.Tensor): - xk = xk.to(self.device) - for i in range(len(pk)): - pk[i] = pk[i].to(self.device) - if isinstance(gfk, torch.Tensor): - gfk = gfk.to(self.device) - if isinstance(old_fval, torch.Tensor): - old_fval = old_fval.to(self.device) - if isinstance(old_old_fval, torch.Tensor): - old_old_fval = old_old_fval.to(self.device) - - - # results for each batch element - alpha_results = [] - e_result = [] - e0_result = [] - no_update_result = [] - - # Initialize step sizes and line search state for each batch element - completed = [False] * self.batch_size - - # Initialize iteration counter - iter_count = 0 - - # Initialize all line searches using the initialize method - for i in range(self.batch_size): - if continue_search[i]: - continue - - ls = self.linesearch_list[i] - mask = (i == self.batch_indices_flatten) - - # Use the initialize method to set up line search parameters - alpha1, phi0, derphi0 = ls.initialize( - xk[mask], pk[i], gfk[mask], old_fval[i], old_old_fval, - maxstep, c1, c2, xtrapl, xtrapu, stpmax, stpmin - ) - - # Store the initialization values - self.steps[i] = alpha1 - self.phi0_values[i] = phi0 - self.derphi0_values[i] = derphi0 - - # Main optimization loop - while True: - # 1. step forward - # logging.info(f"step's input: alpha1: {torch.tensor([step.item() if isinstance(step, torch.Tensor) else step for step in self.steps])}") - for i in range(self.batch_size): - if completed[i]: - continue - ls = self.linesearch_list[i] - if ls.fc > max_iter: - completed[i] = True - logging.warning(f"LineSearchBatch[{i}] reached max_iter: {max_iter}") - continue - stp = ls.step(self.steps[i], self.phi0_values[i], self.derphi0_values[i], - c1, c2, ls.xtol, ls.isave, ls.dsave) - if ls.task[:2] == 'FG': - self.steps[i] = stp - else: - completed[i] = True - - # 2. calculate new function value and gradient - x_new_batch = torch.zeros_like(xk) - for i in range(self.batch_size): - mask = (i == self.batch_indices_flatten) - x_new_batch[mask] = xk[mask] + self.steps[i] * pk[i] - f_batch = func(x_new_batch).to(self.device) - g_batch = myfprime(x_new_batch).to(self.device) - - # 3. update function value and gradient - for i in range(self.batch_size): - ls = self.linesearch_list[i] - mask = (i == self.batch_indices_flatten) - if ls.task[:2] == 'FG': - # Update function value and gradient - f_val = f_batch[i:i+1] - g_val = g_batch[mask] - ls.fc += 1 - phi0, derphi0 = ls.prologue(f_val, g_val, pk[i], self.steps[i]) - # logging.info(f"phi0, derphi0: {phi0}, {derphi0}") - self.phi0_values[i] = phi0 - self.derphi0_values[i] = derphi0 # TODO: why we put the derphi0 here instead of set it inside the LineSearch class? - if ls.no_update: - completed[i] = True - else: - completed[i] = True - - iter_count += 1 - logging.info(f"LineSearchBatch iter: {iter_count}: alpha: {torch.tensor([step.item() if isinstance(step, torch.Tensor) else step for step in self.steps])}") - if any(completed): - break - - # 4. set a linesearch upper limit - # if iter_count > max_iter: - # for i in range(self.batch_size): - # completed[i] = True - # logging.warning(f"LineSearchBatch reached max_iter: {max_iter}") - # break - - # Collect results - for i in range(self.batch_size): - ls = self.linesearch_list[i] - if ls.task[:5] == 'ERROR' or ls.task[1:4] == 'WARN': - stp = torch.tensor(1., device=self.device) - else: - stp = self.steps[i] if isinstance(self.steps[i], torch.Tensor) else torch.tensor(self.steps[i], device=self.device) - - alpha_results.append(stp) - e_result.append(self.phi0_values[i].item() if self.phi0_values[i] is not None else None) - e0_result.append(old_fval[i].item() if isinstance(old_fval[i], torch.Tensor) else old_fval[i]) - no_update_result.append(ls.no_update) - - logging.info(f"LineSearchBatch finished in {iter_count} iterations. \ - LineSearch Status: {[stat for stat in completed]}") - - return alpha_results, e_result, e0_result, no_update_result, completed +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations +import logging +import torch +from torch_scatter import scatter +# from .linesearch_torch import LineSearchBatch +from ..optimizable import OptimizableBatch +from torch.profiler import profile, record_function, ProfilerActivity, schedule, tensorboard_trace_handler +from datetime import datetime +import os +import math +import gc + +class BFGSFusedLS: + """ + Port of BFGSLineSearch from bfgslinesearch.py, adapted to PyTorch + and batched operations, mirroring lbfgs_torch.py structure. + """ + def __init__( + self, + optimizable_batch: OptimizableBatch, + maxstep: float = 0.2, + c1: float = 0.23, + c2: float = 0.46, + alpha: float = 10.0, + stpmax: float = 50.0, + device = 'cpu', + early_stop: bool = False, + use_profiler: bool = False, + profiler_log_dir: str = './log', + profiler_schedule_config: dict = None, + dtype: torch.dtype = torch.float64, + ): + self.optimizable = optimizable_batch + self.maxstep = maxstep + self.c1 = c1 + self.c2 = c2 + self.alpha = alpha + self.stpmax = stpmax + self.nsteps = 0 + self.device = device + self.force_calls = 0 + self.early_stop = early_stop + self.use_profiler = use_profiler + self.profiler_log_dir = profiler_log_dir + self.profiler_schedule_config = profiler_schedule_config or {"wait": 48, "warmup": 1, "active": 1, "repeat": 8} + self.dtype = dtype + + self.converge_indices_list = None + + # The information from the previous round is useful for the current round's calculations. + ## These variables need to be update accroding to new input when eary stop is triggered. + self.Hs = None + self.r0 = None + self.g0 = None + self.p_list = [None] * self.optimizable.batch_size + self.no_update_list = [False] * self.optimizable.batch_size + self.ls_completed = [True] * self.optimizable.batch_size + self.ls_batch = LineSearchBatch(self.optimizable.batch_indices, device="cpu", dtype=self.dtype) + ## need to be recalculate when early stop is triggered + self.forces = None + self.energies = None + + def restart_from_earlystop(self, restart_indices, old_batch_indices): + Hs_new = [] + r0_new = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device) + g0_new = torch.zeros_like(r0_new, device=self.device) + p_list_new = [] + no_update_list_new = [] + ls_completed_new = [] + + # collect the preserved historical info by old_indices + for i, idx in enumerate(restart_indices): + mask_old = (idx==old_batch_indices.repeat_interleave(3)) + mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) + Hs_new.append(self.Hs[idx]) + p_list_new.append(self.p_list[idx]) + no_update_list_new.append(self.no_update_list[idx]) + ls_completed_new.append(self.ls_completed[idx]) + r0_new[mask] = self.r0[mask_old] + g0_new[mask] = self.g0[mask_old] + + # append new info for new element in batch + for i in range(len(Hs_new), self.optimizable.batch_size): + # Hs_new.append(torch.eye(3 * self.optimizable.elem_per_group[i], device=self.device, dtype=torch.float64)) + Hs_new.append(None) + p_list_new.append(None) + no_update_list_new.append(False) + ls_completed_new.append(True) + + self.Hs = Hs_new + self.r0 = r0_new + self.g0 = g0_new + self.p_list = p_list_new + self.no_update_list = no_update_list_new + self.ls_completed = ls_completed_new + self.forces = None + self.energies = None + self.ls_batch.restart_from_earlystop(restart_indices=restart_indices, batch_indices_new=self.optimizable.batch_indices) + + def step(self): + optimizable = self.optimizable + if self.forces is None: + self.forces = optimizable.get_forces().to(self.device) + r = optimizable.get_positions().reshape(-1).to(self.device) + g = -self.forces.reshape(-1) / self.alpha + p0_list = self.p_list + self.update(r, g, self.r0, self.g0, p0_list) + if self.energies is None: + self.energies = self.func(r) + + for i in range(self.optimizable.batch_size): + if self.ls_completed[i]: + p = -torch.matmul(self.Hs[i], g[i==self.optimizable.batch_indices.repeat_interleave(3)]) + + # Implement scaling for numerical stability with simpler calculation + p_size = torch.sqrt((p**2).sum()) + min_size = torch.sqrt(self.optimizable.elem_per_group[i] * 1e-10) + if p_size <= min_size: + p = p * (min_size / p_size) + + self.p_list[i] = p + + # ls_batch = LineSearchBatch(self.optimizable.batch_indices, device="cpu") + continue_search = [not elem for elem in self.ls_completed] + self.alpha_k_list, self.e_list, self.e0_list, self.no_update_list, self.ls_completed = self.ls_batch._linesearch_batch( + self.func, self.fprime, r, self.p_list, g, self.energies, None, + maxstep=self.maxstep, c1=self.c1, c2=self.c2, stpmax=self.stpmax, continue_search=continue_search + ) + + # reset device for linesearch result + for i in range(self.optimizable.batch_size): + if self.ls_completed[i]: + self.alpha_k_list[i] = self.alpha_k_list[i].to(self.device) + self.p_list[i] = self.p_list[i].to(self.device) + + dr_tensor = torch.zeros_like(r) + + + for i in range(self.optimizable.batch_size): + # if check_cache: + # mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) + # dr_tensor_all[mask] = self.alpha_k_list[i].to(self.device) * self.p_list[i].to(self.device) + + if not self.ls_completed[i]: + continue + if self.alpha_k_list[i] is None: + raise RuntimeError("LineSearch failed!") + + mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) + dr_tensor[mask] = self.alpha_k_list[i] * self.p_list[i] + + # if check_cache: + # cached_pos = optimizable.get_positions().reshape(-1).to(self.device) + # update_pos = r + dr_tensor_all + # assert torch.allclose(update_pos, cached_pos), "dr_tensor_cached should be equal to dr_tensor" + + + # TODO: get_forces/get_potential_energies will trigger compare_batch which is time-consuming + forces_cache = optimizable.get_forces() + energies_cache = self.optimizable.get_potential_energies() / self.alpha + + # update self.forces + for i in range(self.optimizable.batch_size): + if not self.ls_completed[i]: + continue + mask = (i == self.optimizable.batch_indices) + self.forces[mask] = forces_cache[mask] + self.energies[i] = energies_cache[i] + + optimizable.set_positions((r + dr_tensor).reshape(-1, 3)) + + self.r0 = r + self.g0 = g + + # @torch.compile + def update(self, r, g, r0, g0, p0_list): + all_sizes = self.optimizable.elem_per_group + + if self.Hs is None: + self.Hs = [ + torch.eye(3 * sz, device=self.device, dtype=self.dtype) + for sz in all_sizes + ] + return + + dr = r - r0 + dg = g - g0 + + for i in range(self.optimizable.batch_size): + if self.Hs[i] is None: + self.Hs[i] = torch.eye(3 * all_sizes[i], device=self.optimizable.device, dtype=self.dtype) + continue + if not self.ls_completed[i]: + continue + if self.no_update_list[i] is True: + print('skip update') + continue + + cur_mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) + cur_g = g[cur_mask] + cur_p0 = p0_list[i] + cur_g0 = g0[cur_mask] + cur_dg = dg[cur_mask] + cur_dr = dr[cur_mask] + + if not (((self.alpha_k_list[i] or 0) > 0 and + abs(torch.dot(cur_g, cur_p0)) - abs(torch.dot(cur_g0, cur_p0)) < 0) or False): + continue + + try: + rhok = 1.0 / (torch.dot(cur_dg, cur_dr)) + except: + rhok = 1000.0 + print("Divide-by-zero encountered: rhok assumed large") + if torch.isinf(rhok): + rhok = 1000.0 + print("Divide-by-zero encountered: rhok assumed large") + I = torch.eye(all_sizes[i]*3, device=self.device, dtype=self.dtype) + A1 = I - cur_dr[:, None] * cur_dg[None, :] * rhok + A2 = I - cur_dg[:, None] * cur_dr[None, :] * rhok + self.Hs[i] = (torch.matmul(A1, torch.matmul(self.Hs[i], A2)) + + rhok * cur_dr[:, None] * cur_dr[None, :]) + + + # def update(self, r, g, r0, g0, p0_list): + # self.Is = [ + # torch.eye(sz * 3, dtype=torch.float64, device=self.device) + # for sz in self.optimizable.elem_per_group + # ] + + # # TODO: BFGS for loop 是不是在被打断之后需要重建这个 self.Hs? + # # TODO: 并且我们保存的上一次的r,g,r0,g0也被丢弃了 + # if self.Hs is None: + # self.Hs = [ + # torch.eye(3 * sz, device=self.optimizable.device, dtype=torch.float64) + # for sz in self.optimizable.elem_per_group + # ] + # return + # else: + # dr = r - r0 + # dg = g - g0 + + # for i in range(self.optimizable.batch_size): + # if not self.ls_completed[i]: + # continue + # cur_mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) + # cur_g = g[cur_mask] + # cur_p0 = p0_list[i] + # cur_g0 = g0[cur_mask] + # cur_dg = dg[cur_mask] + # cur_dr = dr[cur_mask] + + # if not (((self.alpha_k_list[i] or 0) > 0 and + # abs(torch.dot(cur_g, cur_p0)) - abs(torch.dot(cur_g0, cur_p0)) < 0) or False): + # break + + # if self.no_update_list[i] is True: + # print('skip update') + # break + + # try: + # rhok = 1.0 / (torch.dot(cur_dg, cur_dr)) + # except: + # rhok = 1000.0 + # print("Divide-by-zero encountered: rhok assumed large") + # if torch.isinf(rhok): + # rhok = 1000.0 + # print("Divide-by-zero encountered: rhok assumed large") + # A1 = self.Is[i] - cur_dr[:, None] * cur_dg[None, :] * rhok + # A2 = self.Is[i] - cur_dg[:, None] * cur_dr[None, :] * rhok + # self.Hs[i] = (torch.matmul(A1, torch.matmul(self.Hs[i], A2)) + + # rhok * cur_dr[:, None] * cur_dr[None, :]) + + + + def func(self, x): + self.optimizable.set_positions(x.reshape(-1, 3).to(self.device)) + return self.optimizable.get_potential_energies() / self.alpha + + def fprime(self, x): + self.optimizable.set_positions(x.reshape(-1, 3).to(self.device)) + + self.force_calls += 1 + forces = self.optimizable.get_forces().reshape(-1) + return - forces / self.alpha + + def run(self, fmax, maxstep, is_restart_earlystop=False, restart_indices=None, old_batch_indices=None): + logging.info("Enter bfgsfusedlinesearch's main program.") + self.fmax = fmax + self.max_iter = maxstep + + if is_restart_earlystop: + self.restart_from_earlystop(restart_indices, old_batch_indices) + + iteration = 0 + max_forces = self.optimizable.get_max_forces(apply_constraint=True) + logging.info("Step Fmax(eV/A)") + + # Run with profiler if enabled + if self.use_profiler: + activities = [ProfilerActivity.CPU] + if torch.cuda.is_available(): + activities.append(ProfilerActivity.CUDA) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + pid = os.getpid() + with torch.profiler.profile( + activities=activities, + schedule=torch.profiler.schedule( + wait=self.profiler_schedule_config["wait"], + warmup=self.profiler_schedule_config["warmup"], + active=self.profiler_schedule_config["active"], + repeat=self.profiler_schedule_config["repeat"] + ), + on_trace_ready=tensorboard_trace_handler(self.profiler_log_dir, worker_name=f"BFGSLS_{pid}"), + with_stack=True, + profile_memory=True, + ) as prof: + # Main optimization loop with profiling + while iteration < self.max_iter and not self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, + ): + if self.early_stop and iteration > 0: + self.converge_indices_list = self.optimizable.converge_indices_list + if len(self.converge_indices_list) > 0: + logging.info(f"Early stopping at iteration {iteration}") + break + + logging.info( + f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) + ) + + self.step() + max_forces = self.optimizable.get_max_forces(apply_constraint=True, forces=self.forces) + iteration += 1 + + # Step the profiler in each iteration + prof.step() + + else: + # Original optimization loop without profiling + while iteration < self.max_iter and not self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, + ): + if self.early_stop and iteration > 0: + self.converge_indices_list = self.optimizable.converge_indices_list + if len(self.converge_indices_list) > 0: + logging.info(f"Early stopping at iteration {iteration}") + break + + logging.info( + f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) + ) + + self.step() + max_forces = self.optimizable.get_max_forces(apply_constraint=True, forces=self.forces) + iteration += 1 + + logging.info( + f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) + ) + + # GPU memory usage as per nvidia-smi seems to gradually build up as + # batches are processed. This releases unoccupied cached memory. + torch.cuda.empty_cache() + gc.collect() + + # set predicted values to batch + for name, value in self.optimizable.results.items(): + setattr(self.optimizable.batch, name, value) + + self.nsteps = iteration + + if self.early_stop: + self.converge_indices_list = self.optimizable.converge_indices_list + return self.converge_indices_list + else: + return self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces + ) + + def _batched_dot_2d(self, x: torch.Tensor, y: torch.Tensor): + return scatter( + (x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum" + ) + + def _batched_dot_1d(self, x: torch.Tensor, y: torch.Tensor): + return scatter( + (x * y), self.optimizable.batch_indices.repeat_interleave(3), reduce="sum" + ) + +# flake8: noqa +import math +import torch +import logging + +pymin = min +pymax = max + + +class LineSearch: + def __init__(self, xtol=1e-14, device='cpu', dtype=torch.float64): + self.xtol = xtol + self.task = 'START' + self.device = device + self.dtype = dtype + self.isave = torch.zeros(2, dtype=torch.int64, device=self.device) + self.dsave = torch.zeros(13, dtype=self.dtype, device=self.device) + self.fc = 0 + self.gc = 0 + self.case = 0 + self.old_stp = 0 + + def initialize(self, xk, pk, gfk, old_fval, old_old_fval, + maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., + stpmax=50., stpmin=1e-8): + # Scalar parameters can stay as Python scalars + self.stpmin = stpmin + self.stpmax = stpmax + self.xtrapl = xtrapl + self.xtrapu = xtrapu + self.maxstep = maxstep + + # Move tensors to the device + self.pk = pk.to(self.device) + xk = xk.to(self.device) + gfk = gfk.to(self.device) + + phi0 = old_fval + + + # This dot product needs tensors + derphi0 = torch.dot(gfk, self.pk).item() + + # Use Python math for scalar calculations + self.dim = len(pk) + self.gms = math.sqrt(self.dim) * maxstep + + alpha1 = 1.0 + self.no_update = False + self.gradient = True + + self.steps = [] + return alpha1, phi0, derphi0 + + def prologue(self, fval, gval, pk_tensor, alpha1): + phi0 = fval + derphi0 = torch.dot(gval, pk_tensor) + self.old_stp = alpha1 + # TODO: self.no_update == True: break is needed to reimplemented. + + return phi0, derphi0 + + def epilogue(self): + pass + + def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, + maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., + stpmax=50., stpmin=1e-8, args=()): + self.stpmin = stpmin + self.pk = pk.to(self.device) + self.stpmax = stpmax + self.xtrapl = xtrapl + self.xtrapu = xtrapu + self.maxstep = maxstep + + xk = xk.to(self.device) + + # Convert inputs to torch tensors if they're not already + if not isinstance(old_fval, torch.Tensor): + phi0 = torch.tensor(old_fval, dtype=self.dtype, device=self.device) + else: + phi0 = old_fval.to(self.device) + + # Ensure pk and gfk are torch tensors + pk_tensor = torch.tensor(pk, dtype=self.dtype, device=self.device) if not isinstance(pk, torch.Tensor) else pk.to(self.device) + gfk_tensor = torch.tensor(gfk, dtype=self.dtype, device=self.device) if not isinstance(gfk, torch.Tensor) else gfk.to(self.device) + + derphi0 = torch.dot(gfk_tensor, pk_tensor) + self.dim = len(pk) + self.gms = torch.sqrt(torch.tensor(self.dim, dtype=self.dtype, device=self.device)) * maxstep + alpha1 = 1. + self.no_update = False + + if isinstance(myfprime, tuple): + fprime = myfprime[0] + gradient = False + else: + fprime = myfprime + newargs = args + gradient = True + + fval = phi0 + gval = gfk_tensor + self.steps = [] + + while True: + stp = self.step(alpha1, phi0, derphi0, c1, c2, + self.xtol, + self.isave, self.dsave) + + if self.task[:2] == 'FG': + alpha1 = stp + + # Get function value and gradient + x_new = xk + stp * pk_tensor + fval = func(x_new).to(self.device) + self.fc += 1 + + gval = fprime(x_new).to(self.device) + if gradient: + self.gc += 1 + else: + self.fc += len(xk) + 1 + + phi0 = fval + derphi0 = torch.dot(gval, pk_tensor) + self.old_stp = alpha1 + + if self.no_update == True: + break + else: + break + + if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN': + stp = None # failed + + return stp, fval.item(), old_fval.item() if isinstance(old_fval, torch.Tensor) else old_fval, self.no_update + + def step(self, stp, f, g, c1, c2, xtol, isave, dsave): + if self.task[:5] == 'START': + # Check the input arguments for errors. + if stp < self.stpmin: + self.task = 'ERROR: STP .LT. minstep' + if stp > self.stpmax: + self.task = 'ERROR: STP .GT. maxstep' + if g >= 0: + self.task = 'ERROR: INITIAL G >= 0' + if c1 < 0: + self.task = 'ERROR: c1 .LT. 0' + if c2 < 0: + self.task = 'ERROR: c2 .LT. 0' + if xtol < 0: + self.task = 'ERROR: XTOL .LT. 0' + if self.stpmin < 0: + self.task = 'ERROR: minstep .LT. 0' + if self.stpmax < self.stpmin: + self.task = 'ERROR: maxstep .LT. minstep' + if self.task[:5] == 'ERROR': + return stp + + # Initialize local variables. + self.bracket = False + stage = 1 + finit = f + ginit = g + gtest = c1 * ginit + width = self.stpmax - self.stpmin + width1 = width / .5 + + # The variables stx, fx, gx contain the values of the step, + # function, and derivative at the best step. + # The variables sty, fy, gy contain the values of the step, + # function, and derivative at sty. + # The variables stp, f, g contain the values of the step, + # function, and derivative at stp. + stx = 0.0 + fx = finit + gx = ginit + sty = 0.0 + fy = finit + gy = ginit + stmin = 0.0 + stmax = stp + self.xtrapu * stp + self.task = 'FG' + self.save((stage, ginit, gtest, gx, + gy, finit, fx, fy, stx, sty, + stmin, stmax, width, width1)) + stp = self.determine_step(stp) + return stp + else: + if self.isave[0] == 1: + self.bracket = True + else: + self.bracket = False + stage = self.isave[1] + (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, + width, width1) = self.dsave + + # If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the + # algorithm enters the second stage. + ftest = finit + stp * gtest + if stage == 1 and f < ftest and g >= 0.: + stage = 2 + + # Test for warnings. + if self.bracket and (stp <= stmin or stp >= stmax): + self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS' + if self.bracket and stmax - stmin <= self.xtol * stmax: + self.task = 'WARNING: XTOL TEST SATISFIED' + if stp == self.stpmax and f <= ftest and g <= gtest: + self.task = 'WARNING: STP = maxstep' + if stp == self.stpmin and (f > ftest or g >= gtest): + self.task = 'WARNING: STP = minstep' + + # Test for convergence. + # if f <= ftest and abs(g) <= c2 * (- ginit): + # self.task = 'CONVERGENCE' + if (f < ftest or math.isclose(f, ftest, rel_tol=1e-6, abs_tol=1e-5)) and (abs(g) < c2 * (- ginit) or math.isclose(abs(g), c2 * (- ginit), rel_tol=1e-6, abs_tol=1e-5)): + self.task = 'CONVERGENCE' + + # Test for termination. + if self.task[:4] == 'WARN' or self.task[:4] == 'CONV': + self.save((stage, ginit, gtest, gx, + gy, finit, fx, fy, stx, sty, + stmin, stmax, width, width1)) + return stp + + stx, sty, stp, gx, fx, gy, fy = self.update(stx, fx, gx, sty, + fy, gy, stp, f, g, + stmin, stmax) + + # Decide if a bisection step is needed. + if self.bracket: + if abs(sty - stx) >= .66 * width1: + stp = stx + .5 * (sty - stx) + width1 = width + width = abs(sty - stx) + + # Set the minimum and maximum steps allowed for stp. + if self.bracket: + stmin = min(stx, sty) + stmax = max(stx, sty) + else: + stmin = stp + self.xtrapl * (stp - stx) + stmax = stp + self.xtrapu * (stp - stx) + + # Force the step to be within the bounds maxstep and minstep. + stp = max(stp, self.stpmin) + stp = min(stp, self.stpmax) + + if (stx == stp and stp == self.stpmax and stmin > self.stpmax): + self.no_update = True + + # If further progress is not possible, let stp be the best + # point obtained during the search. + if (self.bracket and stp < stmin or stp >= stmax) \ + or (self.bracket and stmax - stmin < self.xtol * stmax): + stp = stx + + # Obtain another function and derivative. + self.task = 'FG' + self.save((stage, ginit, gtest, gx, + gy, finit, fx, fy, stx, sty, + stmin, stmax, width, width1)) + return stp + + def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp, + stpmin, stpmax): + sign = gp * (gx / abs(gx)) + + # First case: A higher function value. The minimum is bracketed. + # If the cubic step is closer to stx than the quadratic step, the + # cubic step is taken, otherwise the average of the cubic and + # quadratic steps is taken. + if fp > fx: # case1 + self.case = 1 + theta = 3. * (fx - fp) / (stp - stx) + gx + gp + s = max(max(abs(theta), abs(gx)), abs(gp)) + gamma = s * math.sqrt((theta / s) ** 2. - (gx / s) * (gp / s)) + if stp < stx: + gamma = -gamma + p = (gamma - gx) + theta + q = ((gamma - gx) + gamma) + gp + r = p / q + stpc = stx + r * (stp - stx) + stpq = stx + ((gx / ((fx - fp) / (stp - stx) + gx)) / 2.) \ + * (stp - stx) + if (abs(stpc - stx) < abs(stpq - stx)): + stpf = stpc + else: + stpf = stpc + (stpq - stpc) / 2. + + self.bracket = True + + # Second case: A lower function value and derivatives of opposite + # sign. The minimum is bracketed. If the cubic step is farther from + # stp than the secant step, the cubic step is taken, otherwise the + # secant step is taken. + elif sign < 0: # case2 + self.case = 2 + theta = 3. * (fx - fp) / (stp - stx) + gx + gp + s = max(max(abs(theta), abs(gx)), abs(gp)) + gamma = s * math.sqrt((theta / s) ** 2 - (gx / s) * (gp / s)) + if stp > stx: + gamma = -gamma + p = (gamma - gp) + theta + q = ((gamma - gp) + gamma) + gx + r = p / q + stpc = stp + r * (stx - stp) + stpq = stp + (gp / (gp - gx)) * (stx - stp) + if (abs(stpc - stp) > abs(stpq - stp)): + stpf = stpc + else: + stpf = stpq + self.bracket = True + + # Third case: A lower function value, derivatives of the same sign, + # and the magnitude of the derivative decreases. + elif abs(gp) < abs(gx): # case3 + self.case = 3 + # The cubic step is computed only if the cubic tends to infinity + # in the direction of the step or if the minimum of the cubic + # is beyond stp. Otherwise the cubic step is defined to be the + # secant step. + theta = 3. * (fx - fp) / (stp - stx) + gx + gp + s = max(max(abs(theta), abs(gx)), abs(gp)) + + # The case gamma = 0 only arises if the cubic does not tend + # to infinity in the direction of the step. + gamma = s * math.sqrt(max(0., (theta / s) ** 2 - (gx / s) * (gp / s))) + if stp > stx: + gamma = -gamma + p = (gamma - gp) + theta + q = (gamma + (gx - gp)) + gamma + r = p / q + if r < 0. and gamma != 0: + stpc = stp + r * (stx - stp) + elif stp > stx: + stpc = stpmax + else: + stpc = stpmin + stpq = stp + (gp / (gp - gx)) * (stx - stp) + + if self.bracket: + # A minimizer has been bracketed. If the cubic step is + # closer to stp than the secant step, the cubic step is + # taken, otherwise the secant step is taken. + if abs(stpc - stp) < abs(stpq - stp): + stpf = stpc + else: + stpf = stpq + if stp > stx: + stpf = min(stp + .66 * (sty - stp), stpf) + else: + stpf = max(stp + .66 * (sty - stp), stpf) + else: + # A minimizer has not been bracketed. If the cubic step is + # farther from stp than the secant step, the cubic step is + # taken, otherwise the secant step is taken. + if abs(stpc - stp) > abs(stpq - stp): + stpf = stpc + else: + stpf = stpq + stpf = min(stpmax, stpf) + stpf = max(stpmin, stpf) + + # Fourth case: A lower function value, derivatives of the same sign, + # and the magnitude of the derivative does not decrease. If the + # minimum is not bracketed, the step is either minstep or maxstep, + # otherwise the cubic step is taken. + else: # case4 + self.case = 4 + if self.bracket: + theta = 3. * (fp - fy) / (sty - stp) + gy + gp + s = max(max(abs(theta), abs(gy)), abs(gp)) + gamma = s * math.sqrt((theta / s) ** 2 - (gy / s) * (gp / s)) + if stp > sty: + gamma = -gamma + p = (gamma - gp) + theta + q = ((gamma - gp) + gamma) + gy + r = p / q + stpc = stp + r * (sty - stp) + stpf = stpc + elif stp > stx: + stpf = stpmax + else: + stpf = stpmin + + # Update the interval which contains a minimizer. + if fp > fx: + sty = stp + fy = fp + gy = gp + else: + if sign < 0: + sty = stx + fy = fx + gy = gx + stx = stp + fx = fp + gx = gp + + # Compute the new step. + stp = self.determine_step(stpf) + + return stx, sty, stp, gx, fx, gy, fy + + def determine_step(self, stp): + dr = stp - self.old_stp + x = torch.reshape(self.pk.to(self.device), (-1, 3)) + steplengths = ((dr * x)**2).sum(1)**0.5 + maxsteplength = max(steplengths) + if maxsteplength >= self.maxstep: + dr *= self.maxstep / maxsteplength + stp = self.old_stp + dr + return stp + + def save(self, data): + if self.bracket: + self.isave[0] = 1 + else: + self.isave[0] = 0 + self.isave[1] = data[0] + self.dsave = data[1:] + +class LineSearchBatch: + + def __init__(self, batch_indices, device='cpu', dtype=torch.float64): + self.device = device + self.dtype = dtype + self.batch_indices = batch_indices.to(self.device) + self.batch_indices_flatten = self.batch_indices.repeat_interleave(3).to(self.device) + self.batch_size = len(torch.unique(batch_indices)) + self.linesearch_list = [LineSearch(device=self.device, dtype=self.dtype) for _ in range(self.batch_size)] + self.steps = [1.] * self.batch_size + self.phi0_values = [None] * self.batch_size + self.derphi0_values = [None] * self.batch_size + + def restart_from_earlystop(self, restart_indices, batch_indices_new): + self.batch_indices = batch_indices_new.to(self.device) + self.batch_indices_flatten = self.batch_indices.repeat_interleave(3).to(self.device) + self.batch_size = len(torch.unique(batch_indices_new)) + + linesearch_list_new = [] + steps_new = [] + phi0_values_new = [] + derphi0_values_new = [] + + for i, idx in enumerate(restart_indices): + linesearch_list_new.append(self.linesearch_list[idx]) + steps_new.append(self.steps[idx]) + phi0_values_new.append(self.phi0_values[idx]) + derphi0_values_new.append(self.derphi0_values[idx]) + + for i in range(len(restart_indices), self.batch_size): + linesearch_list_new.append(LineSearch(device=self.device)) + steps_new.append(1.) + phi0_values_new.append(None) + derphi0_values_new.append(None) + + self.linesearch_list = linesearch_list_new + self.steps = steps_new + self.phi0_values = phi0_values_new + self.derphi0_values = derphi0_values_new + + + + def _linesearch_batch(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, + maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., + stpmax=50., stpmin=1e-8, continue_search=None, max_iter=15): + if continue_search is None: + self.linesearch_list = [LineSearch(device=self.device) for _ in range(self.batch_size)] + else: + assert len(continue_search) == self.batch_size + for i in range(len(continue_search)): + if not continue_search[i]: + self.linesearch_list[i] = LineSearch(device=self.device) + + if isinstance(xk, torch.Tensor): + xk = xk.to(self.device) + for i in range(len(pk)): + pk[i] = pk[i].to(self.device) + if isinstance(gfk, torch.Tensor): + gfk = gfk.to(self.device) + if isinstance(old_fval, torch.Tensor): + old_fval = old_fval.to(self.device) + if isinstance(old_old_fval, torch.Tensor): + old_old_fval = old_old_fval.to(self.device) + + + # results for each batch element + alpha_results = [] + e_result = [] + e0_result = [] + no_update_result = [] + + # Initialize step sizes and line search state for each batch element + completed = [False] * self.batch_size + + # Initialize iteration counter + iter_count = 0 + + # Initialize all line searches using the initialize method + for i in range(self.batch_size): + if continue_search[i]: + continue + + ls = self.linesearch_list[i] + mask = (i == self.batch_indices_flatten) + + # Use the initialize method to set up line search parameters + alpha1, phi0, derphi0 = ls.initialize( + xk[mask], pk[i], gfk[mask], old_fval[i], old_old_fval, + maxstep, c1, c2, xtrapl, xtrapu, stpmax, stpmin + ) + + # Store the initialization values + self.steps[i] = alpha1 + self.phi0_values[i] = phi0 + self.derphi0_values[i] = derphi0 + + # Main optimization loop + while True: + # 1. step forward + # logging.info(f"step's input: alpha1: {torch.tensor([step.item() if isinstance(step, torch.Tensor) else step for step in self.steps])}") + for i in range(self.batch_size): + if completed[i]: + continue + ls = self.linesearch_list[i] + if ls.fc > max_iter: + completed[i] = True + logging.warning(f"LineSearchBatch[{i}] reached max_iter: {max_iter}") + continue + stp = ls.step(self.steps[i], self.phi0_values[i], self.derphi0_values[i], + c1, c2, ls.xtol, ls.isave, ls.dsave) + if ls.task[:2] == 'FG': + self.steps[i] = stp + else: + completed[i] = True + + # 2. calculate new function value and gradient + x_new_batch = torch.zeros_like(xk) + for i in range(self.batch_size): + mask = (i == self.batch_indices_flatten) + x_new_batch[mask] = xk[mask] + self.steps[i] * pk[i] + f_batch = func(x_new_batch).to(self.device) + g_batch = myfprime(x_new_batch).to(self.device) + + # 3. update function value and gradient + for i in range(self.batch_size): + ls = self.linesearch_list[i] + mask = (i == self.batch_indices_flatten) + if ls.task[:2] == 'FG': + # Update function value and gradient + f_val = f_batch[i:i+1] + g_val = g_batch[mask] + ls.fc += 1 + phi0, derphi0 = ls.prologue(f_val, g_val, pk[i], self.steps[i]) + # logging.info(f"phi0, derphi0: {phi0}, {derphi0}") + self.phi0_values[i] = phi0 + self.derphi0_values[i] = derphi0 # TODO: why we put the derphi0 here instead of set it inside the LineSearch class? + if ls.no_update: + completed[i] = True + else: + completed[i] = True + + iter_count += 1 + logging.info(f"LineSearchBatch iter: {iter_count}: alpha: {torch.tensor([step.item() if isinstance(step, torch.Tensor) else step for step in self.steps])}") + if any(completed): + break + + # 4. set a linesearch upper limit + # if iter_count > max_iter: + # for i in range(self.batch_size): + # completed[i] = True + # logging.warning(f"LineSearchBatch reached max_iter: {max_iter}") + # break + + # Collect results + for i in range(self.batch_size): + ls = self.linesearch_list[i] + if ls.task[:5] == 'ERROR' or ls.task[1:4] == 'WARN': + stp = torch.tensor(1., device=self.device) + else: + stp = self.steps[i] if isinstance(self.steps[i], torch.Tensor) else torch.tensor(self.steps[i], device=self.device) + + alpha_results.append(stp) + e_result.append(self.phi0_values[i].item() if self.phi0_values[i] is not None else None) + e0_result.append(old_fval[i].item() if isinstance(old_fval[i], torch.Tensor) else old_fval[i]) + no_update_result.append(ls.no_update) + + logging.info(f"LineSearchBatch finished in {iter_count} iterations. \ + LineSearch Status: {[stat for stat in completed]}") + + return alpha_results, e_result, e0_result, no_update_result, completed diff --git a/mace-bench/src/batchopt/relaxengine.py b/mace-bench/src/batchopt/relaxengine.py index 09c3dae75f0477c38761fb3946134ec6130b0149..617145fe31a036e90355fc896c4a6e545fc04e79 100644 --- a/mace-bench/src/batchopt/relaxengine.py +++ b/mace-bench/src/batchopt/relaxengine.py @@ -1,1433 +1,1433 @@ -""" -Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan} - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from ase.io import read - -# from ase.optimize import ASE_LBFGS -import torch -from torch.multiprocessing import Process, set_start_method -from batchopt.atoms_to_graphs import AtomsToGraphs -from batchopt.utils import data_list_collater -from batchopt.relaxation.optimizers import ( - BFGS, - BFGSFusedLS, -) -from batchopt.relaxation import OptimizableBatch, OptimizableUnitCellBatch -import logging -import time -import csv -from multiprocessing import Queue -import os -import psutil -import multiprocessing -import json -import subprocess - -try: - from chgnet.model.dynamics import CHGNetCalculator -except ImportError: - logging.warning("Failed to import CHGNet modules") - -try: - from sevenn.calculator import SevenNetCalculator, SevenNetD3Calculator -except ImportError: - logging.warning("Failed to import SevenNet modules") - -try: - from fairchem.core import pretrained_mlip, FAIRChemCalculator -except ImportError: - logging.warning("Failed to import FAIRChem modules") - -try: - from mace.calculators import mace_off -except ImportError: - logging.warning("Failed to import MACE modules") - -import threading -from .utils import count_atoms_cif -from collections import deque - - -class Scheduler: - """ - Scheduler distributes relaxation tasks to workers. - """ - - def __init__( - self, - files, - num_workers, - devices, - batch_size, - max_steps, - filter1, - filter2, - optimizer1, - optimizer2, - skip_second_stage, - scalar_pressure, - compile_mode, - profile, - num_threads, - bind_cores, - cueq, - molecule_single, - output_path, - model, - ): - - self.files = files - self.num_workers = num_workers - self.devices = devices - self.batch_size = batch_size - self.max_steps = max_steps - self.filter1 = filter1 - self.filter2 = filter2 - self.optimizer1 = optimizer1 - self.optimizer2 = optimizer2 - self.skip_second_stage = skip_second_stage - self.scalar_pressure = scalar_pressure - self.compile_mode = compile_mode - self.profile = profile - self.num_threads = num_threads - self.cueq = cueq - self.molecule_single = molecule_single - self.output_path = ( - output_path - if os.path.isabs(output_path) - else os.path.abspath(output_path) - ) - self.model = model - - try: - set_start_method("spawn") - except RuntimeError: - logging.warning( - "set_start_method('spawn') failed, trying 'forkserver' instead." - ) - - if bind_cores is not None: - self.cpu_mask = self._parse_bind_cores(bind_cores) - else: - self.cpu_mask = None - - def _parse_bind_cores(self, bind_cores): - # Expect custom_bind_str to be like "0-15,16-31,..." - ranges = bind_cores.split(",") - if len(ranges) != self.num_workers: - return None - binding = [] - for r in ranges: - try: - start_str, end_str = r.split("-") - start = int(start_str) - end = int(end_str) - except ValueError: - logging.error("Custom binding format should be 'start-end'.") - return None - - binding.append(set(range(start, end + 1))) - return binding - - def _get_physical_logical_core_mapping(self): - """Get the mapping between logical cores and their physical core IDs.""" - try: - # This information is available in Linux systems - mapping = {} - logical_cores = psutil.cpu_count(logical=True) - - for i in range(logical_cores): - try: - # Read core_id from /sys/devices/system/cpu/cpu{i}/topology/core_id - with open( - f"/sys/devices/system/cpu/cpu{i}/topology/core_id" - ) as f: - core_id = int(f.read().strip()) - # Read physical_package_id (socket) for more complete information - with open( - f"/sys/devices/system/cpu/cpu{i}/topology/physical_package_id" - ) as f: - package_id = int(f.read().strip()) - mapping[i] = (package_id, core_id) - except (FileNotFoundError, ValueError, IOError): - mapping[i] = None - return mapping - except Exception as e: - logging.error(f"Failed to get core mapping: {e}") - return {} - - def _get_physical_core_mask(self): - # Get the number of physical and logical cores - physical_cores = psutil.cpu_count(logical=False) - logical_cores = psutil.cpu_count(logical=True) - - if physical_cores is None or physical_cores < 1: - # Fallback to multiprocessing if psutil fails - logical_cores = multiprocessing.cpu_count() - physical_cores = logical_cores // 2 - if physical_cores < 1: - physical_cores = 1 - print(f"Using estimated physical cores: {physical_cores}") - - # Get the mapping between logical and physical cores - core_mapping = self._get_physical_logical_core_mapping() - - # Create a CPU mask that includes all physical cores (first core of each physical core) - physical_core_mask = set() - if core_mapping: - # Group by physical core ID - cores_by_physical = {} - for logical_id, physical_info in core_mapping.items(): - if physical_info is not None: - package_id, core_id = physical_info - key = (package_id, core_id) - if key not in cores_by_physical: - cores_by_physical[key] = [] - cores_by_physical[key].append(logical_id) - - # Select one logical core from each physical core - for physical_cores_list in cores_by_physical.values(): - physical_core_mask.add( - physical_cores_list[0] - ) # First logical core of each physical core - else: - # If mapping fails, use a simple assumption (may not be accurate on all systems) - threads_per_core = logical_cores // physical_cores - physical_core_mask = set(range(0, logical_cores, threads_per_core)) - - return physical_core_mask - - def worker_task( - self, files, device, batch_size, result_queue, physical_cores - ): - if physical_cores is not None: - try: - # Bind the current process to physical cores - pid = os.getpid() - os.sched_setaffinity(pid, physical_cores) - logging.info(f"bind to physical_core_ids: {physical_cores}") - - # Verify the affinity was set correctly - current_affinity = os.sched_getaffinity(pid) - logging.info( - f"Process bound to {len(current_affinity)} cores: {sorted(current_affinity)}" - ) - - except AttributeError: - logging.error( - "sched_setaffinity not supported on this platform" - ) - except Exception as e: - logging.error(f"Failed to bind to physical cores: {e}") - - # pass the number of processes on each worker - nproc = self.num_workers // len(self.devices) - - worker = Worker( - files, - device, - batch_size, - self.max_steps, - self.filter1, - self.filter2, - self.optimizer1, - self.optimizer2, - self.skip_second_stage, - self.scalar_pressure, - self.compile_mode, - self.profile, - self.cueq, - self.molecule_single, - self.output_path, - self.model, - nproc, - ) - # results = worker.run() - results = worker.continuous_run() - result_queue.put(results) - - def _terminate_processes(self, processes): - """Helper method to terminate all processes.""" - for i, p in processes: - if p.is_alive(): - logging.info(f"Terminating process {p.pid}") - p.terminate() - p.join(timeout=3) # Wait for up to 3 seconds - if p.is_alive(): - logging.warning( - f"Process {p.pid} did not terminate, killing it" - ) - p.kill() - p.join() - - # create a thread to conduct "nvidia-smi" - @staticmethod - def _monitor_memory(interval=2, gpu_index=1): - try: - while True: - result = subprocess.check_output( - [ - "nvidia-smi", - "--query-gpu=memory.used,memory.total", - "--format=csv,nounits,noheader", - ] - ).decode("utf-8") - - lines = result.strip().split("\n") - used, total = map(int, lines[gpu_index].split(",")) - logging.info( - f"[nvidia-smi] Memory-Usage on GPU {gpu_index}: {used}MiB / {total}MiB" - ) - - time.sleep(interval) - except KeyboardInterrupt: - logging.info("Monitor interrupted.") - - except Exception as e: - logging.error(f"Unexpected error when monitor memory: {str(e)}") - - def run(self): - logging.info(f"Starting Scheduler with {self.num_workers} workers.") - processes = [] - result_queue = Queue() - start_time = time.perf_counter() - - if self.cpu_mask is not None: - physical_cores_per_worker = self.cpu_mask - logging.info( - f"Use customed cores binding. Physical cores per worker: {physical_cores_per_worker}" - ) - else: - # all_physical_cores = self._get_physical_core_mask() - # num_per_worker = len(all_physical_cores) // self.num_workers - # physical_cores_per_worker = [ - # list(all_physical_cores)[i:i + num_per_worker] for i in range(0, len(all_physical_cores), num_per_worker) - # ] - # logging.info(f"Physical cores per worker: {physical_cores_per_worker}") - physical_cores_per_worker = [None] * self.num_workers - - try: - # Start all worker processes - for i in range(self.num_workers): - files_for_worker = self.files[i :: self.num_workers] - device = self.devices[i % len(self.devices)] - logging.info( - f"Starting worker {i} with {len(files_for_worker)} files on device {device}." - ) - p = Process( - target=self.worker_task, - args=( - files_for_worker, - device, - self.batch_size, - result_queue, - physical_cores_per_worker[i], - ), - ) - p.start() - processes.append((i, p)) - - # monitor gpu memory usage to figure out what makes the differences of footprint among batches - # in each iteration. - use_memory_monitor = False - if use_memory_monitor: - monitor_proc = Process( - target=Scheduler._monitor_memory, args=() - ) - monitor_proc.start() - - # Monitor processes and collect results - csv_paths = [] - completed_processes = 0 - while completed_processes < self.num_workers: - for i, p in processes: - if not p.is_alive() and p.exitcode != 0: - if p.exitcode == -11 or p.exitcode == 1: - # Restart the process if exit code is -11 or -1 - logging.warning( - f"Worker process {p.pid} exited with code {p.exitcode}. Restarting worker {i}." - ) - files_for_worker = self.files[i :: self.num_workers] - device = self.devices[i % len(self.devices)] - new_process = Process( - target=self.worker_task, - args=( - files_for_worker, - device, - self.batch_size, - result_queue, - physical_cores_per_worker[i], - ), - ) - new_process.start() - processes[i] = ( - i, - new_process, - ) # Replace the old process with the new one - else: - # Raise an error for other exit codes - raise RuntimeError( - f"Worker process {p.pid} failed with exit code {p.exitcode}" - ) - - # Try to get result from queue with timeout - try: - result = result_queue.get(timeout=10) - csv_paths.append(result) - completed_processes += 1 - except Exception as e: - continue - - # terminate monitor - if use_memory_monitor: - monitor_proc.terminate() - monitor_proc.join() - - # Process results and create final CSV - merged_results = [] - for csv_path in csv_paths: - try: - with open(csv_path, mode="r") as f: - reader = csv.DictReader(f) - merged_results.extend(list(reader)) - except Exception as e: - logging.error(f"Error processing {csv_path}: {str(e)}") - - except Exception as e: - # Log the error and elapsed time - end_time = time.perf_counter() - elapsed_time = end_time - start_time - logging.error( - f"Error occurred after running for {elapsed_time:.2f} seconds: {str(e)}" - ) - - # Create error log file - error_log = f"scheduler_error_{int(time.time())}.log" - with open(error_log, "w") as f: - f.write(f"Error occurred after {elapsed_time:.2f} seconds\n") - f.write(f"Error message: {str(e)}\n") - f.write(f"Number of workers: {self.num_workers}\n") - f.write(f"Batch size: {self.batch_size}\n") - - # Terminate all processes - self._terminate_processes(processes) - raise # Re-raise the exception after cleanup - - finally: - end_time = time.perf_counter() - elapsed_time = end_time - start_time - - # Write final results if we have any - if "merged_results" in locals() and merged_results: - csv_file = os.path.join( - self.output_path, "results_scheduler.csv" - ) - with open(csv_file, mode="w", newline="") as file: - writer = csv.DictWriter( - file, - fieldnames=[ - "file", - "stage1_steps", - "stage1_time", - "stage1_energy", - "stage1_density", - "stage2_steps", - "stage2_time", - "stage2_energy", - "stage2_density", - "total_steps", - "total_time", - ], - ) - writer.writeheader() - for row in merged_results: - try: - processed_row = { - "file": row["file"], - "stage1_steps": int(row["stage1_steps"]), - "stage1_time": float(row["stage1_time"]), - "stage1_energy": float(row["stage1_energy"]), - "stage1_density": float(row["stage1_density"]), - "stage2_steps": int(row["stage2_steps"]), - "stage2_time": float(row["stage2_time"]), - "stage2_energy": float(row["stage2_energy"]), - "stage2_density": float(row["stage2_density"]), - "total_steps": int(row["total_steps"]), - "total_time": float(row["total_time"]), - } - writer.writerow(processed_row) - except (KeyError, ValueError) as e: - logging.error( - f"Invalid data format in row {row}: {str(e)}" - ) - - # Write summary - summary_csv_file = os.path.join( - self.output_path, "summary_scheduler.csv" - ) - with open(summary_csv_file, mode="w", newline="") as file: - writer = csv.DictWriter( - file, - fieldnames=["elapsed_time", "num_workers", "batch_size"], - ) - writer.writeheader() - writer.writerow( - { - "elapsed_time": elapsed_time, - "num_workers": self.num_workers, - "batch_size": self.batch_size, - } - ) - - logging.info(f"Scheduler completed in {elapsed_time:.2f} seconds.") - - def run_debug(self): - logging.info("Starting Scheduler in debug mode (sequential execution).") - - def worker_task(files, device, batch_size): - worker = Worker( - files, device, batch_size, self.max_steps, self.filter1 - ) - worker.run() - - for i in range(self.num_workers): - files_for_worker = self.files[i :: self.num_workers] - device = self.devices[i % len(self.devices)] - logging.info( - f"Running worker {i} with {len(files_for_worker)} files on device {device}." - ) - worker_task(files_for_worker, device, self.batch_size) - - logging.info("All workers have completed their tasks in debug mode.") - - -class Worker: - """ - Worker is single process that runs a batch of optimization tasks. - """ - - def __init__( - self, - files, - device, - batch_size, - max_steps, - filter1, - filter2, - optimizer1, - optimizer2, - skip_second_stage, - scalar_pressure, - compile_mode, - profile, - cueq, - molecule_single, - output_path, - model, - nproc, - ): - self.files = files - self.device = device - self.batch_size = batch_size - self.max_steps = max_steps - self.filter1 = filter1 - self.filter2 = filter2 - self.optimizer1 = optimizer1 - self.optimizer2 = optimizer2 - self.skip_second_stage = skip_second_stage # Store skip_second_stage - self.scalar_pressure = scalar_pressure - self.compile_mode = compile_mode - self.profile = profile - self.cueq = cueq - self.molecule_single = molecule_single - self.output_path = ( - output_path - if os.path.isabs(output_path) - else os.path.abspath(output_path) - ) - self.model = model - self.nproc = nproc - - # Parse profiler options if provided - self.use_profiler = False - self.profiler_schedule_config = { - "wait": 48, - "warmup": 1, - "active": 1, - "repeat": 1, - } - self.profiler_log_dir = None - - if self.profile and self.profile != "False": - self.use_profiler = True - # Create directory for profiler output - self.profiler_log_dir = os.path.join(self.output_path, "log") - os.makedirs(self.profiler_log_dir, exist_ok=True) - if self.profile != "True": - try: - # Try to parse profile as a JSON string with schedule config - profile_config = json.loads(self.profile) - if isinstance(profile_config, dict): - for key in ["wait", "warmup", "active", "repeat"]: - if key in profile_config and isinstance( - profile_config[key], int - ): - self.profiler_schedule_config[key] = ( - profile_config[key] - ) - except json.JSONDecodeError: - logging.warning( - f"Could not parse profile config: {self.profile}, using defaults" - ) - - # For monitor thread - self.stop_event = threading.Event() - - def run(self): - logging.info( - f"Worker started on device {self.device} with {len(self.files)} files." - ) - a2g = AtomsToGraphs(r_edges=False, r_pbc=True) - # model = torch.load("/home/mazhaojia/.cache/mace/MACE-OFF23_small.model", map_location=self.device) - # z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) - calculator = mace_off(model="small", device=self.device) - - results = [] - - for batch_files in self._batch_files(self.files, self.batch_size): - logging.info(f"Processing batch with {len(batch_files)} files.") - start_time = time.perf_counter() - - atoms_list = [] - for file in batch_files: - atoms = read(file) - atoms_list.append(atoms) - gbatch = data_list_collater( - [a2g.convert(atoms) for atoms in atoms_list] - ) - - gbatch = gbatch.to(self.device) - if self.filter1 == "UnitCellFilter": - from batchopt.relaxation import OptimizableUnitCellBatch - - obatch = OptimizableUnitCellBatch( - gbatch, - trainer=calculator, - numpy=False, - scalar_pressure=self.scalar_pressure, - ) - else: - obatch = OptimizableBatch( - gbatch, trainer=calculator, numpy=False - ) - - # First optimization stage - if self.optimizer1 == "LBFGS": - batch_optimizer1 = LBFGS( - obatch, damping=1.0, alpha=70.0, maxstep=0.2 - ) - elif self.optimizer1 == "BFGS": - batch_optimizer1 = BFGS(obatch, alpha=70.0, maxstep=0.2) - elif self.optimizer1 == "BFGSLineSearch": - batch_optimizer1 = BFGSLineSearch(obatch, device=self.device) - elif self.optimizer1 == "BFGSFusedLS": - batch_optimizer1 = BFGSFusedLS(obatch, device=self.device) - else: - raise ValueError(f"Unknown optimizer: {self.optimizer1}") - - start_time1 = time.perf_counter() - batch_optimizer1.run(0.01, self.max_steps) - end_time1 = time.perf_counter() - elapsed_time1 = end_time1 - start_time1 - - # Save intermediate results - atoms_list = obatch.get_atoms_list() - for atoms, file_path in zip(atoms_list, batch_files): - file_name = file_path.split("/")[-1] - output_file = os.path.join( - self.output_path, - "cif_result_press", - file_name.replace(".cif", "_press.cif"), - ) - atoms.write(output_file) - - # Capture maximum force after first optimization stage - max_force1 = obatch.get_max_forces(apply_constraint=True) - - steps1 = batch_optimizer1.nsteps - - if self.skip_second_stage: - # If skipping second stage, set metrics to zero - for file, force in zip(batch_files, max_force1): - results.append( - { - "file": file, - "stage1_time": elapsed_time1, - "stage1_steps": steps1, - "stage2_time": 0.0, - "stage2_steps": 0, - "total_time": elapsed_time1, - "total_steps": steps1, - "force1": force.item(), - "force2": 0.0, - } - ) - continue - - # Only proceed with second stage if not skipping - # Reload intermediate structures for second stage - atoms_list = [] - for file_path in batch_files: - file_name = file_path.split("/")[-1] - press_file = os.path.join( - self.output_path, - "cif_result_press", - file_name.replace(".cif", "_press.cif"), - ) - atoms = read(press_file) - atoms_list.append(atoms) - - # Rebuild batch from optimized structures - gbatch = data_list_collater( - [a2g.convert(atoms) for atoms in atoms_list] - ) - gbatch = gbatch.to(self.device) - - # Second optimization stage - if self.filter2 == "UnitCellFilter": - obatch2 = OptimizableUnitCellBatch( - gbatch, trainer=calculator, numpy=False, scalar_pressure=0.0 - ) - else: - obatch2 = OptimizableBatch( - gbatch, trainer=calculator, numpy=False - ) - - if self.optimizer2 == "LBFGS": - batch_optimizer2 = LBFGS( - obatch2, damping=1.0, alpha=70.0, maxstep=0.2 - ) - elif self.optimizer2 == "BFGS": - batch_optimizer2 = BFGS(obatch2, alpha=70.0, maxstep=0.2) - elif self.optimizer2 == "BFGSLineSearch": - batch_optimizer2 = BFGSLineSearch(obatch2, device=self.device) - elif self.optimizer2 == "BFGSFusedLS": - batch_optimizer2 = BFGSFusedLS(obatch2, device=self.device) - else: - raise ValueError(f"Unknown optimizer: {self.optimizer2}") - start_time2 = time.perf_counter() - batch_optimizer2.run(0.01, self.max_steps) - end_time2 = time.perf_counter() - elapsed_time2 = end_time2 - start_time2 - - # Save final results - atoms_list = obatch2.get_atoms_list() - for atoms, file_path in zip(atoms_list, batch_files): - file_name = file_path.split("/")[-1] - output_file = os.path.join( - self.output_path, - "cif_result_final", - file_name.replace(".cif", "_opt.cif"), - ) - atoms.write(output_file) - - # Capture maximum force after second optimization stage - max_force2 = obatch2.get_max_forces(apply_constraint=True) - - steps2 = batch_optimizer2.nsteps - - for file, f1, f2 in zip(batch_files, max_force1, max_force2): - results.append( - { - "file": file, - "stage1_time": elapsed_time1, - "stage1_steps": steps1, - "stage2_time": elapsed_time2, - "stage2_steps": steps2, - "total_time": elapsed_time1 + elapsed_time2, - "total_steps": steps1 + steps2, - "force1": f1.item(), - "force2": f2.item(), - } - ) - - return results - - def _batch_files(self, files, batch_size): - for i in range(0, len(files), batch_size): - yield files[i : i + batch_size] - - @staticmethod - def _torch_memory_monitor(interval=2, device=None, stop_event=None): - try: - # explicitly CUDA initialization - torch.cuda._lazy_init() - while not stop_event.is_set(): - allocated = torch.cuda.memory_allocated(device=device) - reserved = torch.cuda.memory_reserved(device=device) - logging.info( - f"[torch] Allocated Memory: {allocated / 1024**2:.2f} MiB" - ) - logging.info( - f"[torch] Reserved Memory: {reserved / 1024**2:.2f} MiB" - ) - time.sleep(interval) - except Exception as e: - logging.error(f"Unexpected error when monitor memory: {str(e)}") - - def continuous_run(self): - """ - Execute a continuous run of the batching optimization process. - """ - logging.info("Starting continuous_run with two rounds of optimization.") - - # torch memory monitor api - use_torch_memory_monitor = False - if use_torch_memory_monitor: - memory_monitor = threading.Thread( - target=Worker._torch_memory_monitor, - args=(2, self.device, self.stop_event), - ) - memory_monitor.start() - - # First round of optimization - try: - logging.info("Starting first round of optimization.") - results_round1, new_atoms_files = self.continuous_batching( - atoms_path=self.files, - result_path_prefix=os.path.join( - self.output_path, "cif_result_press/" - ), - fmax=0.01, - maxstep=self.max_steps, - use_filter=self.filter1, - optimizer=self.optimizer1, - scalar_pressure=self.scalar_pressure, - dtype=torch.float64, - ) - logging.info( - f"Completed first round of optimization. Results: {len(results_round1)}" - ) - except KeyboardInterrupt as e: - if use_torch_memory_monitor: - self.stop_event.set() - memory_monitor.join() - logging.error(f"Error during first round of optimization: {e}") - raise - except Exception as e: - logging.error(f"Error during first round of optimization: {e}") - raise - - if self.skip_second_stage: - logging.info("Skipping second round of optimization.") - return results_round1 - - # Second round of optimization without pressure - try: - logging.info("Starting second round of optimization.") - results_round2, _ = self.continuous_batching( - atoms_path=new_atoms_files, - result_path_prefix=os.path.join( - self.output_path, "cif_result_final/" - ), - fmax=0.01, - maxstep=self.max_steps, - # maxstep=3000, - use_filter=self.filter2, - optimizer=self.optimizer2, - scalar_pressure=0.0, - dtype=torch.float64, - ) - logging.info( - f"Completed second round of optimization. Results: {len(results_round2)}" - ) - except KeyboardInterrupt as e: - if use_torch_memory_monitor: - self.stop_event.set() - memory_monitor.join() - logging.error(f"Error during second round of optimization: {e}") - raise - except Exception as e: - logging.error(f"Error during second round of optimization: {e}") - raise - - if use_torch_memory_monitor: - self.stop_event.set() - memory_monitor.join() - - return self._save_results_to_csv(results_round1, results_round2) - - def _save_results_to_csv(self, results_round1, results_round2): - """Helper method to save results to CSV file and return the path.""" - combined_results = [] - results_map = {} - - # Process first round results - for result in results_round1: - file_name = result["file"] - results_map[file_name] = { - "file": file_name, - "stage1_steps": result["steps"], - "stage1_time": result["runtime"], - "stage1_energy": result["energy"], - "stage1_density": result["density"], - "stage2_steps": 0, - "stage2_time": 0.0, - "stage2_energy": 0.0, - "stage2_density": 0, - "total_steps": result["steps"], - "total_time": result["runtime"], - } - - # Process second round results - for result in results_round2: - file_name = result["file"] - if file_name in results_map: - results_map[file_name].update( - { - "stage2_steps": result["steps"], - "stage2_time": result["runtime"], - "stage2_energy": result["energy"], - "stage2_density": result["density"], - "total_steps": results_map[file_name]["stage1_steps"] - + result["steps"], - "total_time": results_map[file_name]["stage1_time"] - + result["runtime"], - } - ) - else: - results_map[file_name] = { - "file": file_name, - "stage1_steps": 0, - "stage1_time": 0.0, - "stage1_energy": 0.0, - "stage1_density": 0, - "stage2_steps": result["steps"], - "stage2_time": result["runtime"], - "stage2_energy": result["energy"], - "stage2_density": result["density"], - "total_steps": result["steps"], - "total_time": result["runtime"], - } - - # Convert map to list - combined_results = list(results_map.values()) - - logging.info( - f"Combined results from both rounds. Total results: {len(combined_results)}" - ) - - worker_id = os.getpid() - timestamp = int(time.time()) - csv_filename = f"worker_{worker_id}_{timestamp}.csv" - csv_path = os.path.join( - self.output_path, "worker_results", csv_filename - ) - os.makedirs(os.path.dirname(csv_path), exist_ok=True) - - with open(csv_path, mode="w", newline="") as csvfile: - fieldnames = [ - "file", - "stage1_steps", - "stage1_time", - "stage1_energy", - "stage1_density", - "stage2_steps", - "stage2_time", - "stage2_energy", - "stage2_density", - "total_steps", - "total_time", - ] - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - writer.writeheader() - for result in combined_results: - writer.writerow(result) - - return csv_path - - def _get_density(self, crystal): - # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 - total_mass = sum(crystal.get_masses()) # 转换为克 - - # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 - # 1 Å^3 = 1e-24 cm^3 - volume = crystal.get_volume() # 转换为立方厘米 - - # 计算密度,质量除以体积 - density = ( - total_mass / (volume * 10**-24) / (6.022140857 * 10**23) - ) # 单位是 g/cm^3 - return density - - @staticmethod - def select_factor(history: deque): - # TODO: when history is mix of different size, the smaller `values` should be selected. - boundaries = [0, 50, 100, 200, 400, 800] - values = [0.4, 0.8, 0.9, 0.6, 0.5, 0.4] - factor_result = [] - for graph_size in history: - for i in range(len(boundaries) - 1): - if boundaries[i] <= graph_size < boundaries[i + 1]: - factor_result.append(values[i]) - break - if len(factor_result) == 0: - return 0.4 - else: - return min(factor_result) - - def continuous_batching( - self, - atoms_path, - result_path_prefix, - fmax, - maxstep, - use_filter, - optimizer, - scalar_pressure, - dtype=torch.float64, - ): - """ - Performs continuous batched optimization of atomic structures. - - This method implements a continuous batching strategy for optimizing multiple atomic structures, - where converged structures are replaced with new ones to maintain batch efficiency. - - Parameters - ---------- - atoms_path : list - List of file paths to atomic structure files to be optimized - result_path_prefix : str - Prefix for output file paths where optimized structures will be saved - fmax : float, optional - Maximum force criterion for convergence, by default 0.01 - maxstep : int, optional - Maximum number of optimization steps per batch, by default 3000 - use_filter : str, optional - Filter to be used for optimization, by default "UnitCellFilter" - optimizer : str, optional - Optimizer to be used for optimization, by default "LBFGS" - scalar_pressure : float, optional - Scalar pressure to be applied, by default 0.0 - - Returns - ------- - None - The optimized structures are saved to disk - - Notes - ----- - The method: - - Processes structures in batches of predefined size - - Uses MACE neural network potential for energy/force calculations - - Employs LBFGS optimization with unit cell relaxation - - Dynamically replaces converged structures with new ones in the batch - - Tracks convergence and optimization steps for each structure - """ - # Load saved structures - result = [] - optimized_atoms_paths = [] - - json_dir = result_path_prefix.replace("cif", "json") - - remove_list = [] - # TODO: Why we read all CIF here? - for pre_cif in atoms_path: - cif_path = os.path.join(result_path_prefix, pre_cif.split("/")[-1]) - json_path = os.path.join( - json_dir, pre_cif.split("/")[-1].replace(".cif", ".json") - ) - if ( - os.path.exists(cif_path) - and os.path.exists(json_path) - and os.path.getsize(cif_path) > 0 - and os.path.getsize(json_path) > 0 - ): - with open(json_path, "r") as f: - result_data = json.load(f) - result.append(result_data) - optimized_atoms_paths.append(cif_path) - remove_list.append(pre_cif) - logging.info(f"File {cif_path} already exists, loaded.") - # else: - # try: - # read(pre_cif) - # except Exception as e: - # logging.info(f"Failed to read {pre_cif}: {e}") - # remove_list.append(pre_cif) - for i in remove_list: - atoms_path.remove(i) - - if self.batch_size > 0: - # Initialize variables - room_in_batch = self.batch_size - indices_to_process = 0 - cur_batch_path = atoms_path[ - indices_to_process : indices_to_process + room_in_batch - ] - if len(cur_batch_path) == 0: - logging.info("No structures to process.") - return result, optimized_atoms_paths - room_in_batch -= len(cur_batch_path) - indices_to_process += len(cur_batch_path) - cur_atoms_list = [read(path) for path in cur_batch_path] - a2g = AtomsToGraphs(r_edges=False, r_pbc=True) - gbatch = data_list_collater( - [a2g.convert(read(path)) for path in cur_batch_path] - ) - else: - # Set Maximum Number of atoms per batch - history = deque(maxlen=10) - history.append(1000) - max_bnatoms = 24080 - safe_factor = self.select_factor(history) - - indices_to_process = 0 - bnatoms = 0 - cur_batch_path = [] - graphs_list = [] - a2g = AtomsToGraphs(r_edges=False, r_pbc=True) - - while indices_to_process < len(atoms_path): - graph_natoms = count_atoms_cif(atoms_path[indices_to_process]) - if ( - bnatoms + graph_natoms - > max_bnatoms * safe_factor // self.nproc - ): - break - graph = a2g.convert(read(atoms_path[indices_to_process])) - bnatoms += graph_natoms - cur_batch_path.append(atoms_path[indices_to_process]) - graphs_list.append(graph) - indices_to_process += 1 - history.append(graph_natoms) - safe_factor = self.select_factor(history) - if len(graphs_list) == 0: - logging.info("No structures to process.") - return result, optimized_atoms_paths - gbatch = data_list_collater(graphs_list) - logging.info(f"current batch size: {len(cur_batch_path)}") - - total_natoms = sum([graph.natoms for graph in graphs_list]) - logging.info(f"total_natoms: {total_natoms}") - - gbatch = gbatch.to(self.device) - batch_optimizer = None - - # Initial calculator - if self.model == "mace": - if dtype == torch.float32: - calculator = mace_off( - model="small", - device=self.device, - enable_cueq=self.cueq, - default_dtype="float32", - ) - else: - calculator = mace_off( - model="small", device=self.device, enable_cueq=self.cueq - ) - elif self.model == "chgnet": - calculator = CHGNetCalculator( - use_device=self.device, enable_cueq=self.cueq - ) - elif self.model == "sevennet": - # calculator = SevenNetCalculator(device=self.device, enable_cueq=self.cueq) - calculator = SevenNetD3Calculator( - device=self.device, - enable_cueq=self.cueq, - batch_size=self.batch_size, - ) - # calculator = SevenNetCalculator('7net-mf-ompa', modal='mpa', device=self.device) - # calculator = MACECalculator(model_paths="/home/mazhaojia/.cache/mace/MACE-OFF23_small.model", device=self.device, compile_mode=self.compile_mode) - if use_filter == "UnitCellFilter": - obatch = OptimizableUnitCellBatch( - gbatch, - trainer=calculator, - numpy=False, - scalar_pressure=scalar_pressure, - ) - else: - obatch = OptimizableBatch(gbatch, trainer=calculator, numpy=False) - - orig_cells = obatch.orig_cells.clone() - - converged_atoms_count = 0 - converge_indices = [] - all_indices = [] - cur_batch_steps = [0] * len(cur_batch_path) - cur_batch_times = [time.perf_counter()] * len( - cur_batch_path - ) # Track start times - - while converged_atoms_count < len(atoms_path): - # Update batch - if len(all_indices) > 0: - if self.batch_size > 0: - room_in_batch += len(all_indices) - new_batch_path = atoms_path[ - indices_to_process : indices_to_process + room_in_batch - ] - logging.info(f"new_batch_path: {new_batch_path}") - room_in_batch -= len(new_batch_path) - indices_to_process += len(new_batch_path) - - optimized_atoms_new = [] - cur_batch_path_new = [] - cur_batch_steps_new = [] - cur_batch_times_new = [] - orig_cells_new = torch.zeros( - [self.batch_size - room_in_batch, 3, 3], - device=self.device, - ) - cell_offset = 0 - - restart_indices = [] - old_batch_indices = obatch.batch_indices - for i in range(len(optimized_atoms)): - if i in all_indices: - continue - else: - restart_indices.append(i) - optimized_atoms_new.append(optimized_atoms[i]) - cur_batch_path_new.append(cur_batch_path[i]) - cur_batch_steps_new.append(cur_batch_steps[i]) - cur_batch_times_new.append(cur_batch_times[i]) - - orig_cells_new[cell_offset] = orig_cells[i] - cell_offset += 1 - - for new_path in new_batch_path: - optimized_atoms_new.append(read(new_path)) - cur_batch_path_new.append(new_path) - cur_batch_steps_new.append(0) - cur_batch_times_new.append(time.perf_counter()) - - # Update the batch with new structures - optimized_atoms = optimized_atoms_new - cur_batch_path = cur_batch_path_new - cur_batch_steps = cur_batch_steps_new - cur_batch_times = cur_batch_times_new - else: - bnatoms = 0 - optimized_atoms_new = [] - cur_batch_path_new = [] - cur_batch_steps_new = [] - cur_batch_times_new = [] - - restart_indices = [] - old_batch_indices = obatch.batch_indices - for i in range(len(optimized_atoms)): - if i in all_indices: - continue - restart_indices.append(i) - optimized_atoms_new.append(optimized_atoms[i]) - cur_batch_path_new.append(cur_batch_path[i]) - cur_batch_steps_new.append(cur_batch_steps[i]) - cur_batch_times_new.append(cur_batch_times[i]) - bnatoms += a2g.convert(read(cur_batch_path[i])).natoms - - while indices_to_process < len(atoms_path): - new_path = atoms_path[indices_to_process] - graph_natoms = count_atoms_cif(new_path) - if ( - bnatoms + graph_natoms - > max_bnatoms * safe_factor // self.nproc - ): - break - bnatoms += graph_natoms - optimized_atoms_new.append(read(new_path)) - cur_batch_path_new.append(new_path) - cur_batch_steps_new.append(0) - cur_batch_times_new.append(time.perf_counter()) - indices_to_process += 1 - history.append(graph_natoms) - safe_factor = self.select_factor(history) - - orig_cells_new = torch.zeros( - [len(optimized_atoms_new), 3, 3], device=self.device - ) - cell_offset = 0 - for i in range(len(optimized_atoms)): - if i in all_indices: - continue - orig_cells_new[cell_offset] = orig_cells[i] - cell_offset += 1 - - # Update the batch with new structures - optimized_atoms = optimized_atoms_new - cur_batch_path = cur_batch_path_new - cur_batch_steps = cur_batch_steps_new - cur_batch_times = cur_batch_times_new - - logging.info(f"current batch size: {len(optimized_atoms)}") - - graphs_list = [a2g.convert(atoms) for atoms in optimized_atoms] - total_natoms = sum([graph.natoms for graph in graphs_list]) - logging.info(f"total_natoms: {total_natoms}") - logging.info(f"cur_batch_path to processing: {cur_batch_path}") - gbatch = data_list_collater(graphs_list) - gbatch = gbatch.to(self.device) - if self.model == "sevennet": - # calculator = SevenNetCalculator('7net-mf-ompa', modal='mpa', device=self.device) - calculator = SevenNetD3Calculator( - device=self.device, - enable_cueq=self.cueq, - batch_size=self.batch_size, - ) - if use_filter == "UnitCellFilter": - obatch = OptimizableUnitCellBatch( - gbatch, - trainer=calculator, - numpy=False, - scalar_pressure=scalar_pressure, - ) - else: - obatch = OptimizableBatch( - gbatch, trainer=calculator, numpy=False - ) - for i in range(cell_offset): - obatch.orig_cells[i] = orig_cells_new[i] - orig_cells = obatch.orig_cells.clone() - - # Optimize the current batch - if optimizer == "LBFGS": - batch_optimizer = LBFGS( - obatch, - damping=1.0, - alpha=70.0, - maxstep=0.2, - early_stop=True, - ) - elif optimizer == "BFGS": - if len(all_indices) > 0: - logging.info(f"Restarting with indices: {restart_indices}") - batch_optimizer.optimizable = obatch - else: - batch_optimizer = BFGS( - obatch, alpha=70.0, maxstep=0.2, early_stop=True - ) - elif optimizer == "BFGSLineSearch": - batch_optimizer = BFGSLineSearch( - obatch, - device=self.device, - early_stop=True, - use_profiler=self.use_profiler, - profiler_log_dir=self.profiler_log_dir, - profiler_schedule_config=self.profiler_schedule_config, - ) - elif optimizer == "BFGSFusedLS": - if len(all_indices) > 0: - logging.info(f"Restarting with indices: {restart_indices}") - batch_optimizer.optimizable = obatch - else: - batch_optimizer = BFGSFusedLS( - obatch, - device=self.device, - early_stop=True, - use_profiler=self.use_profiler, - profiler_log_dir=self.profiler_log_dir, - profiler_schedule_config=self.profiler_schedule_config, - ) - else: - raise ValueError(f"Unknown optimizer: {optimizer}") - - # 动态计算剩余可用步数(基于当前批次最大已执行步数) - current_max_steps = max(cur_batch_steps) if cur_batch_steps else 0 - remaining_steps = max( - maxstep - current_max_steps, 1 - ) # 保证至少运行1步 - - # 执行优化并获取收敛的索引 - if (optimizer == "BFGSFusedLS" or optimizer == "BFGS") and len( - all_indices - ) > 0: - converge_indices = batch_optimizer.run( - fmax, - remaining_steps, - is_restart_earlystop=True, - restart_indices=restart_indices, - old_batch_indices=old_batch_indices, - ) - else: - converge_indices = batch_optimizer.run(fmax, remaining_steps) - - # Print energies of all structures - # logging.info(f"Final energies of all structures: {batch_optimizer.energies}") - energies_list = ( - batch_optimizer.optimizable.get_potential_energies().tolist() - ) - logging.info(f"Final energies of all structures: {energies_list}") - - # 更新所有结构的累计步数 - cur_batch_steps = [ - steps + batch_optimizer.nsteps for steps in cur_batch_steps - ] - - # 找出超过最大步数的结构索引 - over_maxstep_indices = [ - i - for i, steps in enumerate(cur_batch_steps) - if steps >= maxstep - 1 - ] - - # 合并收敛和超限的索引(去重) - all_indices = list(set(converge_indices + over_maxstep_indices)) - - # Get optimized atoms - optimized_atoms = obatch.get_atoms_list() - converged_atoms_count += len(all_indices) - - end_time = time.perf_counter() - # 处理所有需要退出的结构(包括收敛和超限) - for idx in all_indices: - runtime = end_time - cur_batch_times[idx] - - energy_per_mol = ( - energies_list[idx] - / ( - len(optimized_atoms[idx].get_atomic_numbers()) - / self.molecule_single - ) - * 96.485 - ) - density = self._get_density(optimized_atoms[idx]) - - # Save results - result_data = { - "file": cur_batch_path[idx].split("/")[-1].split(".")[0], - "steps": cur_batch_steps[idx], - "runtime": runtime, - "energy": energy_per_mol, - "density": density, - } - result.append(result_data) - - # Save optimized structure - # converged_atoms_path = os.path.join(result_path_prefix, cur_batch_path[idx].split('/')[-1].replace('.cif', '.traj')) - converged_atoms_path = os.path.join( - result_path_prefix, cur_batch_path[idx].split("/")[-1] - ) - optimized_atoms[idx].write(converged_atoms_path) - optimized_atoms_paths.append(converged_atoms_path) - - # write a json file to store reslt_data - os.makedirs(json_dir, exist_ok=True) - # json_path = os.path.join(json_dir, cur_batch_path[idx].split('/')[-1]+'.json') - json_path = os.path.join( - json_dir, - cur_batch_path[idx].split("/")[-1].replace(".cif", ".json"), - ) - with open(json_path, "w") as f: - json.dump(result_data, f) - - logging.info(f"cur_batch_path: {cur_batch_path}") - logging.info(f"cur_batch_steps: {cur_batch_steps}") - logging.info(f"all_indices: {all_indices}") - logging.info(f"length of optimized_atoms: {len(optimized_atoms)}") - - return result, optimized_atoms_paths +""" +Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan} + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from ase.io import read + +# from ase.optimize import ASE_LBFGS +import torch +from torch.multiprocessing import Process, set_start_method +from batchopt.atoms_to_graphs import AtomsToGraphs +from batchopt.utils import data_list_collater +from batchopt.relaxation.optimizers import ( + BFGS, + BFGSFusedLS, +) +from batchopt.relaxation import OptimizableBatch, OptimizableUnitCellBatch +import logging +import time +import csv +from multiprocessing import Queue +import os +import psutil +import multiprocessing +import json +import subprocess + +try: + from chgnet.model.dynamics import CHGNetCalculator +except ImportError: + logging.warning("Failed to import CHGNet modules") + +try: + from sevenn.calculator import SevenNetCalculator, SevenNetD3Calculator +except ImportError: + logging.warning("Failed to import SevenNet modules") + +try: + from fairchem.core import pretrained_mlip, FAIRChemCalculator +except ImportError: + logging.warning("Failed to import FAIRChem modules") + +try: + from mace.calculators import mace_off +except ImportError: + logging.warning("Failed to import MACE modules") + +import threading +from .utils import count_atoms_cif +from collections import deque + + +class Scheduler: + """ + Scheduler distributes relaxation tasks to workers. + """ + + def __init__( + self, + files, + num_workers, + devices, + batch_size, + max_steps, + filter1, + filter2, + optimizer1, + optimizer2, + skip_second_stage, + scalar_pressure, + compile_mode, + profile, + num_threads, + bind_cores, + cueq, + molecule_single, + output_path, + model, + ): + + self.files = files + self.num_workers = num_workers + self.devices = devices + self.batch_size = batch_size + self.max_steps = max_steps + self.filter1 = filter1 + self.filter2 = filter2 + self.optimizer1 = optimizer1 + self.optimizer2 = optimizer2 + self.skip_second_stage = skip_second_stage + self.scalar_pressure = scalar_pressure + self.compile_mode = compile_mode + self.profile = profile + self.num_threads = num_threads + self.cueq = cueq + self.molecule_single = molecule_single + self.output_path = ( + output_path + if os.path.isabs(output_path) + else os.path.abspath(output_path) + ) + self.model = model + + try: + set_start_method("spawn") + except RuntimeError: + logging.warning( + "set_start_method('spawn') failed, trying 'forkserver' instead." + ) + + if bind_cores is not None: + self.cpu_mask = self._parse_bind_cores(bind_cores) + else: + self.cpu_mask = None + + def _parse_bind_cores(self, bind_cores): + # Expect custom_bind_str to be like "0-15,16-31,..." + ranges = bind_cores.split(",") + if len(ranges) != self.num_workers: + return None + binding = [] + for r in ranges: + try: + start_str, end_str = r.split("-") + start = int(start_str) + end = int(end_str) + except ValueError: + logging.error("Custom binding format should be 'start-end'.") + return None + + binding.append(set(range(start, end + 1))) + return binding + + def _get_physical_logical_core_mapping(self): + """Get the mapping between logical cores and their physical core IDs.""" + try: + # This information is available in Linux systems + mapping = {} + logical_cores = psutil.cpu_count(logical=True) + + for i in range(logical_cores): + try: + # Read core_id from /sys/devices/system/cpu/cpu{i}/topology/core_id + with open( + f"/sys/devices/system/cpu/cpu{i}/topology/core_id" + ) as f: + core_id = int(f.read().strip()) + # Read physical_package_id (socket) for more complete information + with open( + f"/sys/devices/system/cpu/cpu{i}/topology/physical_package_id" + ) as f: + package_id = int(f.read().strip()) + mapping[i] = (package_id, core_id) + except (FileNotFoundError, ValueError, IOError): + mapping[i] = None + return mapping + except Exception as e: + logging.error(f"Failed to get core mapping: {e}") + return {} + + def _get_physical_core_mask(self): + # Get the number of physical and logical cores + physical_cores = psutil.cpu_count(logical=False) + logical_cores = psutil.cpu_count(logical=True) + + if physical_cores is None or physical_cores < 1: + # Fallback to multiprocessing if psutil fails + logical_cores = multiprocessing.cpu_count() + physical_cores = logical_cores // 2 + if physical_cores < 1: + physical_cores = 1 + print(f"Using estimated physical cores: {physical_cores}") + + # Get the mapping between logical and physical cores + core_mapping = self._get_physical_logical_core_mapping() + + # Create a CPU mask that includes all physical cores (first core of each physical core) + physical_core_mask = set() + if core_mapping: + # Group by physical core ID + cores_by_physical = {} + for logical_id, physical_info in core_mapping.items(): + if physical_info is not None: + package_id, core_id = physical_info + key = (package_id, core_id) + if key not in cores_by_physical: + cores_by_physical[key] = [] + cores_by_physical[key].append(logical_id) + + # Select one logical core from each physical core + for physical_cores_list in cores_by_physical.values(): + physical_core_mask.add( + physical_cores_list[0] + ) # First logical core of each physical core + else: + # If mapping fails, use a simple assumption (may not be accurate on all systems) + threads_per_core = logical_cores // physical_cores + physical_core_mask = set(range(0, logical_cores, threads_per_core)) + + return physical_core_mask + + def worker_task( + self, files, device, batch_size, result_queue, physical_cores + ): + if physical_cores is not None: + try: + # Bind the current process to physical cores + pid = os.getpid() + os.sched_setaffinity(pid, physical_cores) + logging.info(f"bind to physical_core_ids: {physical_cores}") + + # Verify the affinity was set correctly + current_affinity = os.sched_getaffinity(pid) + logging.info( + f"Process bound to {len(current_affinity)} cores: {sorted(current_affinity)}" + ) + + except AttributeError: + logging.error( + "sched_setaffinity not supported on this platform" + ) + except Exception as e: + logging.error(f"Failed to bind to physical cores: {e}") + + # pass the number of processes on each worker + nproc = self.num_workers // len(self.devices) + + worker = Worker( + files, + device, + batch_size, + self.max_steps, + self.filter1, + self.filter2, + self.optimizer1, + self.optimizer2, + self.skip_second_stage, + self.scalar_pressure, + self.compile_mode, + self.profile, + self.cueq, + self.molecule_single, + self.output_path, + self.model, + nproc, + ) + # results = worker.run() + results = worker.continuous_run() + result_queue.put(results) + + def _terminate_processes(self, processes): + """Helper method to terminate all processes.""" + for i, p in processes: + if p.is_alive(): + logging.info(f"Terminating process {p.pid}") + p.terminate() + p.join(timeout=3) # Wait for up to 3 seconds + if p.is_alive(): + logging.warning( + f"Process {p.pid} did not terminate, killing it" + ) + p.kill() + p.join() + + # create a thread to conduct "nvidia-smi" + @staticmethod + def _monitor_memory(interval=2, gpu_index=1): + try: + while True: + result = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=memory.used,memory.total", + "--format=csv,nounits,noheader", + ] + ).decode("utf-8") + + lines = result.strip().split("\n") + used, total = map(int, lines[gpu_index].split(",")) + logging.info( + f"[nvidia-smi] Memory-Usage on GPU {gpu_index}: {used}MiB / {total}MiB" + ) + + time.sleep(interval) + except KeyboardInterrupt: + logging.info("Monitor interrupted.") + + except Exception as e: + logging.error(f"Unexpected error when monitor memory: {str(e)}") + + def run(self): + logging.info(f"Starting Scheduler with {self.num_workers} workers.") + processes = [] + result_queue = Queue() + start_time = time.perf_counter() + + if self.cpu_mask is not None: + physical_cores_per_worker = self.cpu_mask + logging.info( + f"Use customed cores binding. Physical cores per worker: {physical_cores_per_worker}" + ) + else: + # all_physical_cores = self._get_physical_core_mask() + # num_per_worker = len(all_physical_cores) // self.num_workers + # physical_cores_per_worker = [ + # list(all_physical_cores)[i:i + num_per_worker] for i in range(0, len(all_physical_cores), num_per_worker) + # ] + # logging.info(f"Physical cores per worker: {physical_cores_per_worker}") + physical_cores_per_worker = [None] * self.num_workers + + try: + # Start all worker processes + for i in range(self.num_workers): + files_for_worker = self.files[i :: self.num_workers] + device = self.devices[i % len(self.devices)] + logging.info( + f"Starting worker {i} with {len(files_for_worker)} files on device {device}." + ) + p = Process( + target=self.worker_task, + args=( + files_for_worker, + device, + self.batch_size, + result_queue, + physical_cores_per_worker[i], + ), + ) + p.start() + processes.append((i, p)) + + # monitor gpu memory usage to figure out what makes the differences of footprint among batches + # in each iteration. + use_memory_monitor = False + if use_memory_monitor: + monitor_proc = Process( + target=Scheduler._monitor_memory, args=() + ) + monitor_proc.start() + + # Monitor processes and collect results + csv_paths = [] + completed_processes = 0 + while completed_processes < self.num_workers: + for i, p in processes: + if not p.is_alive() and p.exitcode != 0: + if p.exitcode == -11 or p.exitcode == 1: + # Restart the process if exit code is -11 or -1 + logging.warning( + f"Worker process {p.pid} exited with code {p.exitcode}. Restarting worker {i}." + ) + files_for_worker = self.files[i :: self.num_workers] + device = self.devices[i % len(self.devices)] + new_process = Process( + target=self.worker_task, + args=( + files_for_worker, + device, + self.batch_size, + result_queue, + physical_cores_per_worker[i], + ), + ) + new_process.start() + processes[i] = ( + i, + new_process, + ) # Replace the old process with the new one + else: + # Raise an error for other exit codes + raise RuntimeError( + f"Worker process {p.pid} failed with exit code {p.exitcode}" + ) + + # Try to get result from queue with timeout + try: + result = result_queue.get(timeout=10) + csv_paths.append(result) + completed_processes += 1 + except Exception as e: + continue + + # terminate monitor + if use_memory_monitor: + monitor_proc.terminate() + monitor_proc.join() + + # Process results and create final CSV + merged_results = [] + for csv_path in csv_paths: + try: + with open(csv_path, mode="r") as f: + reader = csv.DictReader(f) + merged_results.extend(list(reader)) + except Exception as e: + logging.error(f"Error processing {csv_path}: {str(e)}") + + except Exception as e: + # Log the error and elapsed time + end_time = time.perf_counter() + elapsed_time = end_time - start_time + logging.error( + f"Error occurred after running for {elapsed_time:.2f} seconds: {str(e)}" + ) + + # Create error log file + error_log = f"scheduler_error_{int(time.time())}.log" + with open(error_log, "w") as f: + f.write(f"Error occurred after {elapsed_time:.2f} seconds\n") + f.write(f"Error message: {str(e)}\n") + f.write(f"Number of workers: {self.num_workers}\n") + f.write(f"Batch size: {self.batch_size}\n") + + # Terminate all processes + self._terminate_processes(processes) + raise # Re-raise the exception after cleanup + + finally: + end_time = time.perf_counter() + elapsed_time = end_time - start_time + + # Write final results if we have any + if "merged_results" in locals() and merged_results: + csv_file = os.path.join( + self.output_path, "results_scheduler.csv" + ) + with open(csv_file, mode="w", newline="") as file: + writer = csv.DictWriter( + file, + fieldnames=[ + "file", + "stage1_steps", + "stage1_time", + "stage1_energy", + "stage1_density", + "stage2_steps", + "stage2_time", + "stage2_energy", + "stage2_density", + "total_steps", + "total_time", + ], + ) + writer.writeheader() + for row in merged_results: + try: + processed_row = { + "file": row["file"], + "stage1_steps": int(row["stage1_steps"]), + "stage1_time": float(row["stage1_time"]), + "stage1_energy": float(row["stage1_energy"]), + "stage1_density": float(row["stage1_density"]), + "stage2_steps": int(row["stage2_steps"]), + "stage2_time": float(row["stage2_time"]), + "stage2_energy": float(row["stage2_energy"]), + "stage2_density": float(row["stage2_density"]), + "total_steps": int(row["total_steps"]), + "total_time": float(row["total_time"]), + } + writer.writerow(processed_row) + except (KeyError, ValueError) as e: + logging.error( + f"Invalid data format in row {row}: {str(e)}" + ) + + # Write summary + summary_csv_file = os.path.join( + self.output_path, "summary_scheduler.csv" + ) + with open(summary_csv_file, mode="w", newline="") as file: + writer = csv.DictWriter( + file, + fieldnames=["elapsed_time", "num_workers", "batch_size"], + ) + writer.writeheader() + writer.writerow( + { + "elapsed_time": elapsed_time, + "num_workers": self.num_workers, + "batch_size": self.batch_size, + } + ) + + logging.info(f"Scheduler completed in {elapsed_time:.2f} seconds.") + + def run_debug(self): + logging.info("Starting Scheduler in debug mode (sequential execution).") + + def worker_task(files, device, batch_size): + worker = Worker( + files, device, batch_size, self.max_steps, self.filter1 + ) + worker.run() + + for i in range(self.num_workers): + files_for_worker = self.files[i :: self.num_workers] + device = self.devices[i % len(self.devices)] + logging.info( + f"Running worker {i} with {len(files_for_worker)} files on device {device}." + ) + worker_task(files_for_worker, device, self.batch_size) + + logging.info("All workers have completed their tasks in debug mode.") + + +class Worker: + """ + Worker is single process that runs a batch of optimization tasks. + """ + + def __init__( + self, + files, + device, + batch_size, + max_steps, + filter1, + filter2, + optimizer1, + optimizer2, + skip_second_stage, + scalar_pressure, + compile_mode, + profile, + cueq, + molecule_single, + output_path, + model, + nproc, + ): + self.files = files + self.device = device + self.batch_size = batch_size + self.max_steps = max_steps + self.filter1 = filter1 + self.filter2 = filter2 + self.optimizer1 = optimizer1 + self.optimizer2 = optimizer2 + self.skip_second_stage = skip_second_stage # Store skip_second_stage + self.scalar_pressure = scalar_pressure + self.compile_mode = compile_mode + self.profile = profile + self.cueq = cueq + self.molecule_single = molecule_single + self.output_path = ( + output_path + if os.path.isabs(output_path) + else os.path.abspath(output_path) + ) + self.model = model + self.nproc = nproc + + # Parse profiler options if provided + self.use_profiler = False + self.profiler_schedule_config = { + "wait": 48, + "warmup": 1, + "active": 1, + "repeat": 1, + } + self.profiler_log_dir = None + + if self.profile and self.profile != "False": + self.use_profiler = True + # Create directory for profiler output + self.profiler_log_dir = os.path.join(self.output_path, "log") + os.makedirs(self.profiler_log_dir, exist_ok=True) + if self.profile != "True": + try: + # Try to parse profile as a JSON string with schedule config + profile_config = json.loads(self.profile) + if isinstance(profile_config, dict): + for key in ["wait", "warmup", "active", "repeat"]: + if key in profile_config and isinstance( + profile_config[key], int + ): + self.profiler_schedule_config[key] = ( + profile_config[key] + ) + except json.JSONDecodeError: + logging.warning( + f"Could not parse profile config: {self.profile}, using defaults" + ) + + # For monitor thread + self.stop_event = threading.Event() + + def run(self): + logging.info( + f"Worker started on device {self.device} with {len(self.files)} files." + ) + a2g = AtomsToGraphs(r_edges=False, r_pbc=True) + # model = torch.load("/home/mazhaojia/.cache/mace/MACE-OFF23_small.model", map_location=self.device) + # z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) + calculator = mace_off(model="small", device=self.device) + + results = [] + + for batch_files in self._batch_files(self.files, self.batch_size): + logging.info(f"Processing batch with {len(batch_files)} files.") + start_time = time.perf_counter() + + atoms_list = [] + for file in batch_files: + atoms = read(file) + atoms_list.append(atoms) + gbatch = data_list_collater( + [a2g.convert(atoms) for atoms in atoms_list] + ) + + gbatch = gbatch.to(self.device) + if self.filter1 == "UnitCellFilter": + from batchopt.relaxation import OptimizableUnitCellBatch + + obatch = OptimizableUnitCellBatch( + gbatch, + trainer=calculator, + numpy=False, + scalar_pressure=self.scalar_pressure, + ) + else: + obatch = OptimizableBatch( + gbatch, trainer=calculator, numpy=False + ) + + # First optimization stage + if self.optimizer1 == "LBFGS": + batch_optimizer1 = LBFGS( + obatch, damping=1.0, alpha=70.0, maxstep=0.2 + ) + elif self.optimizer1 == "BFGS": + batch_optimizer1 = BFGS(obatch, alpha=70.0, maxstep=0.2) + elif self.optimizer1 == "BFGSLineSearch": + batch_optimizer1 = BFGSLineSearch(obatch, device=self.device) + elif self.optimizer1 == "BFGSFusedLS": + batch_optimizer1 = BFGSFusedLS(obatch, device=self.device) + else: + raise ValueError(f"Unknown optimizer: {self.optimizer1}") + + start_time1 = time.perf_counter() + batch_optimizer1.run(0.01, self.max_steps) + end_time1 = time.perf_counter() + elapsed_time1 = end_time1 - start_time1 + + # Save intermediate results + atoms_list = obatch.get_atoms_list() + for atoms, file_path in zip(atoms_list, batch_files): + file_name = file_path.split("/")[-1] + output_file = os.path.join( + self.output_path, + "cif_result_press", + file_name.replace(".cif", "_press.cif"), + ) + atoms.write(output_file) + + # Capture maximum force after first optimization stage + max_force1 = obatch.get_max_forces(apply_constraint=True) + + steps1 = batch_optimizer1.nsteps + + if self.skip_second_stage: + # If skipping second stage, set metrics to zero + for file, force in zip(batch_files, max_force1): + results.append( + { + "file": file, + "stage1_time": elapsed_time1, + "stage1_steps": steps1, + "stage2_time": 0.0, + "stage2_steps": 0, + "total_time": elapsed_time1, + "total_steps": steps1, + "force1": force.item(), + "force2": 0.0, + } + ) + continue + + # Only proceed with second stage if not skipping + # Reload intermediate structures for second stage + atoms_list = [] + for file_path in batch_files: + file_name = file_path.split("/")[-1] + press_file = os.path.join( + self.output_path, + "cif_result_press", + file_name.replace(".cif", "_press.cif"), + ) + atoms = read(press_file) + atoms_list.append(atoms) + + # Rebuild batch from optimized structures + gbatch = data_list_collater( + [a2g.convert(atoms) for atoms in atoms_list] + ) + gbatch = gbatch.to(self.device) + + # Second optimization stage + if self.filter2 == "UnitCellFilter": + obatch2 = OptimizableUnitCellBatch( + gbatch, trainer=calculator, numpy=False, scalar_pressure=0.0 + ) + else: + obatch2 = OptimizableBatch( + gbatch, trainer=calculator, numpy=False + ) + + if self.optimizer2 == "LBFGS": + batch_optimizer2 = LBFGS( + obatch2, damping=1.0, alpha=70.0, maxstep=0.2 + ) + elif self.optimizer2 == "BFGS": + batch_optimizer2 = BFGS(obatch2, alpha=70.0, maxstep=0.2) + elif self.optimizer2 == "BFGSLineSearch": + batch_optimizer2 = BFGSLineSearch(obatch2, device=self.device) + elif self.optimizer2 == "BFGSFusedLS": + batch_optimizer2 = BFGSFusedLS(obatch2, device=self.device) + else: + raise ValueError(f"Unknown optimizer: {self.optimizer2}") + start_time2 = time.perf_counter() + batch_optimizer2.run(0.01, self.max_steps) + end_time2 = time.perf_counter() + elapsed_time2 = end_time2 - start_time2 + + # Save final results + atoms_list = obatch2.get_atoms_list() + for atoms, file_path in zip(atoms_list, batch_files): + file_name = file_path.split("/")[-1] + output_file = os.path.join( + self.output_path, + "cif_result_final", + file_name.replace(".cif", "_opt.cif"), + ) + atoms.write(output_file) + + # Capture maximum force after second optimization stage + max_force2 = obatch2.get_max_forces(apply_constraint=True) + + steps2 = batch_optimizer2.nsteps + + for file, f1, f2 in zip(batch_files, max_force1, max_force2): + results.append( + { + "file": file, + "stage1_time": elapsed_time1, + "stage1_steps": steps1, + "stage2_time": elapsed_time2, + "stage2_steps": steps2, + "total_time": elapsed_time1 + elapsed_time2, + "total_steps": steps1 + steps2, + "force1": f1.item(), + "force2": f2.item(), + } + ) + + return results + + def _batch_files(self, files, batch_size): + for i in range(0, len(files), batch_size): + yield files[i : i + batch_size] + + @staticmethod + def _torch_memory_monitor(interval=2, device=None, stop_event=None): + try: + # explicitly CUDA initialization + torch.cuda._lazy_init() + while not stop_event.is_set(): + allocated = torch.cuda.memory_allocated(device=device) + reserved = torch.cuda.memory_reserved(device=device) + logging.info( + f"[torch] Allocated Memory: {allocated / 1024**2:.2f} MiB" + ) + logging.info( + f"[torch] Reserved Memory: {reserved / 1024**2:.2f} MiB" + ) + time.sleep(interval) + except Exception as e: + logging.error(f"Unexpected error when monitor memory: {str(e)}") + + def continuous_run(self): + """ + Execute a continuous run of the batching optimization process. + """ + logging.info("Starting continuous_run with two rounds of optimization.") + + # torch memory monitor api + use_torch_memory_monitor = False + if use_torch_memory_monitor: + memory_monitor = threading.Thread( + target=Worker._torch_memory_monitor, + args=(2, self.device, self.stop_event), + ) + memory_monitor.start() + + # First round of optimization + try: + logging.info("Starting first round of optimization.") + results_round1, new_atoms_files = self.continuous_batching( + atoms_path=self.files, + result_path_prefix=os.path.join( + self.output_path, "cif_result_press/" + ), + fmax=0.01, + maxstep=self.max_steps, + use_filter=self.filter1, + optimizer=self.optimizer1, + scalar_pressure=self.scalar_pressure, + dtype=torch.float64, + ) + logging.info( + f"Completed first round of optimization. Results: {len(results_round1)}" + ) + except KeyboardInterrupt as e: + if use_torch_memory_monitor: + self.stop_event.set() + memory_monitor.join() + logging.error(f"Error during first round of optimization: {e}") + raise + except Exception as e: + logging.error(f"Error during first round of optimization: {e}") + raise + + if self.skip_second_stage: + logging.info("Skipping second round of optimization.") + return results_round1 + + # Second round of optimization without pressure + try: + logging.info("Starting second round of optimization.") + results_round2, _ = self.continuous_batching( + atoms_path=new_atoms_files, + result_path_prefix=os.path.join( + self.output_path, "cif_result_final/" + ), + fmax=0.01, + maxstep=self.max_steps, + # maxstep=3000, + use_filter=self.filter2, + optimizer=self.optimizer2, + scalar_pressure=0.0, + dtype=torch.float64, + ) + logging.info( + f"Completed second round of optimization. Results: {len(results_round2)}" + ) + except KeyboardInterrupt as e: + if use_torch_memory_monitor: + self.stop_event.set() + memory_monitor.join() + logging.error(f"Error during second round of optimization: {e}") + raise + except Exception as e: + logging.error(f"Error during second round of optimization: {e}") + raise + + if use_torch_memory_monitor: + self.stop_event.set() + memory_monitor.join() + + return self._save_results_to_csv(results_round1, results_round2) + + def _save_results_to_csv(self, results_round1, results_round2): + """Helper method to save results to CSV file and return the path.""" + combined_results = [] + results_map = {} + + # Process first round results + for result in results_round1: + file_name = result["file"] + results_map[file_name] = { + "file": file_name, + "stage1_steps": result["steps"], + "stage1_time": result["runtime"], + "stage1_energy": result["energy"], + "stage1_density": result["density"], + "stage2_steps": 0, + "stage2_time": 0.0, + "stage2_energy": 0.0, + "stage2_density": 0, + "total_steps": result["steps"], + "total_time": result["runtime"], + } + + # Process second round results + for result in results_round2: + file_name = result["file"] + if file_name in results_map: + results_map[file_name].update( + { + "stage2_steps": result["steps"], + "stage2_time": result["runtime"], + "stage2_energy": result["energy"], + "stage2_density": result["density"], + "total_steps": results_map[file_name]["stage1_steps"] + + result["steps"], + "total_time": results_map[file_name]["stage1_time"] + + result["runtime"], + } + ) + else: + results_map[file_name] = { + "file": file_name, + "stage1_steps": 0, + "stage1_time": 0.0, + "stage1_energy": 0.0, + "stage1_density": 0, + "stage2_steps": result["steps"], + "stage2_time": result["runtime"], + "stage2_energy": result["energy"], + "stage2_density": result["density"], + "total_steps": result["steps"], + "total_time": result["runtime"], + } + + # Convert map to list + combined_results = list(results_map.values()) + + logging.info( + f"Combined results from both rounds. Total results: {len(combined_results)}" + ) + + worker_id = os.getpid() + timestamp = int(time.time()) + csv_filename = f"worker_{worker_id}_{timestamp}.csv" + csv_path = os.path.join( + self.output_path, "worker_results", csv_filename + ) + os.makedirs(os.path.dirname(csv_path), exist_ok=True) + + with open(csv_path, mode="w", newline="") as csvfile: + fieldnames = [ + "file", + "stage1_steps", + "stage1_time", + "stage1_energy", + "stage1_density", + "stage2_steps", + "stage2_time", + "stage2_energy", + "stage2_density", + "total_steps", + "total_time", + ] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for result in combined_results: + writer.writerow(result) + + return csv_path + + def _get_density(self, crystal): + # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 + total_mass = sum(crystal.get_masses()) # 转换为克 + + # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 + # 1 Å^3 = 1e-24 cm^3 + volume = crystal.get_volume() # 转换为立方厘米 + + # 计算密度,质量除以体积 + density = ( + total_mass / (volume * 10**-24) / (6.022140857 * 10**23) + ) # 单位是 g/cm^3 + return density + + @staticmethod + def select_factor(history: deque): + # TODO: when history is mix of different size, the smaller `values` should be selected. + boundaries = [0, 50, 100, 200, 400, 800] + values = [0.4, 0.8, 0.9, 0.6, 0.5, 0.4] + factor_result = [] + for graph_size in history: + for i in range(len(boundaries) - 1): + if boundaries[i] <= graph_size < boundaries[i + 1]: + factor_result.append(values[i]) + break + if len(factor_result) == 0: + return 0.4 + else: + return min(factor_result) + + def continuous_batching( + self, + atoms_path, + result_path_prefix, + fmax, + maxstep, + use_filter, + optimizer, + scalar_pressure, + dtype=torch.float64, + ): + """ + Performs continuous batched optimization of atomic structures. + + This method implements a continuous batching strategy for optimizing multiple atomic structures, + where converged structures are replaced with new ones to maintain batch efficiency. + + Parameters + ---------- + atoms_path : list + List of file paths to atomic structure files to be optimized + result_path_prefix : str + Prefix for output file paths where optimized structures will be saved + fmax : float, optional + Maximum force criterion for convergence, by default 0.01 + maxstep : int, optional + Maximum number of optimization steps per batch, by default 3000 + use_filter : str, optional + Filter to be used for optimization, by default "UnitCellFilter" + optimizer : str, optional + Optimizer to be used for optimization, by default "LBFGS" + scalar_pressure : float, optional + Scalar pressure to be applied, by default 0.0 + + Returns + ------- + None + The optimized structures are saved to disk + + Notes + ----- + The method: + - Processes structures in batches of predefined size + - Uses MACE neural network potential for energy/force calculations + - Employs LBFGS optimization with unit cell relaxation + - Dynamically replaces converged structures with new ones in the batch + - Tracks convergence and optimization steps for each structure + """ + # Load saved structures + result = [] + optimized_atoms_paths = [] + + json_dir = result_path_prefix.replace("cif", "json") + + remove_list = [] + # TODO: Why we read all CIF here? + for pre_cif in atoms_path: + cif_path = os.path.join(result_path_prefix, pre_cif.split("/")[-1]) + json_path = os.path.join( + json_dir, pre_cif.split("/")[-1].replace(".cif", ".json") + ) + if ( + os.path.exists(cif_path) + and os.path.exists(json_path) + and os.path.getsize(cif_path) > 0 + and os.path.getsize(json_path) > 0 + ): + with open(json_path, "r") as f: + result_data = json.load(f) + result.append(result_data) + optimized_atoms_paths.append(cif_path) + remove_list.append(pre_cif) + logging.info(f"File {cif_path} already exists, loaded.") + # else: + # try: + # read(pre_cif) + # except Exception as e: + # logging.info(f"Failed to read {pre_cif}: {e}") + # remove_list.append(pre_cif) + for i in remove_list: + atoms_path.remove(i) + + if self.batch_size > 0: + # Initialize variables + room_in_batch = self.batch_size + indices_to_process = 0 + cur_batch_path = atoms_path[ + indices_to_process : indices_to_process + room_in_batch + ] + if len(cur_batch_path) == 0: + logging.info("No structures to process.") + return result, optimized_atoms_paths + room_in_batch -= len(cur_batch_path) + indices_to_process += len(cur_batch_path) + cur_atoms_list = [read(path) for path in cur_batch_path] + a2g = AtomsToGraphs(r_edges=False, r_pbc=True) + gbatch = data_list_collater( + [a2g.convert(read(path)) for path in cur_batch_path] + ) + else: + # Set Maximum Number of atoms per batch + history = deque(maxlen=10) + history.append(1000) + max_bnatoms = 24080 + safe_factor = self.select_factor(history) + + indices_to_process = 0 + bnatoms = 0 + cur_batch_path = [] + graphs_list = [] + a2g = AtomsToGraphs(r_edges=False, r_pbc=True) + + while indices_to_process < len(atoms_path): + graph_natoms = count_atoms_cif(atoms_path[indices_to_process]) + if ( + bnatoms + graph_natoms + > max_bnatoms * safe_factor // self.nproc + ): + break + graph = a2g.convert(read(atoms_path[indices_to_process])) + bnatoms += graph_natoms + cur_batch_path.append(atoms_path[indices_to_process]) + graphs_list.append(graph) + indices_to_process += 1 + history.append(graph_natoms) + safe_factor = self.select_factor(history) + if len(graphs_list) == 0: + logging.info("No structures to process.") + return result, optimized_atoms_paths + gbatch = data_list_collater(graphs_list) + logging.info(f"current batch size: {len(cur_batch_path)}") + + total_natoms = sum([graph.natoms for graph in graphs_list]) + logging.info(f"total_natoms: {total_natoms}") + + gbatch = gbatch.to(self.device) + batch_optimizer = None + + # Initial calculator + if self.model == "mace": + if dtype == torch.float32: + calculator = mace_off( + model="small", + device=self.device, + enable_cueq=self.cueq, + default_dtype="float32", + ) + else: + calculator = mace_off( + model="small", device=self.device, enable_cueq=self.cueq + ) + elif self.model == "chgnet": + calculator = CHGNetCalculator( + use_device=self.device, enable_cueq=self.cueq + ) + elif self.model == "sevennet": + # calculator = SevenNetCalculator(device=self.device, enable_cueq=self.cueq) + calculator = SevenNetD3Calculator( + device=self.device, + enable_cueq=self.cueq, + batch_size=self.batch_size, + ) + # calculator = SevenNetCalculator('7net-mf-ompa', modal='mpa', device=self.device) + # calculator = MACECalculator(model_paths="/home/mazhaojia/.cache/mace/MACE-OFF23_small.model", device=self.device, compile_mode=self.compile_mode) + if use_filter == "UnitCellFilter": + obatch = OptimizableUnitCellBatch( + gbatch, + trainer=calculator, + numpy=False, + scalar_pressure=scalar_pressure, + ) + else: + obatch = OptimizableBatch(gbatch, trainer=calculator, numpy=False) + + orig_cells = obatch.orig_cells.clone() + + converged_atoms_count = 0 + converge_indices = [] + all_indices = [] + cur_batch_steps = [0] * len(cur_batch_path) + cur_batch_times = [time.perf_counter()] * len( + cur_batch_path + ) # Track start times + + while converged_atoms_count < len(atoms_path): + # Update batch + if len(all_indices) > 0: + if self.batch_size > 0: + room_in_batch += len(all_indices) + new_batch_path = atoms_path[ + indices_to_process : indices_to_process + room_in_batch + ] + logging.info(f"new_batch_path: {new_batch_path}") + room_in_batch -= len(new_batch_path) + indices_to_process += len(new_batch_path) + + optimized_atoms_new = [] + cur_batch_path_new = [] + cur_batch_steps_new = [] + cur_batch_times_new = [] + orig_cells_new = torch.zeros( + [self.batch_size - room_in_batch, 3, 3], + device=self.device, + ) + cell_offset = 0 + + restart_indices = [] + old_batch_indices = obatch.batch_indices + for i in range(len(optimized_atoms)): + if i in all_indices: + continue + else: + restart_indices.append(i) + optimized_atoms_new.append(optimized_atoms[i]) + cur_batch_path_new.append(cur_batch_path[i]) + cur_batch_steps_new.append(cur_batch_steps[i]) + cur_batch_times_new.append(cur_batch_times[i]) + + orig_cells_new[cell_offset] = orig_cells[i] + cell_offset += 1 + + for new_path in new_batch_path: + optimized_atoms_new.append(read(new_path)) + cur_batch_path_new.append(new_path) + cur_batch_steps_new.append(0) + cur_batch_times_new.append(time.perf_counter()) + + # Update the batch with new structures + optimized_atoms = optimized_atoms_new + cur_batch_path = cur_batch_path_new + cur_batch_steps = cur_batch_steps_new + cur_batch_times = cur_batch_times_new + else: + bnatoms = 0 + optimized_atoms_new = [] + cur_batch_path_new = [] + cur_batch_steps_new = [] + cur_batch_times_new = [] + + restart_indices = [] + old_batch_indices = obatch.batch_indices + for i in range(len(optimized_atoms)): + if i in all_indices: + continue + restart_indices.append(i) + optimized_atoms_new.append(optimized_atoms[i]) + cur_batch_path_new.append(cur_batch_path[i]) + cur_batch_steps_new.append(cur_batch_steps[i]) + cur_batch_times_new.append(cur_batch_times[i]) + bnatoms += a2g.convert(read(cur_batch_path[i])).natoms + + while indices_to_process < len(atoms_path): + new_path = atoms_path[indices_to_process] + graph_natoms = count_atoms_cif(new_path) + if ( + bnatoms + graph_natoms + > max_bnatoms * safe_factor // self.nproc + ): + break + bnatoms += graph_natoms + optimized_atoms_new.append(read(new_path)) + cur_batch_path_new.append(new_path) + cur_batch_steps_new.append(0) + cur_batch_times_new.append(time.perf_counter()) + indices_to_process += 1 + history.append(graph_natoms) + safe_factor = self.select_factor(history) + + orig_cells_new = torch.zeros( + [len(optimized_atoms_new), 3, 3], device=self.device + ) + cell_offset = 0 + for i in range(len(optimized_atoms)): + if i in all_indices: + continue + orig_cells_new[cell_offset] = orig_cells[i] + cell_offset += 1 + + # Update the batch with new structures + optimized_atoms = optimized_atoms_new + cur_batch_path = cur_batch_path_new + cur_batch_steps = cur_batch_steps_new + cur_batch_times = cur_batch_times_new + + logging.info(f"current batch size: {len(optimized_atoms)}") + + graphs_list = [a2g.convert(atoms) for atoms in optimized_atoms] + total_natoms = sum([graph.natoms for graph in graphs_list]) + logging.info(f"total_natoms: {total_natoms}") + logging.info(f"cur_batch_path to processing: {cur_batch_path}") + gbatch = data_list_collater(graphs_list) + gbatch = gbatch.to(self.device) + if self.model == "sevennet": + # calculator = SevenNetCalculator('7net-mf-ompa', modal='mpa', device=self.device) + calculator = SevenNetD3Calculator( + device=self.device, + enable_cueq=self.cueq, + batch_size=self.batch_size, + ) + if use_filter == "UnitCellFilter": + obatch = OptimizableUnitCellBatch( + gbatch, + trainer=calculator, + numpy=False, + scalar_pressure=scalar_pressure, + ) + else: + obatch = OptimizableBatch( + gbatch, trainer=calculator, numpy=False + ) + for i in range(cell_offset): + obatch.orig_cells[i] = orig_cells_new[i] + orig_cells = obatch.orig_cells.clone() + + # Optimize the current batch + if optimizer == "LBFGS": + batch_optimizer = LBFGS( + obatch, + damping=1.0, + alpha=70.0, + maxstep=0.2, + early_stop=True, + ) + elif optimizer == "BFGS": + if len(all_indices) > 0: + logging.info(f"Restarting with indices: {restart_indices}") + batch_optimizer.optimizable = obatch + else: + batch_optimizer = BFGS( + obatch, alpha=70.0, maxstep=0.2, early_stop=True + ) + elif optimizer == "BFGSLineSearch": + batch_optimizer = BFGSLineSearch( + obatch, + device=self.device, + early_stop=True, + use_profiler=self.use_profiler, + profiler_log_dir=self.profiler_log_dir, + profiler_schedule_config=self.profiler_schedule_config, + ) + elif optimizer == "BFGSFusedLS": + if len(all_indices) > 0: + logging.info(f"Restarting with indices: {restart_indices}") + batch_optimizer.optimizable = obatch + else: + batch_optimizer = BFGSFusedLS( + obatch, + device=self.device, + early_stop=True, + use_profiler=self.use_profiler, + profiler_log_dir=self.profiler_log_dir, + profiler_schedule_config=self.profiler_schedule_config, + ) + else: + raise ValueError(f"Unknown optimizer: {optimizer}") + + # 动态计算剩余可用步数(基于当前批次最大已执行步数) + current_max_steps = max(cur_batch_steps) if cur_batch_steps else 0 + remaining_steps = max( + maxstep - current_max_steps, 1 + ) # 保证至少运行1步 + + # 执行优化并获取收敛的索引 + if (optimizer == "BFGSFusedLS" or optimizer == "BFGS") and len( + all_indices + ) > 0: + converge_indices = batch_optimizer.run( + fmax, + remaining_steps, + is_restart_earlystop=True, + restart_indices=restart_indices, + old_batch_indices=old_batch_indices, + ) + else: + converge_indices = batch_optimizer.run(fmax, remaining_steps) + + # Print energies of all structures + # logging.info(f"Final energies of all structures: {batch_optimizer.energies}") + energies_list = ( + batch_optimizer.optimizable.get_potential_energies().tolist() + ) + logging.info(f"Final energies of all structures: {energies_list}") + + # 更新所有结构的累计步数 + cur_batch_steps = [ + steps + batch_optimizer.nsteps for steps in cur_batch_steps + ] + + # 找出超过最大步数的结构索引 + over_maxstep_indices = [ + i + for i, steps in enumerate(cur_batch_steps) + if steps >= maxstep - 1 + ] + + # 合并收敛和超限的索引(去重) + all_indices = list(set(converge_indices + over_maxstep_indices)) + + # Get optimized atoms + optimized_atoms = obatch.get_atoms_list() + converged_atoms_count += len(all_indices) + + end_time = time.perf_counter() + # 处理所有需要退出的结构(包括收敛和超限) + for idx in all_indices: + runtime = end_time - cur_batch_times[idx] + + energy_per_mol = ( + energies_list[idx] + / ( + len(optimized_atoms[idx].get_atomic_numbers()) + / self.molecule_single + ) + * 96.485 + ) + density = self._get_density(optimized_atoms[idx]) + + # Save results + result_data = { + "file": cur_batch_path[idx].split("/")[-1].split(".")[0], + "steps": cur_batch_steps[idx], + "runtime": runtime, + "energy": energy_per_mol, + "density": density, + } + result.append(result_data) + + # Save optimized structure + # converged_atoms_path = os.path.join(result_path_prefix, cur_batch_path[idx].split('/')[-1].replace('.cif', '.traj')) + converged_atoms_path = os.path.join( + result_path_prefix, cur_batch_path[idx].split("/")[-1] + ) + optimized_atoms[idx].write(converged_atoms_path) + optimized_atoms_paths.append(converged_atoms_path) + + # write a json file to store reslt_data + os.makedirs(json_dir, exist_ok=True) + # json_path = os.path.join(json_dir, cur_batch_path[idx].split('/')[-1]+'.json') + json_path = os.path.join( + json_dir, + cur_batch_path[idx].split("/")[-1].replace(".cif", ".json"), + ) + with open(json_path, "w") as f: + json.dump(result_data, f) + + logging.info(f"cur_batch_path: {cur_batch_path}") + logging.info(f"cur_batch_steps: {cur_batch_steps}") + logging.info(f"all_indices: {all_indices}") + logging.info(f"length of optimized_atoms: {len(optimized_atoms)}") + + return result, optimized_atoms_paths diff --git a/mace-bench/src/batchopt/utils.py b/mace-bench/src/batchopt/utils.py index 7b899f4ccee3b451d2e57ab04f9a02c4b868fccc..ff707072696bfb4ab2dfaf0f219ea858fd8965c7 100644 --- a/mace-bench/src/batchopt/utils.py +++ b/mace-bench/src/batchopt/utils.py @@ -1,121 +1,121 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -import ast -import collections -import copy -import datetime -import errno -import functools -import importlib -import itertools -import json -import logging -import os -import subprocess -import sys -import time -from bisect import bisect -from contextlib import contextmanager -from dataclasses import dataclass -from functools import wraps -from itertools import product -from pathlib import Path -from typing import TYPE_CHECKING, Any -from uuid import uuid4 - -import numpy as np -import torch -import torch.nn as nn -import torch_geometric -import yaml -from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas -from matplotlib.figure import Figure -from torch_geometric.data import Data -from torch_geometric.utils import remove_self_loops -from torch_scatter import scatter, segment_coo, segment_csr - -from torch_geometric.data.data import BaseData -from torch_geometric.data import Batch - -# sort files by atomic number in descending order -def count_atoms_cif(file): - in_atom_site = False - natoms = 0 - with open(file, 'r') as f: - while line := f.readline(): - if line.lower().startswith("loop_"): - in_atom_site = False - continue - # if line.lower().startswith("_atom_site_"): - if "_atom_site_" in line.lower(): - in_atom_site = True - continue - if in_atom_site: - if line.startswith("_"): - in_atom_site = False - continue - elif line: - natoms += 1 - return natoms - -# Override the collation method in `pytorch_geometric.data.InMemoryDataset` -def collate(data_list): - keys = data_list[0].keys - data = data_list[0].__class__() - - for key in keys: - data[key] = [] - slices = {key: [0] for key in keys} - - for item, key in product(data_list, keys): - data[key].append(item[key]) - if torch.is_tensor(item[key]): - s = slices[key][-1] + item[key].size(item.__cat_dim__(key, item[key])) - elif isinstance(item[key], (int, float)): - s = slices[key][-1] + 1 - else: - raise ValueError("Unsupported attribute type") - slices[key].append(s) - - if hasattr(data_list[0], "__num_nodes__"): - data.__num_nodes__ = [] - for item in data_list: - data.__num_nodes__.append(item.num_nodes) - - for key in keys: - if torch.is_tensor(data_list[0][key]): - data[key] = torch.cat( - data[key], dim=data.__cat_dim__(key, data_list[0][key]) - ) - else: - data[key] = torch.tensor(data[key]) - slices[key] = torch.tensor(slices[key], dtype=torch.long) - - return data, slices - -def data_list_collater( - data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False -) -> BaseData | dict[str, torch.Tensor]: - batch = Batch.from_data_list(data_list) - - if not otf_graph: - try: - n_neighbors = [] - for _, data in enumerate(data_list): - n_index = data.edge_index[1, :] - n_neighbors.append(n_index.shape[0]) - batch.neighbors = torch.tensor(n_neighbors) - except (NotImplementedError, TypeError): - logging.warning( - "LMDB does not contain edge index information, set otf_graph=True" - ) - - if to_dict: - batch = dict(batch.items()) - +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import ast +import collections +import copy +import datetime +import errno +import functools +import importlib +import itertools +import json +import logging +import os +import subprocess +import sys +import time +from bisect import bisect +from contextlib import contextmanager +from dataclasses import dataclass +from functools import wraps +from itertools import product +from pathlib import Path +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +import numpy as np +import torch +import torch.nn as nn +import torch_geometric +import yaml +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +from torch_geometric.data import Data +from torch_geometric.utils import remove_self_loops +from torch_scatter import scatter, segment_coo, segment_csr + +from torch_geometric.data.data import BaseData +from torch_geometric.data import Batch + +# sort files by atomic number in descending order +def count_atoms_cif(file): + in_atom_site = False + natoms = 0 + with open(file, 'r') as f: + while line := f.readline(): + if line.lower().startswith("loop_"): + in_atom_site = False + continue + # if line.lower().startswith("_atom_site_"): + if "_atom_site_" in line.lower(): + in_atom_site = True + continue + if in_atom_site: + if line.startswith("_"): + in_atom_site = False + continue + elif line: + natoms += 1 + return natoms + +# Override the collation method in `pytorch_geometric.data.InMemoryDataset` +def collate(data_list): + keys = data_list[0].keys + data = data_list[0].__class__() + + for key in keys: + data[key] = [] + slices = {key: [0] for key in keys} + + for item, key in product(data_list, keys): + data[key].append(item[key]) + if torch.is_tensor(item[key]): + s = slices[key][-1] + item[key].size(item.__cat_dim__(key, item[key])) + elif isinstance(item[key], (int, float)): + s = slices[key][-1] + 1 + else: + raise ValueError("Unsupported attribute type") + slices[key].append(s) + + if hasattr(data_list[0], "__num_nodes__"): + data.__num_nodes__ = [] + for item in data_list: + data.__num_nodes__.append(item.num_nodes) + + for key in keys: + if torch.is_tensor(data_list[0][key]): + data[key] = torch.cat( + data[key], dim=data.__cat_dim__(key, data_list[0][key]) + ) + else: + data[key] = torch.tensor(data[key]) + slices[key] = torch.tensor(slices[key], dtype=torch.long) + + return data, slices + +def data_list_collater( + data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False +) -> BaseData | dict[str, torch.Tensor]: + batch = Batch.from_data_list(data_list) + + if not otf_graph: + try: + n_neighbors = [] + for _, data in enumerate(data_list): + n_index = data.edge_index[1, :] + n_neighbors.append(n_index.shape[0]) + batch.neighbors = torch.tensor(n_neighbors) + except (NotImplementedError, TypeError): + logging.warning( + "LMDB does not contain edge index information, set otf_graph=True" + ) + + if to_dict: + batch = dict(batch.items()) + return batch \ No newline at end of file diff --git a/mace-bench/util/env.sh b/mace-bench/util/env.sh index 00b7806c2914fbf1188759ac5e38cc7971ecfcc9..9a47c6a752f9ea0acb2c9090b021b93998211bcb 100644 --- a/mace-bench/util/env.sh +++ b/mace-bench/util/env.sh @@ -1,7 +1,6 @@ -#!/bin/bash - -export CUDA_HOME=/usr/local/cuda -export PATH=$CUDA_HOME/bin:$PATH -export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH -export LIBRARY_PATH=$CUDA_HOME/lib64:$LIBRARY_PATH + +export CUDA_HOME=/usr/local/cuda +export PATH=$CUDA_HOME/bin:$PATH +export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH +export LIBRARY_PATH=$CUDA_HOME/lib64:$LIBRARY_PATH export CPATH=$CUDA_HOME/include:$CPATH \ No newline at end of file diff --git a/mace-bench/util/mps_clean.sh b/mace-bench/util/mps_clean.sh index 2a1354b8586494f31767bbf1f36bb36dba86a9cd..15ef55800b833f994dead5baa2f8e3d445c6e797 100644 --- a/mace-bench/util/mps_clean.sh +++ b/mace-bench/util/mps_clean.sh @@ -1,11 +1,10 @@ -#!/bin/bash - -echo quit | nvidia-cuda-mps-control -nvidia-smi -i 0 -c DEFAULT -nvidia-smi -i 1 -c DEFAULT -nvidia-smi -i 2 -c DEFAULT -nvidia-smi -i 3 -c DEFAULT -nvidia-smi -i 4 -c DEFAULT -nvidia-smi -i 5 -c DEFAULT -nvidia-smi -i 6 -c DEFAULT + +echo quit | nvidia-cuda-mps-control +nvidia-smi -i 0 -c DEFAULT +nvidia-smi -i 1 -c DEFAULT +nvidia-smi -i 2 -c DEFAULT +nvidia-smi -i 3 -c DEFAULT +nvidia-smi -i 4 -c DEFAULT +nvidia-smi -i 5 -c DEFAULT +nvidia-smi -i 6 -c DEFAULT nvidia-smi -i 7 -c DEFAULT \ No newline at end of file diff --git a/mace-bench/util/mps_start.sh b/mace-bench/util/mps_start.sh index 46b0966e13ea992c6e2ff909aa3c5710944e6b37..9304a8582a7ef3e2b9bffe210472b1f0b80db2c9 100644 --- a/mace-bench/util/mps_start.sh +++ b/mace-bench/util/mps_start.sh @@ -1,10 +1,10 @@ -#!/bin/bash -nvidia-smi -i 0 -c EXCLUSIVE_PROCESS # Set GPU 0 to exclusive mode. -nvidia-smi -i 1 -c EXCLUSIVE_PROCESS # Set GPU 1 to exclusive mode. -nvidia-smi -i 2 -c EXCLUSIVE_PROCESS # Set GPU 2 to exclusive mode. -nvidia-smi -i 3 -c EXCLUSIVE_PROCESS # Set GPU 3 to exclusive mode. -nvidia-smi -i 4 -c EXCLUSIVE_PROCESS # Set GPU 4 to exclusive mode. -nvidia-smi -i 5 -c EXCLUSIVE_PROCESS # Set GPU 5 to exclusive mode. -nvidia-smi -i 6 -c EXCLUSIVE_PROCESS # Set GPU 6 to exclusive mode. -nvidia-smi -i 7 -c EXCLUSIVE_PROCESS # Set GPU 7 to exclusive mode. + +nvidia-smi -i 0 -c EXCLUSIVE_PROCESS # Set GPU 0 to exclusive mode. +nvidia-smi -i 1 -c EXCLUSIVE_PROCESS # Set GPU 1 to exclusive mode. +nvidia-smi -i 2 -c EXCLUSIVE_PROCESS # Set GPU 2 to exclusive mode. +nvidia-smi -i 3 -c EXCLUSIVE_PROCESS # Set GPU 3 to exclusive mode. +nvidia-smi -i 4 -c EXCLUSIVE_PROCESS # Set GPU 4 to exclusive mode. +nvidia-smi -i 5 -c EXCLUSIVE_PROCESS # Set GPU 5 to exclusive mode. +nvidia-smi -i 6 -c EXCLUSIVE_PROCESS # Set GPU 6 to exclusive mode. +nvidia-smi -i 7 -c EXCLUSIVE_PROCESS # Set GPU 7 to exclusive mode. nvidia-cuda-mps-control -d # Start the daemon. \ No newline at end of file diff --git a/main.py b/main.py index 093d67c646e3ddbb783792cff983209999362324..187c4df34c402d36aa97559594e51f4813c85a8f 100644 --- a/main.py +++ b/main.py @@ -1,95 +1,104 @@ -from basic_function import format_parser -from basic_function import packaged_function -from basic_function import conformer_search -import time -import argparse -import os -import itertools - - -if __name__ == '__main__': - - time_start = time.time() - - # initiate configuration - ############################################################################################## - parser = argparse.ArgumentParser() - parser.add_argument('--path', type=str, default="./", help='Path to process') - parser.add_argument('--smiles', type=str, required=True, help='SMILES string of the molecules, split by . if multiple molecules are used') - parser.add_argument('--generate_conformers', type=int, default=20, help='Number of conformers to generate. When it is <=0, only load existing conformers to generate structures') - parser.add_argument('--use_conformers', type=int, default=4, help='Number of conformers used to generate structure. When it is <=0, no structure generation would be done') - parser.add_argument('--molecule_num_in_cell', type=str, nargs='+', default=['1'], help='number of molecules in a unit cell, split by comma for multiple molecules, and split by space for multiple packings') - parser.add_argument('--num_generation', type=int, nargs='+', default=[100], help='number of structures to generate, split by space for multiple packings') - parser.add_argument('--space_group_list', type=str, nargs='+', default=["2,14"], help='Space group list for structure generation, spilt by comma to add mutiple groups, split by space for multiple packings') - parser.add_argument('--add_name', type=str, nargs='+', default=["CRYSTAL"], help='Add name for the generated structures, split by space for multiple packings') - parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of workers for parallel processing') - args = parser.parse_args() - - target_folder = args.path - smiles_list = args.smiles.split('.') - generate_conformers = args.generate_conformers - use_conformers = args.use_conformers - molecule_num_in_cell = [list(map(int, num.split(','))) for num in args.molecule_num_in_cell] - num_generation = args.num_generation - space_group_list = [list(map(int, group.split(','))) for group in args.space_group_list] - add_name = args.add_name - max_workers = args.max_workers - - num_molecules = len(smiles_list) - num_packings = max(len(molecule_num_in_cell), len(space_group_list)) - - for i in range(len(molecule_num_in_cell)): - if len(molecule_num_in_cell[i]) < num_molecules: - molecule_num_in_cell[i].extend([1] * (num_molecules - len(molecule_num_in_cell[i]))) - elif len(molecule_num_in_cell[i]) > num_molecules: - molecule_num_in_cell[i] = molecule_num_in_cell[i][:num_molecules] - - while len(molecule_num_in_cell) < num_packings: - molecule_num_in_cell.append(molecule_num_in_cell[-1]) - - while len(space_group_list) < num_packings: - space_group_list.append(space_group_list[-1]) - - while len(add_name) < num_packings: - add_name.append(add_name[-1]) - - while len(num_generation) < num_packings: - num_generation.append(num_generation[-1]) - - - # step1: conformer search - ############################################################################################## - molecule_data = [] - for i in range(num_molecules): - molecule_folder = os.path.join(target_folder, f"molecule_{i+1}") - molecule_data.append([]) - if generate_conformers > 0: - conformer_search.conformer_search(smiles_list[i], molecule_folder, num_conformers=generate_conformers, max_attempts=10000, rms_thresh=0.1) - with open(os.path.join(molecule_folder, "info.txt"), "w") as smiles_file: - smiles_file.write(f"SMILES: {smiles_list[i]}") - for j in range(use_conformers): - temp_path = os.path.join(molecule_folder, "conformers", f"conformer_{j}.xyz") - if not os.path.exists(temp_path): - break - molecule_data[i].append(format_parser.read_xyz_file(temp_path)) - if len(molecule_data[i]) <= 0: - print(f"No conformer loaded for molecule_{i+1}. Check configurations!") - break - - idx_data = [list(range(len(item))) for item in molecule_data] - combinations = list(itertools.product(*idx_data)) - - - # step2: structure generation - ############################################################################################## - for i in range(num_packings): - for combination in combinations: - molecule_list = [] - for j in range(num_molecules): - for cnt in range(molecule_num_in_cell[i][j]): - molecule_list.append(molecule_data[j][combination[j]]) - c_name = "".join(map(str, combination)) - packaged_function.CSP_generater_parallel(molecule_list, target_folder, need_structure=num_generation[i], space_group_list=space_group_list[i],add_name=f"{add_name[i]}_C{c_name}", max_workers=max_workers,start_seed=1) - - time_end=time.time() - print('time cost',time_end-time_start,'s') +from basic_function import format_parser +from basic_function import packaged_function +from basic_function import conformer_search +import time +import argparse +import os +import itertools + + +if __name__ == '__main__': + + time_start = time.time() + + # initiate configuration + ############################################################################################## + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, default="./", help='Path to process') + parser.add_argument('--smiles', type=str, required=True, help='SMILES string of the molecules, split by . if multiple molecules are used') + parser.add_argument('--generate_conformers', type=int, default=20, help='Number of conformers to generate. When it is <=0, only load existing conformers to generate structures') + parser.add_argument('--use_conformers', type=int, default=4, help='Number of conformers used to generate structure. When it is <=0, no structure generation would be done') + parser.add_argument('--molecule_num_in_cell', type=str, nargs='+', default=['1'], help='number of molecules in a unit cell, split by comma for multiple molecules, and split by space for multiple packings') + parser.add_argument('--num_generation', type=int, nargs='+', default=[100], help='number of structures to generate, split by space for multiple packings') + parser.add_argument('--space_group_list', type=str, nargs='+', default=["2,14"], help='Space group list for structure generation, spilt by comma to add mutiple groups, split by space for multiple packings') + parser.add_argument('--add_name', type=str, nargs='+', default=["CRYSTAL"], help='Add name for the generated structures, split by space for multiple packings') + parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of workers for parallel processing') + parser.add_argument('--mode', type=str, default=8, choices=['all', 'conformer_only', 'structure_only'], help='choose the jobs to do') + args = parser.parse_args() + + target_folder = args.path + smiles_list = args.smiles.split('.') + generate_conformers = args.generate_conformers + use_conformers = args.use_conformers + molecule_num_in_cell = [list(map(int, num.split(','))) for num in args.molecule_num_in_cell] + num_generation = args.num_generation + space_group_list = [list(map(int, group.split(','))) for group in args.space_group_list] + add_name = args.add_name + max_workers = args.max_workers + mode = args.mode + + num_molecules = len(smiles_list) + num_packings = max(len(molecule_num_in_cell), len(space_group_list)) + + for i in range(len(molecule_num_in_cell)): + if len(molecule_num_in_cell[i]) < num_molecules: + molecule_num_in_cell[i].extend([1] * (num_molecules - len(molecule_num_in_cell[i]))) + elif len(molecule_num_in_cell[i]) > num_molecules: + molecule_num_in_cell[i] = molecule_num_in_cell[i][:num_molecules] + + while len(molecule_num_in_cell) < num_packings: + molecule_num_in_cell.append(molecule_num_in_cell[-1]) + + while len(space_group_list) < num_packings: + space_group_list.append(space_group_list[-1]) + + while len(add_name) < num_packings: + add_name.append(add_name[-1]) + + while len(num_generation) < num_packings: + num_generation.append(num_generation[-1]) + + + # step1: conformer search + ############################################################################################## + molecule_data = [] + for i in range(num_molecules): + molecule_folder = os.path.join(target_folder, f"molecule_{i+1}") + molecule_data.append([]) + if generate_conformers > 0 and mode != "structure_only": + conformer_search.conformer_search(smiles_list[i], molecule_folder, num_conformers=generate_conformers, max_attempts=10000, rms_thresh=0.1) + with open(os.path.join(molecule_folder, "info.txt"), "w") as smiles_file: + smiles_file.write(f"SMILES: {smiles_list[i]}") + file_num = len(os.listdir(os.path.join(molecule_folder, "conformers"))) + cnt = 0 + for j in range(file_num): + if cnt >= use_conformers: + break + temp_path = os.path.join(molecule_folder, "conformers", f"conformer_{j}.xyz") + if not os.path.exists(temp_path): + break + molecule_data[i].append(format_parser.read_xyz_file(temp_path)) + cnt += 1 + + if len(molecule_data[i]) <= 0: + print(f"No conformer loaded for molecule_{i+1}. Check configurations!") + break + + idx_data = [list(range(len(item))) for item in molecule_data] + combinations = list(itertools.product(*idx_data)) + + + # step2: structure generation + ############################################################################################## + if mode != "conformer_only": + for i in range(num_packings): + for combination in combinations: + molecule_list = [] + for j in range(num_molecules): + for cnt in range(molecule_num_in_cell[i][j]): + molecule_list.append(molecule_data[j][combination[j]]) + c_name = "".join(map(str, combination)) + packaged_function.CSP_generater_parallel(molecule_list, target_folder, need_structure=num_generation[i], space_group_list=space_group_list[i],add_name=f"{add_name[i]}_C{c_name}", max_workers=max_workers,start_seed=1) + + time_end=time.time() + print('time cost',time_end-time_start,'s') diff --git a/post_process/check_match.py b/post_process/check_match.py index 5f27d3052cd17317210cee68fe13ef85b3d6350e..a4b8953270090ae857a1cfea7396ffc4b2b1a9ed 100644 --- a/post_process/check_match.py +++ b/post_process/check_match.py @@ -1,154 +1,154 @@ -import warnings -warnings.filterwarnings("ignore", category=DeprecationWarning) -warnings.filterwarnings("ignore", category=UserWarning) - -from ccdc.crystal import PackingSimilarity -from ccdc.io import CrystalReader -import glob -import os -import sys -import random -import pandas as pd -from multiprocessing import Pool, cpu_count, TimeoutError as mpTimeoutError # import TimeoutError -import argparse - -# --- Global Configuration --- -REPORT_TARGET = 15 -LARGE_CONFORMER_DIFF = True - -# --- Worker Process Initializer --- -def init_worker(ref_path, engine_settings): - """ - Initializes a worker process. - This function loads the reference structure and creates the similarity engine - once per process, storing them in global variables for that process. - """ - # print(f"Worker process {os.getpid()} initializing...") - # sys.stdout.flush() - global worker_ref_crystal - global worker_similarity_engine - worker_ref_crystal = CrystalReader(ref_path)[0] - worker_similarity_engine = PackingSimilarity() - worker_similarity_engine.settings.allow_molecular_differences = engine_settings['allow_molecular_differences'] - worker_similarity_engine.settings.distance_tolerance = engine_settings['distance_tolerance'] - worker_similarity_engine.settings.angle_tolerance = engine_settings['angle_tolerance'] - worker_similarity_engine.settings.packing_shell_size = engine_settings['packing_shell_size'] - worker_similarity_engine.settings.ignore_hydrogen_positions = engine_settings['ignore_hydrogen_positions'] - worker_similarity_engine.settings.ignore_bond_counts = engine_settings['ignore_bond_counts'] - worker_similarity_engine.settings.ignore_hydrogen_counts = engine_settings['ignore_hydrogen_counts'] - -# --- Single Task Processing Function --- -def process_single_cif(csp_file_path): - """ - Compares a single candidate structure against the reference structure - loaded in the worker's global scope. - Returns a tuple indicating the result type ('matched' or 'failed') and the file path. - """ - global worker_ref_crystal - global worker_similarity_engine - try: - try_structure = CrystalReader(csp_file_path)[0] - h = worker_similarity_engine.compare(try_structure, worker_ref_crystal) - if h.nmatched_molecules >= REPORT_TARGET: - print(f"MATCH: {os.path.basename(csp_file_path)} | Matched Molecules: {h.nmatched_molecules}, RMSD: {h.rmsd:.3f}") - sys.stdout.flush() - return ("matched", csp_file_path) - except Exception as e: - if not LARGE_CONFORMER_DIFF: - print(f"FAIL: {os.path.basename(csp_file_path)} | Reason: {e}") - sys.stdout.flush() - return ("failed", csp_file_path) - return None - -# --- Main Execution Block --- -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--path', type=str, default="./", help='Path to process') - parser.add_argument('--ref_path', type=str, default="../refs", help='Path to find reference structrues') - parser.add_argument('--workers', type=int, default=80, help='Max worker number limit') - parser.add_argument('--timeout', type=int, default=20, help='Timeout for each task in seconds') - - args = parser.parse_args() - base_path = args.path - refs_dir = args.ref_path - PROCESS_NUM = min(args.workers, cpu_count()) # Use the specified number of workers or the max available - TIMEOUT_SECONDS = args.timeout # Set the timeout for each task - all_results = [] - - print(f"Starting checking match using up to {PROCESS_NUM} processes with a {TIMEOUT_SECONDS}s timeout per task...") - - folders_to_process = [] - csp_dir = os.path.join(base_path, "cif_result_final") - if os.path.exists(csp_dir) and os.path.exists(refs_dir): - folders_to_process.append((csp_dir, refs_dir)) - - for csp_dir, refs_dir in folders_to_process: - for ref_filename in os.listdir(refs_dir): - if not ref_filename.endswith(".cif"): - continue - - ref_full_path = os.path.join(refs_dir, ref_filename) - print(f"\n--- Processing Reference File: {ref_full_path} ---") - - csp_files = glob.glob(os.path.join(csp_dir, '*.cif')) - random.shuffle(csp_files) - - if not csp_files: - print("No candidate .cif files found, skipping.") - continue - - engine_settings = { - 'allow_molecular_differences': False, - 'distance_tolerance': 0.2, - 'angle_tolerance': 20, - 'packing_shell_size': 15, - 'ignore_hydrogen_positions': True, - 'ignore_bond_counts': True, - 'ignore_hydrogen_counts': True - } - - with Pool(processes=PROCESS_NUM, initializer=init_worker, initargs=(ref_full_path, engine_settings)) as pool: - - async_results = [] - for f in csp_files: - res = pool.apply_async(process_single_cif, args=(f,)) - async_results.append(res) - - results_list = [] - for i, res_obj in enumerate(async_results): - try: - result = res_obj.get(timeout=TIMEOUT_SECONDS) - results_list.append(result) - except mpTimeoutError: - timed_out_file = csp_files[i] - print(f"TIMEOUT: {timed_out_file} | Task exceeded {TIMEOUT_SECONDS}s limit.") - sys.stdout.flush() - results_list.append(("failed", timed_out_file)) - - matched_structures = [] - failed_structures = [] - for res in results_list: - if res: - status, path = res - if status == "matched": - matched_structures.append(os.path.basename(path)) - elif status == "failed": - failed_structures.append(os.path.basename(path)) - - all_results.append({ - "ref_name": ref_filename, - "matched_count": len(matched_structures), - "matched_structures": ";".join(matched_structures), - "failed_count": len(failed_structures), - "failed_structures": ";".join(failed_structures) - }) - print(f"--- Finished {ref_filename}. Matched: {len(matched_structures)}, Failed: {len(failed_structures)} ---") - - if all_results: - df = pd.DataFrame(all_results) - output_filename = "match_results.csv" - df.to_csv(output_filename, index=False) - print(f"\nAll processing finished. Results saved to {output_filename}") - else: +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +from ccdc.crystal import PackingSimilarity +from ccdc.io import CrystalReader +import glob +import os +import sys +import random +import pandas as pd +from multiprocessing import Pool, cpu_count, TimeoutError as mpTimeoutError # import TimeoutError +import argparse + +# --- Global Configuration --- +REPORT_TARGET = 15 +LARGE_CONFORMER_DIFF = True + +# --- Worker Process Initializer --- +def init_worker(ref_path, engine_settings): + """ + Initializes a worker process. + This function loads the reference structure and creates the similarity engine + once per process, storing them in global variables for that process. + """ + # print(f"Worker process {os.getpid()} initializing...") + # sys.stdout.flush() + global worker_ref_crystal + global worker_similarity_engine + worker_ref_crystal = CrystalReader(ref_path)[0] + worker_similarity_engine = PackingSimilarity() + worker_similarity_engine.settings.allow_molecular_differences = engine_settings['allow_molecular_differences'] + worker_similarity_engine.settings.distance_tolerance = engine_settings['distance_tolerance'] + worker_similarity_engine.settings.angle_tolerance = engine_settings['angle_tolerance'] + worker_similarity_engine.settings.packing_shell_size = engine_settings['packing_shell_size'] + worker_similarity_engine.settings.ignore_hydrogen_positions = engine_settings['ignore_hydrogen_positions'] + worker_similarity_engine.settings.ignore_bond_counts = engine_settings['ignore_bond_counts'] + worker_similarity_engine.settings.ignore_hydrogen_counts = engine_settings['ignore_hydrogen_counts'] + +# --- Single Task Processing Function --- +def process_single_cif(csp_file_path): + """ + Compares a single candidate structure against the reference structure + loaded in the worker's global scope. + Returns a tuple indicating the result type ('matched' or 'failed') and the file path. + """ + global worker_ref_crystal + global worker_similarity_engine + try: + try_structure = CrystalReader(csp_file_path)[0] + h = worker_similarity_engine.compare(try_structure, worker_ref_crystal) + if h.nmatched_molecules >= REPORT_TARGET: + print(f"MATCH: {os.path.basename(csp_file_path)} | Matched Molecules: {h.nmatched_molecules}, RMSD: {h.rmsd:.3f}") + sys.stdout.flush() + return ("matched", csp_file_path) + except Exception as e: + if not LARGE_CONFORMER_DIFF: + print(f"FAIL: {os.path.basename(csp_file_path)} | Reason: {e}") + sys.stdout.flush() + return ("failed", csp_file_path) + return None + +# --- Main Execution Block --- +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, default="./", help='Path to process') + parser.add_argument('--ref_path', type=str, default="../refs", help='Path to find reference structrues') + parser.add_argument('--workers', type=int, default=80, help='Max worker number limit') + parser.add_argument('--timeout', type=int, default=20, help='Timeout for each task in seconds') + + args = parser.parse_args() + base_path = args.path + refs_dir = args.ref_path + PROCESS_NUM = min(args.workers, cpu_count()) # Use the specified number of workers or the max available + TIMEOUT_SECONDS = args.timeout # Set the timeout for each task + all_results = [] + + print(f"Starting checking match using up to {PROCESS_NUM} processes with a {TIMEOUT_SECONDS}s timeout per task...") + + folders_to_process = [] + csp_dir = os.path.join(base_path, "cif_result_final") + if os.path.exists(csp_dir) and os.path.exists(refs_dir): + folders_to_process.append((csp_dir, refs_dir)) + + for csp_dir, refs_dir in folders_to_process: + for ref_filename in os.listdir(refs_dir): + if not ref_filename.endswith(".cif"): + continue + + ref_full_path = os.path.join(refs_dir, ref_filename) + print(f"\n--- Processing Reference File: {ref_full_path} ---") + + csp_files = glob.glob(os.path.join(csp_dir, '*.cif')) + random.shuffle(csp_files) + + if not csp_files: + print("No candidate .cif files found, skipping.") + continue + + engine_settings = { + 'allow_molecular_differences': False, + 'distance_tolerance': 0.2, + 'angle_tolerance': 20, + 'packing_shell_size': 15, + 'ignore_hydrogen_positions': True, + 'ignore_bond_counts': True, + 'ignore_hydrogen_counts': True + } + + with Pool(processes=PROCESS_NUM, initializer=init_worker, initargs=(ref_full_path, engine_settings)) as pool: + + async_results = [] + for f in csp_files: + res = pool.apply_async(process_single_cif, args=(f,)) + async_results.append(res) + + results_list = [] + for i, res_obj in enumerate(async_results): + try: + result = res_obj.get(timeout=TIMEOUT_SECONDS) + results_list.append(result) + except mpTimeoutError: + timed_out_file = csp_files[i] + print(f"TIMEOUT: {timed_out_file} | Task exceeded {TIMEOUT_SECONDS}s limit.") + sys.stdout.flush() + results_list.append(("failed", timed_out_file)) + + matched_structures = [] + failed_structures = [] + for res in results_list: + if res: + status, path = res + if status == "matched": + matched_structures.append(os.path.basename(path)) + elif status == "failed": + failed_structures.append(os.path.basename(path)) + + all_results.append({ + "ref_name": ref_filename, + "matched_count": len(matched_structures), + "matched_structures": ";".join(matched_structures), + "failed_count": len(failed_structures), + "failed_structures": ";".join(failed_structures) + }) + print(f"--- Finished {ref_filename}. Matched: {len(matched_structures)}, Failed: {len(failed_structures)} ---") + + if all_results: + df = pd.DataFrame(all_results) + output_filename = "match_results.csv" + df.to_csv(output_filename, index=False) + print(f"\nAll processing finished. Results saved to {output_filename}") + else: print("\nNo valid data processed. No output file generated.") \ No newline at end of file diff --git a/post_process/clean_table.py b/post_process/clean_table.py index d3cfa3e06b8edbfc4220ef1851fb197b71eb826b..2099ecc4f1fe0e76b5e3897c1f68266da278cb00 100644 --- a/post_process/clean_table.py +++ b/post_process/clean_table.py @@ -1,49 +1,49 @@ -import os -import sys -import pandas as pd - - -target_folder = os.getcwd() - - -if __name__ == '__main__': - print(f"Cleaning table in folder: {target_folder}") - error_folder = os.path.join(target_folder, "error_result") - if not os.path.exists(error_folder): - os.makedirs(error_folder) - cif_folder = os.path.join(target_folder, "cif_result_final") - csv_file = os.path.join(target_folder, "results_scheduler.csv") - - if not os.path.exists(csv_file): - print(f"CSV file not found in {target_folder}, skipping...") - sys.exit(0) - if not os.path.exists(cif_folder): - print(f"CIF folder not found in {target_folder}, skipping...") - sys.exit(0) - - all_crystals = pd.read_csv(csv_file) - cifs = os.listdir(cif_folder) - invalid_list = [] - for i, item in all_crystals.iterrows(): - if int(item['stage2_steps']) > 2990 or (item['stage1_steps'] >= 2999 and item["stage2_steps"] == 0) or abs(float(item['stage2_energy'])) > 1e15 or item['file']+".cif" not in cifs: - invalid_list.append(i) - - if invalid_list: - invalid_df = all_crystals.iloc[invalid_list] - all_crystals.drop(index=invalid_list, inplace=True) - - all_crystals.sort_values(by=["stage2_energy"], ascending=True, inplace=True) - all_crystals.reset_index(drop=True, inplace=True) - - min_energy = all_crystals['stage2_energy'][0] - all_crystals['relative_energy'] = all_crystals['stage2_energy'] - min_energy - - for cif in cifs: - if cif[:-4] not in all_crystals['file'].values: - # move cif to error folder - src_path = os.path.join(cif_folder, cif) - dest_path = os.path.join(error_folder, cif) - os.rename(src_path, dest_path) - - # Save the cleaned DataFrame back to CSV +import os +import sys +import pandas as pd + + +target_folder = os.getcwd() + + +if __name__ == '__main__': + print(f"Cleaning table in folder: {target_folder}") + error_folder = os.path.join(target_folder, "error_result") + if not os.path.exists(error_folder): + os.makedirs(error_folder) + cif_folder = os.path.join(target_folder, "cif_result_final") + csv_file = os.path.join(target_folder, "results_scheduler.csv") + + if not os.path.exists(csv_file): + print(f"CSV file not found in {target_folder}, skipping...") + sys.exit(0) + if not os.path.exists(cif_folder): + print(f"CIF folder not found in {target_folder}, skipping...") + sys.exit(0) + + all_crystals = pd.read_csv(csv_file) + cifs = os.listdir(cif_folder) + invalid_list = [] + for i, item in all_crystals.iterrows(): + if int(item['stage2_steps']) > 2990 or (item['stage1_steps'] >= 2999 and item["stage2_steps"] == 0) or abs(float(item['stage2_energy'])) > 1e15 or item['file']+".cif" not in cifs: + invalid_list.append(i) + + if invalid_list: + invalid_df = all_crystals.iloc[invalid_list] + all_crystals.drop(index=invalid_list, inplace=True) + + all_crystals.sort_values(by=["stage2_energy"], ascending=True, inplace=True) + all_crystals.reset_index(drop=True, inplace=True) + + min_energy = all_crystals['stage2_energy'][0] + all_crystals['relative_energy'] = all_crystals['stage2_energy'] - min_energy + + for cif in cifs: + if cif[:-4] not in all_crystals['file'].values: + # move cif to error folder + src_path = os.path.join(cif_folder, cif) + dest_path = os.path.join(error_folder, cif) + os.rename(src_path, dest_path) + + # Save the cleaned DataFrame back to CSV all_crystals.to_csv(csv_file, index=False) \ No newline at end of file diff --git a/post_process/duplicate_remove.py b/post_process/duplicate_remove.py index a8da37301c5635989e54632be77c9620f03f0fbc..a594adb15739525b595d08c3e0376a28ffc8f1d3 100644 --- a/post_process/duplicate_remove.py +++ b/post_process/duplicate_remove.py @@ -1,247 +1,247 @@ -import sys -import os -import glob -import pandas as pd -import warnings -from pathlib import Path -from tqdm import tqdm -import multiprocessing as mp -from ccdc.crystal import PackingSimilarity -from ccdc.io import CrystalReader -import shutil -import argparse - -warnings.filterwarnings("ignore", category=DeprecationWarning) - -######################### -# Global Settings -######################### - -# The number of matching molecules required in a packing shell to consider two crystals similar. -molecule_shell_size = 15 - -# Initialize the packing similarity engine from the CCDC API. -similarity_engine = PackingSimilarity() - -# Configure the similarity engine settings. -similarity_engine.settings.allow_molecular_differences = False -similarity_engine.settings.distance_tolerance = 0.2 -similarity_engine.settings.angle_tolerance = 20 -similarity_engine.settings.packing_shell_size = molecule_shell_size - -# Settings to make the comparison less strict regarding hydrogen atoms and bond counts. -similarity_engine.settings.ignore_hydrogen_positions = True -similarity_engine.settings.ignore_bond_counts = True -similarity_engine.settings.ignore_hydrogen_counts = True - -# Pre-filtering thresholds to avoid expensive comparisons for vastly different structures. -ENERGY_DIFF = 5.0 -DENSITY_DIFF = 0.2 - - -def compare_sim_pack_pair(args): - """ - Compares two crystal structures to determine if they are identical based on packing similarity. - Accepts a tuple (candidate_fp, ufp) containing file paths to the two structures. - Returns True if the structures are considered the same, False otherwise. - """ - cand_fp, ufp = args - try: - c1 = CrystalReader(cand_fp)[0] - c2 = CrystalReader(ufp)[0] - h = similarity_engine.compare(c1, c2) - # Structures are deemed identical if the number of matched molecules meets the global shell size. - return h.nmatched_molecules >= molecule_shell_size - except Exception: - # Return False if any error occurs during file reading or comparison. - return False - -def find_top_unique(struct_list, target_count, n_workers): - """ - Identifies a target number of unique structures from a list, sorted by energy. - It iterates through structures from lowest to highest energy and uses parallel processing - to compare each candidate against the list of already confirmed unique structures. - - Args: - struct_list (list): A list of tuples, where each tuple is (density, energy, filepath). - target_count (int): The desired number of unique structures to find. - n_workers (int): The number of worker processes for parallel comparison. - - Returns: - tuple: A tuple containing: - - A list of file paths for the unique structures found. - - A dictionary mapping each unique structure's path to a list of its duplicate paths. - """ - unique_list = [] - - # A dictionary to store each unique structure and its corresponding duplicates. - # Format: {unique_fp: [dup_fp1, dup_fp2, ...]} - duplicate_map = {} - - # Initialize the multiprocessing pool. - pool = mp.Pool(processes=n_workers) - try: - # Iterate through the main structure list, which is pre-sorted by energy. - for dens, ene, cand_fp in struct_list: - # Stop if the target count of unique structures has been reached. - if len(unique_list) >= target_count: - break - - # Create comparison tasks only for unique structures within the density and energy thresholds. - tasks = [(cand_fp, ufp) for dens2, ene2, ufp in unique_list if abs(dens2 - dens) < DENSITY_DIFF and abs(ene2 - ene) < ENERGY_DIFF] - is_dup = False - - if tasks: - # Use tqdm for a progress bar during parallel comparison. - results_iterator = tqdm(pool.imap(compare_sim_pack_pair, tasks), - total=len(tasks), - desc=f"Comparing {os.path.basename(cand_fp)}", - leave=False) - # Convert iterator to a list to process results. - results_list = list(results_iterator) - - # Pair the boolean results with the original tasks to identify which comparison was successful. - for result, task in zip(results_list, tasks): - if result: - is_dup = True - # The second element of the task tuple is the path of the matched unique structure. - matched_unique_fp = task[1] - - # Record the current candidate as a duplicate of the matched unique structure. - duplicate_map[matched_unique_fp].append(cand_fp) - - # Once a duplicate is found, no more comparisons are needed for this candidate. - break - - if not is_dup: - # If the candidate is not a duplicate of any existing unique structure, add it to the list. - unique_list.append((dens, ene, cand_fp)) - - # Create a new entry in the duplicate map for this new unique structure. - duplicate_map[cand_fp] = [] - - # Print real-time progress. - print(f">>> Unique count: {len(unique_list)} (added {os.path.basename(cand_fp)})") - - finally: - # Ensure the multiprocessing pool is properly closed. - pool.close() - pool.join() - - # Return the list of unique file paths and the map of duplicates. - return [fp for _, _, fp in unique_list], duplicate_map - - -def process_folder(folder_path): - """ - Main workflow function to process a given folder. It reads structural data, - finds unique structures, and organizes the results into folders and a CSV file. - """ - base_path = folder_path - - # Define output directories for unique structures and their duplicates. - t_folder = os.path.join(base_path, f"unique_{TARGET_UNIQUE_COUNT}") - d_folder = os.path.join(base_path, f"unique_{TARGET_UNIQUE_COUNT}_duplicates") - - # Load the summary data from the CSV file. - df = pd.read_csv(os.path.join(base_path, "results_scheduler.csv")) - cif_folder = os.path.join(base_path, "cif_result_final") - ######################### - - # Create a map from a simplified filename to its full file path for quick lookup. - file_paths = glob.glob(os.path.join(cif_folder, "*.cif")) - file_map = {} - for fp in file_paths: - base = Path(fp).stem - # Standardize the name by removing "_opt" suffix if it exists. - if base.endswith("_opt"): - key = base[:-4] - else: - key = base - file_map[key] = fp - - # Build the list of candidate structures, filtering by the energy threshold. - struct_list = [] - for _, row in df.iterrows(): - name = row["file"] - dens = row["stage2_density"] - ene = row["relative_energy"] - if ene > ENERGY_THRESHOLD: - continue - fp = file_map.get(name) - if fp: - struct_list.append((dens, ene, fp)) - else: - print(f"Warning: no CIF for {name}") - - # Sort the list of structures by energy in ascending order. - struct_list.sort(key=lambda x: x[1]) - - # Call the main function to find unique structures and the duplicate map. - unique_paths, duplicate_map = find_top_unique( - struct_list, - target_count=TARGET_UNIQUE_COUNT, - n_workers=WORKER_COUNT - ) - - # Prepare output directories, removing them first if they already exist. - if os.path.exists(t_folder): - shutil.rmtree(t_folder) - os.makedirs(t_folder) - - if os.path.exists(d_folder): - shutil.rmtree(d_folder) - os.makedirs(d_folder) - - print(f"\nFound {len(unique_paths)} unique structures " - f"(energy <= {ENERGY_THRESHOLD}, target {TARGET_UNIQUE_COUNT}).") - - # Copy the unique structure files to the target folder. - for p in unique_paths: - shutil.copy(p, t_folder) - - # Helper function to get a clean filename without the "_opt" suffix. - get_clean_name = lambda p: Path(p).stem.replace('_opt', '') - - # Create a new DataFrame containing only the data for the unique structures. - unique_names = [get_clean_name(p) for p in unique_paths] - unique_df = df[df['file'].isin(unique_names)].copy() - unique_df['duplicates'] = '' # Add a new column to store duplicate names. - - # Populate the 'duplicates' column and copy duplicate files to their folder. - for unique_fp, duplicates_list in duplicate_map.items(): - unique_name = get_clean_name(unique_fp) - if unique_name in unique_df['file'].values: - du_name = [get_clean_name(p) for p in duplicates_list] - # Add the list of duplicate names as a comma-separated string. - unique_df.loc[unique_df['file'] == unique_name, 'duplicates'] = ', '.join(du_name) - # Copy duplicate files to the duplicates folder. - for p in duplicates_list: - shutil.copy(p, d_folder) - - # Save the final DataFrame with unique structures and their duplicates to a new CSV file. - unique_csv_path = os.path.join(base_path, 'unique_structures.csv') - unique_df.to_csv(unique_csv_path, index=False) - -if __name__ == '__main__': - # Required for freezing the application when creating executables with multiprocessing. - mp.freeze_support() - - # Set up command-line argument parsing. - parser = argparse.ArgumentParser() - parser.add_argument('--path', type=str, default="./", help='Path to process') - parser.add_argument('--energy', type=float, default=30, help='Energy threshold for filtering structures') - parser.add_argument('--count', type=int, default=200, help='Target number of unique structures to find') - parser.add_argument('--workers', type=int, default=80, help='Max worker number limit') - args = parser.parse_args() - - # Set global variables from command-line arguments. - target_folder = args.path - ENERGY_THRESHOLD = args.energy - TARGET_UNIQUE_COUNT = args.count - WORKER_COUNT = args.workers - - # Start the processing if the specified folder exists. - if os.path.exists(target_folder): - print(f"Processing folder: {target_folder}") +import sys +import os +import glob +import pandas as pd +import warnings +from pathlib import Path +from tqdm import tqdm +import multiprocessing as mp +from ccdc.crystal import PackingSimilarity +from ccdc.io import CrystalReader +import shutil +import argparse + +warnings.filterwarnings("ignore", category=DeprecationWarning) + +######################### +# Global Settings +######################### + +# The number of matching molecules required in a packing shell to consider two crystals similar. +molecule_shell_size = 15 + +# Initialize the packing similarity engine from the CCDC API. +similarity_engine = PackingSimilarity() + +# Configure the similarity engine settings. +similarity_engine.settings.allow_molecular_differences = False +similarity_engine.settings.distance_tolerance = 0.2 +similarity_engine.settings.angle_tolerance = 20 +similarity_engine.settings.packing_shell_size = molecule_shell_size + +# Settings to make the comparison less strict regarding hydrogen atoms and bond counts. +similarity_engine.settings.ignore_hydrogen_positions = True +similarity_engine.settings.ignore_bond_counts = True +similarity_engine.settings.ignore_hydrogen_counts = True + +# Pre-filtering thresholds to avoid expensive comparisons for vastly different structures. +ENERGY_DIFF = 5.0 +DENSITY_DIFF = 0.2 + + +def compare_sim_pack_pair(args): + """ + Compares two crystal structures to determine if they are identical based on packing similarity. + Accepts a tuple (candidate_fp, ufp) containing file paths to the two structures. + Returns True if the structures are considered the same, False otherwise. + """ + cand_fp, ufp = args + try: + c1 = CrystalReader(cand_fp)[0] + c2 = CrystalReader(ufp)[0] + h = similarity_engine.compare(c1, c2) + # Structures are deemed identical if the number of matched molecules meets the global shell size. + return h.nmatched_molecules >= molecule_shell_size + except Exception: + # Return False if any error occurs during file reading or comparison. + return False + +def find_top_unique(struct_list, target_count, n_workers): + """ + Identifies a target number of unique structures from a list, sorted by energy. + It iterates through structures from lowest to highest energy and uses parallel processing + to compare each candidate against the list of already confirmed unique structures. + + Args: + struct_list (list): A list of tuples, where each tuple is (density, energy, filepath). + target_count (int): The desired number of unique structures to find. + n_workers (int): The number of worker processes for parallel comparison. + + Returns: + tuple: A tuple containing: + - A list of file paths for the unique structures found. + - A dictionary mapping each unique structure's path to a list of its duplicate paths. + """ + unique_list = [] + + # A dictionary to store each unique structure and its corresponding duplicates. + # Format: {unique_fp: [dup_fp1, dup_fp2, ...]} + duplicate_map = {} + + # Initialize the multiprocessing pool. + pool = mp.Pool(processes=n_workers) + try: + # Iterate through the main structure list, which is pre-sorted by energy. + for dens, ene, cand_fp in struct_list: + # Stop if the target count of unique structures has been reached. + if len(unique_list) >= target_count: + break + + # Create comparison tasks only for unique structures within the density and energy thresholds. + tasks = [(cand_fp, ufp) for dens2, ene2, ufp in unique_list if abs(dens2 - dens) < DENSITY_DIFF and abs(ene2 - ene) < ENERGY_DIFF] + is_dup = False + + if tasks: + # Use tqdm for a progress bar during parallel comparison. + results_iterator = tqdm(pool.imap(compare_sim_pack_pair, tasks), + total=len(tasks), + desc=f"Comparing {os.path.basename(cand_fp)}", + leave=False) + # Convert iterator to a list to process results. + results_list = list(results_iterator) + + # Pair the boolean results with the original tasks to identify which comparison was successful. + for result, task in zip(results_list, tasks): + if result: + is_dup = True + # The second element of the task tuple is the path of the matched unique structure. + matched_unique_fp = task[1] + + # Record the current candidate as a duplicate of the matched unique structure. + duplicate_map[matched_unique_fp].append(cand_fp) + + # Once a duplicate is found, no more comparisons are needed for this candidate. + break + + if not is_dup: + # If the candidate is not a duplicate of any existing unique structure, add it to the list. + unique_list.append((dens, ene, cand_fp)) + + # Create a new entry in the duplicate map for this new unique structure. + duplicate_map[cand_fp] = [] + + # Print real-time progress. + print(f">>> Unique count: {len(unique_list)} (added {os.path.basename(cand_fp)})") + + finally: + # Ensure the multiprocessing pool is properly closed. + pool.close() + pool.join() + + # Return the list of unique file paths and the map of duplicates. + return [fp for _, _, fp in unique_list], duplicate_map + + +def process_folder(folder_path): + """ + Main workflow function to process a given folder. It reads structural data, + finds unique structures, and organizes the results into folders and a CSV file. + """ + base_path = folder_path + + # Define output directories for unique structures and their duplicates. + t_folder = os.path.join(base_path, f"unique_{TARGET_UNIQUE_COUNT}") + d_folder = os.path.join(base_path, f"unique_{TARGET_UNIQUE_COUNT}_duplicates") + + # Load the summary data from the CSV file. + df = pd.read_csv(os.path.join(base_path, "results_scheduler.csv")) + cif_folder = os.path.join(base_path, "cif_result_final") + ######################### + + # Create a map from a simplified filename to its full file path for quick lookup. + file_paths = glob.glob(os.path.join(cif_folder, "*.cif")) + file_map = {} + for fp in file_paths: + base = Path(fp).stem + # Standardize the name by removing "_opt" suffix if it exists. + if base.endswith("_opt"): + key = base[:-4] + else: + key = base + file_map[key] = fp + + # Build the list of candidate structures, filtering by the energy threshold. + struct_list = [] + for _, row in df.iterrows(): + name = row["file"] + dens = row["stage2_density"] + ene = row["relative_energy"] + if ene > ENERGY_THRESHOLD: + continue + fp = file_map.get(name) + if fp: + struct_list.append((dens, ene, fp)) + else: + print(f"Warning: no CIF for {name}") + + # Sort the list of structures by energy in ascending order. + struct_list.sort(key=lambda x: x[1]) + + # Call the main function to find unique structures and the duplicate map. + unique_paths, duplicate_map = find_top_unique( + struct_list, + target_count=TARGET_UNIQUE_COUNT, + n_workers=WORKER_COUNT + ) + + # Prepare output directories, removing them first if they already exist. + if os.path.exists(t_folder): + shutil.rmtree(t_folder) + os.makedirs(t_folder) + + if os.path.exists(d_folder): + shutil.rmtree(d_folder) + os.makedirs(d_folder) + + print(f"\nFound {len(unique_paths)} unique structures " + f"(energy <= {ENERGY_THRESHOLD}, target {TARGET_UNIQUE_COUNT}).") + + # Copy the unique structure files to the target folder. + for p in unique_paths: + shutil.copy(p, t_folder) + + # Helper function to get a clean filename without the "_opt" suffix. + get_clean_name = lambda p: Path(p).stem.replace('_opt', '') + + # Create a new DataFrame containing only the data for the unique structures. + unique_names = [get_clean_name(p) for p in unique_paths] + unique_df = df[df['file'].isin(unique_names)].copy() + unique_df['duplicates'] = '' # Add a new column to store duplicate names. + + # Populate the 'duplicates' column and copy duplicate files to their folder. + for unique_fp, duplicates_list in duplicate_map.items(): + unique_name = get_clean_name(unique_fp) + if unique_name in unique_df['file'].values: + du_name = [get_clean_name(p) for p in duplicates_list] + # Add the list of duplicate names as a comma-separated string. + unique_df.loc[unique_df['file'] == unique_name, 'duplicates'] = ', '.join(du_name) + # Copy duplicate files to the duplicates folder. + for p in duplicates_list: + shutil.copy(p, d_folder) + + # Save the final DataFrame with unique structures and their duplicates to a new CSV file. + unique_csv_path = os.path.join(base_path, 'unique_structures.csv') + unique_df.to_csv(unique_csv_path, index=False) + +if __name__ == '__main__': + # Required for freezing the application when creating executables with multiprocessing. + mp.freeze_support() + + # Set up command-line argument parsing. + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, default="./", help='Path to process') + parser.add_argument('--energy', type=float, default=30, help='Energy threshold for filtering structures') + parser.add_argument('--count', type=int, default=200, help='Target number of unique structures to find') + parser.add_argument('--workers', type=int, default=80, help='Max worker number limit') + args = parser.parse_args() + + # Set global variables from command-line arguments. + target_folder = args.path + ENERGY_THRESHOLD = args.energy + TARGET_UNIQUE_COUNT = args.count + WORKER_COUNT = args.workers + + # Start the processing if the specified folder exists. + if os.path.exists(target_folder): + print(f"Processing folder: {target_folder}") process_folder(target_folder) \ No newline at end of file diff --git a/post_process/run_remove.sh b/post_process/run_remove.sh index 8fc67772f7cb65fd9f7734d94f171a2617f6ebf3..43f555c67664d77dd9a5e4e911daf09b843f3859 100644 --- a/post_process/run_remove.sh +++ b/post_process/run_remove.sh @@ -1,41 +1,40 @@ -#!/bin/sh -l -#An example for GPU job. -#SBATCH -D ./ -#SBATCH --export=ALL -#SBATCH -J csp_test -#SBATCH -o job-%j.log -#SBATCH -e job-%j.err -#SBATCH -p GPU-8A100 -#SBATCH -N 1 -n 8 -#SBATCH --gres=gpu:1 -#SBATCH --qos=gpu_8a100 -#SBATCH -t 3-00:00:00 - - -echo ========================================================= -echo SLURM job: submitted date = `date` -date_start=`date +%s` - -echo ========================================================= -echo Job output begins -echo ----------------- -echo - - -python duplicate_remove.py - - -echo -echo --------------- -echo Job output ends -date_end=`date +%s` -seconds=$((date_end-date_start)) -minutes=$((seconds/60)) -seconds=$((seconds-60*minutes)) -hours=$((minutes/60)) -minutes=$((minutes-60*hours)) -echo ========================================================= -echo SLURM job: finished date = `date` -echo Total run time : $hours Hours $minutes Minutes $seconds Seconds -echo ========================================================= - +#An example for GPU job. +#SBATCH -D ./ +#SBATCH --export=ALL +#SBATCH -J csp_test +#SBATCH -o job-%j.log +#SBATCH -e job-%j.err +#SBATCH -p GPU-8A100 +#SBATCH -N 1 -n 8 +#SBATCH --gres=gpu:1 +#SBATCH --qos=gpu_8a100 +#SBATCH -t 3-00:00:00 + + +echo ========================================================= +echo SLURM job: submitted date = `date` +date_start=`date +%s` + +echo ========================================================= +echo Job output begins +echo ----------------- +echo + + +python duplicate_remove.py + + +echo +echo --------------- +echo Job output ends +date_end=`date +%s` +seconds=$((date_end-date_start)) +minutes=$((seconds/60)) +seconds=$((seconds-60*minutes)) +hours=$((minutes/60)) +minutes=$((minutes-60*hours)) +echo ========================================================= +echo SLURM job: finished date = `date` +echo Total run time : $hours Hours $minutes Minutes $seconds Seconds +echo ========================================================= +