Unverified Commit fa84b16c authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent 09624897
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import ase.db.sqlite
import ase.io.trajectory
import numpy as np
import torch
from ase.geometry import wrap_positions
from torch_geometric.data import Data
from batchopt.utils import collate
if TYPE_CHECKING:
from collections.abc import Sequence
try:
from pymatgen.io.ase import AseAtomsAdaptor
except ImportError:
AseAtomsAdaptor = None
from tqdm import tqdm
class AtomsToGraphs:
"""A class to help convert periodic atomic structures to graphs.
The AtomsToGraphs class takes in periodic atomic structures in form of ASE atoms objects and converts
them into graph representations for use in PyTorch. The primary purpose of this class is to determine the
nearest neighbors within some radius around each individual atom, taking into account PBC, and set the
pair index and distance between atom pairs appropriately. Lastly, atomic properties and the graph information
are put into a PyTorch geometric data object for use with PyTorch.
Args:
max_neigh (int): Maximum number of neighbors to consider.
radius (int or float): Cutoff radius in Angstroms to search for neighbors.
r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned.
r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned.
r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned.
r_distances (bool): Return the distances with other properties.
Default is False, so the distances will not be returned.
r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned.
r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms.
Default is True, so the fixed indices will be returned.
r_pbc (bool): Return the periodic boundary conditions with other properties.
Default is False, so the periodic boundary conditions will not be returned.
r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other
properties. Default is None, so no data will be returned as properties.
Attributes:
max_neigh (int): Maximum number of neighbors to consider.
radius (int or float): Cutoff radius in Angstoms to search for neighbors.
r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned.
r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned.
r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned.
r_distances (bool): Return the distances with other properties.
Default is False, so the distances will not be returned.
r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned.
r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms.
Default is True, so the fixed indices will be returned.
r_pbc (bool): Return the periodic boundary conditions with other properties.
Default is False, so the periodic boundary conditions will not be returned.
r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other
properties. Default is None, so no data will be returned as properties.
"""
def __init__(
self,
max_neigh: int = 200,
radius: int = 6,
r_energy: bool = False,
r_forces: bool = False,
r_distances: bool = False,
r_edges: bool = True,
r_fixed: bool = True,
r_pbc: bool = False,
r_stress: bool = False,
r_data_keys: Sequence[str] | None = None,
) -> None:
self.max_neigh = max_neigh
self.radius = radius
self.r_energy = r_energy
self.r_forces = r_forces
self.r_stress = r_stress
self.r_distances = r_distances
self.r_fixed = r_fixed
self.r_edges = r_edges
self.r_pbc = r_pbc
self.r_data_keys = r_data_keys
def _get_neighbors_pymatgen(self, atoms: ase.Atoms):
"""Preforms nearest neighbor search and returns edge index, distances,
and cell offsets"""
if AseAtomsAdaptor is None:
raise RuntimeError(
"Unable to import pymatgen.io.ase.AseAtomsAdaptor. Make sure pymatgen is properly installed."
)
struct = AseAtomsAdaptor.get_structure(atoms)
_c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list(
r=self.radius, numerical_tol=0, exclude_self=True
)
_nonmax_idx = []
for i in range(len(atoms)):
idx_i = (_c_index == i).nonzero()[0]
# sort neighbors by distance, remove edges larger than max_neighbors
idx_sorted = np.argsort(n_distance[idx_i])[: self.max_neigh]
_nonmax_idx.append(idx_i[idx_sorted])
_nonmax_idx = np.concatenate(_nonmax_idx)
_c_index = _c_index[_nonmax_idx]
_n_index = _n_index[_nonmax_idx]
n_distance = n_distance[_nonmax_idx]
_offsets = _offsets[_nonmax_idx]
return _c_index, _n_index, n_distance, _offsets
def _reshape_features(self, c_index, n_index, n_distance, offsets):
"""Stack center and neighbor index and reshapes distances,
takes in np.arrays and returns torch tensors"""
edge_index = torch.LongTensor(np.vstack((n_index, c_index)))
edge_distances = torch.FloatTensor(n_distance)
cell_offsets = torch.LongTensor(offsets)
# remove distances smaller than a tolerance ~ 0. The small tolerance is
# needed to correct for pymatgen's neighbor_list returning self atoms
# in a few edge cases.
nonzero = torch.where(edge_distances >= 1e-8)[0]
edge_index = edge_index[:, nonzero]
edge_distances = edge_distances[nonzero]
cell_offsets = cell_offsets[nonzero]
return edge_index, edge_distances, cell_offsets
def get_edge_distance_vec(
self,
pos,
edge_index,
cell,
cell_offsets,
):
row, col = edge_index
distance_vectors = pos[row] - pos[col]
# correct for pbc
cell = torch.repeat_interleave(cell, edge_index.shape[1], dim=0)
offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3)
distance_vectors += offsets
return distance_vectors
def convert(self, atoms: ase.Atoms, sid=None):
"""Convert a single atomic structure to a graph.
Args:
atoms (ase.atoms.Atoms): An ASE atoms object.
sid (uniquely identifying object): An identifier that can be used to track the structure in downstream
tasks. Common sids used in OCP datasets include unique strings or integers.
Returns:
data (torch_geometric.data.Data): A torch geometic data object with positions, atomic_numbers, tags,
and optionally, energy, forces, distances, edges, and periodic boundary conditions.
Optional properties can included by setting r_property=True when constructing the class.
"""
# set the atomic numbers, positions, and cell
positions = np.array(atoms.get_positions(), copy=True)
pbc = np.array(atoms.pbc, copy=True)
cell = np.array(atoms.get_cell(complete=True), copy=True)
# TODO: change this back &&& ^^^
# positions = wrap_positions(positions, cell, pbc=pbc, eps=0)
atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.uint8)
positions = torch.from_numpy(positions).float()
cell = torch.from_numpy(cell).view(1, 3, 3).float()
natoms = positions.shape[0]
# initialized to torch.zeros(natoms) if tags missing.
# https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags
tags = torch.tensor(atoms.get_tags(), dtype=torch.int)
# put the minimum data in torch geometric data object
data = Data(
cell=cell,
pos=positions,
atomic_numbers=atomic_numbers,
natoms=natoms,
tags=tags,
)
# Optionally add a systemid (sid) to the object
if sid is not None:
data.sid = sid
# optionally include other properties
if self.r_edges:
# run internal functions to get padded indices and distances
atoms_copy = atoms.copy()
atoms_copy.set_positions(positions)
split_idx_dist = self._get_neighbors_pymatgen(atoms_copy)
edge_index, edge_distances, cell_offsets = self._reshape_features(
*split_idx_dist
)
data.edge_index = edge_index
data.cell_offsets = cell_offsets
data.edge_distance_vec = self.get_edge_distance_vec(
positions, edge_index, cell, cell_offsets
)
del atoms_copy
if self.r_energy:
energy = atoms.get_potential_energy(apply_constraint=False)
data.energy = energy
if self.r_forces:
forces = torch.tensor(
atoms.get_forces(apply_constraint=False), dtype=torch.float32
)
data.forces = forces
if self.r_stress:
stress = torch.tensor(
atoms.get_stress(apply_constraint=False, voigt=False),
dtype=torch.float32,
)
data.stress = stress
if self.r_distances and self.r_edges:
data.distances = edge_distances
if self.r_fixed:
fixed_idx = torch.zeros(natoms, dtype=torch.int)
if hasattr(atoms, "constraints"):
from ase.constraints import FixAtoms
for constraint in atoms.constraints:
if isinstance(constraint, FixAtoms):
fixed_idx[constraint.index] = 1
data.fixed = fixed_idx
if self.r_pbc:
data.pbc = torch.tensor(atoms.pbc, dtype=torch.bool)
if self.r_data_keys is not None:
for data_key in self.r_data_keys:
data[data_key] = (
atoms.info[data_key]
if isinstance(atoms.info[data_key], (int, float, str))
else torch.tensor(atoms.info[data_key])
)
return data
def convert_all(
self,
atoms_collection,
processed_file_path: str | None = None,
collate_and_save=False,
disable_tqdm=False,
):
"""Convert all atoms objects in a list or in an ase.db to graphs.
Args:
atoms_collection (list of ase.atoms.Atoms or ase.db.sqlite.SQLite3Database):
Either a list of ASE atoms objects or an ASE database.
processed_file_path (str):
A string of the path to where the processed file will be written. Default is None.
collate_and_save (bool): A boolean to collate and save or not. Default is False, so will not write a file.
Returns:
data_list (list of torch_geometric.data.Data):
A list of torch geometric data objects containing molecular graph info and properties.
"""
# list for all data
data_list = []
if isinstance(atoms_collection, list):
atoms_iter = atoms_collection
elif isinstance(atoms_collection, ase.db.sqlite.SQLite3Database):
atoms_iter = atoms_collection.select()
elif isinstance(
atoms_collection,
(ase.io.trajectory.SlicedTrajectory, ase.io.trajectory.TrajectoryReader),
):
atoms_iter = atoms_collection
else:
raise NotImplementedError
for atoms in tqdm(
atoms_iter,
desc="converting ASE atoms collection to graphs",
total=len(atoms_collection),
unit=" systems",
disable=disable_tqdm,
):
# check if atoms is an ASE Atoms object this for the ase.db case
data = self.convert(
atoms if isinstance(atoms, ase.atoms.Atoms) else atoms.toatoms()
)
data_list.append(data)
if collate_and_save:
data, slices = collate(data_list)
torch.save((data, slices), processed_file_path)
return data_list
"""
Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan}
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from ase.io import read
import logging
from joblib import Parallel, delayed
from ase.optimize import LBFGS as ASE_LBFGS
from ase.optimize import QuasiNewton as ASE_QuasiNewton
from ase.optimize import BFGS as ASE_BFGS
import time
import csv
import os
try:
from mace.calculators import mace_off
except ImportError:
logging.warning("Failed to import MACE modules")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def ensure_directory(directory):
"""Create directory if it doesn't exist."""
if not os.path.exists(directory):
os.makedirs(directory)
logging.info(f"Created directory: {directory}")
def baseline_task(file, device, max_steps, filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006, first_optimizer="LBFGS", second_optimizer="LBFGS"):
"""
Runs the baseline optimization using LBFGS from ase.optimize.
"""
os.environ["CUDA_VISIBLE_DEVICES"] = device.split(":")[-1]
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.info(f"Starting baseline optimization for file {file} on device {device}.")
start_time = time.perf_counter()
crystal = read(file)
# calc = mace_off(model="small", device=device)
calc = mace_off(model="small", device="cuda")
crystal.calc = calc
first_optimizer_class ={
"LBFGS": ASE_LBFGS,
"QuasiNewton": ASE_QuasiNewton,
"BFGS": ASE_BFGS
}.get(first_optimizer, ASE_LBFGS)
# First optimization stage
if filter1 == "UnitCellFilter":
from ase.filters import UnitCellFilter
atoms_with_filter = UnitCellFilter(crystal, scalar_pressure=scalar_pressure)
first_optimizer_instance = first_optimizer_class(atoms_with_filter)
elif filter1 == "FrechetCellFilter":
from ase.filters import FrechetCellFilter
atoms_with_filter = FrechetCellFilter(crystal, scalar_pressure=scalar_pressure)
first_optimizer_instance = first_optimizer_class(atoms_with_filter)
else:
first_optimizer_instance = first_optimizer_class(crystal)
start_time1 = time.perf_counter()
first_optimizer_instance.run(fmax=0.01, steps=max_steps)
end_time1 = time.perf_counter()
# Save intermediate result
output_dir_press = "./cif_result_press"
output_file_press = os.path.join(output_dir_press, os.path.basename(file).replace(".cif", "_press.cif"))
crystal.write(output_file_press)
elapsed_time1 = end_time1 - start_time1
steps1 = first_optimizer_instance.nsteps
if skip_second_stage:
ret_result = {
"file": file,
"stage1_time": elapsed_time1,
"stage1_steps": steps1,
"stage2_time": 0.0,
"stage2_steps": 0,
"total_time": elapsed_time1,
"total_steps": steps1
}
else:
# Second optimization stage
crystal = read(output_file_press)
crystal.calc = calc
second_optimizer_class = {
"LBFGS": ASE_LBFGS,
"QuasiNewton": ASE_QuasiNewton,
"BFGS": ASE_BFGS
}.get(second_optimizer, ASE_LBFGS)
if filter2 == "UnitCellFilter":
from ase.filters import UnitCellFilter
atoms_with_filter2 = UnitCellFilter(crystal)
second_optimizer_instance = second_optimizer_class(atoms_with_filter2)
elif filter2 == "FrechetCellFilter":
from ase.filters import FrechetCellFilter
atoms_with_filter2 = FrechetCellFilter(crystal)
second_optimizer_instance = second_optimizer_class(atoms_with_filter2)
else:
second_optimizer_instance = second_optimizer_class(crystal)
start_time2 = time.perf_counter()
second_optimizer_instance.run(fmax=0.01, steps=max_steps)
end_time2 = time.perf_counter()
# Save final result
output_dir_final = "./cif_result_final"
output_file_final = os.path.join(output_dir_final, os.path.basename(file).replace(".cif", "_opt.cif"))
crystal.write(output_file_final)
# Collect metrics
elapsed_time2 = end_time2 - start_time2
total_time = elapsed_time1 + elapsed_time2
steps2 = second_optimizer_instance.nsteps
ret_result = {
"file": file,
"stage1_time": elapsed_time1,
"stage1_steps": steps1,
"stage2_time": elapsed_time2,
"stage2_steps": steps2,
"total_time": total_time,
"total_steps": steps1 + steps2
}
logging.info(f"Baseline optimization completed for file {file}.")
return ret_result
def run_baseline(files, num_workers, devices, max_steps,
filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006,
optimizer1=None, optimizer2=None):
"""
Runs the baseline optimization using LBFGS from ase.optimize.
"""
logging.info(f"Starting baseline optimization with {num_workers} workers.")
start_time = time.perf_counter()
results = Parallel(n_jobs=num_workers)(
delayed(baseline_task)(file, devices[i % len(devices)], max_steps, filter1, filter2, skip_second_stage, scalar_pressure, optimizer1, optimizer2)
for i, file in enumerate(files)
)
end_time = time.perf_counter()
csv_file = "results_baseline.csv"
with open(csv_file, mode='w', newline='') as file:
writer = csv.DictWriter(file, fieldnames=["file", "stage1_time", "stage1_steps", "stage2_time", "stage2_steps", "total_time", "total_steps"])
writer.writeheader()
for result in results:
writer.writerow(result)
logging.info(f"Baseline optimization completed in {end_time - start_time:.2f} seconds.")
final_elapsed_time = end_time - start_time
summary_csv_file = "summary_baseline.csv"
with open(summary_csv_file, mode='w', newline='') as file:
writer = csv.DictWriter(file, fieldnames=["elapsed_time", "num_workers", "batch_size"])
writer.writeheader()
writer.writerow({
"elapsed_time": final_elapsed_time,
"num_workers": num_workers,
"batch_size": 1
})
logging.info(f"Summary results written to {summary_csv_file}.")
\ No newline at end of file
"""
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
BatchOpt Extensions - C++ and CUDA implementations for performance-critical operations.
This module provides optimized implementations of common operations using
torch.utils.cpp_extension for JIT compilation.
"""
"""
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
CUDA Extension wrapper for vector addition and PBC graph operations.
"""
import torch
from torch.utils.cpp_extension import load
import os
def load_cuda_extension():
"""Load the CUDA extension for vector addition."""
# Check if CUDA is available
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. Cannot load CUDA extension.")
# Get the directory of this file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Path to the CUDA source file
cuda_file = os.path.join(current_dir, "vector_add.cu")
# Load the extension
return load(
name="vector_add_cuda",
sources=[cuda_file],
verbose=True,
extra_cflags=['-O3'],
extra_cuda_cflags=['-O3', '--use_fast_math'],
)
def load_pbc_graph_cuda_extension():
"""Load the CUDA extension for PBC graph operations."""
# Check if CUDA is available
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. Cannot load CUDA extension.")
# Get the directory of this file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Path to the CUDA source file
cuda_file = os.path.join(current_dir, "pbc_graph.cu")
# Load the extension
return load(
name="pbc_graph_cuda",
sources=[cuda_file],
verbose=True,
extra_cflags=['-O3'],
extra_cuda_cflags=['-O3', '--use_fast_math'],
)
# Global variable to store loaded extension
_cuda_extension = None
_pbc_graph_cuda_extension = None
def get_cuda_extension():
"""Get or load the CUDA extension."""
global _cuda_extension
if _cuda_extension is None:
_cuda_extension = load_cuda_extension()
return _cuda_extension
def get_pbc_graph_cuda_extension():
"""Get or load the PBC graph CUDA extension."""
global _pbc_graph_cuda_extension
if _pbc_graph_cuda_extension is None:
_pbc_graph_cuda_extension = load_pbc_graph_cuda_extension()
return _pbc_graph_cuda_extension
def vector_add_cuda(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Perform vector addition using CUDA implementation.
Args:
a: First input tensor (must be on CUDA device)
b: Second input tensor (must be on CUDA device)
Returns:
Result tensor of element-wise addition
"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available.")
if not (a.is_cuda and b.is_cuda):
raise ValueError("CUDA implementation requires CUDA tensors. Use .cuda() to move tensors to GPU.")
extension = get_cuda_extension()
return extension.vector_add(a.float(), b.float())
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <type_traits>
// Template function to get appropriate epsilon for different floating point types
template<typename T>
__device__ __forceinline__ T get_epsilon() {
if constexpr (std::is_same_v<T, float>) {
return static_cast<T>(1e-8);
} else if constexpr (std::is_same_v<T, double>) {
return static_cast<T>(1e-12);
} else {
return static_cast<T>(1e-8); // fallback
}
}
// Templated CUDA kernel for computing pairwise distances with PBC offsets
// This version avoids repeat_interleave by computing offsets directly in the kernel
template<typename T>
__global__ void pbc_distance_kernel_optimized(
const T* pos1,
const T* pos2,
const T* pbc_offsets, // [batch_size, 3]
const int64_t* num_atoms_per_image_sqr, // [batch_size]
const int64_t* batch_offsets, // [batch_size] - cumulative offsets for each batch
T* distances_squared,
bool* valid_mask,
int num_pairs,
T radius_squared
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_pairs) {
// Find which batch this pair belongs to
int batch_idx = 0;
while (batch_idx < num_pairs && idx >= batch_offsets[batch_idx + 1]) {
batch_idx++;
}
// Get PBC offset for this batch
T offset_x = pbc_offsets[batch_idx * 3];
T offset_y = pbc_offsets[batch_idx * 3 + 1];
T offset_z = pbc_offsets[batch_idx * 3 + 2];
// Get positions for this atom pair with PBC offset
T dx = pos2[idx * 3] - pos1[idx * 3] + offset_x;
T dy = pos2[idx * 3 + 1] - pos1[idx * 3 + 1] + offset_y;
T dz = pos2[idx * 3 + 2] - pos1[idx * 3 + 2] + offset_z;
// Compute squared distance
T dist_sq = dx * dx + dy * dy + dz * dz;
distances_squared[idx] = dist_sq;
// Check if within radius
valid_mask[idx] = (dist_sq <= radius_squared) && (dist_sq > get_epsilon<T>());
}
}
// Original kernel for fallback
template<typename T>
__global__ void pbc_distance_kernel(
const T* pos1,
const T* pos2,
const T* pbc_offsets,
T* distances_squared,
bool* valid_mask,
int num_pairs,
T radius_squared
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_pairs) {
// Get positions for this atom pair
T dx = pos2[idx * 3] - pos1[idx * 3] + pbc_offsets[idx * 3];
T dy = pos2[idx * 3 + 1] - pos1[idx * 3 + 1] + pbc_offsets[idx * 3 + 1];
T dz = pos2[idx * 3 + 2] - pos1[idx * 3 + 2] + pbc_offsets[idx * 3 + 2];
// Compute squared distance
T dist_sq = dx * dx + dy * dy + dz * dz;
distances_squared[idx] = dist_sq;
// Check if within radius
valid_mask[idx] = (dist_sq <= radius_squared) && (dist_sq > get_epsilon<T>());
}
}
// Template helper function to launch the appropriate optimized kernel
template<typename T>
inline void launch_pbc_distance_kernel_optimized(
const T* pos1,
const T* pos2,
const T* pbc_offsets,
const int64_t* num_atoms_per_image_sqr,
const int64_t* batch_offsets,
T* distances_squared,
bool* valid_mask,
int num_pairs,
T radius_squared,
int blocks,
int threads_per_block
) {
pbc_distance_kernel_optimized<T><<<blocks, threads_per_block>>>(
pos1, pos2, pbc_offsets, num_atoms_per_image_sqr, batch_offsets,
distances_squared, valid_mask, num_pairs, radius_squared
);
}
// Template helper function to launch the appropriate kernel (fallback)
template<typename T>
void launch_pbc_distance_kernel(
const T* pos1,
const T* pos2,
const T* pbc_offsets,
T* distances_squared,
bool* valid_mask,
int num_pairs,
T radius_squared,
int blocks,
int threads_per_block
) {
pbc_distance_kernel<T><<<blocks, threads_per_block>>>(
pos1, pos2, pbc_offsets, distances_squared, valid_mask, num_pairs, radius_squared
);
}
// CUDA function to compute distances for all unit cell offsets
std::vector<torch::Tensor> pbc_distance_cuda(
torch::Tensor pos1,
torch::Tensor pos2,
torch::Tensor data_cell,
torch::Tensor num_atoms_per_image_sqr,
int batch_size,
std::vector<int> max_rep,
float radius,
torch::Device device
) {
// Convert tensors to CUDA if not already, but preserve original dtype
pos1 = pos1.to(device).contiguous();
pos2 = pos2.to(device).contiguous();
data_cell = data_cell.to(device).contiguous();
num_atoms_per_image_sqr = num_atoms_per_image_sqr.to(device);
// Check that all position tensors have the same dtype
TORCH_CHECK(pos1.dtype() == pos2.dtype(), "pos1 and pos2 must have the same dtype");
TORCH_CHECK(pos1.dtype() == data_cell.dtype(), "pos1 and data_cell must have the same dtype");
// Determine if we're working with float32 or float64
bool is_float64 = pos1.dtype() == torch::kFloat64;
int num_pairs = pos1.size(0);
// Storage for all results across unit cells
std::vector<torch::Tensor> all_index1, all_index2, all_unit_cell, all_distances_sq;
// Create base indices for original atom pairs
torch::Tensor base_indices = torch::arange(num_pairs, torch::dtype(torch::kLong).device(device));
// Launch parameters
int threads_per_block = 512;
int blocks = (num_pairs + threads_per_block - 1) / threads_per_block;
// Pre-allocate tensors outside the loop for reuse
torch::Tensor distances_squared = torch::zeros({num_pairs},
torch::dtype(pos1.dtype()).device(device));
torch::Tensor valid_mask = torch::zeros({num_pairs},
torch::dtype(torch::kBool).device(device));
torch::Tensor unit_cell_offset = torch::zeros({3},
torch::dtype(pos1.dtype()).device(device));
torch::Tensor unit_cell_offset_batch = torch::zeros({batch_size, 3, 1},
torch::dtype(pos1.dtype()).device(device));
// Pre-compute batch offsets for optimized kernel
torch::Tensor batch_offsets = torch::zeros({batch_size + 1},
torch::dtype(torch::kLong).device(device));
torch::Tensor cumsum = torch::cumsum(num_atoms_per_image_sqr, 0);
batch_offsets.slice(0, 1, batch_size + 1) = cumsum;
// Iterate over unit cell offsets (triple loop)
// NOTE: for i, j, k loop can not be flatten, as we need to limit the device memory usage
#pragma unroll
for (int i = -max_rep[0]; i <= max_rep[0]; i++) {
#pragma unroll
for (int j = -max_rep[1]; j <= max_rep[1]; j++) {
#pragma unroll
for (int k = -max_rep[2]; k <= max_rep[2]; k++) {
// Reuse pre-allocated unit cell offset tensor
unit_cell_offset[0] = static_cast<float>(i);
unit_cell_offset[1] = static_cast<float>(j);
unit_cell_offset[2] = static_cast<float>(k);
// Compute PBC offsets for this unit cell
// unit_cell_offset_batch.fill_(0);
unit_cell_offset_batch.select(2, 0) = unit_cell_offset.unsqueeze(0).expand({batch_size, -1});
torch::Tensor pbc_offsets = torch::bmm(data_cell, unit_cell_offset_batch).squeeze(-1);
// // Optimized: Use index_select instead of repeat_interleave
// // Create index tensor for selecting pbc_offsets based on atom pairs
// int64_t offset = 0;
// for (int b = 0; b < batch_size; b++) {
// int64_t num_pairs_in_batch = num_atoms_per_image_sqr[b].item<int64_t>();
// auto batch_indices = torch::full({num_pairs_in_batch}, b,
// torch::dtype(torch::kLong).device(device));
// pbc_offsets_per_atom.slice(0, offset, offset + num_pairs_in_batch) =
// pbc_offsets.index_select(0, batch_indices);
// offset += num_pairs_in_batch;
// }
// Reset output tensors for reuse
// distances_squared.fill_(0);
// valid_mask.fill_(false);
// Launch templated CUDA kernel
if (is_float64) {
double radius_squared = static_cast<double>(radius) * static_cast<double>(radius);
launch_pbc_distance_kernel_optimized<double>(
pos1.data_ptr<double>(),
pos2.data_ptr<double>(),
// pbc_offsets_per_atom.data_ptr<double>(),
pbc_offsets.data_ptr<double>(),
num_atoms_per_image_sqr.data_ptr<int64_t>(),
batch_offsets.data_ptr<int64_t>(),
distances_squared.data_ptr<double>(),
valid_mask.data_ptr<bool>(),
num_pairs,
radius_squared,
blocks,
threads_per_block
);
} else {
float radius_squared = radius * radius;
launch_pbc_distance_kernel_optimized<float>(
pos1.data_ptr<float>(),
pos2.data_ptr<float>(),
// pbc_offsets_per_atom.data_ptr<float>(),
pbc_offsets.data_ptr<float>(),
num_atoms_per_image_sqr.data_ptr<int64_t>(),
batch_offsets.data_ptr<int64_t>(),
distances_squared.data_ptr<float>(),
valid_mask.data_ptr<bool>(),
num_pairs,
radius_squared,
blocks,
threads_per_block
);
}
// Filter valid pairs
torch::Tensor valid_indices = torch::nonzero(valid_mask).squeeze(-1);
if (valid_indices.numel() > 0) {
torch::Tensor valid_base_indices = base_indices.index_select(0, valid_indices);
torch::Tensor valid_distances = distances_squared.index_select(0, valid_indices);
torch::Tensor valid_unit_cell = unit_cell_offset.unsqueeze(0).repeat({valid_indices.size(0), 1});
all_index1.push_back(valid_base_indices);
all_unit_cell.push_back(valid_unit_cell);
all_distances_sq.push_back(valid_distances);
}
}
}
}
// Single synchronization after all kernel launches
cudaDeviceSynchronize();
// Concatenate results
torch::Tensor final_indices, final_unit_cell, final_distances;
if (all_index1.size() > 0) {
final_indices = torch::cat(all_index1);
final_unit_cell = torch::cat(all_unit_cell);
final_distances = torch::cat(all_distances_sq);
} else {
final_indices = torch::empty({0}, torch::dtype(torch::kLong).device(device));
final_unit_cell = torch::empty({0, 3}, torch::dtype(pos1.dtype()).device(device));
final_distances = torch::empty({0}, torch::dtype(pos1.dtype()).device(device));
}
return {final_indices, final_unit_cell, final_distances};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("pbc_distance_cuda", &pbc_distance_cuda, "PBC distance computation with CUDA");
}
"""
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
CUDA-accelerated PBC graph operations for atomic systems.
"""
import torch
from typing import Optional, List
from .pbc_graph_legacy import get_max_neighbors_mask
from .extensions.cuda_ops import get_pbc_graph_cuda_extension
def radius_graph_pbc_cuda(
data,
radius,
max_num_neighbors_threshold,
enforce_max_neighbors_strictly: bool = False,
pbc=None,
dtype=torch.float64,
):
"""
Memory-efficient CUDA-accelerated version of radius_graph_pbc.
This implementation follows the memory-efficient approach with triple loops
but accelerates the distance computation using CUDA kernels.
"""
if pbc is None:
pbc = [True, True, True]
device = data.pos.device
batch_size = len(data.natoms)
# Handle PBC settings
if hasattr(data, "pbc"):
data.pbc = torch.atleast_2d(data.pbc)
for i in range(3):
if not torch.any(data.pbc[:, i]).item():
pbc[i] = False
elif torch.all(data.pbc[:, i]).item():
pbc[i] = True
else:
raise RuntimeError(
"Different structures in the batch have different PBC configurations."
)
# position of the atoms
atom_pos = data.pos
# Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
num_atoms_per_image = data.natoms
num_atoms_per_image_sqr = (num_atoms_per_image**2).long()
# index offset between images
index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image
index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr)
num_atoms_per_image_expand = torch.repeat_interleave(
num_atoms_per_image, num_atoms_per_image_sqr
)
# Compute atom pair indices
num_atom_pairs = torch.sum(num_atoms_per_image_sqr)
index_sqr_offset = (
torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr
)
index_sqr_offset = torch.repeat_interleave(
index_sqr_offset, num_atoms_per_image_sqr
)
atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset
# Compute the indices for the pairs of atoms (using division and mod)
index1 = (
torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor")
) + index_offset_expand
index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand
# Get the positions for each atom
pos1 = torch.index_select(atom_pos, 0, index1)
pos2 = torch.index_select(atom_pos, 0, index2)
# Calculate required number of unit cells in each direction for PBC
cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1)
cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True)
if pbc[0]:
inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1)
rep_a1 = torch.ceil(radius * inv_min_dist_a1)
else:
rep_a1 = data.cell.new_zeros(1)
if pbc[1]:
cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1)
inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1)
rep_a2 = torch.ceil(radius * inv_min_dist_a2)
else:
rep_a2 = data.cell.new_zeros(1)
if pbc[2]:
cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1)
inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1)
rep_a3 = torch.ceil(radius * inv_min_dist_a3)
else:
rep_a3 = data.cell.new_zeros(1)
# Take the max over all images for uniformity
max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())]
# Pre-transpose data_cell for efficiency
data_cell = torch.transpose(data.cell, 1, 2)
# Use CUDA kernel for the triple loop computation
# try:
pbc_graph_cuda = get_pbc_graph_cuda_extension()
# Call the CUDA implementation
valid_pair_indices, unit_cell, atom_distance_sqr = pbc_graph_cuda.pbc_distance_cuda(
pos1, pos2, data_cell,
num_atoms_per_image_sqr, batch_size, max_rep, float(radius), device
)
# Map back to original index1 and index2
if len(valid_pair_indices) > 0:
index1 = index1.index_select(0, valid_pair_indices.long())
index2 = index2.index_select(0, valid_pair_indices.long())
else:
index1 = torch.empty(0, dtype=torch.long, device=device)
index2 = torch.empty(0, dtype=torch.long, device=device)
unit_cell = torch.empty(0, 3, dtype=dtype, device=device)
atom_distance_sqr = torch.empty(0, dtype=dtype, device=device)
# Sort index1 in ascending order and rearrange other arrays correspondingly
if len(index1) > 0:
sort_indices = torch.argsort(index1)
index1 = index1[sort_indices]
index2 = index2[sort_indices]
unit_cell = unit_cell[sort_indices]
atom_distance_sqr = atom_distance_sqr[sort_indices]
mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask(
natoms=data.natoms,
index=index1,
atom_distance=atom_distance_sqr,
max_num_neighbors_threshold=max_num_neighbors_threshold,
enforce_max_strictly=enforce_max_neighbors_strictly,
)
if not torch.all(mask_num_neighbors):
# Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
index1 = torch.masked_select(index1, mask_num_neighbors)
index2 = torch.masked_select(index2, mask_num_neighbors)
unit_cell = torch.masked_select(
unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3)
)
unit_cell = unit_cell.view(-1, 3)
edge_index = torch.stack((index2, index1))
return edge_index, unit_cell, num_neighbors_image
\ No newline at end of file
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.nn as nn
import torch_geometric
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from torch_geometric.data import Data
from torch_geometric.utils import remove_self_loops
from torch_scatter import scatter, segment_coo, segment_csr
if TYPE_CHECKING:
from collections.abc import Mapping
from torch.nn.modules.module import _IncompatibleKeys
DEFAULT_ENV_VARS = {
# Expandable segments is a new cuda feature that helps with memory fragmentation during frequent allocations (ie: in the case of variable batch sizes).
# see https://pytorch.org/docs/stable/notes/cuda.html.
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
}
def get_pbc_distances(
pos,
edge_index,
cell,
cell_offsets,
neighbors,
return_offsets: bool = False,
return_distance_vec: bool = False,
):
row, col = edge_index
distance_vectors = pos[row] - pos[col]
# correct for pbc
neighbors = neighbors.to(cell.device)
cell = torch.repeat_interleave(cell, neighbors, dim=0)
offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3)
distance_vectors += offsets
# compute distances
distances = distance_vectors.norm(dim=-1)
# redundancy: remove zero distances
nonzero_idx = torch.arange(len(distances), device=distances.device)[distances != 0]
edge_index = edge_index[:, nonzero_idx]
distances = distances[nonzero_idx]
out = {
"edge_index": edge_index,
"distances": distances,
}
if return_distance_vec:
out["distance_vec"] = distance_vectors[nonzero_idx]
if return_offsets:
out["offsets"] = offsets[nonzero_idx]
return out
def radius_graph_pbc_mem_effi(
data,
radius,
max_num_neighbors_threshold,
enforce_max_neighbors_strictly: bool = False,
pbc=None,
dtype=torch.float64,
):
if pbc is None:
pbc = [True, True, True]
device = data.pos.device
batch_size = len(data.natoms)
if hasattr(data, "pbc"):
data.pbc = torch.atleast_2d(data.pbc)
for i in range(3):
if not torch.any(data.pbc[:, i]).item():
pbc[i] = False
elif torch.all(data.pbc[:, i]).item():
pbc[i] = True
else:
raise RuntimeError(
"Different structures in the batch have different PBC configurations. This is not currently supported."
)
# position of the atoms
atom_pos = data.pos
# Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
num_atoms_per_image = data.natoms
num_atoms_per_image_sqr = (num_atoms_per_image**2).long()
# index offset between images
index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image
index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr)
num_atoms_per_image_expand = torch.repeat_interleave(
num_atoms_per_image, num_atoms_per_image_sqr
)
# Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image
# that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement
# the following (but 10x faster since it removes the for loop)
# for batch_idx in range(batch_size):
# batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0)
num_atom_pairs = torch.sum(num_atoms_per_image_sqr)
index_sqr_offset = (
torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr
)
index_sqr_offset = torch.repeat_interleave(
index_sqr_offset, num_atoms_per_image_sqr
)
atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset
# Compute the indices for the pairs of atoms (using division and mod)
# If the systems get too large this apporach could run into numerical precision issues
index1 = (
torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor")
) + index_offset_expand
index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand
# Get the positions for each atom
pos1 = torch.index_select(atom_pos, 0, index1)
pos2 = torch.index_select(atom_pos, 0, index2)
# Calculate required number of unit cells in each direction.
# Smallest distance between planes separated by a1 is
# 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane.
# Note that the unit cell volume V = a1 * (a2 x a3) and that
# (a2 x a3) / V is also the reciprocal primitive vector
# (crystallographer's definition).
cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1)
cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True)
if pbc[0]:
inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1)
rep_a1 = torch.ceil(radius * inv_min_dist_a1)
else:
rep_a1 = data.cell.new_zeros(1)
if pbc[1]:
cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1)
inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1)
rep_a2 = torch.ceil(radius * inv_min_dist_a2)
else:
rep_a2 = data.cell.new_zeros(1)
if pbc[2]:
cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1)
inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1)
rep_a3 = torch.ceil(radius * inv_min_dist_a3)
else:
rep_a3 = data.cell.new_zeros(1)
# Take the max over all images for uniformity. This is essentially padding.
# Note that this can significantly increase the number of computed distances
# if the required repetitions are very different between images
# (which they usually are). Changing this to sparse (scatter) operations
# might be worth the effort if this function becomes a bottleneck.
max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())]
# Memory-efficient implementation: iterate over unit cell offsets instead of expanding all at once
# This reduces memory usage by avoiding the creation of large tensor products
all_index1 = []
all_index2 = []
all_unit_cell = []
all_atom_distance_sqr = []
# Pre-transpose data_cell for efficiency
data_cell = torch.transpose(data.cell, 1, 2)
# Iterate over each unit cell offset combination
for i in range(-max_rep[0], max_rep[0] + 1):
for j in range(-max_rep[1], max_rep[1] + 1):
for k in range(-max_rep[2], max_rep[2] + 1):
# Create unit cell offset
unit_cell_offset = torch.tensor([i, j, k], device=device, dtype=dtype)
# Compute the x, y, z positional offsets for this specific cell in each image
# unit_cell_offset_batch = unit_cell_offset.view(3, 1).expand(3, batch_size)
unit_cell_offset_batch = unit_cell_offset.view(1,3,1).expand(batch_size, -1, -1)
pbc_offsets = torch.bmm(data_cell, unit_cell_offset_batch).squeeze(-1)
pbc_offsets_per_atom = torch.repeat_interleave(
pbc_offsets, num_atoms_per_image_sqr, dim=0
)
# Apply PBC offsets to the second atom positions
pos2_offset = pos2 + pbc_offsets_per_atom
# Compute the squared distance between atoms
atom_distance_sqr = torch.sum((pos1 - pos2_offset) ** 2, dim=1)
# Remove pairs that are too far apart
mask_within_radius = torch.le(atom_distance_sqr, radius * radius)
# Remove pairs with the same atoms (distance = 0.0)
mask_not_same = torch.gt(atom_distance_sqr, 0.0001)
mask = torch.logical_and(mask_within_radius, mask_not_same)
# Only keep valid pairs for this unit cell offset
if torch.any(mask):
valid_index1 = torch.masked_select(index1, mask)
valid_index2 = torch.masked_select(index2, mask)
valid_distances = torch.masked_select(atom_distance_sqr, mask)
valid_unit_cell = unit_cell_offset.unsqueeze(0).repeat(valid_index1.shape[0], 1)
all_index1.append(valid_index1)
all_index2.append(valid_index2)
all_unit_cell.append(valid_unit_cell)
all_atom_distance_sqr.append(valid_distances)
# Concatenate all results
if len(all_index1) > 0:
index1 = torch.cat(all_index1)
index2 = torch.cat(all_index2)
unit_cell = torch.cat(all_unit_cell)
atom_distance_sqr = torch.cat(all_atom_distance_sqr)
# Sort index1 in ascending order and rearrange other arrays correspondingly
sort_indices = torch.argsort(index1)
index1 = index1[sort_indices]
index2 = index2[sort_indices]
unit_cell = unit_cell[sort_indices]
atom_distance_sqr = atom_distance_sqr[sort_indices]
else:
# No valid pairs found
index1 = torch.empty(0, dtype=torch.long, device=device)
index2 = torch.empty(0, dtype=torch.long, device=device)
unit_cell = torch.empty(0, 3, dtype=dtype, device=device)
atom_distance_sqr = torch.empty(0, dtype=dtype, device=device)
mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask(
natoms=data.natoms,
index=index1,
atom_distance=atom_distance_sqr,
max_num_neighbors_threshold=max_num_neighbors_threshold,
enforce_max_strictly=enforce_max_neighbors_strictly,
)
if not torch.all(mask_num_neighbors):
# Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
index1 = torch.masked_select(index1, mask_num_neighbors)
index2 = torch.masked_select(index2, mask_num_neighbors)
unit_cell = torch.masked_select(
unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3)
)
unit_cell = unit_cell.view(-1, 3)
edge_index = torch.stack((index2, index1))
return edge_index, unit_cell, num_neighbors_image
def radius_graph_pbc(
data,
radius,
max_num_neighbors_threshold,
enforce_max_neighbors_strictly: bool = False,
pbc=None,
dtype=torch.float64,
):
if pbc is None:
pbc = [True, True, True]
device = data.pos.device
batch_size = len(data.natoms)
if hasattr(data, "pbc"):
data.pbc = torch.atleast_2d(data.pbc)
for i in range(3):
if not torch.any(data.pbc[:, i]).item():
pbc[i] = False
elif torch.all(data.pbc[:, i]).item():
pbc[i] = True
else:
raise RuntimeError(
"Different structures in the batch have different PBC configurations. This is not currently supported."
)
# position of the atoms
atom_pos = data.pos
# Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
num_atoms_per_image = data.natoms
num_atoms_per_image_sqr = (num_atoms_per_image**2).long()
# index offset between images
index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image
index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr)
num_atoms_per_image_expand = torch.repeat_interleave(
num_atoms_per_image, num_atoms_per_image_sqr
)
# Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image
# that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement
# the following (but 10x faster since it removes the for loop)
# for batch_idx in range(batch_size):
# batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0)
num_atom_pairs = torch.sum(num_atoms_per_image_sqr)
index_sqr_offset = (
torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr
)
index_sqr_offset = torch.repeat_interleave(
index_sqr_offset, num_atoms_per_image_sqr
)
atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset
# Compute the indices for the pairs of atoms (using division and mod)
# If the systems get too large this apporach could run into numerical precision issues
index1 = (
torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor")
) + index_offset_expand
index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand
# Get the positions for each atom
pos1 = torch.index_select(atom_pos, 0, index1)
pos2 = torch.index_select(atom_pos, 0, index2)
# Calculate required number of unit cells in each direction.
# Smallest distance between planes separated by a1 is
# 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane.
# Note that the unit cell volume V = a1 * (a2 x a3) and that
# (a2 x a3) / V is also the reciprocal primitive vector
# (crystallographer's definition).
cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1)
cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True)
if pbc[0]:
inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1)
rep_a1 = torch.ceil(radius * inv_min_dist_a1)
else:
rep_a1 = data.cell.new_zeros(1)
if pbc[1]:
cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1)
inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1)
rep_a2 = torch.ceil(radius * inv_min_dist_a2)
else:
rep_a2 = data.cell.new_zeros(1)
if pbc[2]:
cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1)
inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1)
rep_a3 = torch.ceil(radius * inv_min_dist_a3)
else:
rep_a3 = data.cell.new_zeros(1)
# Take the max over all images for uniformity. This is essentially padding.
# Note that this can significantly increase the number of computed distances
# if the required repetitions are very different between images
# (which they usually are). Changing this to sparse (scatter) operations
# might be worth the effort if this function becomes a bottleneck.
max_rep = [2*rep_a1.max(), 2*rep_a2.max(), 2*rep_a3.max()]
# max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()]
# max_rep = [torch.tensor(1, device=device)] * 3
# logging.info(f"&&& max_rep: {max_rep}")
# Tensor of unit cells
cells_per_dim = [
torch.arange(-rep.item(), rep.item() + 1, device=device, dtype=dtype)
for rep in max_rep
]
unit_cell = torch.cartesian_prod(*cells_per_dim)
num_cells = len(unit_cell)
unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat(len(index2), 1, 1)
unit_cell = torch.transpose(unit_cell, 0, 1)
unit_cell_batch = unit_cell.view(1, 3, num_cells).expand(batch_size, -1, -1)
# Compute the x, y, z positional offsets for each cell in each image
# data_cell = torch.transpose(data.cell, 1, 2)
data_cell = torch.transpose(data.cell, 1, 2)
pbc_offsets = torch.bmm(data_cell, unit_cell_batch)
pbc_offsets_per_atom = torch.repeat_interleave(
pbc_offsets, num_atoms_per_image_sqr, dim=0
)
# Expand the positions and indices for the 9 cells
pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells)
pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells)
index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1)
index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1)
# Add the PBC offsets for the second atom
pos2 = pos2 + pbc_offsets_per_atom
# Compute the squared distance between atoms
atom_distance_sqr = torch.sum((pos1 - pos2) ** 2, dim=1)
atom_distance_sqr = atom_distance_sqr.view(-1)
# Remove pairs that are too far apart
mask_within_radius = torch.le(atom_distance_sqr, radius * radius)
# Remove pairs with the same atoms (distance = 0.0)
mask_not_same = torch.gt(atom_distance_sqr, 0.0001)
mask = torch.logical_and(mask_within_radius, mask_not_same)
index1 = torch.masked_select(index1, mask)
index2 = torch.masked_select(index2, mask)
unit_cell = torch.masked_select(
unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3)
)
unit_cell = unit_cell.view(-1, 3)
atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask)
mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask(
natoms=data.natoms,
index=index1,
atom_distance=atom_distance_sqr,
max_num_neighbors_threshold=max_num_neighbors_threshold,
enforce_max_strictly=enforce_max_neighbors_strictly,
)
if not torch.all(mask_num_neighbors):
# Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
index1 = torch.masked_select(index1, mask_num_neighbors)
index2 = torch.masked_select(index2, mask_num_neighbors)
unit_cell = torch.masked_select(
unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3)
)
unit_cell = unit_cell.view(-1, 3)
edge_index = torch.stack((index2, index1))
return edge_index, unit_cell, num_neighbors_image
@torch.compiler.disable
def get_max_neighbors_mask(
natoms,
index,
atom_distance,
max_num_neighbors_threshold,
degeneracy_tolerance: float = 0.01,
enforce_max_strictly: bool = False,
):
"""
Give a mask that filters out edges so that each atom has at most
`max_num_neighbors_threshold` neighbors.
Assumes that `index` is sorted.
Enforcing the max strictly can force the arbitrary choice between
degenerate edges. This can lead to undesired behaviors; for
example, bulk formation energies which are not invariant to
unit cell choice.
A degeneracy tolerance can help prevent sudden changes in edge
existence from small changes in atom position, for example,
rounding errors, slab relaxation, temperature, etc.
"""
device = natoms.device
num_atoms = natoms.sum()
# Get number of neighbors
# segment_coo assumes sorted index
ones = index.new_ones(1).expand_as(index)
num_neighbors = segment_coo(ones, index, dim_size=num_atoms)
max_num_neighbors = num_neighbors.max()
num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold)
# Get number of (thresholded) neighbors per image
image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long)
image_indptr[1:] = torch.cumsum(natoms, dim=0)
num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr)
# If max_num_neighbors is below the threshold, return early
if (
max_num_neighbors <= max_num_neighbors_threshold
or max_num_neighbors_threshold <= 0
):
mask_num_neighbors = torch.tensor([True], dtype=bool, device=device).expand_as(
index
)
return mask_num_neighbors, num_neighbors_image
# Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors.
# Fill with infinity so we can easily remove unused distances later.
distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device)
# Create an index map to map distances from atom_distance to distance_sort
# index_sort_map assumes index to be sorted
index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors
index_neighbor_offset_expand = torch.repeat_interleave(
index_neighbor_offset, num_neighbors
)
index_sort_map = (
index * max_num_neighbors
+ torch.arange(len(index), device=device)
- index_neighbor_offset_expand
)
distance_sort.index_copy_(0, index_sort_map, atom_distance)
distance_sort = distance_sort.view(num_atoms, max_num_neighbors)
# Sort neighboring atoms based on distance
distance_sort, index_sort = torch.sort(distance_sort, dim=1)
# Select the max_num_neighbors_threshold neighbors that are closest
if enforce_max_strictly:
distance_sort = distance_sort[:, :max_num_neighbors_threshold]
index_sort = index_sort[:, :max_num_neighbors_threshold]
max_num_included = max_num_neighbors_threshold
else:
effective_cutoff = (
distance_sort[:, max_num_neighbors_threshold] + degeneracy_tolerance
)
is_included = torch.le(distance_sort.T, effective_cutoff)
# Set all undesired edges to infinite length to be removed later
distance_sort[~is_included.T] = np.inf
# Subselect tensors for efficiency
num_included_per_atom = torch.sum(is_included, dim=0)
max_num_included = torch.max(num_included_per_atom)
distance_sort = distance_sort[:, :max_num_included]
index_sort = index_sort[:, :max_num_included]
# Recompute the number of neighbors
num_neighbors_thresholded = num_neighbors.clamp(max=num_included_per_atom)
num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr)
# Offset index_sort so that it indexes into index
index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand(
-1, max_num_included
)
# Remove "unused pairs" with infinite distances
mask_finite = torch.isfinite(distance_sort)
index_sort = torch.masked_select(index_sort, mask_finite)
# At this point index_sort contains the index into index of the
# closest max_num_neighbors_threshold neighbors per atom
# Create a mask to remove all pairs not in index_sort
mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool)
mask_num_neighbors.index_fill_(0, index_sort, True)
return mask_num_neighbors, num_neighbors_image
def get_pruned_edge_idx(
edge_index, num_atoms: int, max_neigh: float = 1e9
) -> torch.Tensor:
assert num_atoms is not None # TODO: Shouldn't be necessary
# removes neighbors > max_neigh
# assumes neighbors are sorted in increasing distance
_nonmax_idx_list = []
for i in range(num_atoms):
idx_i = torch.arange(len(edge_index[1]))[(edge_index[1] == i)][:max_neigh]
_nonmax_idx_list.append(idx_i)
return torch.cat(_nonmax_idx_list)
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
from .optimizable import OptimizableBatch, OptimizableUnitCellBatch
__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"]
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
Utilities to interface OCP models/trainers with the Atomic Simulation
Environment (ASE)
"""
from __future__ import annotations
from types import MappingProxyType
from typing import TYPE_CHECKING
import torch
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from ase.constraints import FixAtoms
if TYPE_CHECKING:
from torch_geometric.data import Batch
# system level model predictions have different shapes than expected by ASE
ASE_PROP_RESHAPE = MappingProxyType(
{"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)}
)
def batch_to_atoms(
batch: Batch,
results: dict[str, torch.Tensor] | None = None,
wrap_pos: bool = True,
eps: float = 1e-7,
) -> list[Atoms]:
"""Convert a data batch to ase Atoms
Args:
batch: data batch
results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results
are given no calculator will be added to the atoms objects.
wrap_pos: wrap positions back into the cell.
eps: Small number to prevent slightly negative coordinates from being wrapped.
Returns:
list of Atoms
"""
n_systems = batch.natoms.shape[0]
natoms = batch.natoms.tolist()
numbers = torch.split(batch.atomic_numbers, natoms)
fixed = torch.split(batch.fixed.to(torch.bool), natoms)
if results is not None:
results = {
key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist()
if len(val) == len(batch)
else [v.cpu().detach().numpy() for v in torch.split(val, natoms)]
for key, val in results.items()
}
positions = torch.split(batch.pos, natoms)
tags = torch.split(batch.tags, natoms)
cells = batch.cell
atoms_objects = []
for idx in range(n_systems):
pos = positions[idx].cpu().detach().numpy()
cell = cells[idx].cpu().detach().numpy()
# TODO take pbc from data
# TODO: &&& ^^^ change this back !!!
# if wrap_pos:
# pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps)
atoms = Atoms(
numbers=numbers[idx].tolist(),
cell=cell,
positions=pos,
tags=tags[idx].tolist(),
constraint=FixAtoms(mask=fixed[idx].tolist()),
pbc=[True, True, True],
)
if results is not None:
calc = SinglePointCalculator(
atoms=atoms, **{key: val[idx] for key, val in results.items()}
)
atoms.set_calculator(calc)
atoms_objects.append(atoms)
return atoms_objects
"""
Copyright (c) Meta, Inc. and its affiliates.
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
Modified from original Meta implementation.
"""
from __future__ import annotations
from functools import cached_property
from types import SimpleNamespace
from typing import TYPE_CHECKING, ClassVar, Any, Generator
import numpy as np
import torch
import logging
from ase.calculators.calculator import PropertyNotImplementedError
from ase.stress import voigt_6_to_full_3x3_stress
from torch_scatter import scatter
from batchopt.relaxation.ase_utils import batch_to_atoms
# Define dummy classes for when imports fail
class _DummyCalculator:
pass
try:
from mace.calculators import MACECalculator
except ImportError:
logging.warning("Unable to import MACECalculator.")
MACECalculator = _DummyCalculator
try:
from chgnet.model.dynamics import CHGNetCalculator
except ImportError:
logging.warning("Unable to import CHGNetCalculator.")
CHGNetCalculator = _DummyCalculator
try:
from sevenn.calculator import (
SevenNetCalculator,
SevenNetD3Calculator,
D3Calculator,
)
except ImportError:
logging.warning("Unable to import SevenNetCalculator.")
SevenNetCalculator = _DummyCalculator
SevenNetD3Calculator = _DummyCalculator
D3Calculator = _DummyCalculator
try:
from fairchem.core import pretrained_mlip, FAIRChemCalculator
except ImportError:
logging.warning("Unable to import FAIRChemCalculator.")
FAIRChemCalculator = _DummyCalculator
# this can be removed after pinning ASE dependency >= 3.23
try:
from ase.optimize.optimize import Optimizable
except ImportError:
class Optimizable:
pass
if TYPE_CHECKING:
from collections.abc import Sequence
from ase import Atoms
from numpy.typing import NDArray
from torch_geometric.data import Batch
ALL_CHANGES: set[str] = {
"pos",
"atomic_numbers",
"cell",
"pbc",
}
# @torch.compile
def compare_batches(
batch1: Batch | None,
batch2: Batch,
tol: float = 1e-6,
excluded_properties: set[str] | None = None,
) -> list[str]:
"""Compare properties between two batches
Args:
batch1: atoms batch
batch2: atoms batch
tol: tolerance used to compare equility of floating point properties
excluded_properties: list of properties to exclude from comparison
Returns:
list of system changes, property names that are differente between batch1 and batch2
"""
system_changes = []
if batch1 is None:
system_changes = ALL_CHANGES
else:
properties_to_check = set(ALL_CHANGES)
if excluded_properties:
properties_to_check -= set(excluded_properties)
# Check properties that aren't
for prop in ALL_CHANGES:
if prop in properties_to_check:
properties_to_check.remove(prop)
if not torch.allclose(
getattr(batch1, prop), getattr(batch2, prop), atol=tol
):
system_changes.append(prop)
return system_changes
class OptimizableBatch(Optimizable):
"""A Batch version of ase Optimizable Atoms
This class can be used with ML relaxations in fairchem.core.relaxations.ml_relaxation
or in ase relaxations classes, i.e. ase.optimize.lbfgs
"""
ignored_changes: ClassVar[set[str]] = set()
def __init__(
self,
batch: Batch,
trainer: Any, # Any calculator type (MACECalculator | CHGNetCalculator | SevenNetCalculator | FAIRChemCalculator)
transform: torch.nn.Module | None = None,
mask_converged: bool = True,
numpy: bool = False,
masked_eps: float = 1e-8,
compute_stress: bool = False,
use_fast_predict: bool = True,
dtype: torch.dtype = torch.float64,
):
"""Initialize Optimizable Batch
Args:
batch: A batch of atoms graph data
model: An instance of a BaseTrainer derived class
transform: graph transform
mask_converged: if true will mask systems in batch that are already converged
numpy: whether to cast results to numpy arrays
masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero
from zero differences in masked positions at future steps, we add a small number to prevent this.
compute_stress: whether to compute stress during prediction
use_fast_predict: use fast prediction method when available
dtype: data type for tensor operations (torch.float32 or torch.float64)
"""
self.batch = batch.to(trainer.device)
self.trainer = trainer
self.transform = transform
self.numpy = numpy
self.mask_converged = mask_converged
self._cached_batch = None
self._update_mask = None
self.torch_results = {}
self.results = {}
self._eps = masked_eps
self.dtype = dtype
self.otf_graph = True # trainer._unwrapped_model.otf_graph
if not self.otf_graph and "edge_index" not in self.batch:
self.update_graph()
self.batch.pos = self.batch.pos.to(dtype=self.dtype)
self.batch.cell = self.batch.cell.to(dtype=self.dtype)
self.compute_stress = compute_stress
self.use_fast_predict = use_fast_predict
# Determine calculator type once during initialization for efficiency
self._calculator_type = self._determine_calculator_type()
logging.info(
f"OptimizableBatch initialized with calculator type: {self._calculator_type}"
)
def _determine_calculator_type(self) -> str:
"""Determine the type of calculator to avoid repeated isinstance checks."""
# Check against actual imported classes, not dummy classes
trainer_class_name = type(self.trainer).__name__
trainer_module = type(self.trainer).__module__
if (
"mace" in trainer_module.lower()
or trainer_class_name == "MACECalculator"
):
return "mace"
elif (
"chgnet" in trainer_module.lower()
or trainer_class_name == "CHGNetCalculator"
):
return "chgnet"
elif "sevenn" in trainer_module.lower() or trainer_class_name in [
"SevenNetCalculator",
"SevenNetD3Calculator",
"D3Calculator",
]:
return "sevennet"
elif (
"fairchem" in trainer_module.lower()
or trainer_class_name == "FAIRChemCalculator"
):
return "fairchem"
else:
return "default"
@property
def device(self):
return self.trainer.device
@property
def batch_indices(self):
"""Get the batch indices specifying which position/force corresponds to which batch."""
return self.batch.batch
@property
def converged_mask(self):
if self._update_mask is not None:
return torch.logical_not(self._update_mask)
return None
@property
def update_mask(self):
if self._update_mask is None:
return torch.ones(len(self.batch), dtype=bool)
return self._update_mask
@property
def converge_indices_list(self):
return torch.where(~self.update_mask)[0].tolist()
@property
def elem_per_group(self):
# This return value actually represents the number of elements
# in a group within a batch. Each group corresponds to batch_indices.
# It will count the number of CELL elements in each group.
return torch.bincount(self.batch_indices)
@property
def batch_size(self):
return len(torch.unique(self.batch_indices))
def check_state(self, batch: Batch, tol: float = 1e-12) -> bool:
"""Check for any system changes since last calculation."""
return compare_batches(
self._cached_batch,
batch,
tol=tol,
excluded_properties=set(self.ignored_changes),
)
def _predict(self) -> None:
"""Run prediction if batch has any changes."""
# TODO: Currently, the batch inference interfaces of various models are not unified and are poorly implemented.
system_changes = self.check_state(self.batch)
if len(system_changes) > 0:
if self._calculator_type == "mace":
# FIXME: &&&
# for key, val in self.batch.to_dict().items():
# print(f'&&& key: {key}, val: {val}')
# self.torch_results = self.trainer.predict_debug(atoms_list, self.batch, compute_stress=self.compute_stress)
# self.torch_results = self.trainer.predict(self.config_batch)
if self.use_fast_predict:
self.torch_results = self.trainer.fast_predict(
self.batch, compute_stress=self.compute_stress
)
self.batch.pos = self.batch.pos.to(self.dtype)
self.batch.cell = self.batch.cell.to(self.dtype)
else:
atoms_list = batch_to_atoms(
self.batch, results=None, wrap_pos=False, eps=1e-17
)
self.torch_results = self.trainer.predict(
atoms_list, compute_stress=self.compute_stress
)
elif self._calculator_type == "fairchem":
# TODO: FAIRChemCalculator does not support batch prediction yet
atoms_list = batch_to_atoms(
self.batch, results=None, wrap_pos=False, eps=1e-17
)
self.torch_results = self.trainer.predict(atoms_list=atoms_list)
elif self._calculator_type == "chgnet":
atoms_list = batch_to_atoms(
self.batch, results=None, wrap_pos=False, eps=1e-17
)
model_prediction = self.trainer.predict(
atoms_list=atoms_list, task="efs"
)
results = {
"energy": torch.tensor(
[pred["e"].item() for pred in model_prediction],
device=self.device,
dtype=self.dtype,
),
"forces": torch.vstack(
[
torch.from_numpy(pred["f"]).to(
device=self.device, dtype=self.dtype
)
for pred in model_prediction
]
),
"stress": torch.vstack(
[
torch.from_numpy(pred["s"]).to(
device=self.device, dtype=self.dtype
)
for pred in model_prediction
]
).view(-1, 3, 3),
}
self.torch_results = results
elif self._calculator_type == "sevennet":
atoms_list = batch_to_atoms(
self.batch, results=None, wrap_pos=False, eps=1e-17
)
self.torch_results = self.trainer.predict(atoms_list=atoms_list)
else: # default case
self.torch_results = self.trainer.predict(
self.batch, per_image=False, disable_tqdm=True
)
# save only subset of props in simple namespace instead of cloning the whole batch to save memory
changes = ALL_CHANGES - set(self.ignored_changes)
self._cached_batch = SimpleNamespace(
**{prop: self.batch[prop].clone() for prop in changes}
)
def get_property(
self, name, no_numpy: bool = False
) -> torch.Tensor | NDArray:
"""Get a predicted property by name."""
self._predict()
if self.numpy:
self.results = {
key: pred.item() if pred.numel() == 1 else pred.cpu().numpy()
for key, pred in self.torch_results.items()
}
else:
self.results = self.torch_results
if name not in self.results:
raise PropertyNotImplementedError(
f"{name} not present in this calculation"
)
return (
self.results[name]
if no_numpy is False
else self.torch_results[name]
)
def get_positions(self) -> torch.Tensor | NDArray:
"""Get the batch positions"""
pos = self.batch.pos.clone()
if self.numpy:
if self.mask_converged:
pos[~self.update_mask[self.batch.batch]] = self._eps
pos = pos.cpu().numpy()
return pos
def set_positions(self, positions: torch.Tensor | NDArray) -> None:
"""Set the atom positions in the batch."""
if isinstance(positions, np.ndarray):
positions = torch.tensor(
positions, dtype=self.dtype, device=self.device
)
else:
positions = positions.to(dtype=self.dtype, device=self.device)
if self.mask_converged and self._update_mask is not None:
mask = self.update_mask[self.batch.batch]
self.batch.pos[mask] = positions[mask]
else:
self.batch.pos = positions
if not self.otf_graph:
self.update_graph()
def get_forces(
self, apply_constraint: bool = False, no_numpy: bool = False
) -> torch.Tensor | NDArray:
"""Get predicted batch forces."""
forces = self.get_property("forces", no_numpy=no_numpy)
if apply_constraint:
fixed_idx = torch.where(self.batch.fixed == 1)[0]
if isinstance(forces, np.ndarray):
fixed_idx = fixed_idx.tolist()
forces[fixed_idx] = 0.0
return forces.view(-1, 3)
def get_potential_energy(self, **kwargs) -> torch.Tensor | NDArray:
"""Get predicted energy as the sum of all batch energies."""
# ASE 3.22.1 expects a check for force_consistent calculations
if kwargs.get("force_consistent", False) is True:
raise PropertyNotImplementedError(
"force_consistent calculations are not implemented"
)
if (
len(self.batch) == 1
): # unfortunately batch size 1 returns a float, not a tensor
return self.get_property("energy")
return self.get_property("energy").sum()
def get_potential_energies(self) -> torch.Tensor | NDArray:
"""Get the predicted energy for each system in batch."""
return self.get_property("energy")
def get_cells(self) -> torch.Tensor:
"""Get batch crystallographic cells."""
return self.batch.cell
def set_cells(
self, cells: torch.Tensor | NDArray, scale_atoms=False
) -> None:
"""Set batch cells."""
assert self.batch.cell.shape == cells.shape, "Cell shape mismatch"
if isinstance(cells, np.ndarray):
cells = torch.tensor(cells, dtype=self.dtype, device=self.device)
cells = cells.to(dtype=self.dtype, device=self.device)
if scale_atoms:
from ase.geometry.cell import complete_cell
# M = torch.linalg.solve(
# self.batch.cell.view(-1, 3, 3),
# cells.view(-1, 3, 3),
# )
# TODO: need to implement a sparse version.
# tmp_pos = torch.matmul(self.batch.pos, M.reshape(-1,3))
for i in range(self.batch_size):
if not self.update_mask[i]:
continue
M = np.linalg.solve(
complete_cell(self.batch.cell[i].cpu().detach().numpy()),
complete_cell(cells[i].cpu().detach().numpy()),
)
pos_update_mask = self.batch.batch == i
self.batch.pos[pos_update_mask] = torch.matmul(
self.batch.pos[pos_update_mask],
torch.from_numpy(M).to(self.device).reshape(-1, 3),
)
self.batch.cell[self.update_mask] = cells[self.update_mask]
def get_volumes(self) -> torch.Tensor:
"""Get a tensor of volumes for each cell in batch"""
cells = self.get_cells()
return torch.linalg.det(cells)
def iterimages(self) -> Generator[Batch, None, None]:
# XXX document purpose of iterimages - this is just needed to work with ASE optimizers
yield self.batch
def get_max_forces(
self, forces: torch.Tensor | None = None, apply_constraint: bool = False
) -> torch.Tensor:
"""Get the maximum forces per structure in batch"""
if forces is None:
forces = self.get_forces(
apply_constraint=apply_constraint, no_numpy=True
)
return scatter(
(forces**2).sum(axis=1).sqrt(), self.batch_indices, reduce="max"
)
def converged(
self,
forces: torch.Tensor | NDArray | None,
fmax: float,
max_forces: torch.Tensor | None = None,
f_upper_limit: float = 1e20,
) -> bool:
"""Check if norm of all predicted forces are below fmax"""
if forces is not None:
if isinstance(forces, np.ndarray):
forces = torch.tensor(
forces, device=self.device, dtype=self.dtype
)
max_forces = self.get_max_forces(forces)
elif max_forces is None:
max_forces = self.get_max_forces()
# Update mask is True for forces that are greater than fmax AND less than f_upper_limit
update_mask = torch.logical_and(
max_forces.ge(fmax), max_forces.le(f_upper_limit)
)
# update cached mask
if self.mask_converged:
if self._update_mask is None:
self._update_mask = update_mask
else:
# some models can have random noise in their predictions, so the mask is updated by
# keeping all previously converged structures masked even if new force predictions
# push it slightly above threshold
self._update_mask = torch.logical_and(
self._update_mask, update_mask
)
update_mask = self._update_mask
return not torch.any(update_mask).item()
def get_atoms_list(self) -> list[Atoms]:
"""Get ase Atoms objects corresponding to the batch"""
self._predict() # in case no predictions have been run
return batch_to_atoms(self.batch, results=self.torch_results)
def update_graph(self):
"""Update the graph if model does not use otf_graph."""
graph = self.trainer._unwrapped_model.generate_graph(self.batch)
self.batch.edge_index = graph.edge_index
self.batch.cell_offsets = graph.cell_offsets
self.batch.neighbors = graph.neighbors
if self.transform is not None:
self.batch = self.transform(self.batch)
def __len__(self) -> int:
# TODO: this might be changed in ASE to be 3 * len(self.atoms)
return len(self.batch.pos)
class OptimizableUnitCellBatch(OptimizableBatch):
"""Modify the supercell and the atom positions in relaxations.
Based on ase UnitCellFilter to work on data batches
"""
def __init__(
self,
batch: Batch,
trainer: Any, # Any calculator type (MACECalculator | CHGNetCalculator | SevenNetD3Calculator | FAIRChemCalculator)
transform: torch.nn.Module | None = None,
numpy: bool = False,
mask_converged: bool = True,
mask: Sequence[bool] | None = None,
cell_factor: float | torch.Tensor | None = None,
hydrostatic_strain: bool = False,
constant_volume: bool = False,
scalar_pressure: float = 0.0,
masked_eps: float = 1e-8,
use_fast_predict: bool = True,
dtype: torch.dtype = torch.float64,
):
"""Create a filter that returns the forces and unit cell stresses together, for simultaneous optimization.
For full details see:
E. B. Tadmor, G. S. Smith, N. Bernstein, and E. Kaxiras,
Phys. Rev. B 59, 235 (1999)
Args:
batch: A batch of atoms graph data
model: An instance of a BaseTrainer derived class
transform: graph transform
numpy: whether to cast results to numpy arrays
mask_converged: if true will mask systems in batch that are already converged
mask: a boolean mask specifying which strain components are allowed to relax
cell_factor:
Factor by which deformation gradient is multiplied to put
it on the same scale as the positions when assembling
the combined position/cell vector. The stress contribution to
the forces is scaled down by the same factor. This can be thought
of as a very simple preconditioner. Default is number of atoms
which gives approximately the correct scaling.
hydrostatic_strain:
Constrain the cell by only allowing hydrostatic deformation.
The virial tensor is replaced by np.diag([np.trace(virial)]*3).
constant_volume:
Project out the diagonal elements of the virial tensor to allow
relaxations at constant volume, e.g. for mapping out an
energy-volume curve. Note: this only approximately conserves
the volume and breaks energy/force consistency so can only be
used with optimizers that do require a line minimisation
(e.g. FIRE).
scalar_pressure:
Applied pressure to use for enthalpy pV term. As above, this
breaks energy/force consistency.
masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero
from zero differences in masked positions at future steps, we add a small number to prevent this.
dtype: data type for tensor operations (torch.float32 or torch.float64)
"""
super().__init__(
batch=batch,
trainer=trainer,
transform=transform,
numpy=numpy,
mask_converged=mask_converged,
masked_eps=masked_eps,
compute_stress=True,
use_fast_predict=use_fast_predict,
dtype=dtype,
)
self.orig_cells = self.get_cells().clone()
self.stress = None
if mask is None:
# mask = torch.eye(3, device=self.device)
mask = torch.ones(6, device=self.device)
# TODO make sure mask is on GPU
if mask.shape == (6,):
self.mask = torch.tensor(
voigt_6_to_full_3x3_stress(mask.detach().cpu()),
device=self.device,
)
elif mask.shape == (3, 3):
self.mask = mask
else:
raise ValueError("shape of mask should be (3,3) or (6,)")
if isinstance(cell_factor, float):
cell_factor = cell_factor * torch.ones(
(3 * len(batch), 1), requires_grad=False
)
if cell_factor is None:
cell_factor = self.batch.natoms.repeat_interleave(3).unsqueeze(
dim=1
)
self.hydrostatic_strain = hydrostatic_strain
self.constant_volume = constant_volume
self.pressure = scalar_pressure * torch.eye(3, device=self.device)
self.cell_factor = cell_factor
self.stress = None
self._batch_trace = torch.vmap(torch.trace)
self._batch_diag = torch.vmap(
lambda x: x * torch.eye(3, device=x.device)
)
@cached_property
def batch_indices(self):
"""Get the batch indices specifying which position/force corresponds to which batch.
We augment this to specify the batch indices for augmented positions and forces.
"""
augmented_batch = torch.repeat_interleave(
torch.arange(
len(self.batch),
dtype=self.batch.batch.dtype,
device=self.device,
),
3,
)
return torch.cat([self.batch.batch, augmented_batch])
def deform_grad(self):
"""Get the cell deformation matrix"""
return torch.transpose(
torch.linalg.solve(self.orig_cells, self.get_cells()), 1, 2
)
def get_positions(self):
"""Get positions and cell deformation gradient."""
cur_deform_grad = self.deform_grad()
natoms = self.batch.num_nodes
pos = torch.zeros(
(natoms + 3 * len(self.get_cells()), 3),
dtype=self.batch.pos.dtype,
device=self.device,
)
# Augmented positions are the self.atoms.positions but without the applied deformation gradient
pos[:natoms] = torch.linalg.solve(
cur_deform_grad[self.batch.batch, :, :],
self.batch.pos.view(-1, 3, 1),
).view(-1, 3)
# cell DOFs are the deformation gradient times a scaling factor
pos[natoms:] = self.cell_factor * cur_deform_grad.view(-1, 3)
return pos.cpu().numpy() if self.numpy else pos
def set_positions(self, positions: torch.Tensor | NDArray) -> None:
"""Set positions and cell.
positions has shape (natoms + ncells * 3, 3).
the first natoms rows are the positions of the atoms, the last nsystems * three rows are the deformation tensor
for each cell.
"""
if isinstance(positions, np.ndarray):
positions = torch.tensor(
positions, dtype=self.dtype, device=self.device
)
else:
positions = positions.to(dtype=self.dtype, device=self.device)
natoms = self.batch.num_nodes
new_atom_positions = positions[:natoms]
new_deform_grad = (positions[natoms:] / self.cell_factor).view(-1, 3, 3)
# TODO check that in fact symmetry is preserved setting cells and positions
# Set the new cell from the original cell and the new deformation gradient. Both current and final structures
# should preserve symmetry.
new_cells = torch.bmm(
self.orig_cells, torch.transpose(new_deform_grad, 1, 2)
)
self.set_cells(new_cells)
# Set the positions from the ones passed in (which are without the deformation gradient applied) and the new
# deformation gradient. This should also preserve symmetry
new_atom_positions = torch.bmm(
new_atom_positions.view(-1, 1, 3),
torch.transpose(
new_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), 1, 2
),
)
super().set_positions(new_atom_positions.view(-1, 3))
def get_potential_energy(self, **kwargs):
"""
returns potential energy including enthalpy PV term.
"""
atoms_energy = super().get_potential_energy(**kwargs)
return atoms_energy + self.pressure[0, 0] * self.get_volumes().sum()
def get_forces(
self, apply_constraint: bool = False, no_numpy: bool = False
) -> torch.Tensor | NDArray:
"""Get forces and unit cell stress."""
stress = self.get_property("stress", no_numpy=True).view(-1, 3, 3)
atom_forces = self.get_property("forces", no_numpy=True)
if apply_constraint:
fixed_idx = torch.where(self.batch.fixed == 1)[0]
atom_forces[fixed_idx] = 0.0
volumes = self.get_volumes().view(-1, 1, 1)
# virial = -volumes * stress + self.pressure.view(-1, 3, 3)
virial = -volumes * (stress + self.pressure.view(-1, 3, 3))
# print(f'&&& virial0: {virial}')
cur_deform_grad = self.deform_grad()
atom_forces = torch.bmm(
atom_forces.view(-1, 1, 3),
cur_deform_grad[self.batch.batch, :, :].view(-1, 3, 3),
)
virial = torch.linalg.solve(
cur_deform_grad, torch.transpose(virial, dim0=1, dim1=2)
)
virial = torch.transpose(virial, dim0=1, dim1=2)
# print(f'&&& virial1: {virial}')
# TODO this does not work yet! maybe _batch_trace gives an issue
if self.hydrostatic_strain:
virial = self._batch_diag(self._batch_trace(virial) / 3.0)
# Zero out components corresponding to fixed lattice elements
if (self.mask != 1.0).any():
virial *= self.mask.view(-1, 3, 3)
if self.constant_volume:
virial[:, range(3), range(3)] -= (
self._batch_trace(virial).view(3, -1) / 3.0
)
natoms = self.batch.num_nodes
augmented_forces = torch.zeros(
(natoms + 3 * len(self.get_cells()), 3),
device=self.device,
dtype=atom_forces.dtype,
)
# print(f'&&& atom_forces: {atom_forces}')
# print(f'&&& virial2: {virial}')
augmented_forces[:natoms] = atom_forces.view(-1, 3)
augmented_forces[natoms:] = virial.view(-1, 3) / self.cell_factor
self.stress = -virial.view(-1, 9) / volumes.view(-1, 1)
if self.numpy and not no_numpy:
augmented_forces = augmented_forces.cpu().numpy()
# print(f'&&& augmented_forces: {augmented_forces}')
return augmented_forces
def __len__(self):
return len(self.batch.pos) + 3 * len(self.batch)
def get_potential_energies(self) -> torch.Tensor:
"""Get the predicted energy for each system in batch."""
return (
self.get_property("energy").view(-1)
+ self.pressure[0, 0] * self.get_volumes()
)
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
from .bfgs_torch import BFGS
from .bfgsfusedls import BFGSFusedLS
__all__ = ["BFGS", "BFGSFusedLS"]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment