Unverified Commit 66462d91 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent b95dac35
"""
This module defines the core data structures for representing atomic structures:
Atom, Crystal, and Molecule.
These classes store information about atomic coordinates, lattice parameters,
and other physical properties, providing a foundational toolkit for geometric
and structural analysis in materials science simulations.
"""
import copy
from typing import List, Tuple, Union, Any, Optional, Dict
import numpy as np
import fractions
import re
from scipy.spatial.distance import cdist
from tqdm import tqdm
from basic_function import unit_cell_parser
from basic_function import chemical_knowledge
from basic_function import operation
class Atom:
"""
Represents a single atom in a chemical structure.
Attributes:
element (str): The chemical symbol of the atom (e.g., 'H', 'C', 'O').
cart_xyz (np.ndarray): Cartesian coordinates [x, y, z] in Angstroms.
frac_xyz (np.ndarray): Fractional coordinates [u, v, w] with respect to a lattice.
atom_id (int): A unique identifier for the atom within a larger structure.
force (np.ndarray): Force vector [fx, fy, fz] acting on the atom.
atom_charge (float): Partial charge of the atom.
atom_energy (float): Site potential energy of the atom.
molecule (int): Identifier for the molecule this atom belongs to.
bonded_atom (list): A list of IDs of atoms bonded to this one.
descriptor (any): A placeholder for feature vectors or other descriptors.
comment (dict): A dictionary for storing arbitrary metadata.
"""
def __init__(self, **kwargs: Any):
"""
Initializes an Atom object.
Args:
**kwargs: Keyword arguments to set atom attributes.
Required: 'element' and one of 'cart_xyz' or 'frac_xyz'.
Optional: 'atom_id', 'force_xyz', 'atom_charge', 'atom_energy', etc.
"""
self.element: str = kwargs.get("element", "unknown")
self.cart_xyz: Union[str, np.ndarray] = kwargs.get('cart_xyz', "unknown")
self.frac_xyz: Union[str, np.ndarray] = kwargs.get('frac_xyz', "unknown")
self.atom_id: Union[str, int] = kwargs.get("atom_id", "unknown")
self.force: Union[str, np.ndarray] = kwargs.get('force_xyz', 'unknown')
self.atom_charge: Union[str, float] = kwargs.get('atom_charge', 'unknown')
self.atom_energy: Union[str, float] = kwargs.get('atom_energy', 'unknown')
self.molecule: Union[str, int] = kwargs.get('molecule', 'unknown')
self.bonded_atom: list = kwargs.get('bonded_atom', [])
self.descriptor: Any = kwargs.get("descriptor", "unknown")
self.comment: dict = kwargs.get("comment", {})
def info(self) -> None:
"""Prints all attributes of the atom to the console."""
for key, value in self.__dict__.items():
print(f"{key}: {value}")
def check(self) -> None:
"""
Performs basic sanity checks on the atom's attributes.
Raises:
AssertionError: If element is not defined, or if neither cartesian
nor fractional coordinates are provided.
"""
assert self.element != "unknown", "Atom must have an element type."
has_cart = isinstance(self.cart_xyz, (np.ndarray, list))
has_frac = isinstance(self.frac_xyz, (np.ndarray, list))
assert has_cart or has_frac, "Atom needs either cart_xyz or frac_xyz."
class Crystal:
"""
Represents a periodic crystal structure.
Contains lattice information (cell parameters or vectors) and a list of atoms
that constitute the structure within the unit cell.
"""
def __init__(self, **kwargs: Any):
"""
Initializes a Crystal object.
The constructor requires lattice and atom information. It will automatically
calculate derived properties like volume and density.
Args:
**kwargs: Keyword arguments to set crystal attributes.
Required: 'atoms' and one of 'cell_vect' or 'cell_para'.
- cell_vect (list): 3x3 list or array of lattice vectors.
- cell_para (list): [[a, b, c], [alpha, beta, gamma]].
- atoms (List[Atom]): A list of Atom objects.
Optional: 'energy', 'comment', 'system_name', 'space_group', 'SYMM', etc.
"""
self.cell_vect: Union[str, np.ndarray] = kwargs.get("cell_vect", "unknown")
self.cell_para: Union[str, list] = kwargs.get("cell_para", "unknown")
self.atoms: Union[str, List[Atom]] = kwargs.get("atoms", "unknown")
self.energy: Union[str, float] = kwargs.get("energy", "unknown")
self.comment: Any = kwargs.get("comment", "unknown")
self.descriptor: Any = kwargs.get("descriptor", "unknown")
self.molecule_number: Union[str, int] = kwargs.get("molecule_number", "unknown")
self.system_name: str = kwargs.get("system_name", "unknown")
self.virial: Any = kwargs.get("virial", "unknown")
self.SYMM: list = kwargs.get("SYMM", ["x,y,z"])
self.space_group: int = kwargs.get("space_group", 1)
self.other_properties: dict = {}
# This method completes the initialization.
self.lattice_and_atom_complete()
def lattice_and_atom_complete(self) -> None:
"""
Completes the initialization by ensuring consistency between cell representations,
atom coordinates, and calculating derived properties.
"""
# --- 1. Finalize Lattice Representation ---
has_vect = isinstance(self.cell_vect, (np.ndarray, list))
has_para = isinstance(self.cell_para, (np.ndarray, list))
if has_vect and has_para:
# If both are provided, check for consistency
derived_para = np.array(unit_cell_parser.cell_vect_to_para(self.cell_vect)).flatten()
provided_para = np.array(self.cell_para).flatten()
assert np.allclose(derived_para, provided_para, atol=1e-3), \
"Provided cell_para and cell_vect are inconsistent."
elif has_vect and not has_para:
self.cell_para = unit_cell_parser.cell_vect_to_para(self.cell_vect)
elif not has_vect and has_para:
self.cell_vect = unit_cell_parser.cell_para_to_vect(self.cell_para)
else:
raise ValueError("Crystal lattice is not defined. Provide 'cell_vect' or 'cell_para'.")
# --- 2. Finalize Atom Coordinates ---
if self.atoms == "unknown" or not self.atoms:
print("Warning: Crystal initialized with no atoms.")
self.atoms = []
else:
for atom in self.atoms:
has_cart = isinstance(atom.cart_xyz, (np.ndarray, list))
has_frac = isinstance(atom.frac_xyz, (np.ndarray, list))
if not (has_cart and has_frac):
if has_frac:
atom.cart_xyz = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect)
elif has_cart:
atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_vect(atom.cart_xyz, self.cell_vect)
else:
raise ValueError(f"Atom {atom.element} {atom.atom_id} has no coordinate information.")
# --- 3. Calculate Derived Properties ---
self.volume = unit_cell_parser.calculate_volume(self.cell_para)
if self.atoms:
total_mass = sum(chemical_knowledge.element_masses[atom.element] for atom in self.atoms)
# Density in g/cm^3
self.density = total_mass / (self.volume * 1e-24) / (6.022140857e23)
else:
self.density = 0.0
def update_cart_by_frac(self) -> None:
"""Updates all atom cartesian coordinates from their fractional coordinates."""
for atom in self.atoms:
atom.cart_xyz = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect)
def update_frac_by_cart(self) -> None:
"""Updates all atom fractional coordinates from their cartesian coordinates."""
for atom in self.atoms:
atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_vect(atom.cart_xyz, self.cell_vect)
def check(self) -> None:
"""Performs consistency checks on the crystal structure."""
print("Performing consistency checks...")
# Check lattice consistency
self.lattice_and_atom_complete()
# Check atom coordinate consistency
for atom in self.atoms:
derived_cart = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect)
assert np.allclose(atom.cart_xyz, derived_cart, atol=1e-3), \
f"Atom {atom.atom_id} cartesian and fractional coordinates do not match."
# Check atom IDs
if all(atom.atom_id != "unknown" for atom in self.atoms):
print("All atoms have IDs.")
else:
print("Warning: Not all atoms have IDs. Use .give_atom_id_forced() to assign them.")
print("Checks passed.")
def give_atom_id_forced(self) -> None:
"""Assigns or resets atom IDs from 0 to N-1 and clears bonding info."""
print("Warning: Resetting all atom IDs and bonding information!")
for i, atom in enumerate(self.atoms):
atom.atom_id = i
atom.bonded_atom = []
def move_atom_into_cell(self) -> None:
"""
Moves all atoms into the primary unit cell [0, 1) in fractional coordinates.
"""
for atom in self.atoms:
# Use modulo for a more direct and efficient way to wrap coordinates
atom.frac_xyz = np.mod(atom.frac_xyz, 1.0)
self.update_cart_by_frac()
def find_molecule(self, tolerance: float = 1.15) -> None:
"""
Identifies molecules within the crystal based on bonding distances.
This method performs a graph search (BFS) on the atoms, connecting them
based on scaled covalent radii. It populates the `atom.molecule` and
`self.molecule_number` attributes.
Args:
tolerance: A scaling factor for covalent radii to determine bonding.
A bond is formed if dist(A, B) < (radius(A) + radius(B)) * tolerance.
"""
self.move_atom_into_cell()
atoms_to_visit = list(range(len(self.atoms)))
molecule_id = 0
while atoms_to_visit:
molecule_id += 1
# Start a Breadth-First Search (BFS) from the first unvisited atom
q = [atoms_to_visit[0]]
visited_in_molecule = {atoms_to_visit[0]}
head = 0
while head < len(q):
current_atom_idx = q[head]
head += 1
self.atoms[current_atom_idx].molecule = molecule_id
# Check for bonds with all other atoms
for other_atom_idx in range(len(self.atoms)):
if current_atom_idx == other_atom_idx:
continue
# is_bonding_crystal handles periodic boundaries
is_bonded, _ = operation.is_bonding_crystal(
self.atoms[current_atom_idx],
self.atoms[other_atom_idx],
self.cell_vect,
tolerance=tolerance,
update_atom2=False # Do not modify coordinates during search
)
if is_bonded and other_atom_idx not in visited_in_molecule:
visited_in_molecule.add(other_atom_idx)
q.append(other_atom_idx)
# Remove all atoms found in the new molecule from the list to visit
atoms_to_visit = [idx for idx in atoms_to_visit if idx not in visited_in_molecule]
self.molecule_number = molecule_id
def get_element(self) -> List[str]:
"""Returns a sorted list of unique element symbols in the crystal."""
return chemical_knowledge.sort_by_atomic_number(set(atom.element for atom in self.atoms))
def get_element_amount(self) -> List[int]:
"""Returns the count of each element, sorted by atomic number."""
all_elements = [atom.element for atom in self.atoms]
return [all_elements.count(element) for element in self.get_element()]
def make_p1(self) -> None:
"""
Expands the asymmetric unit to the full P1 cell using symmetry operations.
The crystal's space group is set to 1 (P1) and SYMM is reset. This
implementation is robustly designed to ensure the final coordinate array
is always 2-dimensional, preventing downstream errors.
"""
all_ele, all_frac = self.get_ele_and_frac()
all_reflect_position = []
all_matrix_M = []
all_matrix_C = []
for sym_opt in self.SYMM:
sym_opt_ele = sym_opt.lower().replace(" ", "").split(",")
# assert len(sym_opt_ele) == 3, "sym {} could not be treat".format(sym_opt_ele)
matrix_M = np.zeros((3, 3))
matrix_C = np.zeros((1, 3))
for idx, word in enumerate(sym_opt_ele):
sym_opt_ele_split = re.findall(r".*?([+-]*[xyz0-9\/\.]+)", word)
for sym_opt_frag in sym_opt_ele_split:
if sym_opt_frag == 'x' or sym_opt_frag == '+x':
matrix_M[0][idx] = 1
elif str(sym_opt_frag) == '-x':
matrix_M[0][idx] = -1
elif sym_opt_frag == 'y' or sym_opt_frag == '+y':
matrix_M[1][idx] = 1
elif sym_opt_frag == '-y':
matrix_M[1][idx] = -1
elif sym_opt_frag == 'z' or sym_opt_frag == '+z':
matrix_M[2][idx] = 1
elif sym_opt_frag == '-z':
matrix_M[2][idx] = -1
elif operation.is_number(sym_opt_frag) is True:
matrix_C[0][idx] = float(fractions.Fraction(sym_opt_frag))
else:
raise Exception("wrong sym opt of" + sym_opt_frag)
all_matrix_M.append(matrix_M)
all_matrix_C.append(matrix_C)
for j in range(0, len(all_matrix_M)):
new_positions = np.dot(np.array([all_frac]), all_matrix_M[j]) + all_matrix_C[j]
all_reflect_position.append(new_positions.squeeze())
all_ele = all_ele*len(self.SYMM)
new_atoms = []
idx=0
for element, frac_xyz in zip(all_ele, np.array(all_reflect_position).reshape(-1,3)):
new_atoms.append(Atom(element=element,
frac_xyz=frac_xyz,
atom_id=idx))
idx+=1
self.SYMM = "[x,y,z]"
self.space_group = 1
self.atoms = new_atoms
self.update_cart_by_frac()
def sort_by_element(self) -> None:
"""Sorts the atoms list based on atomic number."""
self.atoms.sort(key=lambda atom: chemical_knowledge.periodic_table_list[atom.element])
def get_ele_and_cart(self) -> Tuple[List[str], np.ndarray]:
"""Returns all element symbols and their cartesian coordinates."""
if not self.atoms:
return [], np.empty((0, 3))
all_ele = [atom.element for atom in self.atoms]
all_carts = np.array([atom.cart_xyz for atom in self.atoms])
return all_ele, all_carts
def get_ele_and_frac(self) -> Tuple[List[str], np.ndarray]:
"""Returns all element symbols and their fractional coordinates."""
if not self.atoms:
return [], np.empty((0, 3))
all_ele = [atom.element for atom in self.atoms]
all_fracs = np.array([atom.frac_xyz for atom in self.atoms])
return all_ele, all_fracs
def info(self, all_info: bool = False) -> None:
"""
Prints a formatted summary of the crystal structure.
Args:
all_info: If True, prints an extended table including fractional
coordinates, forces, and other properties.
"""
print("--- Crystal System ---")
print(f"Name: {self.system_name}")
print("Lattice Vectors (Angstrom):")
for vec in self.cell_vect:
print(f"{vec[0]:16.8f} {vec[1]:16.8f} {vec[2]:16.8f}")
print("Lattice Parameters:")
print(f"a, b, c (A): {self.cell_para[0][0]:.4f}, {self.cell_para[0][1]:.4f}, {self.cell_para[0][2]:.4f}")
print(f"alpha, beta, gamma (deg): {self.cell_para[1][0]:.4f}, {self.cell_para[1][1]:.4f}, {self.cell_para[1][2]:.4f}")
print(f"Volume (A^3): {self.volume:.4f} | Density (g/cm^3): {self.density:.4f}")
print(f"\n--- Atomic Coordinates (Total: {len(self.atoms)}) ---")
if not all_info:
print(f"{'Element':<10} {'Cartesian X':>16} {'Cartesian Y':>16} {'Cartesian Z':>16}")
print("-" * 58)
for atom in self.atoms:
print(f"{atom.element:<10} {atom.cart_xyz[0]:16.8f} {atom.cart_xyz[1]:16.8f} {atom.cart_xyz[2]:16.8f}")
else:
header = (
f"{'ID':<5} {'Elem':<6} "
f"{'Frac X':>10} {'Frac Y':>10} {'Frac Z':>10} | "
f"{'Cart X':>12} {'Cart Y':>12} {'Cart Z':>12}"
)
print(header)
print("-" * len(header))
for atom in self.atoms:
aid = str(atom.atom_id) if atom.atom_id != 'unknown' else '-'
print(
f"{aid:<5} {atom.element:<6} "
f"{atom.frac_xyz[0]:10.6f} {atom.frac_xyz[1]:10.6f} {atom.frac_xyz[2]:10.6f} | "
f"{atom.cart_xyz[0]:12.6f} {atom.cart_xyz[1]:12.6f} {atom.cart_xyz[2]:12.6f}"
)
print("\n--- Other Properties ---")
print(f"Energy: {self.energy}")
print(f"Comment: {self.comment}")
print(f"Virial: {self.virial}")
class Molecule:
"""Represents a non-periodic molecule (a collection of atoms)."""
def __init__(self, **kwargs: Any):
"""
Initializes a Molecule object.
Args:
**kwargs: Keyword arguments to set molecule attributes.
Required: 'atoms' (List[Atom]).
Optional: 'energy', 'comment', 'name', 'system_name'.
"""
self.atoms: Union[str, List[Atom]] = kwargs.get("atoms", "unknown")
self.energy: Union[str, float] = kwargs.get("energy", "unknown")
self.comment: Any = kwargs.get("comment", "unknown")
self.descriptor: Any = kwargs.get("descriptor", "unknown")
self.name: str = kwargs.get("name", "unknown")
self.system_name: str = kwargs.get("system_name", "unknown")
if self.atoms == "unknown":
print("Warning: Molecule initialized with no atoms.")
self.atoms = []
def give_atom_id_forced(self) -> None:
"""Assigns or resets atom IDs from 0 to N-1 and clears bonding info."""
print("Warning: Resetting all atom IDs and bonding information!")
for i, atom in enumerate(self.atoms):
atom.atom_id = i
atom.bonded_atom = []
def get_element(self) -> List[str]:
"""Returns a sorted list of unique element symbols in the molecule."""
if not self.atoms: return []
return chemical_knowledge.sort_by_atomic_number(set(atom.element for atom in self.atoms))
def get_element_amount(self) -> List[int]:
"""Returns the count of each element, sorted by atomic number."""
if not self.atoms: return []
all_elements = [atom.element for atom in self.atoms]
return [all_elements.count(element) for element in self.get_element()]
def get_ele_and_cart(self) -> Tuple[List[str], np.ndarray]:
"""Returns all element symbols and their cartesian coordinates."""
if not self.atoms:
return [], np.empty((0, 3))
all_ele = [atom.element for atom in self.atoms]
all_carts = np.array([atom.cart_xyz for atom in self.atoms])
return all_ele, all_carts
def put_ele_cart_back(self, all_ele: List[str], all_carts: np.ndarray) -> None:
"""Updates the molecule's atoms from lists of elements and coordinates."""
for i, atom in enumerate(self.atoms):
atom.element = all_ele[i]
atom.cart_xyz = all_carts[i]
def build_molecules_by_ele_cart(self, all_ele: List[str], all_carts: np.ndarray) -> None:
"""Rebuilds the molecule's atoms list from elements and coordinates."""
assert len(all_ele) == len(all_carts), "Element and coordinate lists must have the same length."
self.atoms = [
Atom(element=ele, cart_xyz=cart, atom_id=i)
for i, (ele, cart) in enumerate(zip(all_ele, all_carts))
]
def get_mass(self) -> float:
"""Calculates the total mass of the molecule."""
if not self.atoms: return 0.0
return sum(chemical_knowledge.element_masses[atom.element] for atom in self.atoms)
def get_center_of_mass(self) -> np.ndarray:
"""Calculates the center of mass of the molecule."""
if not self.atoms: return np.zeros(3)
all_ele, all_carts = self.get_ele_and_cart()
masses = np.array([chemical_knowledge.element_masses[x] for x in all_ele])
total_mass = np.sum(masses)
if total_mass == 0: return np.zeros(3)
return np.sum(all_carts * masses[:, np.newaxis], axis=0) / total_mass
def sort_by_element(self) -> None:
"""Sorts the atoms list based on atomic number."""
self.atoms.sort(key=lambda atom: chemical_knowledge.periodic_table_list[atom.element])
def sort_by_id(self) -> None:
"""Sorts the atoms list based on their atom_id."""
self.atoms.sort(key=lambda atom: atom.atom_id)
def info(self) -> None:
"""Prints a formatted summary of the molecule."""
print(f"--- Molecule ---")
print(f"Name: {self.name} | System: {self.system_name}")
print(f"Number of atoms: {len(self.atoms)}")
print(f"Total Mass (amu): {self.get_mass():.4f}")
print(f"Energy: {self.energy}")
print(f"Comment: {self.comment}")
print(f"\n{'Element':<10} {'Cartesian X':>16} {'Cartesian Y':>16} {'Cartesian Z':>16}")
print("-" * 58)
if self.atoms:
for atom in self.atoms:
print(f"{atom.element:<10} {atom.cart_xyz[0]:16.8f} {atom.cart_xyz[1]:16.8f} {atom.cart_xyz[2]:16.8f}")
def find_fragment(self, tolerance: float = 1.15) -> Dict[int, List[int]]:
"""
Identifies covalently bonded fragments within the molecule.
This is useful for molecules that are actually composed of several
disconnected components (e.g., salts, solvent shells).
Args:
tolerance: Scaling factor for covalent radii to determine bonding.
Returns:
A dictionary mapping a fragment ID (starting from 1) to a list of
atom indices belonging to that fragment.
"""
if not self.atoms: return {}
num_atoms = len(self.atoms)
cart_matrix = np.array([atom.cart_xyz for atom in self.atoms])
radii = np.array([chemical_knowledge.element_covalent_radii[atom.element] for atom in self.atoms])
# Create a matrix of bond thresholds (r_i + r_j)
bond_threshold_matrix = (radii[:, np.newaxis] + radii) * tolerance
# True where distance is less than the bond threshold
dist_matrix = cdist(cart_matrix, cart_matrix)
adj_matrix = dist_matrix < bond_threshold_matrix
np.fill_diagonal(adj_matrix, False)
# Graph traversal (DFS) to find connected components
visited = [False] * num_atoms
groups = {}
group_id = 0
for i in range(num_atoms):
if not visited[i]:
group_id += 1
groups[group_id] = []
stack = [i]
while stack:
atom_idx = stack.pop()
if not visited[atom_idx]:
visited[atom_idx] = True
groups[group_id].append(atom_idx)
# Find neighbors and add to stack
neighbors = np.where(adj_matrix[atom_idx])[0]
stack.extend(neighbors)
return groups
def give_molecule_id(self, tolerance: float = 1.15) -> None:
"""Assigns a molecule ID to each atom based on fragment analysis."""
fragments = self.find_fragment(tolerance=tolerance)
for group_id, atom_indices in fragments.items():
for atom_idx in atom_indices:
self.atoms[atom_idx].molecule = group_id
def take_out_fragment(self, tolerance: float = 1.15) -> List['Molecule']:
"""
Splits the current molecule into a list of new Molecule objects,
one for each disconnected fragment.
"""
if not self.atoms: return []
self.give_atom_id_forced() # Ensure IDs are set for lookup
fragments = self.find_fragment(tolerance=tolerance)
new_molecules = []
for i, atom_indices in fragments.items():
fragment_atoms = [self.atoms[j] for j in atom_indices]
new_mol = Molecule(
atoms=copy.deepcopy(fragment_atoms),
name=f"{self.name}_frag{i}",
system_name=f"{self.system_name}_frag{i}"
)
new_molecules.append(new_mol)
return new_molecules
def calculate_frac_xyz_by_cell_para(self, cell_para: list) -> None:
"""Calculates fractional coordinates for all atoms given cell parameters."""
for atom in self.atoms:
atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_para(atom.cart_xyz, cell_para)
def molecule_volume(self, num_samples: int = 100000) -> float:
"""
Calculates the van der Waals volume using a Monte Carlo integration method.
This method samples points in a bounding box around the molecule and
determines the ratio of points that fall within any atom's vdW sphere.
Args:
num_samples: The number of random points to sample. More points
yield a more accurate volume at the cost of performance.
Returns:
The estimated van der Waals volume in cubic Angstroms.
"""
if not self.atoms: return 0.0
elements, coords = self.get_ele_and_cart()
radii = np.array([chemical_knowledge.element_vdw_radii[el] for el in elements])
# Determine bounding box for sampling
min_bounds = np.min(coords, axis=0) - np.max(radii)
max_bounds = np.max(coords, axis=0) + np.max(radii)
bounding_box_volume = np.prod(max_bounds - min_bounds)
# Generate random sample points within the bounding box
random_points = np.random.uniform(min_bounds, max_bounds, (num_samples, 3))
# Check for each point if it's inside ANY sphere
count_inside = 0
for rp in tqdm(random_points, desc="Monte Carlo Volume", leave=False):
# Calculate squared distances from the point to all atom centers
dist_sq = np.sum((coords - rp)**2, axis=1)
# If any distance is within the radius, the point is inside
if np.any(dist_sq <= radii**2):
count_inside += 1
return (count_inside / num_samples) * bounding_box_volume
\ No newline at end of file
import re
from basic_function import operation
from basic_function import data_classes
from basic_function import chemical_knowledge
import copy
def read_xyz_file(file_path):
input_file = open(file_path, 'r')
lines = input_file.readlines()
number_of_atoms = int(lines[0])
name = str(lines[1][:-1])
atoms = []
for index,line in enumerate(lines):
split_line = list(filter(lambda x: x != '', re.split("\\s+", line)))
if len(split_line)==4 and operation.is_number(split_line[1]) and \
operation.is_number(split_line[2]) and operation.is_number(split_line[3]):
atoms.append(data_classes.Atom(element=split_line[0],
cart_xyz=[float(split_line[1]), float(split_line[2]), float(split_line[3])],
atom_id=index-2))
if number_of_atoms!=len(atoms):
print("Warning! The length of atoms don't match the number of atoms given")
molecule = data_classes.Molecule(atoms=atoms, name=name, system_name=name)
return molecule
def write_cif_file(crystal, sym=False, name="zcx"):
"""
Accept crystal class, give the cif file out
:param crystal: crystal class
:param coordinates: frac or cart
:param sym: False:give all atoms out; True:with symmetry
:param name: file name
:return: cif_out
cif file in list format should be print using the following function:
target=open("D:\\zcx.cif",'w')
target.writelines(cif_out)
target.close()
"""
if crystal.system_name!="unknown":
name = crystal.system_name
cif_file = []
cif_file.append("data_"+str(name)+"\n")
if sym==False:
if crystal.space_group==1:
crystal_temp = crystal
else:
crystal_temp = copy.deepcopy(crystal)
crystal_temp.make_p1()
cif_file.append("_symmetry_space_group_name_H-M \'P1\'"+"\n")
cif_file.append("_symmetry_Int_Tables_number 1"+"\n")
cif_file.append("loop_"+"\n")
cif_file.append("_symmetry_equiv_pos_site_id"+"\n")
cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n")
cif_file.append("1 x,y,z"+"\n")
cif_file.append("_cell_length_a "+str(crystal_temp.cell_para[0][0])+"\n")
cif_file.append("_cell_length_b "+str(crystal_temp.cell_para[0][1])+"\n")
cif_file.append("_cell_length_c "+str(crystal_temp.cell_para[0][2])+"\n")
cif_file.append("_cell_angle_alpha "+str(crystal_temp.cell_para[1][0])+"\n")
cif_file.append("_cell_angle_beta "+str(crystal_temp.cell_para[1][1])+"\n")
cif_file.append("_cell_angle_gamma "+str(crystal_temp.cell_para[1][2])+"\n")
cif_file.append("loop_"+"\n")
cif_file.append("_atom_site_label"+"\n")
cif_file.append("_atom_site_type_symbol"+"\n")
cif_file.append("_atom_site_fract_x"+"\n")
cif_file.append("_atom_site_fract_y"+"\n")
cif_file.append("_atom_site_fract_z"+"\n")
for i in range(0,len(crystal_temp.atoms)):
cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n"
.format(i+1,crystal_temp.atoms[i].element,crystal_temp.atoms[i].frac_xyz[0],
crystal_temp.atoms[i].frac_xyz[1],crystal_temp.atoms[i].frac_xyz[2]))
return cif_file
elif sym==True:
cif_file.append("_symmetry_space_group_name_H-M \'{}\'".format(chemical_knowledge.space_group[crystal.space_group][1])+"\n")
cif_file.append("_symmetry_Int_Tables_number {}".format(crystal.space_group)+"\n")
cif_file.append("loop_"+"\n")
cif_file.append("_symmetry_equiv_pos_site_id"+"\n")
cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n")
for idx, SYMM in enumerate(crystal.SYMM):
cif_file.append("{} {}".format(idx+1,SYMM)+"\n")
cif_file.append("_cell_length_a "+str(crystal.cell_para[0][0])+"\n")
cif_file.append("_cell_length_b "+str(crystal.cell_para[0][1])+"\n")
cif_file.append("_cell_length_c "+str(crystal.cell_para[0][2])+"\n")
cif_file.append("_cell_angle_alpha "+str(crystal.cell_para[1][0])+"\n")
cif_file.append("_cell_angle_beta "+str(crystal.cell_para[1][1])+"\n")
cif_file.append("_cell_angle_gamma "+str(crystal.cell_para[1][2])+"\n")
cif_file.append("loop_"+"\n")
cif_file.append("_atom_site_label"+"\n")
cif_file.append("_atom_site_type_symbol"+"\n")
cif_file.append("_atom_site_fract_x"+"\n")
cif_file.append("_atom_site_fract_y"+"\n")
cif_file.append("_atom_site_fract_z"+"\n")
for i in range(0,len(crystal.atoms)):
cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n"
.format(i+1,crystal.atoms[i].element,crystal.atoms[i].frac_xyz[0],
crystal.atoms[i].frac_xyz[1],crystal.atoms[i].frac_xyz[2]))
return cif_file
def write_cifs_file(crystals, sym=False, name="zcx"):
cifs_file = []
for crystal in crystals:
single_cif = write_cif_file(crystal,sym=sym, name=name)
cifs_file.extend(single_cif)
return cifs_file
def read_cif_file(file_path,on_sym_check=False,shut_up=False,system_name="unknown",comment_name="unknown"):
input_file = open(file_path, 'r')
lines = input_file.readlines()
step_pickle = []
crystal_all = []
if system_name=="unknown":
no_name = True
else:
no_name = False
# first time scan
for index,line in enumerate(lines):
# find out all the step pickle
if line.startswith("data_"):
step_pickle.append(index)
step_pickle.append(len(lines))
# treat every step and return a crystal
for m in range(0,len(step_pickle)-1):
atoms = []
atoms_P1 = []
SYMM = []
cell_para = [["unknown","unknown","unknown"],["unknown","unknown","unknown"]]
for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]):
split_line = list(filter(lambda x: x != '', re.split("\\s+", line)))
if line.startswith("#"):
pass
elif len(split_line)==0:
pass
# read the loop of symmetry
# elif split_line[0]=="loop_" and lines[step_pickle[m]+index+1]=="_symmetry_equiv_pos_as_xyz\n":
elif split_line[0] == "loop_" and lines[step_pickle[m] + index + 1] == "_symmetry_equiv_pos_as_xyz\n":
temp_number = 1
while "_" not in lines[step_pickle[m]+index+1+temp_number]:
split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))
temp_number+=1
if not operation.is_number(split_line_temp[0]):
SYMM.append(split_line_temp[0])
else:
SYMM.append(split_line_temp[1])
elif split_line[0] == "loop_" and lines[step_pickle[m]+index+2].strip(" ")=="_symmetry_equiv_pos_as_xyz\n":
temp_number = 1
while "_" not in lines[step_pickle[m]+index+2+temp_number]:
split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+2+temp_number])))
temp_number+=1
SYMM.append("".join(split_line_temp[1:]))
elif split_line[0] == "loop_" and "_space_group_symop_operation_xyz\n" in lines[step_pickle[m] + index + 1]:
# ase format
temp_number = 1
while "_" not in lines[step_pickle[m]+index+1+temp_number]:
if lines[step_pickle[m]+index+1+temp_number]=="\n":
temp_number += 1
continue
split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))
temp_number+=1
if not operation.is_number(split_line_temp[0]):
SYMM.append("".join(split_line_temp))
# read the loop of atoms:
elif (split_line[0] == "loop_" and lines[step_pickle[m] + index + 1].strip(" ") == "_atom_site_label\n") or \
(split_line[0] == "loop_" and lines[step_pickle[m] + index + 2].strip(" ") == "_atom_site_label\n"):
temp_number = 0
while "_" in lines[step_pickle[m]+index+1+temp_number]:
if lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_type_symbol\n":
ele_pos = temp_number
elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_x\n":
x_pos = temp_number
elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_y\n":
y_pos = temp_number
elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_z\n":
z_pos = temp_number
temp_number+=1
how_long = temp_number
while len(list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))) == how_long:
split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))
atoms.append(data_classes.Atom(element=split_line_temp[ele_pos],
frac_xyz=[float(split_line_temp[x_pos]),float(split_line_temp[y_pos]),
float(split_line_temp[z_pos])]))
temp_number += 1
if step_pickle[m]+index+1+temp_number==len(lines):
break
elif split_line[0] == "_cell_length_a":
cell_para[0][0] = float(split_line[1])
elif split_line[0] == "_cell_length_b":
cell_para[0][1] = float(split_line[1])
elif split_line[0] == "_cell_length_c":
cell_para[0][2] = float(split_line[1])
elif split_line[0] == "_cell_angle_alpha":
cell_para[1][0] = float(split_line[1])
elif split_line[0] == "_cell_angle_beta":
cell_para[1][1] = float(split_line[1])
elif split_line[0] == "_cell_angle_gamma":
cell_para[1][2] = float(split_line[1])
elif "data_" in line:
if no_name == True:
system_name = line[5:]
system_name = system_name.replace(" ","_")
system_name = system_name.replace("\\", "_")
system_name = system_name.replace("\n", "")
for atom in atoms:
all_reflect_position = operation.space_group_transfer_for_single_atom(atom.frac_xyz, SYMM)
for new_position in all_reflect_position:
atoms_P1.append(data_classes.Atom(element=atom.element,
frac_xyz=[new_position[0], new_position[1], new_position[2]]))
crystal_all.append(data_classes.Crystal(cell_para=cell_para, atoms=atoms_P1, comment=comment_name, system_name=system_name))
if on_sym_check == True:
raise Exception("Not finished part, TODO in code")
if shut_up==False:
if m%100 == 0:
print("{} structures have been treated".format(m))
return crystal_all
def write_poscar_file(crystal, coordinates = 'frac', name = "parser_zcx_create"):
vasp_file = []
vasp_file.append('{}\n'.format(name))
vasp_file.append('1.0\n')
cell_vect = crystal.cell_vect
for vect in cell_vect:
vasp_file.append("{:16.8f} {:16.8f} {:16.8f}\n".format(vect[0],vect[1],vect[2]))
crystal.sort_by_element()
vasp_file.append("".join("{:>6s}".format(x) for x in crystal.get_element()) + "\n")
vasp_file.append("".join("{:>6.0f}".format(x) for x in crystal.get_element_amount()) + "\n")
if coordinates == 'frac':
vasp_file.append('Direct\n')
for ELEMENT in crystal.get_element():
for ATOM in crystal.atoms:
if ATOM.element == ELEMENT:
vasp_file.append(
"{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.frac_xyz[0], ATOM.frac_xyz[1], ATOM.frac_xyz[2]))
elif coordinates == 'cart':
vasp_file.append('Cartesian\n')
for ELEMENT in crystal.get_element():
for ATOM in crystal.atoms:
if ATOM.element == ELEMENT:
vasp_file.append(
"{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.cart_xyz[0], ATOM.cart_xyz[1], ATOM.cart_xyz[2]))
else:
raise Exception("Wrong coordinates type: {}".format(coordinates))
return vasp_file
def read_ase_pbc_file(file_path,shut_up=False):
input_file = open(file_path, 'r')
lines = input_file.readlines()[2:]
step_pickle = []
crystal_all = []
# first time scan
for index,line in enumerate(lines):
# find out all the step pickle
if line.startswith("Step "):
step_pickle.append(index)
step_pickle.append(len(lines))
# treat every step and return a crystal
for m in range(0,len(step_pickle)-1):
atoms_P1 = []
force_matrix = []
position_matrix = []
in_forces = False
in_positions = False
for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]):
# split_line = list(filter(lambda x: x != '', re.split("\\s+", line)))
line = line.strip()
# check Forces part
if line.startswith("Forces:"):
in_forces = True
in_positions = False
continue
# check Positions part
if line.startswith("Positions:"):
in_positions = True
in_forces = False
continue
if in_forces and line.startswith("[") and line.endswith("]"):
line = line.replace("[", "").replace("]", "")
force_matrix.append([float(x) for x in line.split()])
# analyse Positions part
if in_positions and line.startswith("[") and line.endswith("]"):
line = line.replace("[", "").replace("]", "")
position_matrix.append([float(x) for x in line.split()])
if line.startswith("Elements:"):
elements_string = line.strip().split(":", 1)[-1].strip()
elements_string = elements_string[1:-1]
elements = [elem.strip().strip("'") for elem in elements_string.split(",")]
if line.startswith("cell:"):
matrix_string = line[len("cell: Cell("):-1]
rows = matrix_string.split("], [")
cell_vect = [
[float(value) for value in row.replace('[', '').replace(']', '').replace(')', '').split(", ")]
for row in rows
]
for i in range(0,len(elements)):
atoms_P1.append(data_classes.Atom(element=elements[i], cart_xyz=[position_matrix[i][0], position_matrix[i][1], position_matrix[i][2]]))
crystal_all.append(data_classes.Crystal(cell_vect=cell_vect, atoms=atoms_P1))
return crystal_all
\ No newline at end of file
# -*- coding: utf-8 -*-
"""
A collection of functions for performing crystallographic and molecular operations,
such as symmetry application, supercell generation, and geometric analysis.
"""
# --- Standard Library Imports ---
import copy
import fractions
import re
from typing import Any, Dict, List, Optional, Tuple, Union
# --- Third-Party Imports ---
import networkx as nx
import numpy as np
import numpy.typing as npt
from scipy.spatial import cKDTree as KDTree
# --- Local Application Imports ---
from basic_function import chemical_knowledge, data_classes
# Type aliases for clarity
NDArrayFloat = npt.NDArray[np.float64]
CellVectors = List[List[float]]
SymmetryOperations = List[str]
def is_number(s: str) -> bool:
"""Checks if a string can be interpreted as a number (float or fraction).
Args:
s: The input string.
Returns:
True if the string represents a number, False otherwise.
"""
try:
float(s)
return True
except ValueError:
pass
try:
# Check for fractional representations like "1/2"
float(fractions.Fraction(s))
return True
except ValueError:
return False
def _parse_symmetry_operations(
sym_ops: SymmetryOperations,
) -> Tuple[List[NDArrayFloat], List[NDArrayFloat]]:
"""Parses a list of symmetry operation strings into matrices.
This is an internal helper function to avoid code duplication in public functions.
Args:
sym_ops: A list of symmetry operation strings (e.g., ['x, y, z+1/2']).
Returns:
A tuple containing two lists:
- A list of 3x3 rotation/reflection matrices (M).
- A list of 1x3 translation vectors (C).
Raises:
ValueError: If a symmetry operation string is malformed.
"""
rotation_matrices = []
translation_vectors = []
for sym_op_str in sym_ops:
sym_op_parts = sym_op_str.lower().replace(" ", "").split(",")
if len(sym_op_parts) != 3:
raise ValueError(f"Symmetry operation '{sym_op_str}' is invalid.")
matrix_m = np.zeros((3, 3))
matrix_c = np.zeros((1, 3))
for i, part in enumerate(sym_op_parts):
# Regex to find elements like '+x', '-y', 'z', '1/2', '-0.5'
tokens = re.findall(r"([+-]?[xyz0-9./]+)", part)
for token in tokens:
token = token.strip()
if not token:
continue
if "x" in token:
matrix_m[0, i] = -1.0 if token.startswith("-") else 1.0
elif "y" in token:
matrix_m[1, i] = -1.0 if token.startswith("-") else 1.0
elif "z" in token:
matrix_m[2, i] = -1.0 if token.startswith("-") else 1.0
elif is_number(token):
matrix_c[0, i] += float(fractions.Fraction(token))
else:
raise ValueError(f"Invalid fragment '{token}' in symmetry operation.")
rotation_matrices.append(matrix_m)
translation_vectors.append(matrix_c)
return rotation_matrices, translation_vectors
def space_group_transfer_for_single_atom(
frac_xyz: List[float], space_group_ops: SymmetryOperations
) -> List[List[float]]:
"""Applies space group symmetry operations to a single atomic coordinate.
Args:
frac_xyz: The fractional coordinates [x, y, z] of a single atom.
space_group_ops: A list of space group symmetry operation strings.
Returns:
A list of all symmetrically equivalent fractional coordinates.
"""
rot_matrices, trans_vectors = _parse_symmetry_operations(space_group_ops)
equivalent_positions = []
atom_pos = np.array(frac_xyz)
for rot, trans in zip(rot_matrices, trans_vectors):
new_pos = np.dot(atom_pos, rot.T) + trans.squeeze()
equivalent_positions.append(new_pos.tolist())
return equivalent_positions
def super_cell(
crystal: "data_classes.Crystal",
cell_range: Optional[List[List[int]]] = None,
) -> "data_classes.Crystal":
"""Constructs a supercell from a unit cell.
Args:
crystal: The input Crystal object.
cell_range: A list of ranges for each lattice vector, e.g.,
[[-1, 1], [-1, 1], [-1, 1]] creates a 3x3x3 supercell.
If None, defaults to [[-1, 1], [-1, 1], [-1, 1]].
Returns:
A new Crystal object representing the supercell.
"""
if cell_range is None:
cell_range = [[-1, 1], [-1, 1], [-1, 1]]
dims = [r[1] - r[0] + 1 for r in cell_range]
new_lattice = [
[dim * val for val in crystal.cell_vect[i]] for i, dim in enumerate(dims)
]
translation_vectors = []
for h in range(cell_range[0][0], cell_range[0][1] + 1):
for k in range(cell_range[1][0], cell_range[1][1] + 1):
for l in range(cell_range[2][0], cell_range[2][1] + 1):
translation_vectors.append([h, k, l])
new_atoms = []
for atom in crystal.atoms:
for trans_vec in translation_vectors:
new_frac_xyz = [
(atom.frac_xyz[i] + trans_vec[i]) / dims[i] for i in range(3)
]
new_atoms.append(
data_classes.Atom(element=atom.element, frac_xyz=new_frac_xyz)
)
if crystal.energy != "unknown":
total_cells = dims[0] * dims[1] * dims[2]
new_energy = crystal.energy * total_cells
else:
new_energy = "unknown"
return data_classes.Crystal(
cell_vect=new_lattice, energy=new_energy, atoms=new_atoms
)
def orient_molecule(molecule: "data_classes.Molecule") -> "data_classes.Molecule":
"""Orients a molecule along its principal axes of inertia.
The method uses the Moment of Inertia tensor to define a canonical orientation.
The molecule's coordinates are modified in-place. For more details, see:
http://sobereva.com/426
Args:
molecule: The Molecule object to be oriented.
Returns:
The same Molecule object with its atoms reoriented.
"""
all_ele, all_cart = molecule.get_ele_and_cart()
if len(all_cart) <= 1:
return molecule # No orientation needed for single atoms or empty molecules.
masses = np.array([chemical_knowledge.element_masses[el] for el in all_ele])
relative_position = all_cart - molecule.get_center_of_mass()
# Calculate the moment of inertia tensor
I_xx = np.sum(masses * (relative_position[:, 1] ** 2 + relative_position[:, 2] ** 2))
I_yy = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 2] ** 2))
I_zz = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 1] ** 2))
I_xy = -np.sum(masses * relative_position[:, 0] * relative_position[:, 1])
I_xz = -np.sum(masses * relative_position[:, 0] * relative_position[:, 2])
I_yz = -np.sum(masses * relative_position[:, 1] * relative_position[:, 2])
I_matrix = np.array([[I_xx, I_xy, I_xz], [I_xy, I_yy, I_yz], [I_xz, I_yz, I_zz]])
# Eigenvectors of the inertia tensor are the principal axes.
# np.linalg.eigh is used for symmetric matrices.
eigenvalues, eigenvectors = np.linalg.eigh(I_matrix)
principal_axes = eigenvectors.T
# Project the relative positions onto the new axes system.
new_positions = np.dot(relative_position, principal_axes.T)
molecule.put_ele_cart_back(all_ele, new_positions)
return molecule
def get_rotate_matrix(v: NDArrayFloat) -> NDArrayFloat:
"""Generates a 3x3 rotation matrix from a 3D vector `v`.
This function uses a mapping from a 3D vector to a quaternion, which is then
used to construct the rotation matrix. This method avoids gimbal lock. A
left-handed coordinate system is assumed.
Args:
v: A 3-element NumPy array used to generate the quaternion.
Returns:
A 3x3 rotation matrix.
"""
# Ensure v elements are within valid ranges if necessary, though the
# formulas handle most inputs gracefully.
v0_sqrt = np.sqrt(max(v[0], 0))
v0_1_sqrt = np.sqrt(max(1.0 - v[0], 0))
angle1 = 2.0 * np.pi * v[1]
angle2 = 2.0 * np.pi * v[2]
# Quaternion components (x, y, z, w)
qx = v0_1_sqrt * np.sin(angle1)
qy = v0_1_sqrt * np.cos(angle1)
qz = v0_sqrt * np.sin(angle2)
qw = v0_sqrt * np.cos(angle2)
return np.array([
[1 - 2*qy**2 - 2*qz**2, 2*qx*qy + 2*qw*qz, 2*qx*qz - 2*qw*qy],
[2*qx*qy - 2*qw*qz, 1 - 2*qx**2 - 2*qz**2, 2*qy*qz + 2*qw*qx],
[2*qx*qz + 2*qw*qy, 2*qy*qz - 2*qw*qx, 1 - 2*qx**2 - 2*qy**2]
])
def f2c_matrix(
cell_params: Tuple[List[float], List[float]]
) -> Optional[NDArrayFloat]:
"""Calculates the fractional-to-Cartesian transformation matrix.
Args:
cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]],
where lengths are in Angstroms and angles are in degrees.
Returns:
The 3x3 transformation matrix, or None if cell parameters are invalid.
"""
lengths, angles = cell_params
a, b, c = lengths
alpha, beta, gamma = np.deg2rad(angles)
cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma])
sin_g = np.sin(gamma)
# Volume calculation term
volume_term_sq = (
1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g
)
if volume_term_sq < 0:
return None
volume = a * b * c * np.sqrt(volume_term_sq)
matrix = np.zeros((3, 3))
matrix[0, 0] = a
matrix[0, 1] = b * cos_g
matrix[0, 2] = c * cos_b
matrix[1, 1] = b * sin_g
matrix[1, 2] = c * (cos_a - cos_b * cos_g) / sin_g
matrix[2, 2] = volume / (a * b * sin_g)
return matrix.T
def c2f_matrix(
cell_params: Tuple[List[float], List[float]]
) -> Optional[NDArrayFloat]:
"""Calculates the Cartesian-to-fractional transformation matrix.
This is the inverse of the matrix generated by `f2c_matrix`.
Args:
cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]],
where lengths are in Angstroms and angles are in degrees.
Returns:
The 3x3 transformation matrix, or None if cell parameters are invalid.
"""
f2c = f2c_matrix(cell_params)
if f2c is None:
return None
try:
return np.linalg.inv(f2c)
except np.linalg.LinAlgError:
return None
def apply_SYMM(
frac_xyz: NDArrayFloat, symm_ops: SymmetryOperations
) -> NDArrayFloat:
"""Applies symmetry operations to a single set of fractional coordinates.
Args:
frac_xyz: A NumPy array of fractional coordinates [x, y, z].
symm_ops: A list of symmetry operation strings.
Returns:
A NumPy array of all symmetrically equivalent fractional coordinates.
"""
rot_matrices, trans_vectors = _parse_symmetry_operations(symm_ops)
equivalent_positions = [
np.dot(frac_xyz, rot.T) + trans.squeeze()
for rot, trans in zip(rot_matrices, trans_vectors)
]
return np.array(equivalent_positions)
def apply_SYMM_with_element(
elements: Union[str, List[str]],
frac_xyzs: NDArrayFloat,
symm_ops: SymmetryOperations,
) -> Tuple[NDArrayFloat, NDArrayFloat]:
"""Applies symmetry operations, returning new elements and coordinates.
Args:
elements: The element symbol(s) corresponding to the coordinates.
frac_xyzs: A NumPy array of fractional coordinates.
symm_ops: A list of symmetry operation strings.
Returns:
A tuple containing:
- A NumPy array of element symbols for each new position.
- A NumPy array of all symmetrically equivalent fractional coordinates.
"""
equivalent_positions = apply_SYMM(frac_xyzs, symm_ops)
num_ops = len(equivalent_positions)
replicated_elements = np.tile(np.array(elements).squeeze(), (num_ops, 1))
return replicated_elements, equivalent_positions
def calculate_longest_diagonal_length(cell_vect: CellVectors) -> float:
"""Calculates the length of the longest space diagonal of a unit cell.
The longest diagonal connects the origin (0,0,0) to the opposite
corner (1,1,1) of the unit cell.
Args:
cell_vect: The three lattice vectors of the cell.
Returns:
The length of the longest diagonal in Angstroms.
"""
cell_vect_np = np.array(cell_vect)
diagonal_vector = np.sum(cell_vect_np, axis=0)
return float(np.linalg.norm(diagonal_vector))
def calculate_distance_of_parallel_plane_in_crystal(cell_vect: CellVectors) -> List[float]:
"""Calculates inter-planar distances for primary crystallographic planes.
This computes the distances for the (100), (010), and (001) families of planes.
Args:
cell_vect: The three lattice vectors [a, b, c] of the cell.
Returns:
A list of three distances [d_a, d_b, d_c], where d_a is the distance
between planes parallel to the b-c plane, and so on.
"""
distances = []
vectors = [np.array(v) for v in cell_vect]
# Permutations to calculate distance for each primary plane
# (a to b-c plane, b to a-c plane, c to a-b plane)
indices = [(0, 1, 2), (1, 0, 2), (2, 0, 1)]
for i, j, k in indices:
point_p = vectors[i]
plane_v1 = vectors[j]
plane_v2 = vectors[k]
# Normal vector to the plane defined by plane_v1 and plane_v2
normal_vector = np.cross(plane_v1, plane_v2)
# Distance from point P to the plane is |N · P| / ||N||
distance = abs(np.dot(normal_vector, point_p)) / np.linalg.norm(normal_vector)
distances.append(distance)
return distances
def detect_is_frame_vdw_new(crystal: "data_classes.Crystal", tolerance: float = 1.2) -> bool:
"""Detects if a crystal structure forms a connected framework via VdW radii.
The method involves:
1. Expanding the crystal to a P1 symmetry supercell.
2. Building a 3x3x3 supercell to ensure periodic connections are considered.
3. Constructing a graph where atoms are nodes and an edge exists if their
distance is within a scaled sum of their van der Waals radii.
4. Checking if the largest connected component in the graph is large enough
to be considered a single, percolating framework.
Args:
crystal: The Crystal object to analyze.
tolerance: A tolerance factor to scale the VdW radii sum.
Returns:
True if the structure is a connected framework, False otherwise.
"""
crystal_temp = copy.deepcopy(crystal)
crystal_temp.make_p1()
crystal_temp.move_atom_into_cell()
# Create a 3x3x3 supercell to check for connectivity across boundaries
crystal_supercell = super_cell(crystal_temp, cell_range=[[-1, 1], [-1, 1], [-1, 1]])
all_ele, all_carts = crystal_supercell.get_ele_and_cart()
vdw_radii_map = chemical_knowledge.element_vdw_radii
vdw_max = max(vdw_radii_map[el] for el in set(all_ele))
distance_threshold = vdw_max * tolerance * 2
# KDTree for efficient nearest-neighbor search
tree = KDTree(all_carts)
pairs = tree.query_pairs(r=distance_threshold)
# Build a graph to find connected components
graph = nx.Graph()
graph.add_nodes_from(range(len(all_carts)))
graph.add_edges_from(list(pairs))
if not graph.nodes:
return False
# Find the largest connected component
largest_cc = max(nx.connected_components(graph), key=len)
# A heuristic to check for a percolating framework. A connected framework
# should connect most atoms. The threshold '9' is empirical but robustly
# distinguishes between isolated molecules and a fully connected lattice.
# In a 3x3x3 supercell (27 unit cells), a connected framework should involve
# significantly more atoms than in a few unit cells.
return len(largest_cc) > 9 * len(crystal_temp.atoms)
\ No newline at end of file
from basic_function import format_parser
from basic_function import CSP_generator_normal
import os
import concurrent.futures
import sys
def process_crystal(seed, sg, molecules,output_path,add_name):
aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg)
molecules_number = sum(len(molecule.atoms) for molecule in molecules)
new_crystal = aaa.generate(seed=seed)
sys.stdout.flush()
if new_crystal is not None:
cif_out = format_parser.write_cif_file(new_crystal)
with open(f"{output_path}/structures/{add_name}_{sg}_{seed}_z{len(molecules)}_{molecules_number}.cif", 'w') as target:
target.writelines(cif_out)
return True
return False
def CSP_generater_parallel(molecules,output_path,need_structure = 100, space_group_list=[1],max_workers=8,add_name='',start_seed=1):
space_groups = space_group_list
accept_count = need_structure
try:
os.makedirs("{}/structures".format(output_path))
except:
print("Warning, these is already an structures folder in this path, skip mkdir")
for sg in space_groups:
accept = 0
seed = start_seed
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {}
while accept < accept_count:
# submit new task
while len(futures) < max_workers and accept + len(futures) < accept_count:
future = executor.submit(process_crystal, seed, sg, molecules, output_path,add_name)
futures[future] = seed
seed += 1
# check the finished task
done, _ = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED)
for future in done:
if future.result():
accept += 1
# remove it from list, no matter what result it is
del futures[future]
# cancel all task if the number need is arrived.
if accept >= accept_count:
for future in futures:
future.cancel()
break
def CSP_generater_serial(molecules,output_path,need_structure = 100, densely_pack_method=False, space_group_list=[1]):
"""
:param molecules: a list [molecule1, molecule2, ...]
:param output_path: a str indicate the path of output folder
:param need_structure: int
:param space_group_list:a list indicate the space group need to search
"""
try:
os.makedirs("{}\\structures".format(output_path))
except:
print("Warning, these is already an structures folder in this path, skip mkdir")
for sg in space_group_list:
aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg)
accept=0
i=1
while accept<need_structure:
new_crystal = aaa.generate(seed=i,densely_pack_method=densely_pack_method)
if new_crystal==None:
i += 1
continue
cif_out = format_parser.write_cif_file(new_crystal)
target = open("{}\\structures\\{}_{}.cif".format(output_path,sg,i), 'w')
target.writelines(cif_out)
target.close()
accept+=1
i += 1
def CSP_generater_one_test(molecules,output_path,space_group=1,seed=1):
"""
used to test the generator for a given space group and seed
:param molecules: a list [molecule1, molecule2, ...]
:param output_path: a str indicate the path of output folder
:param space_group: a list indicate the space group need to search
:param seed: int
:return: write out a cif
"""
aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=space_group)
new_crystal = aaa.generate(seed=seed, test=True)
if new_crystal==None:
print("Return failed generate")
else:
cif_out = format_parser.write_cif_file(new_crystal,sym=True)
target = open("{}\\{}_{}.cif".format(output_path,space_group,seed), 'w')
target.writelines(cif_out)
target.close()
# -*- coding: utf-8 -*-
"""
Provides functions for converting between different representations of a
crystallographic unit cell (cell parameters and lattice vectors) and for
transforming atomic coordinates between fractional and Cartesian systems.
"""
# --- Standard Library Imports ---
from typing import List, Tuple, Union
# --- Third-Party Imports ---
import numpy as np
import numpy.typing as npt
# --- Type Aliases for Clarity ---
NDArrayFloat = npt.NDArray[np.float64]
CellParameters = Tuple[List[float], List[float]]
CellVectors = Union[List[List[float]], NDArrayFloat]
Coordinates = Union[List[float], NDArrayFloat]
def cell_para_to_vect(
cell_para: CellParameters, check: bool = False
) -> CellVectors:
"""Converts cell parameters to lattice vectors.
The lattice vector `a` is aligned with the x-axis. The vector `b` lies in
the xy-plane.
Args:
cell_para: A tuple containing [[a, b, c], [alpha, beta, gamma]],
where lengths are in Angstroms and angles are in degrees.
check: If True, asserts the input shape is correct.
Returns:
A 3x3 list of lists representing the cell vectors [a, b, c].
"""
if check:
shape_check = np.array(cell_para)
assert shape_check.shape == (2, 3), "Input `cell_para` must have shape (2, 3)."
lengths = cell_para[0]
angles_deg = cell_para[1]
a, b, c = lengths
alpha, beta, gamma = np.deg2rad(angles_deg)
cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma])
sin_g = np.sin(gamma)
# This term is related to the square of the cell volume.
# It ensures the cell parameters are physically valid.
volume_term_sq = (
1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g
)
# Ensure the argument for sqrt is non-negative
volume_term = np.sqrt(max(0, volume_term_sq))
cell_vect = np.zeros((3, 3))
cell_vect[0, 0] = a
cell_vect[1, 0] = b * cos_g
cell_vect[1, 1] = b * sin_g
cell_vect[2, 0] = c * cos_b
cell_vect[2, 1] = c * (cos_a - cos_b * cos_g) / sin_g
cell_vect[2, 2] = c * volume_term / sin_g
return cell_vect.tolist()
def cell_vect_to_para(cell_vect: CellVectors, check: bool = False) -> CellParameters:
"""Converts lattice vectors to cell parameters.
Args:
cell_vect: A 3x3 array-like object representing the lattice vectors.
check: If True, asserts the input shape is correct.
Returns:
A tuple containing [[a, b, c], [alpha, beta, gamma]].
"""
cell_vect_np = np.array(cell_vect)
if check:
assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)."
vec_a, vec_b, vec_c = cell_vect_np
len_a = np.linalg.norm(vec_a)
len_b = np.linalg.norm(vec_b)
len_c = np.linalg.norm(vec_c)
lengths = [len_a, len_b, len_c]
# Calculate angles using the dot product formula; handle potential floating point inaccuracies.
def _calculate_angle(v1, v2, norm1, norm2):
cosine_angle = np.dot(v1, v2) / (norm1 * norm2)
# Clip to handle values slightly outside [-1, 1] due to precision issues
return np.arccos(np.clip(cosine_angle, -1.0, 1.0))
alpha_rad = _calculate_angle(vec_b, vec_c, len_b, len_c)
beta_rad = _calculate_angle(vec_a, vec_c, len_a, len_c)
gamma_rad = _calculate_angle(vec_a, vec_b, len_a, len_b)
angles_deg = np.rad2deg([alpha_rad, beta_rad, gamma_rad]).tolist()
return (lengths, angles_deg)
def atom_frac_to_cart_by_cell_vect(
atom_frac: Coordinates, cell_vect: CellVectors, check: bool = False
) -> List[float]:
"""Converts fractional coordinates to Cartesian coordinates using cell vectors.
Args:
atom_frac: A 3-element list or array of fractional coordinates.
cell_vect: A 3x3 matrix of lattice vectors.
check: If True, asserts input shapes are correct.
Returns:
A list of 3 Cartesian coordinates.
"""
atom_frac_np = np.array(atom_frac)
cell_vect_np = np.array(cell_vect)
if check:
assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)."
assert atom_frac_np.shape == (3,), "Input `atom_frac` must have 3 elements."
# The transformation is a linear combination of the basis vectors.
# atom_cart = frac_x * vec_a + frac_y * vec_b + frac_z * vec_c
# This is equivalent to a dot product: [fx, fy, fz] @ [[ax,ay,az],[bx,by,bz],[cx,cy,cz]]
atom_cart = np.dot(atom_frac_np, cell_vect_np)
return atom_cart.tolist()
def atom_frac_to_cart_by_cell_para(
atom_frac: Coordinates, cell_para: CellParameters, check: bool = False
) -> List[float]:
"""Converts fractional coordinates to Cartesian using cell parameters.
Args:
atom_frac: A 3-element list or array of fractional coordinates.
cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]].
check: If True, performs validation checks in underlying functions.
Returns:
A list of 3 Cartesian coordinates.
"""
cell_vect = cell_para_to_vect(cell_para, check=check)
return atom_frac_to_cart_by_cell_vect(atom_frac, cell_vect, check=check)
def atom_cart_to_frac_by_cell_vect(
atom_cart: Coordinates, cell_vect: CellVectors, check: bool = False
) -> List[float]:
"""Converts Cartesian coordinates to fractional coordinates using cell vectors.
Args:
atom_cart: A 3-element list or array of Cartesian coordinates.
cell_vect: A 3x3 matrix of lattice vectors.
check: If True, asserts input shapes are correct.
Returns:
A list of 3 fractional coordinates.
"""
atom_cart_np = np.array(atom_cart)
cell_vect_np = np.array(cell_vect)
if check:
assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)."
assert atom_cart_np.shape == (3,), "Input `atom_cart` must have 3 elements."
# The transformation is atom_frac = atom_cart @ inverse(cell_vect)
inv_cell_vect = np.linalg.inv(cell_vect_np)
atom_frac = np.dot(atom_cart_np, inv_cell_vect)
return atom_frac.tolist()
def atom_cart_to_frac_by_cell_para(
atom_cart: Coordinates, cell_para: CellParameters, check: bool = False
) -> List[float]:
"""Converts Cartesian coordinates to fractional using cell parameters.
Args:
atom_cart: A 3-element list or array of Cartesian coordinates.
cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]].
check: If True, performs validation checks in underlying functions.
Returns:
A list of 3 fractional coordinates.
"""
cell_vect = cell_para_to_vect(cell_para, check=check)
return atom_cart_to_frac_by_cell_vect(atom_cart, cell_vect, check=check)
def calculate_volume(cell_info: Union[CellParameters, CellVectors]) -> float:
"""Calculates the volume of the unit cell.
Args:
cell_info: Can be either cell parameters [[a,b,c], [al,be,ga]] or
a 3x3 matrix of cell vectors.
Returns:
The volume of the cell in cubic Angstroms.
Raises:
ValueError: If the shape of `cell_info` is not (2, 3) or (3, 3).
"""
cell_info_np = np.array(cell_info)
if cell_info_np.shape == (3, 3):
# Input is cell vectors, calculate volume using the scalar triple product.
return float(np.abs(np.dot(cell_info_np[0], np.cross(cell_info_np[1], cell_info_np[2]))))
elif cell_info_np.shape == (2, 3):
# Input is cell parameters.
lengths, angles_deg = cell_info_np
a, b, c = lengths
alpha, beta, gamma = np.deg2rad(angles_deg)
cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma])
# Standard formula for volume from cell parameters
volume_sq = (
a**2 * b**2 * c**2 * (1 - cos_a**2 - cos_b**2 - cos_g**2 + 2 * cos_a * cos_b * cos_g)
)
return float(np.sqrt(max(0, volume_sq)))
else:
raise ValueError(f"Cannot understand input shape {cell_info_np.shape} for `cell_info`.")
\ No newline at end of file
#!/bin/bash
TOP_DIR=$(pwd)
TAR_DIR="${TOP_DIR}/test"
mkdir -p "${TAR_DIR}"
cd ${TAR_DIR}
# generate structures
python "${TOP_DIR}/main.py" --path ${TAR_DIR} --smiles "OC(=O)c1cc(O)c(O)c(O)c1.O" \
--molecule_num_in_cell 1,1 --space_group_list 13,14 --add_name KONTIQ --max_workers 16\
--num_generation 100 --generate_conformers 20 --use_conformers 4 > generate.log 2>&1
# opt structures using mace, --batch_size 0 means auto batch size only for mace
mkdir -p "${TAR_DIR}/mace_opt"
cd "${TAR_DIR}/mace_opt"
python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \
--molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 80 --batch_size 0 \
--max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \
--optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 1 --cueq true \
--use_ordered_files true --model mace > opt.log 2>&1
# opt structures using 7net
# mkdir -p "${TAR_DIR}/7net_opt"
# cd "${TAR_DIR}/7net_opt"
# python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \
# --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 48 --batch_size 2 \
# --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \
# --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 --cueq true \
# --use_ordered_files true --model sevennet > opt.log 2>&1
# Postprocess the opt structures
python "${TOP_DIR}/post_process/clean_table.py"
## Make sure you have installed csd-python-api in current env before execuing following commands
# conda activate ccdc
# python "${TOP_DIR}/post_process/check_match.py" --workers 80 --timeout 20 --ref_path "${TAR_DIR}/refs"
# python "${TOP_DIR}/post_process/duplicate_remove.py" --workers 80
from basic_function import format_parser
from basic_function import packaged_function
from basic_function import conformer_search
import time
import argparse
import os
import itertools
if __name__ == '__main__':
time_start = time.time()
# initiate configuration
##############################################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, default="./", help='Path to process')
parser.add_argument('--smiles', type=str, required=True, help='SMILES string of the molecules, split by . if multiple molecules are used')
parser.add_argument('--generate_conformers', type=int, default=20, help='Number of conformers to generate. When it is <=0, only load existing conformers to generate structures')
parser.add_argument('--use_conformers', type=int, default=4, help='Number of conformers used to generate structure. When it is <=0, no structure generation would be done')
parser.add_argument('--molecule_num_in_cell', type=str, nargs='+', default=['1'], help='number of molecules in a unit cell, split by comma for multiple molecules, and split by space for multiple packings')
parser.add_argument('--num_generation', type=int, nargs='+', default=[100], help='number of structures to generate, split by space for multiple packings')
parser.add_argument('--space_group_list', type=str, nargs='+', default=["2,14"], help='Space group list for structure generation, spilt by comma to add mutiple groups, split by space for multiple packings')
parser.add_argument('--add_name', type=str, nargs='+', default=["CRYSTAL"], help='Add name for the generated structures, split by space for multiple packings')
parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of workers for parallel processing')
args = parser.parse_args()
target_folder = args.path
smiles_list = args.smiles.split('.')
generate_conformers = args.generate_conformers
use_conformers = args.use_conformers
molecule_num_in_cell = [list(map(int, num.split(','))) for num in args.molecule_num_in_cell]
num_generation = args.num_generation
space_group_list = [list(map(int, group.split(','))) for group in args.space_group_list]
add_name = args.add_name
max_workers = args.max_workers
num_molecules = len(smiles_list)
num_packings = max(len(molecule_num_in_cell), len(space_group_list))
for i in range(len(molecule_num_in_cell)):
if len(molecule_num_in_cell[i]) < num_molecules:
molecule_num_in_cell[i].extend([1] * (num_molecules - len(molecule_num_in_cell[i])))
elif len(molecule_num_in_cell[i]) > num_molecules:
molecule_num_in_cell[i] = molecule_num_in_cell[i][:num_molecules]
while len(molecule_num_in_cell) < num_packings:
molecule_num_in_cell.append(molecule_num_in_cell[-1])
while len(space_group_list) < num_packings:
space_group_list.append(space_group_list[-1])
while len(add_name) < num_packings:
add_name.append(add_name[-1])
while len(num_generation) < num_packings:
num_generation.append(num_generation[-1])
# step1: conformer search
##############################################################################################
molecule_data = []
for i in range(num_molecules):
molecule_folder = os.path.join(target_folder, f"molecule_{i+1}")
molecule_data.append([])
if generate_conformers > 0:
conformer_search.conformer_search(smiles_list[i], molecule_folder, num_conformers=generate_conformers, max_attempts=10000, rms_thresh=0.1)
with open(os.path.join(molecule_folder, "info.txt"), "w") as smiles_file:
smiles_file.write(f"SMILES: {smiles_list[i]}")
for j in range(use_conformers):
temp_path = os.path.join(molecule_folder, "conformers", f"conformer_{j}.xyz")
if not os.path.exists(temp_path):
break
molecule_data[i].append(format_parser.read_xyz_file(temp_path))
if len(molecule_data[i]) <= 0:
print(f"No conformer loaded for molecule_{i+1}. Check configurations!")
break
idx_data = [list(range(len(item))) for item in molecule_data]
combinations = list(itertools.product(*idx_data))
# step2: structure generation
##############################################################################################
for i in range(num_packings):
for combination in combinations:
molecule_list = []
for j in range(num_molecules):
for cnt in range(molecule_num_in_cell[i][j]):
molecule_list.append(molecule_data[j][combination[j]])
c_name = "".join(map(str, combination))
packaged_function.CSP_generater_parallel(molecule_list, target_folder, need_structure=num_generation[i], space_group_list=space_group_list[i],add_name=f"{add_name[i]}_C{c_name}", max_workers=max_workers,start_seed=1)
time_end=time.time()
print('time cost',time_end-time_start,'s')
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
from ccdc.crystal import PackingSimilarity
from ccdc.io import CrystalReader
import glob
import os
import sys
import random
import pandas as pd
from multiprocessing import Pool, cpu_count, TimeoutError as mpTimeoutError # import TimeoutError
import argparse
# --- Global Configuration ---
REPORT_TARGET = 15
LARGE_CONFORMER_DIFF = True
# --- Worker Process Initializer ---
def init_worker(ref_path, engine_settings):
"""
Initializes a worker process.
This function loads the reference structure and creates the similarity engine
once per process, storing them in global variables for that process.
"""
# print(f"Worker process {os.getpid()} initializing...")
# sys.stdout.flush()
global worker_ref_crystal
global worker_similarity_engine
worker_ref_crystal = CrystalReader(ref_path)[0]
worker_similarity_engine = PackingSimilarity()
worker_similarity_engine.settings.allow_molecular_differences = engine_settings['allow_molecular_differences']
worker_similarity_engine.settings.distance_tolerance = engine_settings['distance_tolerance']
worker_similarity_engine.settings.angle_tolerance = engine_settings['angle_tolerance']
worker_similarity_engine.settings.packing_shell_size = engine_settings['packing_shell_size']
worker_similarity_engine.settings.ignore_hydrogen_positions = engine_settings['ignore_hydrogen_positions']
worker_similarity_engine.settings.ignore_bond_counts = engine_settings['ignore_bond_counts']
worker_similarity_engine.settings.ignore_hydrogen_counts = engine_settings['ignore_hydrogen_counts']
# --- Single Task Processing Function ---
def process_single_cif(csp_file_path):
"""
Compares a single candidate structure against the reference structure
loaded in the worker's global scope.
Returns a tuple indicating the result type ('matched' or 'failed') and the file path.
"""
global worker_ref_crystal
global worker_similarity_engine
try:
try_structure = CrystalReader(csp_file_path)[0]
h = worker_similarity_engine.compare(try_structure, worker_ref_crystal)
if h.nmatched_molecules >= REPORT_TARGET:
print(f"MATCH: {os.path.basename(csp_file_path)} | Matched Molecules: {h.nmatched_molecules}, RMSD: {h.rmsd:.3f}")
sys.stdout.flush()
return ("matched", csp_file_path)
except Exception as e:
if not LARGE_CONFORMER_DIFF:
print(f"FAIL: {os.path.basename(csp_file_path)} | Reason: {e}")
sys.stdout.flush()
return ("failed", csp_file_path)
return None
# --- Main Execution Block ---
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, default="./", help='Path to process')
parser.add_argument('--ref_path', type=str, default="../refs", help='Path to find reference structrues')
parser.add_argument('--workers', type=int, default=80, help='Max worker number limit')
parser.add_argument('--timeout', type=int, default=20, help='Timeout for each task in seconds')
args = parser.parse_args()
base_path = args.path
refs_dir = args.ref_path
PROCESS_NUM = min(args.workers, cpu_count()) # Use the specified number of workers or the max available
TIMEOUT_SECONDS = args.timeout # Set the timeout for each task
all_results = []
print(f"Starting checking match using up to {PROCESS_NUM} processes with a {TIMEOUT_SECONDS}s timeout per task...")
folders_to_process = []
csp_dir = os.path.join(base_path, "cif_result_final")
if os.path.exists(csp_dir) and os.path.exists(refs_dir):
folders_to_process.append((csp_dir, refs_dir))
for csp_dir, refs_dir in folders_to_process:
for ref_filename in os.listdir(refs_dir):
if not ref_filename.endswith(".cif"):
continue
ref_full_path = os.path.join(refs_dir, ref_filename)
print(f"\n--- Processing Reference File: {ref_full_path} ---")
csp_files = glob.glob(os.path.join(csp_dir, '*.cif'))
random.shuffle(csp_files)
if not csp_files:
print("No candidate .cif files found, skipping.")
continue
engine_settings = {
'allow_molecular_differences': False,
'distance_tolerance': 0.2,
'angle_tolerance': 20,
'packing_shell_size': 15,
'ignore_hydrogen_positions': True,
'ignore_bond_counts': True,
'ignore_hydrogen_counts': True
}
with Pool(processes=PROCESS_NUM, initializer=init_worker, initargs=(ref_full_path, engine_settings)) as pool:
async_results = []
for f in csp_files:
res = pool.apply_async(process_single_cif, args=(f,))
async_results.append(res)
results_list = []
for i, res_obj in enumerate(async_results):
try:
result = res_obj.get(timeout=TIMEOUT_SECONDS)
results_list.append(result)
except mpTimeoutError:
timed_out_file = csp_files[i]
print(f"TIMEOUT: {timed_out_file} | Task exceeded {TIMEOUT_SECONDS}s limit.")
sys.stdout.flush()
results_list.append(("failed", timed_out_file))
matched_structures = []
failed_structures = []
for res in results_list:
if res:
status, path = res
if status == "matched":
matched_structures.append(os.path.basename(path))
elif status == "failed":
failed_structures.append(os.path.basename(path))
all_results.append({
"ref_name": ref_filename,
"matched_count": len(matched_structures),
"matched_structures": ";".join(matched_structures),
"failed_count": len(failed_structures),
"failed_structures": ";".join(failed_structures)
})
print(f"--- Finished {ref_filename}. Matched: {len(matched_structures)}, Failed: {len(failed_structures)} ---")
if all_results:
df = pd.DataFrame(all_results)
output_filename = "match_results.csv"
df.to_csv(output_filename, index=False)
print(f"\nAll processing finished. Results saved to {output_filename}")
else:
print("\nNo valid data processed. No output file generated.")
\ No newline at end of file
import os
import sys
import pandas as pd
target_folder = os.getcwd()
if __name__ == '__main__':
print(f"Cleaning table in folder: {target_folder}")
error_folder = os.path.join(target_folder, "error_result")
if not os.path.exists(error_folder):
os.makedirs(error_folder)
cif_folder = os.path.join(target_folder, "cif_result_final")
csv_file = os.path.join(target_folder, "results_scheduler.csv")
if not os.path.exists(csv_file):
print(f"CSV file not found in {target_folder}, skipping...")
sys.exit(0)
if not os.path.exists(cif_folder):
print(f"CIF folder not found in {target_folder}, skipping...")
sys.exit(0)
all_crystals = pd.read_csv(csv_file)
cifs = os.listdir(cif_folder)
invalid_list = []
for i, item in all_crystals.iterrows():
if int(item['stage2_steps']) > 2990 or (item['stage1_steps'] >= 2999 and item["stage2_steps"] == 0) or abs(float(item['stage2_energy'])) > 1e15 or item['file']+".cif" not in cifs:
invalid_list.append(i)
if invalid_list:
invalid_df = all_crystals.iloc[invalid_list]
all_crystals.drop(index=invalid_list, inplace=True)
all_crystals.sort_values(by=["stage2_energy"], ascending=True, inplace=True)
all_crystals.reset_index(drop=True, inplace=True)
min_energy = all_crystals['stage2_energy'][0]
all_crystals['relative_energy'] = all_crystals['stage2_energy'] - min_energy
for cif in cifs:
if cif[:-4] not in all_crystals['file'].values:
# move cif to error folder
src_path = os.path.join(cif_folder, cif)
dest_path = os.path.join(error_folder, cif)
os.rename(src_path, dest_path)
# Save the cleaned DataFrame back to CSV
all_crystals.to_csv(csv_file, index=False)
\ No newline at end of file
import sys
import os
import glob
import pandas as pd
import warnings
from pathlib import Path
from tqdm import tqdm
import multiprocessing as mp
from ccdc.crystal import PackingSimilarity
from ccdc.io import CrystalReader
import shutil
import argparse
warnings.filterwarnings("ignore", category=DeprecationWarning)
#########################
# Global Settings
#########################
# The number of matching molecules required in a packing shell to consider two crystals similar.
molecule_shell_size = 15
# Initialize the packing similarity engine from the CCDC API.
similarity_engine = PackingSimilarity()
# Configure the similarity engine settings.
similarity_engine.settings.allow_molecular_differences = False
similarity_engine.settings.distance_tolerance = 0.2
similarity_engine.settings.angle_tolerance = 20
similarity_engine.settings.packing_shell_size = molecule_shell_size
# Settings to make the comparison less strict regarding hydrogen atoms and bond counts.
similarity_engine.settings.ignore_hydrogen_positions = True
similarity_engine.settings.ignore_bond_counts = True
similarity_engine.settings.ignore_hydrogen_counts = True
# Pre-filtering thresholds to avoid expensive comparisons for vastly different structures.
ENERGY_DIFF = 5.0
DENSITY_DIFF = 0.2
def compare_sim_pack_pair(args):
"""
Compares two crystal structures to determine if they are identical based on packing similarity.
Accepts a tuple (candidate_fp, ufp) containing file paths to the two structures.
Returns True if the structures are considered the same, False otherwise.
"""
cand_fp, ufp = args
try:
c1 = CrystalReader(cand_fp)[0]
c2 = CrystalReader(ufp)[0]
h = similarity_engine.compare(c1, c2)
# Structures are deemed identical if the number of matched molecules meets the global shell size.
return h.nmatched_molecules >= molecule_shell_size
except Exception:
# Return False if any error occurs during file reading or comparison.
return False
def find_top_unique(struct_list, target_count, n_workers):
"""
Identifies a target number of unique structures from a list, sorted by energy.
It iterates through structures from lowest to highest energy and uses parallel processing
to compare each candidate against the list of already confirmed unique structures.
Args:
struct_list (list): A list of tuples, where each tuple is (density, energy, filepath).
target_count (int): The desired number of unique structures to find.
n_workers (int): The number of worker processes for parallel comparison.
Returns:
tuple: A tuple containing:
- A list of file paths for the unique structures found.
- A dictionary mapping each unique structure's path to a list of its duplicate paths.
"""
unique_list = []
# A dictionary to store each unique structure and its corresponding duplicates.
# Format: {unique_fp: [dup_fp1, dup_fp2, ...]}
duplicate_map = {}
# Initialize the multiprocessing pool.
pool = mp.Pool(processes=n_workers)
try:
# Iterate through the main structure list, which is pre-sorted by energy.
for dens, ene, cand_fp in struct_list:
# Stop if the target count of unique structures has been reached.
if len(unique_list) >= target_count:
break
# Create comparison tasks only for unique structures within the density and energy thresholds.
tasks = [(cand_fp, ufp) for dens2, ene2, ufp in unique_list if abs(dens2 - dens) < DENSITY_DIFF and abs(ene2 - ene) < ENERGY_DIFF]
is_dup = False
if tasks:
# Use tqdm for a progress bar during parallel comparison.
results_iterator = tqdm(pool.imap(compare_sim_pack_pair, tasks),
total=len(tasks),
desc=f"Comparing {os.path.basename(cand_fp)}",
leave=False)
# Convert iterator to a list to process results.
results_list = list(results_iterator)
# Pair the boolean results with the original tasks to identify which comparison was successful.
for result, task in zip(results_list, tasks):
if result:
is_dup = True
# The second element of the task tuple is the path of the matched unique structure.
matched_unique_fp = task[1]
# Record the current candidate as a duplicate of the matched unique structure.
duplicate_map[matched_unique_fp].append(cand_fp)
# Once a duplicate is found, no more comparisons are needed for this candidate.
break
if not is_dup:
# If the candidate is not a duplicate of any existing unique structure, add it to the list.
unique_list.append((dens, ene, cand_fp))
# Create a new entry in the duplicate map for this new unique structure.
duplicate_map[cand_fp] = []
# Print real-time progress.
print(f">>> Unique count: {len(unique_list)} (added {os.path.basename(cand_fp)})")
finally:
# Ensure the multiprocessing pool is properly closed.
pool.close()
pool.join()
# Return the list of unique file paths and the map of duplicates.
return [fp for _, _, fp in unique_list], duplicate_map
def process_folder(folder_path):
"""
Main workflow function to process a given folder. It reads structural data,
finds unique structures, and organizes the results into folders and a CSV file.
"""
base_path = folder_path
# Define output directories for unique structures and their duplicates.
t_folder = os.path.join(base_path, f"unique_{TARGET_UNIQUE_COUNT}")
d_folder = os.path.join(base_path, f"unique_{TARGET_UNIQUE_COUNT}_duplicates")
# Load the summary data from the CSV file.
df = pd.read_csv(os.path.join(base_path, "results_scheduler.csv"))
cif_folder = os.path.join(base_path, "cif_result_final")
#########################
# Create a map from a simplified filename to its full file path for quick lookup.
file_paths = glob.glob(os.path.join(cif_folder, "*.cif"))
file_map = {}
for fp in file_paths:
base = Path(fp).stem
# Standardize the name by removing "_opt" suffix if it exists.
if base.endswith("_opt"):
key = base[:-4]
else:
key = base
file_map[key] = fp
# Build the list of candidate structures, filtering by the energy threshold.
struct_list = []
for _, row in df.iterrows():
name = row["file"]
dens = row["stage2_density"]
ene = row["relative_energy"]
if ene > ENERGY_THRESHOLD:
continue
fp = file_map.get(name)
if fp:
struct_list.append((dens, ene, fp))
else:
print(f"Warning: no CIF for {name}")
# Sort the list of structures by energy in ascending order.
struct_list.sort(key=lambda x: x[1])
# Call the main function to find unique structures and the duplicate map.
unique_paths, duplicate_map = find_top_unique(
struct_list,
target_count=TARGET_UNIQUE_COUNT,
n_workers=WORKER_COUNT
)
# Prepare output directories, removing them first if they already exist.
if os.path.exists(t_folder):
shutil.rmtree(t_folder)
os.makedirs(t_folder)
if os.path.exists(d_folder):
shutil.rmtree(d_folder)
os.makedirs(d_folder)
print(f"\nFound {len(unique_paths)} unique structures "
f"(energy <= {ENERGY_THRESHOLD}, target {TARGET_UNIQUE_COUNT}).")
# Copy the unique structure files to the target folder.
for p in unique_paths:
shutil.copy(p, t_folder)
# Helper function to get a clean filename without the "_opt" suffix.
get_clean_name = lambda p: Path(p).stem.replace('_opt', '')
# Create a new DataFrame containing only the data for the unique structures.
unique_names = [get_clean_name(p) for p in unique_paths]
unique_df = df[df['file'].isin(unique_names)].copy()
unique_df['duplicates'] = '' # Add a new column to store duplicate names.
# Populate the 'duplicates' column and copy duplicate files to their folder.
for unique_fp, duplicates_list in duplicate_map.items():
unique_name = get_clean_name(unique_fp)
if unique_name in unique_df['file'].values:
du_name = [get_clean_name(p) for p in duplicates_list]
# Add the list of duplicate names as a comma-separated string.
unique_df.loc[unique_df['file'] == unique_name, 'duplicates'] = ', '.join(du_name)
# Copy duplicate files to the duplicates folder.
for p in duplicates_list:
shutil.copy(p, d_folder)
# Save the final DataFrame with unique structures and their duplicates to a new CSV file.
unique_csv_path = os.path.join(base_path, 'unique_structures.csv')
unique_df.to_csv(unique_csv_path, index=False)
if __name__ == '__main__':
# Required for freezing the application when creating executables with multiprocessing.
mp.freeze_support()
# Set up command-line argument parsing.
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, default="./", help='Path to process')
parser.add_argument('--energy', type=float, default=30, help='Energy threshold for filtering structures')
parser.add_argument('--count', type=int, default=200, help='Target number of unique structures to find')
parser.add_argument('--workers', type=int, default=80, help='Max worker number limit')
args = parser.parse_args()
# Set global variables from command-line arguments.
target_folder = args.path
ENERGY_THRESHOLD = args.energy
TARGET_UNIQUE_COUNT = args.count
WORKER_COUNT = args.workers
# Start the processing if the specified folder exists.
if os.path.exists(target_folder):
print(f"Processing folder: {target_folder}")
process_folder(target_folder)
\ No newline at end of file
#!/bin/sh -l
#An example for GPU job.
#SBATCH -D ./
#SBATCH --export=ALL
#SBATCH -J csp_test
#SBATCH -o job-%j.log
#SBATCH -e job-%j.err
#SBATCH -p GPU-8A100
#SBATCH -N 1 -n 8
#SBATCH --gres=gpu:1
#SBATCH --qos=gpu_8a100
#SBATCH -t 3-00:00:00
echo =========================================================
echo SLURM job: submitted date = `date`
date_start=`date +%s`
echo =========================================================
echo Job output begins
echo -----------------
echo
python duplicate_remove3.py
echo
echo ---------------
echo Job output ends
date_end=`date +%s`
seconds=$((date_end-date_start))
minutes=$((seconds/60))
seconds=$((seconds-60*minutes))
hours=$((minutes/60))
minutes=$((minutes-60*hours))
echo =========================================================
echo SLURM job: finished date = `date`
echo Total run time : $hours Hours $minutes Minutes $seconds Seconds
echo =========================================================
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment