Unverified Commit fa84b16c authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent 09624897
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import ase.db.sqlite
import ase.io.trajectory
import numpy as np
import torch
from ase.geometry import wrap_positions
from torch_geometric.data import Data
from batchopt.utils import collate
if TYPE_CHECKING:
from collections.abc import Sequence
try:
from pymatgen.io.ase import AseAtomsAdaptor
except ImportError:
AseAtomsAdaptor = None
from tqdm import tqdm
class AtomsToGraphs:
"""A class to help convert periodic atomic structures to graphs.
The AtomsToGraphs class takes in periodic atomic structures in form of ASE atoms objects and converts
them into graph representations for use in PyTorch. The primary purpose of this class is to determine the
nearest neighbors within some radius around each individual atom, taking into account PBC, and set the
pair index and distance between atom pairs appropriately. Lastly, atomic properties and the graph information
are put into a PyTorch geometric data object for use with PyTorch.
Args:
max_neigh (int): Maximum number of neighbors to consider.
radius (int or float): Cutoff radius in Angstroms to search for neighbors.
r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned.
r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned.
r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned.
r_distances (bool): Return the distances with other properties.
Default is False, so the distances will not be returned.
r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned.
r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms.
Default is True, so the fixed indices will be returned.
r_pbc (bool): Return the periodic boundary conditions with other properties.
Default is False, so the periodic boundary conditions will not be returned.
r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other
properties. Default is None, so no data will be returned as properties.
Attributes:
max_neigh (int): Maximum number of neighbors to consider.
radius (int or float): Cutoff radius in Angstoms to search for neighbors.
r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned.
r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned.
r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned.
r_distances (bool): Return the distances with other properties.
Default is False, so the distances will not be returned.
r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned.
r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms.
Default is True, so the fixed indices will be returned.
r_pbc (bool): Return the periodic boundary conditions with other properties.
Default is False, so the periodic boundary conditions will not be returned.
r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other
properties. Default is None, so no data will be returned as properties.
"""
def __init__(
self,
max_neigh: int = 200,
radius: int = 6,
r_energy: bool = False,
r_forces: bool = False,
r_distances: bool = False,
r_edges: bool = True,
r_fixed: bool = True,
r_pbc: bool = False,
r_stress: bool = False,
r_data_keys: Sequence[str] | None = None,
) -> None:
self.max_neigh = max_neigh
self.radius = radius
self.r_energy = r_energy
self.r_forces = r_forces
self.r_stress = r_stress
self.r_distances = r_distances
self.r_fixed = r_fixed
self.r_edges = r_edges
self.r_pbc = r_pbc
self.r_data_keys = r_data_keys
def _get_neighbors_pymatgen(self, atoms: ase.Atoms):
"""Preforms nearest neighbor search and returns edge index, distances,
and cell offsets"""
if AseAtomsAdaptor is None:
raise RuntimeError(
"Unable to import pymatgen.io.ase.AseAtomsAdaptor. Make sure pymatgen is properly installed."
)
struct = AseAtomsAdaptor.get_structure(atoms)
_c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list(
r=self.radius, numerical_tol=0, exclude_self=True
)
_nonmax_idx = []
for i in range(len(atoms)):
idx_i = (_c_index == i).nonzero()[0]
# sort neighbors by distance, remove edges larger than max_neighbors
idx_sorted = np.argsort(n_distance[idx_i])[: self.max_neigh]
_nonmax_idx.append(idx_i[idx_sorted])
_nonmax_idx = np.concatenate(_nonmax_idx)
_c_index = _c_index[_nonmax_idx]
_n_index = _n_index[_nonmax_idx]
n_distance = n_distance[_nonmax_idx]
_offsets = _offsets[_nonmax_idx]
return _c_index, _n_index, n_distance, _offsets
def _reshape_features(self, c_index, n_index, n_distance, offsets):
"""Stack center and neighbor index and reshapes distances,
takes in np.arrays and returns torch tensors"""
edge_index = torch.LongTensor(np.vstack((n_index, c_index)))
edge_distances = torch.FloatTensor(n_distance)
cell_offsets = torch.LongTensor(offsets)
# remove distances smaller than a tolerance ~ 0. The small tolerance is
# needed to correct for pymatgen's neighbor_list returning self atoms
# in a few edge cases.
nonzero = torch.where(edge_distances >= 1e-8)[0]
edge_index = edge_index[:, nonzero]
edge_distances = edge_distances[nonzero]
cell_offsets = cell_offsets[nonzero]
return edge_index, edge_distances, cell_offsets
def get_edge_distance_vec(
self,
pos,
edge_index,
cell,
cell_offsets,
):
row, col = edge_index
distance_vectors = pos[row] - pos[col]
# correct for pbc
cell = torch.repeat_interleave(cell, edge_index.shape[1], dim=0)
offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3)
distance_vectors += offsets
return distance_vectors
def convert(self, atoms: ase.Atoms, sid=None):
"""Convert a single atomic structure to a graph.
Args:
atoms (ase.atoms.Atoms): An ASE atoms object.
sid (uniquely identifying object): An identifier that can be used to track the structure in downstream
tasks. Common sids used in OCP datasets include unique strings or integers.
Returns:
data (torch_geometric.data.Data): A torch geometic data object with positions, atomic_numbers, tags,
and optionally, energy, forces, distances, edges, and periodic boundary conditions.
Optional properties can included by setting r_property=True when constructing the class.
"""
# set the atomic numbers, positions, and cell
positions = np.array(atoms.get_positions(), copy=True)
pbc = np.array(atoms.pbc, copy=True)
cell = np.array(atoms.get_cell(complete=True), copy=True)
# TODO: change this back &&& ^^^
# positions = wrap_positions(positions, cell, pbc=pbc, eps=0)
atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.uint8)
positions = torch.from_numpy(positions).float()
cell = torch.from_numpy(cell).view(1, 3, 3).float()
natoms = positions.shape[0]
# initialized to torch.zeros(natoms) if tags missing.
# https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags
tags = torch.tensor(atoms.get_tags(), dtype=torch.int)
# put the minimum data in torch geometric data object
data = Data(
cell=cell,
pos=positions,
atomic_numbers=atomic_numbers,
natoms=natoms,
tags=tags,
)
# Optionally add a systemid (sid) to the object
if sid is not None:
data.sid = sid
# optionally include other properties
if self.r_edges:
# run internal functions to get padded indices and distances
atoms_copy = atoms.copy()
atoms_copy.set_positions(positions)
split_idx_dist = self._get_neighbors_pymatgen(atoms_copy)
edge_index, edge_distances, cell_offsets = self._reshape_features(
*split_idx_dist
)
data.edge_index = edge_index
data.cell_offsets = cell_offsets
data.edge_distance_vec = self.get_edge_distance_vec(
positions, edge_index, cell, cell_offsets
)
del atoms_copy
if self.r_energy:
energy = atoms.get_potential_energy(apply_constraint=False)
data.energy = energy
if self.r_forces:
forces = torch.tensor(
atoms.get_forces(apply_constraint=False), dtype=torch.float32
)
data.forces = forces
if self.r_stress:
stress = torch.tensor(
atoms.get_stress(apply_constraint=False, voigt=False),
dtype=torch.float32,
)
data.stress = stress
if self.r_distances and self.r_edges:
data.distances = edge_distances
if self.r_fixed:
fixed_idx = torch.zeros(natoms, dtype=torch.int)
if hasattr(atoms, "constraints"):
from ase.constraints import FixAtoms
for constraint in atoms.constraints:
if isinstance(constraint, FixAtoms):
fixed_idx[constraint.index] = 1
data.fixed = fixed_idx
if self.r_pbc:
data.pbc = torch.tensor(atoms.pbc, dtype=torch.bool)
if self.r_data_keys is not None:
for data_key in self.r_data_keys:
data[data_key] = (
atoms.info[data_key]
if isinstance(atoms.info[data_key], (int, float, str))
else torch.tensor(atoms.info[data_key])
)
return data
def convert_all(
self,
atoms_collection,
processed_file_path: str | None = None,
collate_and_save=False,
disable_tqdm=False,
):
"""Convert all atoms objects in a list or in an ase.db to graphs.
Args:
atoms_collection (list of ase.atoms.Atoms or ase.db.sqlite.SQLite3Database):
Either a list of ASE atoms objects or an ASE database.
processed_file_path (str):
A string of the path to where the processed file will be written. Default is None.
collate_and_save (bool): A boolean to collate and save or not. Default is False, so will not write a file.
Returns:
data_list (list of torch_geometric.data.Data):
A list of torch geometric data objects containing molecular graph info and properties.
"""
# list for all data
data_list = []
if isinstance(atoms_collection, list):
atoms_iter = atoms_collection
elif isinstance(atoms_collection, ase.db.sqlite.SQLite3Database):
atoms_iter = atoms_collection.select()
elif isinstance(
atoms_collection,
(ase.io.trajectory.SlicedTrajectory, ase.io.trajectory.TrajectoryReader),
):
atoms_iter = atoms_collection
else:
raise NotImplementedError
for atoms in tqdm(
atoms_iter,
desc="converting ASE atoms collection to graphs",
total=len(atoms_collection),
unit=" systems",
disable=disable_tqdm,
):
# check if atoms is an ASE Atoms object this for the ase.db case
data = self.convert(
atoms if isinstance(atoms, ase.atoms.Atoms) else atoms.toatoms()
)
data_list.append(data)
if collate_and_save:
data, slices = collate(data_list)
torch.save((data, slices), processed_file_path)
return data_list
"""
Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan}
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from ase.io import read
import logging
from joblib import Parallel, delayed
from ase.optimize import LBFGS as ASE_LBFGS
from ase.optimize import QuasiNewton as ASE_QuasiNewton
from ase.optimize import BFGS as ASE_BFGS
import time
import csv
import os
try:
from mace.calculators import mace_off
except ImportError:
logging.warning("Failed to import MACE modules")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def ensure_directory(directory):
"""Create directory if it doesn't exist."""
if not os.path.exists(directory):
os.makedirs(directory)
logging.info(f"Created directory: {directory}")
def baseline_task(file, device, max_steps, filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006, first_optimizer="LBFGS", second_optimizer="LBFGS"):
"""
Runs the baseline optimization using LBFGS from ase.optimize.
"""
os.environ["CUDA_VISIBLE_DEVICES"] = device.split(":")[-1]
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.info(f"Starting baseline optimization for file {file} on device {device}.")
start_time = time.perf_counter()
crystal = read(file)
# calc = mace_off(model="small", device=device)
calc = mace_off(model="small", device="cuda")
crystal.calc = calc
first_optimizer_class ={
"LBFGS": ASE_LBFGS,
"QuasiNewton": ASE_QuasiNewton,
"BFGS": ASE_BFGS
}.get(first_optimizer, ASE_LBFGS)
# First optimization stage
if filter1 == "UnitCellFilter":
from ase.filters import UnitCellFilter
atoms_with_filter = UnitCellFilter(crystal, scalar_pressure=scalar_pressure)
first_optimizer_instance = first_optimizer_class(atoms_with_filter)
elif filter1 == "FrechetCellFilter":
from ase.filters import FrechetCellFilter
atoms_with_filter = FrechetCellFilter(crystal, scalar_pressure=scalar_pressure)
first_optimizer_instance = first_optimizer_class(atoms_with_filter)
else:
first_optimizer_instance = first_optimizer_class(crystal)
start_time1 = time.perf_counter()
first_optimizer_instance.run(fmax=0.01, steps=max_steps)
end_time1 = time.perf_counter()
# Save intermediate result
output_dir_press = "./cif_result_press"
output_file_press = os.path.join(output_dir_press, os.path.basename(file).replace(".cif", "_press.cif"))
crystal.write(output_file_press)
elapsed_time1 = end_time1 - start_time1
steps1 = first_optimizer_instance.nsteps
if skip_second_stage:
ret_result = {
"file": file,
"stage1_time": elapsed_time1,
"stage1_steps": steps1,
"stage2_time": 0.0,
"stage2_steps": 0,
"total_time": elapsed_time1,
"total_steps": steps1
}
else:
# Second optimization stage
crystal = read(output_file_press)
crystal.calc = calc
second_optimizer_class = {
"LBFGS": ASE_LBFGS,
"QuasiNewton": ASE_QuasiNewton,
"BFGS": ASE_BFGS
}.get(second_optimizer, ASE_LBFGS)
if filter2 == "UnitCellFilter":
from ase.filters import UnitCellFilter
atoms_with_filter2 = UnitCellFilter(crystal)
second_optimizer_instance = second_optimizer_class(atoms_with_filter2)
elif filter2 == "FrechetCellFilter":
from ase.filters import FrechetCellFilter
atoms_with_filter2 = FrechetCellFilter(crystal)
second_optimizer_instance = second_optimizer_class(atoms_with_filter2)
else:
second_optimizer_instance = second_optimizer_class(crystal)
start_time2 = time.perf_counter()
second_optimizer_instance.run(fmax=0.01, steps=max_steps)
end_time2 = time.perf_counter()
# Save final result
output_dir_final = "./cif_result_final"
output_file_final = os.path.join(output_dir_final, os.path.basename(file).replace(".cif", "_opt.cif"))
crystal.write(output_file_final)
# Collect metrics
elapsed_time2 = end_time2 - start_time2
total_time = elapsed_time1 + elapsed_time2
steps2 = second_optimizer_instance.nsteps
ret_result = {
"file": file,
"stage1_time": elapsed_time1,
"stage1_steps": steps1,
"stage2_time": elapsed_time2,
"stage2_steps": steps2,
"total_time": total_time,
"total_steps": steps1 + steps2
}
logging.info(f"Baseline optimization completed for file {file}.")
return ret_result
def run_baseline(files, num_workers, devices, max_steps,
filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006,
optimizer1=None, optimizer2=None):
"""
Runs the baseline optimization using LBFGS from ase.optimize.
"""
logging.info(f"Starting baseline optimization with {num_workers} workers.")
start_time = time.perf_counter()
results = Parallel(n_jobs=num_workers)(
delayed(baseline_task)(file, devices[i % len(devices)], max_steps, filter1, filter2, skip_second_stage, scalar_pressure, optimizer1, optimizer2)
for i, file in enumerate(files)
)
end_time = time.perf_counter()
csv_file = "results_baseline.csv"
with open(csv_file, mode='w', newline='') as file:
writer = csv.DictWriter(file, fieldnames=["file", "stage1_time", "stage1_steps", "stage2_time", "stage2_steps", "total_time", "total_steps"])
writer.writeheader()
for result in results:
writer.writerow(result)
logging.info(f"Baseline optimization completed in {end_time - start_time:.2f} seconds.")
final_elapsed_time = end_time - start_time
summary_csv_file = "summary_baseline.csv"
with open(summary_csv_file, mode='w', newline='') as file:
writer = csv.DictWriter(file, fieldnames=["elapsed_time", "num_workers", "batch_size"])
writer.writeheader()
writer.writerow({
"elapsed_time": final_elapsed_time,
"num_workers": num_workers,
"batch_size": 1
})
logging.info(f"Summary results written to {summary_csv_file}.")
\ No newline at end of file
"""
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
BatchOpt Extensions - C++ and CUDA implementations for performance-critical operations.
This module provides optimized implementations of common operations using
torch.utils.cpp_extension for JIT compilation.
"""
"""
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
CUDA Extension wrapper for vector addition and PBC graph operations.
"""
import torch
from torch.utils.cpp_extension import load
import os
def load_cuda_extension():
"""Load the CUDA extension for vector addition."""
# Check if CUDA is available
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. Cannot load CUDA extension.")
# Get the directory of this file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Path to the CUDA source file
cuda_file = os.path.join(current_dir, "vector_add.cu")
# Load the extension
return load(
name="vector_add_cuda",
sources=[cuda_file],
verbose=True,
extra_cflags=['-O3'],
extra_cuda_cflags=['-O3', '--use_fast_math'],
)
def load_pbc_graph_cuda_extension():
"""Load the CUDA extension for PBC graph operations."""
# Check if CUDA is available
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. Cannot load CUDA extension.")
# Get the directory of this file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Path to the CUDA source file
cuda_file = os.path.join(current_dir, "pbc_graph.cu")
# Load the extension
return load(
name="pbc_graph_cuda",
sources=[cuda_file],
verbose=True,
extra_cflags=['-O3'],
extra_cuda_cflags=['-O3', '--use_fast_math'],
)
# Global variable to store loaded extension
_cuda_extension = None
_pbc_graph_cuda_extension = None
def get_cuda_extension():
"""Get or load the CUDA extension."""
global _cuda_extension
if _cuda_extension is None:
_cuda_extension = load_cuda_extension()
return _cuda_extension
def get_pbc_graph_cuda_extension():
"""Get or load the PBC graph CUDA extension."""
global _pbc_graph_cuda_extension
if _pbc_graph_cuda_extension is None:
_pbc_graph_cuda_extension = load_pbc_graph_cuda_extension()
return _pbc_graph_cuda_extension
def vector_add_cuda(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Perform vector addition using CUDA implementation.
Args:
a: First input tensor (must be on CUDA device)
b: Second input tensor (must be on CUDA device)
Returns:
Result tensor of element-wise addition
"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available.")
if not (a.is_cuda and b.is_cuda):
raise ValueError("CUDA implementation requires CUDA tensors. Use .cuda() to move tensors to GPU.")
extension = get_cuda_extension()
return extension.vector_add(a.float(), b.float())
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <type_traits>
// Template function to get appropriate epsilon for different floating point types
template<typename T>
__device__ __forceinline__ T get_epsilon() {
if constexpr (std::is_same_v<T, float>) {
return static_cast<T>(1e-8);
} else if constexpr (std::is_same_v<T, double>) {
return static_cast<T>(1e-12);
} else {
return static_cast<T>(1e-8); // fallback
}
}
// Templated CUDA kernel for computing pairwise distances with PBC offsets
// This version avoids repeat_interleave by computing offsets directly in the kernel
template<typename T>
__global__ void pbc_distance_kernel_optimized(
const T* pos1,
const T* pos2,
const T* pbc_offsets, // [batch_size, 3]
const int64_t* num_atoms_per_image_sqr, // [batch_size]
const int64_t* batch_offsets, // [batch_size] - cumulative offsets for each batch
T* distances_squared,
bool* valid_mask,
int num_pairs,
T radius_squared
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_pairs) {
// Find which batch this pair belongs to
int batch_idx = 0;
while (batch_idx < num_pairs && idx >= batch_offsets[batch_idx + 1]) {
batch_idx++;
}
// Get PBC offset for this batch
T offset_x = pbc_offsets[batch_idx * 3];
T offset_y = pbc_offsets[batch_idx * 3 + 1];
T offset_z = pbc_offsets[batch_idx * 3 + 2];
// Get positions for this atom pair with PBC offset
T dx = pos2[idx * 3] - pos1[idx * 3] + offset_x;
T dy = pos2[idx * 3 + 1] - pos1[idx * 3 + 1] + offset_y;
T dz = pos2[idx * 3 + 2] - pos1[idx * 3 + 2] + offset_z;
// Compute squared distance
T dist_sq = dx * dx + dy * dy + dz * dz;
distances_squared[idx] = dist_sq;
// Check if within radius
valid_mask[idx] = (dist_sq <= radius_squared) && (dist_sq > get_epsilon<T>());
}
}
// Original kernel for fallback
template<typename T>
__global__ void pbc_distance_kernel(
const T* pos1,
const T* pos2,
const T* pbc_offsets,
T* distances_squared,
bool* valid_mask,
int num_pairs,
T radius_squared
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_pairs) {
// Get positions for this atom pair
T dx = pos2[idx * 3] - pos1[idx * 3] + pbc_offsets[idx * 3];
T dy = pos2[idx * 3 + 1] - pos1[idx * 3 + 1] + pbc_offsets[idx * 3 + 1];
T dz = pos2[idx * 3 + 2] - pos1[idx * 3 + 2] + pbc_offsets[idx * 3 + 2];
// Compute squared distance
T dist_sq = dx * dx + dy * dy + dz * dz;
distances_squared[idx] = dist_sq;
// Check if within radius
valid_mask[idx] = (dist_sq <= radius_squared) && (dist_sq > get_epsilon<T>());
}
}
// Template helper function to launch the appropriate optimized kernel
template<typename T>
inline void launch_pbc_distance_kernel_optimized(
const T* pos1,
const T* pos2,
const T* pbc_offsets,
const int64_t* num_atoms_per_image_sqr,
const int64_t* batch_offsets,
T* distances_squared,
bool* valid_mask,
int num_pairs,
T radius_squared,
int blocks,
int threads_per_block
) {
pbc_distance_kernel_optimized<T><<<blocks, threads_per_block>>>(
pos1, pos2, pbc_offsets, num_atoms_per_image_sqr, batch_offsets,
distances_squared, valid_mask, num_pairs, radius_squared
);
}
// Template helper function to launch the appropriate kernel (fallback)
template<typename T>
void launch_pbc_distance_kernel(
const T* pos1,
const T* pos2,
const T* pbc_offsets,
T* distances_squared,
bool* valid_mask,
int num_pairs,
T radius_squared,
int blocks,
int threads_per_block
) {
pbc_distance_kernel<T><<<blocks, threads_per_block>>>(
pos1, pos2, pbc_offsets, distances_squared, valid_mask, num_pairs, radius_squared
);
}
// CUDA function to compute distances for all unit cell offsets
std::vector<torch::Tensor> pbc_distance_cuda(
torch::Tensor pos1,
torch::Tensor pos2,
torch::Tensor data_cell,
torch::Tensor num_atoms_per_image_sqr,
int batch_size,
std::vector<int> max_rep,
float radius,
torch::Device device
) {
// Convert tensors to CUDA if not already, but preserve original dtype
pos1 = pos1.to(device).contiguous();
pos2 = pos2.to(device).contiguous();
data_cell = data_cell.to(device).contiguous();
num_atoms_per_image_sqr = num_atoms_per_image_sqr.to(device);
// Check that all position tensors have the same dtype
TORCH_CHECK(pos1.dtype() == pos2.dtype(), "pos1 and pos2 must have the same dtype");
TORCH_CHECK(pos1.dtype() == data_cell.dtype(), "pos1 and data_cell must have the same dtype");
// Determine if we're working with float32 or float64
bool is_float64 = pos1.dtype() == torch::kFloat64;
int num_pairs = pos1.size(0);
// Storage for all results across unit cells
std::vector<torch::Tensor> all_index1, all_index2, all_unit_cell, all_distances_sq;
// Create base indices for original atom pairs
torch::Tensor base_indices = torch::arange(num_pairs, torch::dtype(torch::kLong).device(device));
// Launch parameters
int threads_per_block = 512;
int blocks = (num_pairs + threads_per_block - 1) / threads_per_block;
// Pre-allocate tensors outside the loop for reuse
torch::Tensor distances_squared = torch::zeros({num_pairs},
torch::dtype(pos1.dtype()).device(device));
torch::Tensor valid_mask = torch::zeros({num_pairs},
torch::dtype(torch::kBool).device(device));
torch::Tensor unit_cell_offset = torch::zeros({3},
torch::dtype(pos1.dtype()).device(device));
torch::Tensor unit_cell_offset_batch = torch::zeros({batch_size, 3, 1},
torch::dtype(pos1.dtype()).device(device));
// Pre-compute batch offsets for optimized kernel
torch::Tensor batch_offsets = torch::zeros({batch_size + 1},
torch::dtype(torch::kLong).device(device));
torch::Tensor cumsum = torch::cumsum(num_atoms_per_image_sqr, 0);
batch_offsets.slice(0, 1, batch_size + 1) = cumsum;
// Iterate over unit cell offsets (triple loop)
// NOTE: for i, j, k loop can not be flatten, as we need to limit the device memory usage
#pragma unroll
for (int i = -max_rep[0]; i <= max_rep[0]; i++) {
#pragma unroll
for (int j = -max_rep[1]; j <= max_rep[1]; j++) {
#pragma unroll
for (int k = -max_rep[2]; k <= max_rep[2]; k++) {
// Reuse pre-allocated unit cell offset tensor
unit_cell_offset[0] = static_cast<float>(i);
unit_cell_offset[1] = static_cast<float>(j);
unit_cell_offset[2] = static_cast<float>(k);
// Compute PBC offsets for this unit cell
// unit_cell_offset_batch.fill_(0);
unit_cell_offset_batch.select(2, 0) = unit_cell_offset.unsqueeze(0).expand({batch_size, -1});
torch::Tensor pbc_offsets = torch::bmm(data_cell, unit_cell_offset_batch).squeeze(-1);
// // Optimized: Use index_select instead of repeat_interleave
// // Create index tensor for selecting pbc_offsets based on atom pairs
// int64_t offset = 0;
// for (int b = 0; b < batch_size; b++) {
// int64_t num_pairs_in_batch = num_atoms_per_image_sqr[b].item<int64_t>();
// auto batch_indices = torch::full({num_pairs_in_batch}, b,
// torch::dtype(torch::kLong).device(device));
// pbc_offsets_per_atom.slice(0, offset, offset + num_pairs_in_batch) =
// pbc_offsets.index_select(0, batch_indices);
// offset += num_pairs_in_batch;
// }
// Reset output tensors for reuse
// distances_squared.fill_(0);
// valid_mask.fill_(false);
// Launch templated CUDA kernel
if (is_float64) {
double radius_squared = static_cast<double>(radius) * static_cast<double>(radius);
launch_pbc_distance_kernel_optimized<double>(
pos1.data_ptr<double>(),
pos2.data_ptr<double>(),
// pbc_offsets_per_atom.data_ptr<double>(),
pbc_offsets.data_ptr<double>(),
num_atoms_per_image_sqr.data_ptr<int64_t>(),
batch_offsets.data_ptr<int64_t>(),
distances_squared.data_ptr<double>(),
valid_mask.data_ptr<bool>(),
num_pairs,
radius_squared,
blocks,
threads_per_block
);
} else {
float radius_squared = radius * radius;
launch_pbc_distance_kernel_optimized<float>(
pos1.data_ptr<float>(),
pos2.data_ptr<float>(),
// pbc_offsets_per_atom.data_ptr<float>(),
pbc_offsets.data_ptr<float>(),
num_atoms_per_image_sqr.data_ptr<int64_t>(),
batch_offsets.data_ptr<int64_t>(),
distances_squared.data_ptr<float>(),
valid_mask.data_ptr<bool>(),
num_pairs,
radius_squared,
blocks,
threads_per_block
);
}
// Filter valid pairs
torch::Tensor valid_indices = torch::nonzero(valid_mask).squeeze(-1);
if (valid_indices.numel() > 0) {
torch::Tensor valid_base_indices = base_indices.index_select(0, valid_indices);
torch::Tensor valid_distances = distances_squared.index_select(0, valid_indices);
torch::Tensor valid_unit_cell = unit_cell_offset.unsqueeze(0).repeat({valid_indices.size(0), 1});
all_index1.push_back(valid_base_indices);
all_unit_cell.push_back(valid_unit_cell);
all_distances_sq.push_back(valid_distances);
}
}
}
}
// Single synchronization after all kernel launches
cudaDeviceSynchronize();
// Concatenate results
torch::Tensor final_indices, final_unit_cell, final_distances;
if (all_index1.size() > 0) {
final_indices = torch::cat(all_index1);
final_unit_cell = torch::cat(all_unit_cell);
final_distances = torch::cat(all_distances_sq);
} else {
final_indices = torch::empty({0}, torch::dtype(torch::kLong).device(device));
final_unit_cell = torch::empty({0, 3}, torch::dtype(pos1.dtype()).device(device));
final_distances = torch::empty({0}, torch::dtype(pos1.dtype()).device(device));
}
return {final_indices, final_unit_cell, final_distances};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("pbc_distance_cuda", &pbc_distance_cuda, "PBC distance computation with CUDA");
}
"""
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
CUDA-accelerated PBC graph operations for atomic systems.
"""
import torch
from typing import Optional, List
from .pbc_graph_legacy import get_max_neighbors_mask
from .extensions.cuda_ops import get_pbc_graph_cuda_extension
def radius_graph_pbc_cuda(
data,
radius,
max_num_neighbors_threshold,
enforce_max_neighbors_strictly: bool = False,
pbc=None,
dtype=torch.float64,
):
"""
Memory-efficient CUDA-accelerated version of radius_graph_pbc.
This implementation follows the memory-efficient approach with triple loops
but accelerates the distance computation using CUDA kernels.
"""
if pbc is None:
pbc = [True, True, True]
device = data.pos.device
batch_size = len(data.natoms)
# Handle PBC settings
if hasattr(data, "pbc"):
data.pbc = torch.atleast_2d(data.pbc)
for i in range(3):
if not torch.any(data.pbc[:, i]).item():
pbc[i] = False
elif torch.all(data.pbc[:, i]).item():
pbc[i] = True
else:
raise RuntimeError(
"Different structures in the batch have different PBC configurations."
)
# position of the atoms
atom_pos = data.pos
# Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
num_atoms_per_image = data.natoms
num_atoms_per_image_sqr = (num_atoms_per_image**2).long()
# index offset between images
index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image
index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr)
num_atoms_per_image_expand = torch.repeat_interleave(
num_atoms_per_image, num_atoms_per_image_sqr
)
# Compute atom pair indices
num_atom_pairs = torch.sum(num_atoms_per_image_sqr)
index_sqr_offset = (
torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr
)
index_sqr_offset = torch.repeat_interleave(
index_sqr_offset, num_atoms_per_image_sqr
)
atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset
# Compute the indices for the pairs of atoms (using division and mod)
index1 = (
torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor")
) + index_offset_expand
index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand
# Get the positions for each atom
pos1 = torch.index_select(atom_pos, 0, index1)
pos2 = torch.index_select(atom_pos, 0, index2)
# Calculate required number of unit cells in each direction for PBC
cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1)
cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True)
if pbc[0]:
inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1)
rep_a1 = torch.ceil(radius * inv_min_dist_a1)
else:
rep_a1 = data.cell.new_zeros(1)
if pbc[1]:
cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1)
inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1)
rep_a2 = torch.ceil(radius * inv_min_dist_a2)
else:
rep_a2 = data.cell.new_zeros(1)
if pbc[2]:
cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1)
inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1)
rep_a3 = torch.ceil(radius * inv_min_dist_a3)
else:
rep_a3 = data.cell.new_zeros(1)
# Take the max over all images for uniformity
max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())]
# Pre-transpose data_cell for efficiency
data_cell = torch.transpose(data.cell, 1, 2)
# Use CUDA kernel for the triple loop computation
# try:
pbc_graph_cuda = get_pbc_graph_cuda_extension()
# Call the CUDA implementation
valid_pair_indices, unit_cell, atom_distance_sqr = pbc_graph_cuda.pbc_distance_cuda(
pos1, pos2, data_cell,
num_atoms_per_image_sqr, batch_size, max_rep, float(radius), device
)
# Map back to original index1 and index2
if len(valid_pair_indices) > 0:
index1 = index1.index_select(0, valid_pair_indices.long())
index2 = index2.index_select(0, valid_pair_indices.long())
else:
index1 = torch.empty(0, dtype=torch.long, device=device)
index2 = torch.empty(0, dtype=torch.long, device=device)
unit_cell = torch.empty(0, 3, dtype=dtype, device=device)
atom_distance_sqr = torch.empty(0, dtype=dtype, device=device)
# Sort index1 in ascending order and rearrange other arrays correspondingly
if len(index1) > 0:
sort_indices = torch.argsort(index1)
index1 = index1[sort_indices]
index2 = index2[sort_indices]
unit_cell = unit_cell[sort_indices]
atom_distance_sqr = atom_distance_sqr[sort_indices]
mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask(
natoms=data.natoms,
index=index1,
atom_distance=atom_distance_sqr,
max_num_neighbors_threshold=max_num_neighbors_threshold,
enforce_max_strictly=enforce_max_neighbors_strictly,
)
if not torch.all(mask_num_neighbors):
# Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
index1 = torch.masked_select(index1, mask_num_neighbors)
index2 = torch.masked_select(index2, mask_num_neighbors)
unit_cell = torch.masked_select(
unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3)
)
unit_cell = unit_cell.view(-1, 3)
edge_index = torch.stack((index2, index1))
return edge_index, unit_cell, num_neighbors_image
\ No newline at end of file
This diff is collapsed.
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
from .optimizable import OptimizableBatch, OptimizableUnitCellBatch
__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"]
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
Utilities to interface OCP models/trainers with the Atomic Simulation
Environment (ASE)
"""
from __future__ import annotations
from types import MappingProxyType
from typing import TYPE_CHECKING
import torch
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from ase.constraints import FixAtoms
if TYPE_CHECKING:
from torch_geometric.data import Batch
# system level model predictions have different shapes than expected by ASE
ASE_PROP_RESHAPE = MappingProxyType(
{"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)}
)
def batch_to_atoms(
batch: Batch,
results: dict[str, torch.Tensor] | None = None,
wrap_pos: bool = True,
eps: float = 1e-7,
) -> list[Atoms]:
"""Convert a data batch to ase Atoms
Args:
batch: data batch
results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results
are given no calculator will be added to the atoms objects.
wrap_pos: wrap positions back into the cell.
eps: Small number to prevent slightly negative coordinates from being wrapped.
Returns:
list of Atoms
"""
n_systems = batch.natoms.shape[0]
natoms = batch.natoms.tolist()
numbers = torch.split(batch.atomic_numbers, natoms)
fixed = torch.split(batch.fixed.to(torch.bool), natoms)
if results is not None:
results = {
key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist()
if len(val) == len(batch)
else [v.cpu().detach().numpy() for v in torch.split(val, natoms)]
for key, val in results.items()
}
positions = torch.split(batch.pos, natoms)
tags = torch.split(batch.tags, natoms)
cells = batch.cell
atoms_objects = []
for idx in range(n_systems):
pos = positions[idx].cpu().detach().numpy()
cell = cells[idx].cpu().detach().numpy()
# TODO take pbc from data
# TODO: &&& ^^^ change this back !!!
# if wrap_pos:
# pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps)
atoms = Atoms(
numbers=numbers[idx].tolist(),
cell=cell,
positions=pos,
tags=tags[idx].tolist(),
constraint=FixAtoms(mask=fixed[idx].tolist()),
pbc=[True, True, True],
)
if results is not None:
calc = SinglePointCalculator(
atoms=atoms, **{key: val[idx] for key, val in results.items()}
)
atoms.set_calculator(calc)
atoms_objects.append(atoms)
return atoms_objects
This diff is collapsed.
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
from .bfgs_torch import BFGS
from .bfgsfusedls import BFGSFusedLS
__all__ = ["BFGS", "BFGSFusedLS"]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment