Unverified Commit fa84b16c authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent 09624897
"""
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
import logging
import torch
from torch_scatter import scatter
from ..optimizable import OptimizableBatch
class BFGS:
def __init__(
self,
optimizable_batch: OptimizableBatch,
maxstep: float = 0.2,
alpha: float = 70.0,
early_stop = False,
) -> None:
"""
Args:
"""
self.optimizable = optimizable_batch
self.maxstep = maxstep
self.alpha = alpha
# self.H0 = 1.0 / self.alpha
self.trajectories = None
self.device=self.optimizable.device
self.fmax = None
self.steps = None
self.initialize()
self.early_stop = early_stop
def initialize(self):
# initial hessian
self.H0 = [
torch.eye(3 * size, device=self.optimizable.device, dtype=torch.float64) * self.alpha
for size in self.optimizable.elem_per_group
]
self.H = [None] * self.optimizable.batch_size
self.pos0 = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device, dtype=torch.float64)
self.forces0 = torch.zeros_like(self.pos0, device=self.device, dtype=torch.float64)
def restart_from_earlystop(self, restart_indices, old_batch_indices):
H_new = []
pos0_new = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device, dtype=torch.float64)
forces0_new = torch.zeros_like(pos0_new, device=self.device, dtype=torch.float64)
# collect the preserved historical data by old_batch_indices
for i, idx in enumerate(restart_indices):
mask_old = (idx==old_batch_indices.repeat_interleave(3))
mask = (i==self.optimizable.batch_indices.repeat_interleave(3))
H_new.append(self.H[idx])
pos0_new[mask] = self.pos0[mask_old]
forces0_new[mask] = self.forces0[mask_old]
# append new info for the new batch
for i in range(len(H_new), self.optimizable.batch_size):
H_new.append(None)
self.H = H_new
self.pos0 = pos0_new
self.forces0 = forces0_new
def run(self, fmax, maxstep, is_restart_earlystop=False, restart_indices=None, old_batch_indices=None):
logging.info("Enter bfgs's main program.")
self.fmax = fmax
self.max_iter = maxstep
if is_restart_earlystop:
self.restart_from_earlystop(restart_indices, old_batch_indices)
iteration = 0
max_forces = self.optimizable.get_max_forces(apply_constraint=True)
logging.info("Step Fmax(eV/A)")
while iteration < self.max_iter and not self.optimizable.converged(
forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25,
):
if self.early_stop and iteration > 0:
converge_indices = self.optimizable.converge_indices_list
if len(converge_indices) > 0:
logging.info(f"Early stopping at iteration {iteration}")
break
logging.info(
f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist())
)
self.step()
max_forces = self.optimizable.get_max_forces(apply_constraint=True)
iteration += 1
logging.info(
f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist())
)
# GPU memory usage as per nvidia-smi seems to gradually build up as
# batches are processed. This releases unoccupied cached memory.
torch.cuda.empty_cache()
# set predicted values to batch
for name, value in self.optimizable.results.items():
setattr(self.optimizable.batch, name, value)
self.nsteps = iteration
if self.early_stop:
converge_indices_list = self.optimizable.converge_indices_list
return converge_indices_list
else:
return self.optimizable.converged(
forces=None, fmax=self.fmax, max_forces=max_forces
)
def step(self):
forces = self.optimizable.get_forces(apply_constraint=True).to(
dtype=torch.float64
)
pos = self.optimizable.get_positions().to(dtype=torch.float64)
dpos, steplengths = self.prepare_step(pos, forces)
dpos = self.determine_step(dpos, steplengths)
self.optimizable.set_positions(pos+dpos)
def prepare_step(self, pos, forces):
forces = forces.reshape(-1)
pos = pos.view(-1)
self.update(pos, forces, self.pos0, self.forces0)
dpos_list = []
cur_indices = self.optimizable.batch_indices.repeat_interleave(3)
# 预初始化结果列表
dpos_list = [None] * len(self.H)
# 分离计算任务:仅对需要计算的H矩阵创建流
calc_indices = [i for i, need_update in enumerate(self.optimizable.update_mask) if need_update]
streams = [torch.cuda.Stream() for _ in calc_indices]
# 并行执行实际计算
for i, stream in zip(calc_indices, streams):
with torch.cuda.stream(stream):
omega, V = torch.linalg.eigh(self.H[i])
dpos_list[i] = (V @ (forces[cur_indices==i].t() @ V / torch.abs(omega)).t())
# 同步所有计算流
torch.cuda.current_stream().synchronize()
# 在主线程处理零张量
for i in range(len(self.H)):
if not self.optimizable.update_mask[i]:
dpos_list[i] = torch.zeros_like(forces[cur_indices==i])
# 同步所有流
for stream in streams:
stream.synchronize()
# dpos = torch.vstack(dpos_list)
dpos = torch.zeros_like(forces)
for i in torch.unique(cur_indices):
mask = (cur_indices == i)
dpos[mask] = dpos_list[i]
dpos = dpos.reshape(-1, 3)
steplengths = (dpos ** 2).sum(dim=-1).sqrt()
self.pos0 = pos
self.forces0 = forces
return dpos, steplengths
def determine_step(self, dpos, steplengths):
longest_steps = scatter(
steplengths, self.optimizable.batch_indices, reduce="max"
)
longest_steps = longest_steps[self.optimizable.batch_indices]
maxstep = longest_steps.new_tensor(self.maxstep)
scale = (longest_steps).reciprocal() * torch.min(longest_steps, maxstep)
dpos *= scale.unsqueeze(1)
return dpos
def update(self, pos, forces, pos0, forces0):
if self.H is None:
self.H = self.H0
return
dpos = pos - pos0
dforces = forces - forces0
batch_indices_flatten = self.optimizable.batch_indices.repeat_interleave(3)
dg = torch.zeros_like(dforces)
all_size = self.optimizable.elem_per_group
for i in range(self.optimizable.batch_size):
if self.H[i] is None:
continue
mask = (i==batch_indices_flatten)
if torch.abs(dpos[mask]).max() < 1e-7:
continue
dg[mask] = self.H[i] @ dpos[mask]
a = self._batched_dot_1d(dforces, dpos)
b = self._batched_dot_1d(dpos, dg)
for i in range(self.optimizable.batch_size):
if self.H[i] is None:
self.H[i] = torch.eye(3*all_size[i], device=self.device, dtype=torch.float64) * self.alpha
continue
mask = (i==batch_indices_flatten)
if not self.optimizable.update_mask[i]:
continue
if torch.abs(dpos[mask]).max() < 1e-7:
continue
outer_force = torch.outer(dforces[mask], dforces[mask])
outer_dg = torch.outer(dg[mask], dg[mask])
self.H[i] -= outer_force / a[i] + outer_dg / b[i]
def update_parallel(self, pos, forces, pos0, forces0):
if self.H is None:
self.H = self.H0
return
dpos = pos - pos0
if torch.abs(dpos).max() < 1e-7:
return
dforces = forces - forces0
cur_indices = self.optimizable.batch_indices.repeat_interleave(3)
a = self._batched_dot_1d(dforces, dpos)
# DONE: There is a bug using hstack.
# dg = torch.hstack([self.H[i] @ dpos[cur_indices == i] for i in range(len(self.H))])
# DONE: parallel this part
# dg_list = [self.H[i] @ dpos[cur_indices == i] for i in range(len(self.H))]
dg_list = [None] * len(self.H)
streams = [torch.cuda.Stream() for _ in dg_list]
for i, stream in zip(range(len(dg_list)), streams):
with torch.cuda.stream(stream):
dg_list[i] = self.H[i] @ dpos[cur_indices == i]
torch.cuda.current_stream().synchronize()
for stream in streams:
stream.synchronize()
dg = torch.zeros_like(dforces)
for i in torch.unique(cur_indices):
mask = (cur_indices == i)
dg[mask] = dg_list[i]
b = self._batched_dot_1d(dpos, dg)
# DONE: parallel this part
for i, stream in zip(range(len(self.H)), streams):
if not self.optimizable.update_mask[i]:
continue
with torch.cuda.stream(stream):
outer_force = torch.outer(dforces[cur_indices==i], dforces[cur_indices==i])
outer_dg = torch.outer(dg[cur_indices==i], dg[cur_indices==i])
self.H[i] -= outer_force / a[i] + outer_dg / b[i]
torch.cuda.current_stream().synchronize()
for stream in streams:
stream.synchronize()
def _batched_dot_2d(self, x: torch.Tensor, y: torch.Tensor):
return scatter(
(x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum"
)
def _batched_dot_1d(self, x: torch.Tensor, y: torch.Tensor):
return scatter(
(x * y), self.optimizable.batch_indices.repeat_interleave(3), reduce="sum"
)
\ No newline at end of file
"""
Copyright (c) 2025 Ma Zhaojia
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
import logging
import torch
from torch_scatter import scatter
# from .linesearch_torch import LineSearchBatch
from ..optimizable import OptimizableBatch
from torch.profiler import profile, record_function, ProfilerActivity, schedule, tensorboard_trace_handler
from datetime import datetime
import os
import math
import gc
class BFGSFusedLS:
"""
Port of BFGSLineSearch from bfgslinesearch.py, adapted to PyTorch
and batched operations, mirroring lbfgs_torch.py structure.
"""
def __init__(
self,
optimizable_batch: OptimizableBatch,
maxstep: float = 0.2,
c1: float = 0.23,
c2: float = 0.46,
alpha: float = 10.0,
stpmax: float = 50.0,
device = 'cpu',
early_stop: bool = False,
use_profiler: bool = False,
profiler_log_dir: str = './log',
profiler_schedule_config: dict = None,
dtype: torch.dtype = torch.float64,
):
self.optimizable = optimizable_batch
self.maxstep = maxstep
self.c1 = c1
self.c2 = c2
self.alpha = alpha
self.stpmax = stpmax
self.nsteps = 0
self.device = device
self.force_calls = 0
self.early_stop = early_stop
self.use_profiler = use_profiler
self.profiler_log_dir = profiler_log_dir
self.profiler_schedule_config = profiler_schedule_config or {"wait": 48, "warmup": 1, "active": 1, "repeat": 8}
self.dtype = dtype
self.converge_indices_list = None
# The information from the previous round is useful for the current round's calculations.
## These variables need to be update accroding to new input when eary stop is triggered.
self.Hs = None
self.r0 = None
self.g0 = None
self.p_list = [None] * self.optimizable.batch_size
self.no_update_list = [False] * self.optimizable.batch_size
self.ls_completed = [True] * self.optimizable.batch_size
self.ls_batch = LineSearchBatch(self.optimizable.batch_indices, device="cpu", dtype=self.dtype)
## need to be recalculate when early stop is triggered
self.forces = None
self.energies = None
def restart_from_earlystop(self, restart_indices, old_batch_indices):
Hs_new = []
r0_new = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device)
g0_new = torch.zeros_like(r0_new, device=self.device)
p_list_new = []
no_update_list_new = []
ls_completed_new = []
# collect the preserved historical info by old_indices
for i, idx in enumerate(restart_indices):
mask_old = (idx==old_batch_indices.repeat_interleave(3))
mask = (i==self.optimizable.batch_indices.repeat_interleave(3))
Hs_new.append(self.Hs[idx])
p_list_new.append(self.p_list[idx])
no_update_list_new.append(self.no_update_list[idx])
ls_completed_new.append(self.ls_completed[idx])
r0_new[mask] = self.r0[mask_old]
g0_new[mask] = self.g0[mask_old]
# append new info for new element in batch
for i in range(len(Hs_new), self.optimizable.batch_size):
# Hs_new.append(torch.eye(3 * self.optimizable.elem_per_group[i], device=self.device, dtype=torch.float64))
Hs_new.append(None)
p_list_new.append(None)
no_update_list_new.append(False)
ls_completed_new.append(True)
self.Hs = Hs_new
self.r0 = r0_new
self.g0 = g0_new
self.p_list = p_list_new
self.no_update_list = no_update_list_new
self.ls_completed = ls_completed_new
self.forces = None
self.energies = None
self.ls_batch.restart_from_earlystop(restart_indices=restart_indices, batch_indices_new=self.optimizable.batch_indices)
def step(self):
optimizable = self.optimizable
if self.forces is None:
self.forces = optimizable.get_forces().to(self.device)
r = optimizable.get_positions().reshape(-1).to(self.device)
g = -self.forces.reshape(-1) / self.alpha
p0_list = self.p_list
self.update(r, g, self.r0, self.g0, p0_list)
if self.energies is None:
self.energies = self.func(r)
for i in range(self.optimizable.batch_size):
if self.ls_completed[i]:
p = -torch.matmul(self.Hs[i], g[i==self.optimizable.batch_indices.repeat_interleave(3)])
# Implement scaling for numerical stability with simpler calculation
p_size = torch.sqrt((p**2).sum())
min_size = torch.sqrt(self.optimizable.elem_per_group[i] * 1e-10)
if p_size <= min_size:
p = p * (min_size / p_size)
self.p_list[i] = p
# ls_batch = LineSearchBatch(self.optimizable.batch_indices, device="cpu")
continue_search = [not elem for elem in self.ls_completed]
self.alpha_k_list, self.e_list, self.e0_list, self.no_update_list, self.ls_completed = self.ls_batch._linesearch_batch(
self.func, self.fprime, r, self.p_list, g, self.energies, None,
maxstep=self.maxstep, c1=self.c1, c2=self.c2, stpmax=self.stpmax, continue_search=continue_search
)
# reset device for linesearch result
for i in range(self.optimizable.batch_size):
if self.ls_completed[i]:
self.alpha_k_list[i] = self.alpha_k_list[i].to(self.device)
self.p_list[i] = self.p_list[i].to(self.device)
dr_tensor = torch.zeros_like(r)
for i in range(self.optimizable.batch_size):
# if check_cache:
# mask = (i == self.optimizable.batch_indices.repeat_interleave(3))
# dr_tensor_all[mask] = self.alpha_k_list[i].to(self.device) * self.p_list[i].to(self.device)
if not self.ls_completed[i]:
continue
if self.alpha_k_list[i] is None:
raise RuntimeError("LineSearch failed!")
mask = (i == self.optimizable.batch_indices.repeat_interleave(3))
dr_tensor[mask] = self.alpha_k_list[i] * self.p_list[i]
# if check_cache:
# cached_pos = optimizable.get_positions().reshape(-1).to(self.device)
# update_pos = r + dr_tensor_all
# assert torch.allclose(update_pos, cached_pos), "dr_tensor_cached should be equal to dr_tensor"
# TODO: get_forces/get_potential_energies will trigger compare_batch which is time-consuming
forces_cache = optimizable.get_forces()
energies_cache = self.optimizable.get_potential_energies() / self.alpha
# update self.forces
for i in range(self.optimizable.batch_size):
if not self.ls_completed[i]:
continue
mask = (i == self.optimizable.batch_indices)
self.forces[mask] = forces_cache[mask]
self.energies[i] = energies_cache[i]
optimizable.set_positions((r + dr_tensor).reshape(-1, 3))
self.r0 = r
self.g0 = g
# @torch.compile
def update(self, r, g, r0, g0, p0_list):
all_sizes = self.optimizable.elem_per_group
if self.Hs is None:
self.Hs = [
torch.eye(3 * sz, device=self.device, dtype=self.dtype)
for sz in all_sizes
]
return
dr = r - r0
dg = g - g0
for i in range(self.optimizable.batch_size):
if self.Hs[i] is None:
self.Hs[i] = torch.eye(3 * all_sizes[i], device=self.optimizable.device, dtype=self.dtype)
continue
if not self.ls_completed[i]:
continue
if self.no_update_list[i] is True:
print('skip update')
continue
cur_mask = (i == self.optimizable.batch_indices.repeat_interleave(3))
cur_g = g[cur_mask]
cur_p0 = p0_list[i]
cur_g0 = g0[cur_mask]
cur_dg = dg[cur_mask]
cur_dr = dr[cur_mask]
if not (((self.alpha_k_list[i] or 0) > 0 and
abs(torch.dot(cur_g, cur_p0)) - abs(torch.dot(cur_g0, cur_p0)) < 0) or False):
continue
try:
rhok = 1.0 / (torch.dot(cur_dg, cur_dr))
except:
rhok = 1000.0
print("Divide-by-zero encountered: rhok assumed large")
if torch.isinf(rhok):
rhok = 1000.0
print("Divide-by-zero encountered: rhok assumed large")
I = torch.eye(all_sizes[i]*3, device=self.device, dtype=self.dtype)
A1 = I - cur_dr[:, None] * cur_dg[None, :] * rhok
A2 = I - cur_dg[:, None] * cur_dr[None, :] * rhok
self.Hs[i] = (torch.matmul(A1, torch.matmul(self.Hs[i], A2)) +
rhok * cur_dr[:, None] * cur_dr[None, :])
# def update(self, r, g, r0, g0, p0_list):
# self.Is = [
# torch.eye(sz * 3, dtype=torch.float64, device=self.device)
# for sz in self.optimizable.elem_per_group
# ]
# # TODO: BFGS for loop 是不是在被打断之后需要重建这个 self.Hs?
# # TODO: 并且我们保存的上一次的r,g,r0,g0也被丢弃了
# if self.Hs is None:
# self.Hs = [
# torch.eye(3 * sz, device=self.optimizable.device, dtype=torch.float64)
# for sz in self.optimizable.elem_per_group
# ]
# return
# else:
# dr = r - r0
# dg = g - g0
# for i in range(self.optimizable.batch_size):
# if not self.ls_completed[i]:
# continue
# cur_mask = (i==self.optimizable.batch_indices.repeat_interleave(3))
# cur_g = g[cur_mask]
# cur_p0 = p0_list[i]
# cur_g0 = g0[cur_mask]
# cur_dg = dg[cur_mask]
# cur_dr = dr[cur_mask]
# if not (((self.alpha_k_list[i] or 0) > 0 and
# abs(torch.dot(cur_g, cur_p0)) - abs(torch.dot(cur_g0, cur_p0)) < 0) or False):
# break
# if self.no_update_list[i] is True:
# print('skip update')
# break
# try:
# rhok = 1.0 / (torch.dot(cur_dg, cur_dr))
# except:
# rhok = 1000.0
# print("Divide-by-zero encountered: rhok assumed large")
# if torch.isinf(rhok):
# rhok = 1000.0
# print("Divide-by-zero encountered: rhok assumed large")
# A1 = self.Is[i] - cur_dr[:, None] * cur_dg[None, :] * rhok
# A2 = self.Is[i] - cur_dg[:, None] * cur_dr[None, :] * rhok
# self.Hs[i] = (torch.matmul(A1, torch.matmul(self.Hs[i], A2)) +
# rhok * cur_dr[:, None] * cur_dr[None, :])
def func(self, x):
self.optimizable.set_positions(x.reshape(-1, 3).to(self.device))
return self.optimizable.get_potential_energies() / self.alpha
def fprime(self, x):
self.optimizable.set_positions(x.reshape(-1, 3).to(self.device))
self.force_calls += 1
forces = self.optimizable.get_forces().reshape(-1)
return - forces / self.alpha
def run(self, fmax, maxstep, is_restart_earlystop=False, restart_indices=None, old_batch_indices=None):
logging.info("Enter bfgsfusedlinesearch's main program.")
self.fmax = fmax
self.max_iter = maxstep
if is_restart_earlystop:
self.restart_from_earlystop(restart_indices, old_batch_indices)
iteration = 0
max_forces = self.optimizable.get_max_forces(apply_constraint=True)
logging.info("Step Fmax(eV/A)")
# Run with profiler if enabled
if self.use_profiler:
activities = [ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(ProfilerActivity.CUDA)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
pid = os.getpid()
with torch.profiler.profile(
activities=activities,
schedule=torch.profiler.schedule(
wait=self.profiler_schedule_config["wait"],
warmup=self.profiler_schedule_config["warmup"],
active=self.profiler_schedule_config["active"],
repeat=self.profiler_schedule_config["repeat"]
),
on_trace_ready=tensorboard_trace_handler(self.profiler_log_dir, worker_name=f"BFGSLS_{pid}"),
with_stack=True,
profile_memory=True,
) as prof:
# Main optimization loop with profiling
while iteration < self.max_iter and not self.optimizable.converged(
forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25,
):
if self.early_stop and iteration > 0:
self.converge_indices_list = self.optimizable.converge_indices_list
if len(self.converge_indices_list) > 0:
logging.info(f"Early stopping at iteration {iteration}")
break
logging.info(
f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist())
)
self.step()
max_forces = self.optimizable.get_max_forces(apply_constraint=True, forces=self.forces)
iteration += 1
# Step the profiler in each iteration
prof.step()
else:
# Original optimization loop without profiling
while iteration < self.max_iter and not self.optimizable.converged(
forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25,
):
if self.early_stop and iteration > 0:
self.converge_indices_list = self.optimizable.converge_indices_list
if len(self.converge_indices_list) > 0:
logging.info(f"Early stopping at iteration {iteration}")
break
logging.info(
f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist())
)
self.step()
max_forces = self.optimizable.get_max_forces(apply_constraint=True, forces=self.forces)
iteration += 1
logging.info(
f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist())
)
# GPU memory usage as per nvidia-smi seems to gradually build up as
# batches are processed. This releases unoccupied cached memory.
torch.cuda.empty_cache()
gc.collect()
# set predicted values to batch
for name, value in self.optimizable.results.items():
setattr(self.optimizable.batch, name, value)
self.nsteps = iteration
if self.early_stop:
self.converge_indices_list = self.optimizable.converge_indices_list
return self.converge_indices_list
else:
return self.optimizable.converged(
forces=None, fmax=self.fmax, max_forces=max_forces
)
def _batched_dot_2d(self, x: torch.Tensor, y: torch.Tensor):
return scatter(
(x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum"
)
def _batched_dot_1d(self, x: torch.Tensor, y: torch.Tensor):
return scatter(
(x * y), self.optimizable.batch_indices.repeat_interleave(3), reduce="sum"
)
# flake8: noqa
import math
import torch
import logging
pymin = min
pymax = max
class LineSearch:
def __init__(self, xtol=1e-14, device='cpu', dtype=torch.float64):
self.xtol = xtol
self.task = 'START'
self.device = device
self.dtype = dtype
self.isave = torch.zeros(2, dtype=torch.int64, device=self.device)
self.dsave = torch.zeros(13, dtype=self.dtype, device=self.device)
self.fc = 0
self.gc = 0
self.case = 0
self.old_stp = 0
def initialize(self, xk, pk, gfk, old_fval, old_old_fval,
maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4.,
stpmax=50., stpmin=1e-8):
# Scalar parameters can stay as Python scalars
self.stpmin = stpmin
self.stpmax = stpmax
self.xtrapl = xtrapl
self.xtrapu = xtrapu
self.maxstep = maxstep
# Move tensors to the device
self.pk = pk.to(self.device)
xk = xk.to(self.device)
gfk = gfk.to(self.device)
phi0 = old_fval
# This dot product needs tensors
derphi0 = torch.dot(gfk, self.pk).item()
# Use Python math for scalar calculations
self.dim = len(pk)
self.gms = math.sqrt(self.dim) * maxstep
alpha1 = 1.0
self.no_update = False
self.gradient = True
self.steps = []
return alpha1, phi0, derphi0
def prologue(self, fval, gval, pk_tensor, alpha1):
phi0 = fval
derphi0 = torch.dot(gval, pk_tensor)
self.old_stp = alpha1
# TODO: self.no_update == True: break is needed to reimplemented.
return phi0, derphi0
def epilogue(self):
pass
def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval,
maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4.,
stpmax=50., stpmin=1e-8, args=()):
self.stpmin = stpmin
self.pk = pk.to(self.device)
self.stpmax = stpmax
self.xtrapl = xtrapl
self.xtrapu = xtrapu
self.maxstep = maxstep
xk = xk.to(self.device)
# Convert inputs to torch tensors if they're not already
if not isinstance(old_fval, torch.Tensor):
phi0 = torch.tensor(old_fval, dtype=self.dtype, device=self.device)
else:
phi0 = old_fval.to(self.device)
# Ensure pk and gfk are torch tensors
pk_tensor = torch.tensor(pk, dtype=self.dtype, device=self.device) if not isinstance(pk, torch.Tensor) else pk.to(self.device)
gfk_tensor = torch.tensor(gfk, dtype=self.dtype, device=self.device) if not isinstance(gfk, torch.Tensor) else gfk.to(self.device)
derphi0 = torch.dot(gfk_tensor, pk_tensor)
self.dim = len(pk)
self.gms = torch.sqrt(torch.tensor(self.dim, dtype=self.dtype, device=self.device)) * maxstep
alpha1 = 1.
self.no_update = False
if isinstance(myfprime, tuple):
fprime = myfprime[0]
gradient = False
else:
fprime = myfprime
newargs = args
gradient = True
fval = phi0
gval = gfk_tensor
self.steps = []
while True:
stp = self.step(alpha1, phi0, derphi0, c1, c2,
self.xtol,
self.isave, self.dsave)
if self.task[:2] == 'FG':
alpha1 = stp
# Get function value and gradient
x_new = xk + stp * pk_tensor
fval = func(x_new).to(self.device)
self.fc += 1
gval = fprime(x_new).to(self.device)
if gradient:
self.gc += 1
else:
self.fc += len(xk) + 1
phi0 = fval
derphi0 = torch.dot(gval, pk_tensor)
self.old_stp = alpha1
if self.no_update == True:
break
else:
break
if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN':
stp = None # failed
return stp, fval.item(), old_fval.item() if isinstance(old_fval, torch.Tensor) else old_fval, self.no_update
def step(self, stp, f, g, c1, c2, xtol, isave, dsave):
if self.task[:5] == 'START':
# Check the input arguments for errors.
if stp < self.stpmin:
self.task = 'ERROR: STP .LT. minstep'
if stp > self.stpmax:
self.task = 'ERROR: STP .GT. maxstep'
if g >= 0:
self.task = 'ERROR: INITIAL G >= 0'
if c1 < 0:
self.task = 'ERROR: c1 .LT. 0'
if c2 < 0:
self.task = 'ERROR: c2 .LT. 0'
if xtol < 0:
self.task = 'ERROR: XTOL .LT. 0'
if self.stpmin < 0:
self.task = 'ERROR: minstep .LT. 0'
if self.stpmax < self.stpmin:
self.task = 'ERROR: maxstep .LT. minstep'
if self.task[:5] == 'ERROR':
return stp
# Initialize local variables.
self.bracket = False
stage = 1
finit = f
ginit = g
gtest = c1 * ginit
width = self.stpmax - self.stpmin
width1 = width / .5
# The variables stx, fx, gx contain the values of the step,
# function, and derivative at the best step.
# The variables sty, fy, gy contain the values of the step,
# function, and derivative at sty.
# The variables stp, f, g contain the values of the step,
# function, and derivative at stp.
stx = 0.0
fx = finit
gx = ginit
sty = 0.0
fy = finit
gy = ginit
stmin = 0.0
stmax = stp + self.xtrapu * stp
self.task = 'FG'
self.save((stage, ginit, gtest, gx,
gy, finit, fx, fy, stx, sty,
stmin, stmax, width, width1))
stp = self.determine_step(stp)
return stp
else:
if self.isave[0] == 1:
self.bracket = True
else:
self.bracket = False
stage = self.isave[1]
(ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax,
width, width1) = self.dsave
# If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the
# algorithm enters the second stage.
ftest = finit + stp * gtest
if stage == 1 and f < ftest and g >= 0.:
stage = 2
# Test for warnings.
if self.bracket and (stp <= stmin or stp >= stmax):
self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS'
if self.bracket and stmax - stmin <= self.xtol * stmax:
self.task = 'WARNING: XTOL TEST SATISFIED'
if stp == self.stpmax and f <= ftest and g <= gtest:
self.task = 'WARNING: STP = maxstep'
if stp == self.stpmin and (f > ftest or g >= gtest):
self.task = 'WARNING: STP = minstep'
# Test for convergence.
# if f <= ftest and abs(g) <= c2 * (- ginit):
# self.task = 'CONVERGENCE'
if (f < ftest or math.isclose(f, ftest, rel_tol=1e-6, abs_tol=1e-5)) and (abs(g) < c2 * (- ginit) or math.isclose(abs(g), c2 * (- ginit), rel_tol=1e-6, abs_tol=1e-5)):
self.task = 'CONVERGENCE'
# Test for termination.
if self.task[:4] == 'WARN' or self.task[:4] == 'CONV':
self.save((stage, ginit, gtest, gx,
gy, finit, fx, fy, stx, sty,
stmin, stmax, width, width1))
return stp
stx, sty, stp, gx, fx, gy, fy = self.update(stx, fx, gx, sty,
fy, gy, stp, f, g,
stmin, stmax)
# Decide if a bisection step is needed.
if self.bracket:
if abs(sty - stx) >= .66 * width1:
stp = stx + .5 * (sty - stx)
width1 = width
width = abs(sty - stx)
# Set the minimum and maximum steps allowed for stp.
if self.bracket:
stmin = min(stx, sty)
stmax = max(stx, sty)
else:
stmin = stp + self.xtrapl * (stp - stx)
stmax = stp + self.xtrapu * (stp - stx)
# Force the step to be within the bounds maxstep and minstep.
stp = max(stp, self.stpmin)
stp = min(stp, self.stpmax)
if (stx == stp and stp == self.stpmax and stmin > self.stpmax):
self.no_update = True
# If further progress is not possible, let stp be the best
# point obtained during the search.
if (self.bracket and stp < stmin or stp >= stmax) \
or (self.bracket and stmax - stmin < self.xtol * stmax):
stp = stx
# Obtain another function and derivative.
self.task = 'FG'
self.save((stage, ginit, gtest, gx,
gy, finit, fx, fy, stx, sty,
stmin, stmax, width, width1))
return stp
def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp,
stpmin, stpmax):
sign = gp * (gx / abs(gx))
# First case: A higher function value. The minimum is bracketed.
# If the cubic step is closer to stx than the quadratic step, the
# cubic step is taken, otherwise the average of the cubic and
# quadratic steps is taken.
if fp > fx: # case1
self.case = 1
theta = 3. * (fx - fp) / (stp - stx) + gx + gp
s = max(max(abs(theta), abs(gx)), abs(gp))
gamma = s * math.sqrt((theta / s) ** 2. - (gx / s) * (gp / s))
if stp < stx:
gamma = -gamma
p = (gamma - gx) + theta
q = ((gamma - gx) + gamma) + gp
r = p / q
stpc = stx + r * (stp - stx)
stpq = stx + ((gx / ((fx - fp) / (stp - stx) + gx)) / 2.) \
* (stp - stx)
if (abs(stpc - stx) < abs(stpq - stx)):
stpf = stpc
else:
stpf = stpc + (stpq - stpc) / 2.
self.bracket = True
# Second case: A lower function value and derivatives of opposite
# sign. The minimum is bracketed. If the cubic step is farther from
# stp than the secant step, the cubic step is taken, otherwise the
# secant step is taken.
elif sign < 0: # case2
self.case = 2
theta = 3. * (fx - fp) / (stp - stx) + gx + gp
s = max(max(abs(theta), abs(gx)), abs(gp))
gamma = s * math.sqrt((theta / s) ** 2 - (gx / s) * (gp / s))
if stp > stx:
gamma = -gamma
p = (gamma - gp) + theta
q = ((gamma - gp) + gamma) + gx
r = p / q
stpc = stp + r * (stx - stp)
stpq = stp + (gp / (gp - gx)) * (stx - stp)
if (abs(stpc - stp) > abs(stpq - stp)):
stpf = stpc
else:
stpf = stpq
self.bracket = True
# Third case: A lower function value, derivatives of the same sign,
# and the magnitude of the derivative decreases.
elif abs(gp) < abs(gx): # case3
self.case = 3
# The cubic step is computed only if the cubic tends to infinity
# in the direction of the step or if the minimum of the cubic
# is beyond stp. Otherwise the cubic step is defined to be the
# secant step.
theta = 3. * (fx - fp) / (stp - stx) + gx + gp
s = max(max(abs(theta), abs(gx)), abs(gp))
# The case gamma = 0 only arises if the cubic does not tend
# to infinity in the direction of the step.
gamma = s * math.sqrt(max(0., (theta / s) ** 2 - (gx / s) * (gp / s)))
if stp > stx:
gamma = -gamma
p = (gamma - gp) + theta
q = (gamma + (gx - gp)) + gamma
r = p / q
if r < 0. and gamma != 0:
stpc = stp + r * (stx - stp)
elif stp > stx:
stpc = stpmax
else:
stpc = stpmin
stpq = stp + (gp / (gp - gx)) * (stx - stp)
if self.bracket:
# A minimizer has been bracketed. If the cubic step is
# closer to stp than the secant step, the cubic step is
# taken, otherwise the secant step is taken.
if abs(stpc - stp) < abs(stpq - stp):
stpf = stpc
else:
stpf = stpq
if stp > stx:
stpf = min(stp + .66 * (sty - stp), stpf)
else:
stpf = max(stp + .66 * (sty - stp), stpf)
else:
# A minimizer has not been bracketed. If the cubic step is
# farther from stp than the secant step, the cubic step is
# taken, otherwise the secant step is taken.
if abs(stpc - stp) > abs(stpq - stp):
stpf = stpc
else:
stpf = stpq
stpf = min(stpmax, stpf)
stpf = max(stpmin, stpf)
# Fourth case: A lower function value, derivatives of the same sign,
# and the magnitude of the derivative does not decrease. If the
# minimum is not bracketed, the step is either minstep or maxstep,
# otherwise the cubic step is taken.
else: # case4
self.case = 4
if self.bracket:
theta = 3. * (fp - fy) / (sty - stp) + gy + gp
s = max(max(abs(theta), abs(gy)), abs(gp))
gamma = s * math.sqrt((theta / s) ** 2 - (gy / s) * (gp / s))
if stp > sty:
gamma = -gamma
p = (gamma - gp) + theta
q = ((gamma - gp) + gamma) + gy
r = p / q
stpc = stp + r * (sty - stp)
stpf = stpc
elif stp > stx:
stpf = stpmax
else:
stpf = stpmin
# Update the interval which contains a minimizer.
if fp > fx:
sty = stp
fy = fp
gy = gp
else:
if sign < 0:
sty = stx
fy = fx
gy = gx
stx = stp
fx = fp
gx = gp
# Compute the new step.
stp = self.determine_step(stpf)
return stx, sty, stp, gx, fx, gy, fy
def determine_step(self, stp):
dr = stp - self.old_stp
x = torch.reshape(self.pk.to(self.device), (-1, 3))
steplengths = ((dr * x)**2).sum(1)**0.5
maxsteplength = max(steplengths)
if maxsteplength >= self.maxstep:
dr *= self.maxstep / maxsteplength
stp = self.old_stp + dr
return stp
def save(self, data):
if self.bracket:
self.isave[0] = 1
else:
self.isave[0] = 0
self.isave[1] = data[0]
self.dsave = data[1:]
class LineSearchBatch:
def __init__(self, batch_indices, device='cpu', dtype=torch.float64):
self.device = device
self.dtype = dtype
self.batch_indices = batch_indices.to(self.device)
self.batch_indices_flatten = self.batch_indices.repeat_interleave(3).to(self.device)
self.batch_size = len(torch.unique(batch_indices))
self.linesearch_list = [LineSearch(device=self.device, dtype=self.dtype) for _ in range(self.batch_size)]
self.steps = [1.] * self.batch_size
self.phi0_values = [None] * self.batch_size
self.derphi0_values = [None] * self.batch_size
def restart_from_earlystop(self, restart_indices, batch_indices_new):
self.batch_indices = batch_indices_new.to(self.device)
self.batch_indices_flatten = self.batch_indices.repeat_interleave(3).to(self.device)
self.batch_size = len(torch.unique(batch_indices_new))
linesearch_list_new = []
steps_new = []
phi0_values_new = []
derphi0_values_new = []
for i, idx in enumerate(restart_indices):
linesearch_list_new.append(self.linesearch_list[idx])
steps_new.append(self.steps[idx])
phi0_values_new.append(self.phi0_values[idx])
derphi0_values_new.append(self.derphi0_values[idx])
for i in range(len(restart_indices), self.batch_size):
linesearch_list_new.append(LineSearch(device=self.device))
steps_new.append(1.)
phi0_values_new.append(None)
derphi0_values_new.append(None)
self.linesearch_list = linesearch_list_new
self.steps = steps_new
self.phi0_values = phi0_values_new
self.derphi0_values = derphi0_values_new
def _linesearch_batch(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval,
maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4.,
stpmax=50., stpmin=1e-8, continue_search=None, max_iter=15):
if continue_search is None:
self.linesearch_list = [LineSearch(device=self.device) for _ in range(self.batch_size)]
else:
assert len(continue_search) == self.batch_size
for i in range(len(continue_search)):
if not continue_search[i]:
self.linesearch_list[i] = LineSearch(device=self.device)
if isinstance(xk, torch.Tensor):
xk = xk.to(self.device)
for i in range(len(pk)):
pk[i] = pk[i].to(self.device)
if isinstance(gfk, torch.Tensor):
gfk = gfk.to(self.device)
if isinstance(old_fval, torch.Tensor):
old_fval = old_fval.to(self.device)
if isinstance(old_old_fval, torch.Tensor):
old_old_fval = old_old_fval.to(self.device)
# results for each batch element
alpha_results = []
e_result = []
e0_result = []
no_update_result = []
# Initialize step sizes and line search state for each batch element
completed = [False] * self.batch_size
# Initialize iteration counter
iter_count = 0
# Initialize all line searches using the initialize method
for i in range(self.batch_size):
if continue_search[i]:
continue
ls = self.linesearch_list[i]
mask = (i == self.batch_indices_flatten)
# Use the initialize method to set up line search parameters
alpha1, phi0, derphi0 = ls.initialize(
xk[mask], pk[i], gfk[mask], old_fval[i], old_old_fval,
maxstep, c1, c2, xtrapl, xtrapu, stpmax, stpmin
)
# Store the initialization values
self.steps[i] = alpha1
self.phi0_values[i] = phi0
self.derphi0_values[i] = derphi0
# Main optimization loop
while True:
# 1. step forward
# logging.info(f"step's input: alpha1: {torch.tensor([step.item() if isinstance(step, torch.Tensor) else step for step in self.steps])}")
for i in range(self.batch_size):
if completed[i]:
continue
ls = self.linesearch_list[i]
if ls.fc > max_iter:
completed[i] = True
logging.warning(f"LineSearchBatch[{i}] reached max_iter: {max_iter}")
continue
stp = ls.step(self.steps[i], self.phi0_values[i], self.derphi0_values[i],
c1, c2, ls.xtol, ls.isave, ls.dsave)
if ls.task[:2] == 'FG':
self.steps[i] = stp
else:
completed[i] = True
# 2. calculate new function value and gradient
x_new_batch = torch.zeros_like(xk)
for i in range(self.batch_size):
mask = (i == self.batch_indices_flatten)
x_new_batch[mask] = xk[mask] + self.steps[i] * pk[i]
f_batch = func(x_new_batch).to(self.device)
g_batch = myfprime(x_new_batch).to(self.device)
# 3. update function value and gradient
for i in range(self.batch_size):
ls = self.linesearch_list[i]
mask = (i == self.batch_indices_flatten)
if ls.task[:2] == 'FG':
# Update function value and gradient
f_val = f_batch[i:i+1]
g_val = g_batch[mask]
ls.fc += 1
phi0, derphi0 = ls.prologue(f_val, g_val, pk[i], self.steps[i])
# logging.info(f"phi0, derphi0: {phi0}, {derphi0}")
self.phi0_values[i] = phi0
self.derphi0_values[i] = derphi0 # TODO: why we put the derphi0 here instead of set it inside the LineSearch class?
if ls.no_update:
completed[i] = True
else:
completed[i] = True
iter_count += 1
logging.info(f"LineSearchBatch iter: {iter_count}: alpha: {torch.tensor([step.item() if isinstance(step, torch.Tensor) else step for step in self.steps])}")
if any(completed):
break
# 4. set a linesearch upper limit
# if iter_count > max_iter:
# for i in range(self.batch_size):
# completed[i] = True
# logging.warning(f"LineSearchBatch reached max_iter: {max_iter}")
# break
# Collect results
for i in range(self.batch_size):
ls = self.linesearch_list[i]
if ls.task[:5] == 'ERROR' or ls.task[1:4] == 'WARN':
stp = torch.tensor(1., device=self.device)
else:
stp = self.steps[i] if isinstance(self.steps[i], torch.Tensor) else torch.tensor(self.steps[i], device=self.device)
alpha_results.append(stp)
e_result.append(self.phi0_values[i].item() if self.phi0_values[i] is not None else None)
e0_result.append(old_fval[i].item() if isinstance(old_fval[i], torch.Tensor) else old_fval[i])
no_update_result.append(ls.no_update)
logging.info(f"LineSearchBatch finished in {iter_count} iterations. \
LineSearch Status: {[stat for stat in completed]}")
return alpha_results, e_result, e0_result, no_update_result, completed
"""
Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan}
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from ase.io import read
# from ase.optimize import ASE_LBFGS
import torch
from torch.multiprocessing import Process, set_start_method
from batchopt.atoms_to_graphs import AtomsToGraphs
from batchopt.utils import data_list_collater
from batchopt.relaxation.optimizers import (
BFGS,
BFGSFusedLS,
)
from batchopt.relaxation import OptimizableBatch, OptimizableUnitCellBatch
import logging
import time
import csv
from multiprocessing import Queue
import os
import psutil
import multiprocessing
import json
import subprocess
try:
from chgnet.model.dynamics import CHGNetCalculator
except ImportError:
logging.warning("Failed to import CHGNet modules")
try:
from sevenn.calculator import SevenNetCalculator, SevenNetD3Calculator
except ImportError:
logging.warning("Failed to import SevenNet modules")
try:
from fairchem.core import pretrained_mlip, FAIRChemCalculator
except ImportError:
logging.warning("Failed to import FAIRChem modules")
try:
from mace.calculators import mace_off
except ImportError:
logging.warning("Failed to import MACE modules")
import threading
from .utils import count_atoms_cif
from collections import deque
class Scheduler:
"""
Scheduler distributes relaxation tasks to workers.
"""
def __init__(
self,
files,
num_workers,
devices,
batch_size,
max_steps,
filter1,
filter2,
optimizer1,
optimizer2,
skip_second_stage,
scalar_pressure,
compile_mode,
profile,
num_threads,
bind_cores,
cueq,
molecule_single,
output_path,
model,
):
self.files = files
self.num_workers = num_workers
self.devices = devices
self.batch_size = batch_size
self.max_steps = max_steps
self.filter1 = filter1
self.filter2 = filter2
self.optimizer1 = optimizer1
self.optimizer2 = optimizer2
self.skip_second_stage = skip_second_stage
self.scalar_pressure = scalar_pressure
self.compile_mode = compile_mode
self.profile = profile
self.num_threads = num_threads
self.cueq = cueq
self.molecule_single = molecule_single
self.output_path = (
output_path
if os.path.isabs(output_path)
else os.path.abspath(output_path)
)
self.model = model
try:
set_start_method("spawn")
except RuntimeError:
logging.warning(
"set_start_method('spawn') failed, trying 'forkserver' instead."
)
if bind_cores is not None:
self.cpu_mask = self._parse_bind_cores(bind_cores)
else:
self.cpu_mask = None
def _parse_bind_cores(self, bind_cores):
# Expect custom_bind_str to be like "0-15,16-31,..."
ranges = bind_cores.split(",")
if len(ranges) != self.num_workers:
return None
binding = []
for r in ranges:
try:
start_str, end_str = r.split("-")
start = int(start_str)
end = int(end_str)
except ValueError:
logging.error("Custom binding format should be 'start-end'.")
return None
binding.append(set(range(start, end + 1)))
return binding
def _get_physical_logical_core_mapping(self):
"""Get the mapping between logical cores and their physical core IDs."""
try:
# This information is available in Linux systems
mapping = {}
logical_cores = psutil.cpu_count(logical=True)
for i in range(logical_cores):
try:
# Read core_id from /sys/devices/system/cpu/cpu{i}/topology/core_id
with open(
f"/sys/devices/system/cpu/cpu{i}/topology/core_id"
) as f:
core_id = int(f.read().strip())
# Read physical_package_id (socket) for more complete information
with open(
f"/sys/devices/system/cpu/cpu{i}/topology/physical_package_id"
) as f:
package_id = int(f.read().strip())
mapping[i] = (package_id, core_id)
except (FileNotFoundError, ValueError, IOError):
mapping[i] = None
return mapping
except Exception as e:
logging.error(f"Failed to get core mapping: {e}")
return {}
def _get_physical_core_mask(self):
# Get the number of physical and logical cores
physical_cores = psutil.cpu_count(logical=False)
logical_cores = psutil.cpu_count(logical=True)
if physical_cores is None or physical_cores < 1:
# Fallback to multiprocessing if psutil fails
logical_cores = multiprocessing.cpu_count()
physical_cores = logical_cores // 2
if physical_cores < 1:
physical_cores = 1
print(f"Using estimated physical cores: {physical_cores}")
# Get the mapping between logical and physical cores
core_mapping = self._get_physical_logical_core_mapping()
# Create a CPU mask that includes all physical cores (first core of each physical core)
physical_core_mask = set()
if core_mapping:
# Group by physical core ID
cores_by_physical = {}
for logical_id, physical_info in core_mapping.items():
if physical_info is not None:
package_id, core_id = physical_info
key = (package_id, core_id)
if key not in cores_by_physical:
cores_by_physical[key] = []
cores_by_physical[key].append(logical_id)
# Select one logical core from each physical core
for physical_cores_list in cores_by_physical.values():
physical_core_mask.add(
physical_cores_list[0]
) # First logical core of each physical core
else:
# If mapping fails, use a simple assumption (may not be accurate on all systems)
threads_per_core = logical_cores // physical_cores
physical_core_mask = set(range(0, logical_cores, threads_per_core))
return physical_core_mask
def worker_task(
self, files, device, batch_size, result_queue, physical_cores
):
if physical_cores is not None:
try:
# Bind the current process to physical cores
pid = os.getpid()
os.sched_setaffinity(pid, physical_cores)
logging.info(f"bind to physical_core_ids: {physical_cores}")
# Verify the affinity was set correctly
current_affinity = os.sched_getaffinity(pid)
logging.info(
f"Process bound to {len(current_affinity)} cores: {sorted(current_affinity)}"
)
except AttributeError:
logging.error(
"sched_setaffinity not supported on this platform"
)
except Exception as e:
logging.error(f"Failed to bind to physical cores: {e}")
# pass the number of processes on each worker
nproc = self.num_workers // len(self.devices)
worker = Worker(
files,
device,
batch_size,
self.max_steps,
self.filter1,
self.filter2,
self.optimizer1,
self.optimizer2,
self.skip_second_stage,
self.scalar_pressure,
self.compile_mode,
self.profile,
self.cueq,
self.molecule_single,
self.output_path,
self.model,
nproc,
)
# results = worker.run()
results = worker.continuous_run()
result_queue.put(results)
def _terminate_processes(self, processes):
"""Helper method to terminate all processes."""
for i, p in processes:
if p.is_alive():
logging.info(f"Terminating process {p.pid}")
p.terminate()
p.join(timeout=3) # Wait for up to 3 seconds
if p.is_alive():
logging.warning(
f"Process {p.pid} did not terminate, killing it"
)
p.kill()
p.join()
# create a thread to conduct "nvidia-smi"
@staticmethod
def _monitor_memory(interval=2, gpu_index=1):
try:
while True:
result = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=memory.used,memory.total",
"--format=csv,nounits,noheader",
]
).decode("utf-8")
lines = result.strip().split("\n")
used, total = map(int, lines[gpu_index].split(","))
logging.info(
f"[nvidia-smi] Memory-Usage on GPU {gpu_index}: {used}MiB / {total}MiB"
)
time.sleep(interval)
except KeyboardInterrupt:
logging.info("Monitor interrupted.")
except Exception as e:
logging.error(f"Unexpected error when monitor memory: {str(e)}")
def run(self):
logging.info(f"Starting Scheduler with {self.num_workers} workers.")
processes = []
result_queue = Queue()
start_time = time.perf_counter()
if self.cpu_mask is not None:
physical_cores_per_worker = self.cpu_mask
logging.info(
f"Use customed cores binding. Physical cores per worker: {physical_cores_per_worker}"
)
else:
# all_physical_cores = self._get_physical_core_mask()
# num_per_worker = len(all_physical_cores) // self.num_workers
# physical_cores_per_worker = [
# list(all_physical_cores)[i:i + num_per_worker] for i in range(0, len(all_physical_cores), num_per_worker)
# ]
# logging.info(f"Physical cores per worker: {physical_cores_per_worker}")
physical_cores_per_worker = [None] * self.num_workers
try:
# Start all worker processes
for i in range(self.num_workers):
files_for_worker = self.files[i :: self.num_workers]
device = self.devices[i % len(self.devices)]
logging.info(
f"Starting worker {i} with {len(files_for_worker)} files on device {device}."
)
p = Process(
target=self.worker_task,
args=(
files_for_worker,
device,
self.batch_size,
result_queue,
physical_cores_per_worker[i],
),
)
p.start()
processes.append((i, p))
# monitor gpu memory usage to figure out what makes the differences of footprint among batches
# in each iteration.
use_memory_monitor = False
if use_memory_monitor:
monitor_proc = Process(
target=Scheduler._monitor_memory, args=()
)
monitor_proc.start()
# Monitor processes and collect results
csv_paths = []
completed_processes = 0
while completed_processes < self.num_workers:
for i, p in processes:
if not p.is_alive() and p.exitcode != 0:
if p.exitcode == -11 or p.exitcode == 1:
# Restart the process if exit code is -11 or -1
logging.warning(
f"Worker process {p.pid} exited with code {p.exitcode}. Restarting worker {i}."
)
files_for_worker = self.files[i :: self.num_workers]
device = self.devices[i % len(self.devices)]
new_process = Process(
target=self.worker_task,
args=(
files_for_worker,
device,
self.batch_size,
result_queue,
physical_cores_per_worker[i],
),
)
new_process.start()
processes[i] = (
i,
new_process,
) # Replace the old process with the new one
else:
# Raise an error for other exit codes
raise RuntimeError(
f"Worker process {p.pid} failed with exit code {p.exitcode}"
)
# Try to get result from queue with timeout
try:
result = result_queue.get(timeout=10)
csv_paths.append(result)
completed_processes += 1
except Exception as e:
continue
# terminate monitor
if use_memory_monitor:
monitor_proc.terminate()
monitor_proc.join()
# Process results and create final CSV
merged_results = []
for csv_path in csv_paths:
try:
with open(csv_path, mode="r") as f:
reader = csv.DictReader(f)
merged_results.extend(list(reader))
except Exception as e:
logging.error(f"Error processing {csv_path}: {str(e)}")
except Exception as e:
# Log the error and elapsed time
end_time = time.perf_counter()
elapsed_time = end_time - start_time
logging.error(
f"Error occurred after running for {elapsed_time:.2f} seconds: {str(e)}"
)
# Create error log file
error_log = f"scheduler_error_{int(time.time())}.log"
with open(error_log, "w") as f:
f.write(f"Error occurred after {elapsed_time:.2f} seconds\n")
f.write(f"Error message: {str(e)}\n")
f.write(f"Number of workers: {self.num_workers}\n")
f.write(f"Batch size: {self.batch_size}\n")
# Terminate all processes
self._terminate_processes(processes)
raise # Re-raise the exception after cleanup
finally:
end_time = time.perf_counter()
elapsed_time = end_time - start_time
# Write final results if we have any
if "merged_results" in locals() and merged_results:
csv_file = os.path.join(
self.output_path, "results_scheduler.csv"
)
with open(csv_file, mode="w", newline="") as file:
writer = csv.DictWriter(
file,
fieldnames=[
"file",
"stage1_steps",
"stage1_time",
"stage1_energy",
"stage1_density",
"stage2_steps",
"stage2_time",
"stage2_energy",
"stage2_density",
"total_steps",
"total_time",
],
)
writer.writeheader()
for row in merged_results:
try:
processed_row = {
"file": row["file"],
"stage1_steps": int(row["stage1_steps"]),
"stage1_time": float(row["stage1_time"]),
"stage1_energy": float(row["stage1_energy"]),
"stage1_density": float(row["stage1_density"]),
"stage2_steps": int(row["stage2_steps"]),
"stage2_time": float(row["stage2_time"]),
"stage2_energy": float(row["stage2_energy"]),
"stage2_density": float(row["stage2_density"]),
"total_steps": int(row["total_steps"]),
"total_time": float(row["total_time"]),
}
writer.writerow(processed_row)
except (KeyError, ValueError) as e:
logging.error(
f"Invalid data format in row {row}: {str(e)}"
)
# Write summary
summary_csv_file = os.path.join(
self.output_path, "summary_scheduler.csv"
)
with open(summary_csv_file, mode="w", newline="") as file:
writer = csv.DictWriter(
file,
fieldnames=["elapsed_time", "num_workers", "batch_size"],
)
writer.writeheader()
writer.writerow(
{
"elapsed_time": elapsed_time,
"num_workers": self.num_workers,
"batch_size": self.batch_size,
}
)
logging.info(f"Scheduler completed in {elapsed_time:.2f} seconds.")
def run_debug(self):
logging.info("Starting Scheduler in debug mode (sequential execution).")
def worker_task(files, device, batch_size):
worker = Worker(
files, device, batch_size, self.max_steps, self.filter1
)
worker.run()
for i in range(self.num_workers):
files_for_worker = self.files[i :: self.num_workers]
device = self.devices[i % len(self.devices)]
logging.info(
f"Running worker {i} with {len(files_for_worker)} files on device {device}."
)
worker_task(files_for_worker, device, self.batch_size)
logging.info("All workers have completed their tasks in debug mode.")
class Worker:
"""
Worker is single process that runs a batch of optimization tasks.
"""
def __init__(
self,
files,
device,
batch_size,
max_steps,
filter1,
filter2,
optimizer1,
optimizer2,
skip_second_stage,
scalar_pressure,
compile_mode,
profile,
cueq,
molecule_single,
output_path,
model,
nproc,
):
self.files = files
self.device = device
self.batch_size = batch_size
self.max_steps = max_steps
self.filter1 = filter1
self.filter2 = filter2
self.optimizer1 = optimizer1
self.optimizer2 = optimizer2
self.skip_second_stage = skip_second_stage # Store skip_second_stage
self.scalar_pressure = scalar_pressure
self.compile_mode = compile_mode
self.profile = profile
self.cueq = cueq
self.molecule_single = molecule_single
self.output_path = (
output_path
if os.path.isabs(output_path)
else os.path.abspath(output_path)
)
self.model = model
self.nproc = nproc
# Parse profiler options if provided
self.use_profiler = False
self.profiler_schedule_config = {
"wait": 48,
"warmup": 1,
"active": 1,
"repeat": 1,
}
self.profiler_log_dir = None
if self.profile and self.profile != "False":
self.use_profiler = True
# Create directory for profiler output
self.profiler_log_dir = os.path.join(self.output_path, "log")
os.makedirs(self.profiler_log_dir, exist_ok=True)
if self.profile != "True":
try:
# Try to parse profile as a JSON string with schedule config
profile_config = json.loads(self.profile)
if isinstance(profile_config, dict):
for key in ["wait", "warmup", "active", "repeat"]:
if key in profile_config and isinstance(
profile_config[key], int
):
self.profiler_schedule_config[key] = (
profile_config[key]
)
except json.JSONDecodeError:
logging.warning(
f"Could not parse profile config: {self.profile}, using defaults"
)
# For monitor thread
self.stop_event = threading.Event()
def run(self):
logging.info(
f"Worker started on device {self.device} with {len(self.files)} files."
)
a2g = AtomsToGraphs(r_edges=False, r_pbc=True)
# model = torch.load("/home/mazhaojia/.cache/mace/MACE-OFF23_small.model", map_location=self.device)
# z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])
calculator = mace_off(model="small", device=self.device)
results = []
for batch_files in self._batch_files(self.files, self.batch_size):
logging.info(f"Processing batch with {len(batch_files)} files.")
start_time = time.perf_counter()
atoms_list = []
for file in batch_files:
atoms = read(file)
atoms_list.append(atoms)
gbatch = data_list_collater(
[a2g.convert(atoms) for atoms in atoms_list]
)
gbatch = gbatch.to(self.device)
if self.filter1 == "UnitCellFilter":
from batchopt.relaxation import OptimizableUnitCellBatch
obatch = OptimizableUnitCellBatch(
gbatch,
trainer=calculator,
numpy=False,
scalar_pressure=self.scalar_pressure,
)
else:
obatch = OptimizableBatch(
gbatch, trainer=calculator, numpy=False
)
# First optimization stage
if self.optimizer1 == "LBFGS":
batch_optimizer1 = LBFGS(
obatch, damping=1.0, alpha=70.0, maxstep=0.2
)
elif self.optimizer1 == "BFGS":
batch_optimizer1 = BFGS(obatch, alpha=70.0, maxstep=0.2)
elif self.optimizer1 == "BFGSLineSearch":
batch_optimizer1 = BFGSLineSearch(obatch, device=self.device)
elif self.optimizer1 == "BFGSFusedLS":
batch_optimizer1 = BFGSFusedLS(obatch, device=self.device)
else:
raise ValueError(f"Unknown optimizer: {self.optimizer1}")
start_time1 = time.perf_counter()
batch_optimizer1.run(0.01, self.max_steps)
end_time1 = time.perf_counter()
elapsed_time1 = end_time1 - start_time1
# Save intermediate results
atoms_list = obatch.get_atoms_list()
for atoms, file_path in zip(atoms_list, batch_files):
file_name = file_path.split("/")[-1]
output_file = os.path.join(
self.output_path,
"cif_result_press",
file_name.replace(".cif", "_press.cif"),
)
atoms.write(output_file)
# Capture maximum force after first optimization stage
max_force1 = obatch.get_max_forces(apply_constraint=True)
steps1 = batch_optimizer1.nsteps
if self.skip_second_stage:
# If skipping second stage, set metrics to zero
for file, force in zip(batch_files, max_force1):
results.append(
{
"file": file,
"stage1_time": elapsed_time1,
"stage1_steps": steps1,
"stage2_time": 0.0,
"stage2_steps": 0,
"total_time": elapsed_time1,
"total_steps": steps1,
"force1": force.item(),
"force2": 0.0,
}
)
continue
# Only proceed with second stage if not skipping
# Reload intermediate structures for second stage
atoms_list = []
for file_path in batch_files:
file_name = file_path.split("/")[-1]
press_file = os.path.join(
self.output_path,
"cif_result_press",
file_name.replace(".cif", "_press.cif"),
)
atoms = read(press_file)
atoms_list.append(atoms)
# Rebuild batch from optimized structures
gbatch = data_list_collater(
[a2g.convert(atoms) for atoms in atoms_list]
)
gbatch = gbatch.to(self.device)
# Second optimization stage
if self.filter2 == "UnitCellFilter":
obatch2 = OptimizableUnitCellBatch(
gbatch, trainer=calculator, numpy=False, scalar_pressure=0.0
)
else:
obatch2 = OptimizableBatch(
gbatch, trainer=calculator, numpy=False
)
if self.optimizer2 == "LBFGS":
batch_optimizer2 = LBFGS(
obatch2, damping=1.0, alpha=70.0, maxstep=0.2
)
elif self.optimizer2 == "BFGS":
batch_optimizer2 = BFGS(obatch2, alpha=70.0, maxstep=0.2)
elif self.optimizer2 == "BFGSLineSearch":
batch_optimizer2 = BFGSLineSearch(obatch2, device=self.device)
elif self.optimizer2 == "BFGSFusedLS":
batch_optimizer2 = BFGSFusedLS(obatch2, device=self.device)
else:
raise ValueError(f"Unknown optimizer: {self.optimizer2}")
start_time2 = time.perf_counter()
batch_optimizer2.run(0.01, self.max_steps)
end_time2 = time.perf_counter()
elapsed_time2 = end_time2 - start_time2
# Save final results
atoms_list = obatch2.get_atoms_list()
for atoms, file_path in zip(atoms_list, batch_files):
file_name = file_path.split("/")[-1]
output_file = os.path.join(
self.output_path,
"cif_result_final",
file_name.replace(".cif", "_opt.cif"),
)
atoms.write(output_file)
# Capture maximum force after second optimization stage
max_force2 = obatch2.get_max_forces(apply_constraint=True)
steps2 = batch_optimizer2.nsteps
for file, f1, f2 in zip(batch_files, max_force1, max_force2):
results.append(
{
"file": file,
"stage1_time": elapsed_time1,
"stage1_steps": steps1,
"stage2_time": elapsed_time2,
"stage2_steps": steps2,
"total_time": elapsed_time1 + elapsed_time2,
"total_steps": steps1 + steps2,
"force1": f1.item(),
"force2": f2.item(),
}
)
return results
def _batch_files(self, files, batch_size):
for i in range(0, len(files), batch_size):
yield files[i : i + batch_size]
@staticmethod
def _torch_memory_monitor(interval=2, device=None, stop_event=None):
try:
# explicitly CUDA initialization
torch.cuda._lazy_init()
while not stop_event.is_set():
allocated = torch.cuda.memory_allocated(device=device)
reserved = torch.cuda.memory_reserved(device=device)
logging.info(
f"[torch] Allocated Memory: {allocated / 1024**2:.2f} MiB"
)
logging.info(
f"[torch] Reserved Memory: {reserved / 1024**2:.2f} MiB"
)
time.sleep(interval)
except Exception as e:
logging.error(f"Unexpected error when monitor memory: {str(e)}")
def continuous_run(self):
"""
Execute a continuous run of the batching optimization process.
"""
logging.info("Starting continuous_run with two rounds of optimization.")
# torch memory monitor api
use_torch_memory_monitor = False
if use_torch_memory_monitor:
memory_monitor = threading.Thread(
target=Worker._torch_memory_monitor,
args=(2, self.device, self.stop_event),
)
memory_monitor.start()
# First round of optimization
try:
logging.info("Starting first round of optimization.")
results_round1, new_atoms_files = self.continuous_batching(
atoms_path=self.files,
result_path_prefix=os.path.join(
self.output_path, "cif_result_press/"
),
fmax=0.01,
maxstep=self.max_steps,
use_filter=self.filter1,
optimizer=self.optimizer1,
scalar_pressure=self.scalar_pressure,
dtype=torch.float64,
)
logging.info(
f"Completed first round of optimization. Results: {len(results_round1)}"
)
except KeyboardInterrupt as e:
if use_torch_memory_monitor:
self.stop_event.set()
memory_monitor.join()
logging.error(f"Error during first round of optimization: {e}")
raise
except Exception as e:
logging.error(f"Error during first round of optimization: {e}")
raise
if self.skip_second_stage:
logging.info("Skipping second round of optimization.")
return results_round1
# Second round of optimization without pressure
try:
logging.info("Starting second round of optimization.")
results_round2, _ = self.continuous_batching(
atoms_path=new_atoms_files,
result_path_prefix=os.path.join(
self.output_path, "cif_result_final/"
),
fmax=0.01,
maxstep=self.max_steps,
# maxstep=3000,
use_filter=self.filter2,
optimizer=self.optimizer2,
scalar_pressure=0.0,
dtype=torch.float64,
)
logging.info(
f"Completed second round of optimization. Results: {len(results_round2)}"
)
except KeyboardInterrupt as e:
if use_torch_memory_monitor:
self.stop_event.set()
memory_monitor.join()
logging.error(f"Error during second round of optimization: {e}")
raise
except Exception as e:
logging.error(f"Error during second round of optimization: {e}")
raise
if use_torch_memory_monitor:
self.stop_event.set()
memory_monitor.join()
return self._save_results_to_csv(results_round1, results_round2)
def _save_results_to_csv(self, results_round1, results_round2):
"""Helper method to save results to CSV file and return the path."""
combined_results = []
results_map = {}
# Process first round results
for result in results_round1:
file_name = result["file"]
results_map[file_name] = {
"file": file_name,
"stage1_steps": result["steps"],
"stage1_time": result["runtime"],
"stage1_energy": result["energy"],
"stage1_density": result["density"],
"stage2_steps": 0,
"stage2_time": 0.0,
"stage2_energy": 0.0,
"stage2_density": 0,
"total_steps": result["steps"],
"total_time": result["runtime"],
}
# Process second round results
for result in results_round2:
file_name = result["file"]
if file_name in results_map:
results_map[file_name].update(
{
"stage2_steps": result["steps"],
"stage2_time": result["runtime"],
"stage2_energy": result["energy"],
"stage2_density": result["density"],
"total_steps": results_map[file_name]["stage1_steps"]
+ result["steps"],
"total_time": results_map[file_name]["stage1_time"]
+ result["runtime"],
}
)
else:
results_map[file_name] = {
"file": file_name,
"stage1_steps": 0,
"stage1_time": 0.0,
"stage1_energy": 0.0,
"stage1_density": 0,
"stage2_steps": result["steps"],
"stage2_time": result["runtime"],
"stage2_energy": result["energy"],
"stage2_density": result["density"],
"total_steps": result["steps"],
"total_time": result["runtime"],
}
# Convert map to list
combined_results = list(results_map.values())
logging.info(
f"Combined results from both rounds. Total results: {len(combined_results)}"
)
worker_id = os.getpid()
timestamp = int(time.time())
csv_filename = f"worker_{worker_id}_{timestamp}.csv"
csv_path = os.path.join(
self.output_path, "worker_results", csv_filename
)
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
with open(csv_path, mode="w", newline="") as csvfile:
fieldnames = [
"file",
"stage1_steps",
"stage1_time",
"stage1_energy",
"stage1_density",
"stage2_steps",
"stage2_time",
"stage2_energy",
"stage2_density",
"total_steps",
"total_time",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for result in combined_results:
writer.writerow(result)
return csv_path
def _get_density(self, crystal):
# 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量
total_mass = sum(crystal.get_masses()) # 转换为克
# 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3
# 1 Å^3 = 1e-24 cm^3
volume = crystal.get_volume() # 转换为立方厘米
# 计算密度,质量除以体积
density = (
total_mass / (volume * 10**-24) / (6.022140857 * 10**23)
) # 单位是 g/cm^3
return density
@staticmethod
def select_factor(history: deque):
# TODO: when history is mix of different size, the smaller `values` should be selected.
boundaries = [0, 50, 100, 200, 400, 800]
values = [0.4, 0.8, 0.9, 0.6, 0.5, 0.4]
factor_result = []
for graph_size in history:
for i in range(len(boundaries) - 1):
if boundaries[i] <= graph_size < boundaries[i + 1]:
factor_result.append(values[i])
break
if len(factor_result) == 0:
return 0.4
else:
return min(factor_result)
def continuous_batching(
self,
atoms_path,
result_path_prefix,
fmax,
maxstep,
use_filter,
optimizer,
scalar_pressure,
dtype=torch.float64,
):
"""
Performs continuous batched optimization of atomic structures.
This method implements a continuous batching strategy for optimizing multiple atomic structures,
where converged structures are replaced with new ones to maintain batch efficiency.
Parameters
----------
atoms_path : list
List of file paths to atomic structure files to be optimized
result_path_prefix : str
Prefix for output file paths where optimized structures will be saved
fmax : float, optional
Maximum force criterion for convergence, by default 0.01
maxstep : int, optional
Maximum number of optimization steps per batch, by default 3000
use_filter : str, optional
Filter to be used for optimization, by default "UnitCellFilter"
optimizer : str, optional
Optimizer to be used for optimization, by default "LBFGS"
scalar_pressure : float, optional
Scalar pressure to be applied, by default 0.0
Returns
-------
None
The optimized structures are saved to disk
Notes
-----
The method:
- Processes structures in batches of predefined size
- Uses MACE neural network potential for energy/force calculations
- Employs LBFGS optimization with unit cell relaxation
- Dynamically replaces converged structures with new ones in the batch
- Tracks convergence and optimization steps for each structure
"""
# Load saved structures
result = []
optimized_atoms_paths = []
json_dir = result_path_prefix.replace("cif", "json")
remove_list = []
# TODO: Why we read all CIF here?
for pre_cif in atoms_path:
cif_path = os.path.join(result_path_prefix, pre_cif.split("/")[-1])
json_path = os.path.join(
json_dir, pre_cif.split("/")[-1].replace(".cif", ".json")
)
if (
os.path.exists(cif_path)
and os.path.exists(json_path)
and os.path.getsize(cif_path) > 0
and os.path.getsize(json_path) > 0
):
with open(json_path, "r") as f:
result_data = json.load(f)
result.append(result_data)
optimized_atoms_paths.append(cif_path)
remove_list.append(pre_cif)
logging.info(f"File {cif_path} already exists, loaded.")
# else:
# try:
# read(pre_cif)
# except Exception as e:
# logging.info(f"Failed to read {pre_cif}: {e}")
# remove_list.append(pre_cif)
for i in remove_list:
atoms_path.remove(i)
if self.batch_size > 0:
# Initialize variables
room_in_batch = self.batch_size
indices_to_process = 0
cur_batch_path = atoms_path[
indices_to_process : indices_to_process + room_in_batch
]
if len(cur_batch_path) == 0:
logging.info("No structures to process.")
return result, optimized_atoms_paths
room_in_batch -= len(cur_batch_path)
indices_to_process += len(cur_batch_path)
cur_atoms_list = [read(path) for path in cur_batch_path]
a2g = AtomsToGraphs(r_edges=False, r_pbc=True)
gbatch = data_list_collater(
[a2g.convert(read(path)) for path in cur_batch_path]
)
else:
# Set Maximum Number of atoms per batch
history = deque(maxlen=10)
history.append(1000)
max_bnatoms = 24080
safe_factor = self.select_factor(history)
indices_to_process = 0
bnatoms = 0
cur_batch_path = []
graphs_list = []
a2g = AtomsToGraphs(r_edges=False, r_pbc=True)
while indices_to_process < len(atoms_path):
graph_natoms = count_atoms_cif(atoms_path[indices_to_process])
if (
bnatoms + graph_natoms
> max_bnatoms * safe_factor // self.nproc
):
break
graph = a2g.convert(read(atoms_path[indices_to_process]))
bnatoms += graph_natoms
cur_batch_path.append(atoms_path[indices_to_process])
graphs_list.append(graph)
indices_to_process += 1
history.append(graph_natoms)
safe_factor = self.select_factor(history)
if len(graphs_list) == 0:
logging.info("No structures to process.")
return result, optimized_atoms_paths
gbatch = data_list_collater(graphs_list)
logging.info(f"current batch size: {len(cur_batch_path)}")
total_natoms = sum([graph.natoms for graph in graphs_list])
logging.info(f"total_natoms: {total_natoms}")
gbatch = gbatch.to(self.device)
batch_optimizer = None
# Initial calculator
if self.model == "mace":
if dtype == torch.float32:
calculator = mace_off(
model="small",
device=self.device,
enable_cueq=self.cueq,
default_dtype="float32",
)
else:
calculator = mace_off(
model="small", device=self.device, enable_cueq=self.cueq
)
elif self.model == "chgnet":
calculator = CHGNetCalculator(
use_device=self.device, enable_cueq=self.cueq
)
elif self.model == "sevennet":
# calculator = SevenNetCalculator(device=self.device, enable_cueq=self.cueq)
calculator = SevenNetD3Calculator(
device=self.device,
enable_cueq=self.cueq,
batch_size=self.batch_size,
)
# calculator = SevenNetCalculator('7net-mf-ompa', modal='mpa', device=self.device)
# calculator = MACECalculator(model_paths="/home/mazhaojia/.cache/mace/MACE-OFF23_small.model", device=self.device, compile_mode=self.compile_mode)
if use_filter == "UnitCellFilter":
obatch = OptimizableUnitCellBatch(
gbatch,
trainer=calculator,
numpy=False,
scalar_pressure=scalar_pressure,
)
else:
obatch = OptimizableBatch(gbatch, trainer=calculator, numpy=False)
orig_cells = obatch.orig_cells.clone()
converged_atoms_count = 0
converge_indices = []
all_indices = []
cur_batch_steps = [0] * len(cur_batch_path)
cur_batch_times = [time.perf_counter()] * len(
cur_batch_path
) # Track start times
while converged_atoms_count < len(atoms_path):
# Update batch
if len(all_indices) > 0:
if self.batch_size > 0:
room_in_batch += len(all_indices)
new_batch_path = atoms_path[
indices_to_process : indices_to_process + room_in_batch
]
logging.info(f"new_batch_path: {new_batch_path}")
room_in_batch -= len(new_batch_path)
indices_to_process += len(new_batch_path)
optimized_atoms_new = []
cur_batch_path_new = []
cur_batch_steps_new = []
cur_batch_times_new = []
orig_cells_new = torch.zeros(
[self.batch_size - room_in_batch, 3, 3],
device=self.device,
)
cell_offset = 0
restart_indices = []
old_batch_indices = obatch.batch_indices
for i in range(len(optimized_atoms)):
if i in all_indices:
continue
else:
restart_indices.append(i)
optimized_atoms_new.append(optimized_atoms[i])
cur_batch_path_new.append(cur_batch_path[i])
cur_batch_steps_new.append(cur_batch_steps[i])
cur_batch_times_new.append(cur_batch_times[i])
orig_cells_new[cell_offset] = orig_cells[i]
cell_offset += 1
for new_path in new_batch_path:
optimized_atoms_new.append(read(new_path))
cur_batch_path_new.append(new_path)
cur_batch_steps_new.append(0)
cur_batch_times_new.append(time.perf_counter())
# Update the batch with new structures
optimized_atoms = optimized_atoms_new
cur_batch_path = cur_batch_path_new
cur_batch_steps = cur_batch_steps_new
cur_batch_times = cur_batch_times_new
else:
bnatoms = 0
optimized_atoms_new = []
cur_batch_path_new = []
cur_batch_steps_new = []
cur_batch_times_new = []
restart_indices = []
old_batch_indices = obatch.batch_indices
for i in range(len(optimized_atoms)):
if i in all_indices:
continue
restart_indices.append(i)
optimized_atoms_new.append(optimized_atoms[i])
cur_batch_path_new.append(cur_batch_path[i])
cur_batch_steps_new.append(cur_batch_steps[i])
cur_batch_times_new.append(cur_batch_times[i])
bnatoms += a2g.convert(read(cur_batch_path[i])).natoms
while indices_to_process < len(atoms_path):
new_path = atoms_path[indices_to_process]
graph_natoms = count_atoms_cif(new_path)
if (
bnatoms + graph_natoms
> max_bnatoms * safe_factor // self.nproc
):
break
bnatoms += graph_natoms
optimized_atoms_new.append(read(new_path))
cur_batch_path_new.append(new_path)
cur_batch_steps_new.append(0)
cur_batch_times_new.append(time.perf_counter())
indices_to_process += 1
history.append(graph_natoms)
safe_factor = self.select_factor(history)
orig_cells_new = torch.zeros(
[len(optimized_atoms_new), 3, 3], device=self.device
)
cell_offset = 0
for i in range(len(optimized_atoms)):
if i in all_indices:
continue
orig_cells_new[cell_offset] = orig_cells[i]
cell_offset += 1
# Update the batch with new structures
optimized_atoms = optimized_atoms_new
cur_batch_path = cur_batch_path_new
cur_batch_steps = cur_batch_steps_new
cur_batch_times = cur_batch_times_new
logging.info(f"current batch size: {len(optimized_atoms)}")
graphs_list = [a2g.convert(atoms) for atoms in optimized_atoms]
total_natoms = sum([graph.natoms for graph in graphs_list])
logging.info(f"total_natoms: {total_natoms}")
logging.info(f"cur_batch_path to processing: {cur_batch_path}")
gbatch = data_list_collater(graphs_list)
gbatch = gbatch.to(self.device)
if self.model == "sevennet":
# calculator = SevenNetCalculator('7net-mf-ompa', modal='mpa', device=self.device)
calculator = SevenNetD3Calculator(
device=self.device,
enable_cueq=self.cueq,
batch_size=self.batch_size,
)
if use_filter == "UnitCellFilter":
obatch = OptimizableUnitCellBatch(
gbatch,
trainer=calculator,
numpy=False,
scalar_pressure=scalar_pressure,
)
else:
obatch = OptimizableBatch(
gbatch, trainer=calculator, numpy=False
)
for i in range(cell_offset):
obatch.orig_cells[i] = orig_cells_new[i]
orig_cells = obatch.orig_cells.clone()
# Optimize the current batch
if optimizer == "LBFGS":
batch_optimizer = LBFGS(
obatch,
damping=1.0,
alpha=70.0,
maxstep=0.2,
early_stop=True,
)
elif optimizer == "BFGS":
if len(all_indices) > 0:
logging.info(f"Restarting with indices: {restart_indices}")
batch_optimizer.optimizable = obatch
else:
batch_optimizer = BFGS(
obatch, alpha=70.0, maxstep=0.2, early_stop=True
)
elif optimizer == "BFGSLineSearch":
batch_optimizer = BFGSLineSearch(
obatch,
device=self.device,
early_stop=True,
use_profiler=self.use_profiler,
profiler_log_dir=self.profiler_log_dir,
profiler_schedule_config=self.profiler_schedule_config,
)
elif optimizer == "BFGSFusedLS":
if len(all_indices) > 0:
logging.info(f"Restarting with indices: {restart_indices}")
batch_optimizer.optimizable = obatch
else:
batch_optimizer = BFGSFusedLS(
obatch,
device=self.device,
early_stop=True,
use_profiler=self.use_profiler,
profiler_log_dir=self.profiler_log_dir,
profiler_schedule_config=self.profiler_schedule_config,
)
else:
raise ValueError(f"Unknown optimizer: {optimizer}")
# 动态计算剩余可用步数(基于当前批次最大已执行步数)
current_max_steps = max(cur_batch_steps) if cur_batch_steps else 0
remaining_steps = max(
maxstep - current_max_steps, 1
) # 保证至少运行1步
# 执行优化并获取收敛的索引
if (optimizer == "BFGSFusedLS" or optimizer == "BFGS") and len(
all_indices
) > 0:
converge_indices = batch_optimizer.run(
fmax,
remaining_steps,
is_restart_earlystop=True,
restart_indices=restart_indices,
old_batch_indices=old_batch_indices,
)
else:
converge_indices = batch_optimizer.run(fmax, remaining_steps)
# Print energies of all structures
# logging.info(f"Final energies of all structures: {batch_optimizer.energies}")
energies_list = (
batch_optimizer.optimizable.get_potential_energies().tolist()
)
logging.info(f"Final energies of all structures: {energies_list}")
# 更新所有结构的累计步数
cur_batch_steps = [
steps + batch_optimizer.nsteps for steps in cur_batch_steps
]
# 找出超过最大步数的结构索引
over_maxstep_indices = [
i
for i, steps in enumerate(cur_batch_steps)
if steps >= maxstep - 1
]
# 合并收敛和超限的索引(去重)
all_indices = list(set(converge_indices + over_maxstep_indices))
# Get optimized atoms
optimized_atoms = obatch.get_atoms_list()
converged_atoms_count += len(all_indices)
end_time = time.perf_counter()
# 处理所有需要退出的结构(包括收敛和超限)
for idx in all_indices:
runtime = end_time - cur_batch_times[idx]
energy_per_mol = (
energies_list[idx]
/ (
len(optimized_atoms[idx].get_atomic_numbers())
/ self.molecule_single
)
* 96.485
)
density = self._get_density(optimized_atoms[idx])
# Save results
result_data = {
"file": cur_batch_path[idx].split("/")[-1].split(".")[0],
"steps": cur_batch_steps[idx],
"runtime": runtime,
"energy": energy_per_mol,
"density": density,
}
result.append(result_data)
# Save optimized structure
# converged_atoms_path = os.path.join(result_path_prefix, cur_batch_path[idx].split('/')[-1].replace('.cif', '.traj'))
converged_atoms_path = os.path.join(
result_path_prefix, cur_batch_path[idx].split("/")[-1]
)
optimized_atoms[idx].write(converged_atoms_path)
optimized_atoms_paths.append(converged_atoms_path)
# write a json file to store reslt_data
os.makedirs(json_dir, exist_ok=True)
# json_path = os.path.join(json_dir, cur_batch_path[idx].split('/')[-1]+'.json')
json_path = os.path.join(
json_dir,
cur_batch_path[idx].split("/")[-1].replace(".cif", ".json"),
)
with open(json_path, "w") as f:
json.dump(result_data, f)
logging.info(f"cur_batch_path: {cur_batch_path}")
logging.info(f"cur_batch_steps: {cur_batch_steps}")
logging.info(f"all_indices: {all_indices}")
logging.info(f"length of optimized_atoms: {len(optimized_atoms)}")
return result, optimized_atoms_paths
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import ast
import collections
import copy
import datetime
import errno
import functools
import importlib
import itertools
import json
import logging
import os
import subprocess
import sys
import time
from bisect import bisect
from contextlib import contextmanager
from dataclasses import dataclass
from functools import wraps
from itertools import product
from pathlib import Path
from typing import TYPE_CHECKING, Any
from uuid import uuid4
import numpy as np
import torch
import torch.nn as nn
import torch_geometric
import yaml
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from torch_geometric.data import Data
from torch_geometric.utils import remove_self_loops
from torch_scatter import scatter, segment_coo, segment_csr
from torch_geometric.data.data import BaseData
from torch_geometric.data import Batch
# sort files by atomic number in descending order
def count_atoms_cif(file):
in_atom_site = False
natoms = 0
with open(file, 'r') as f:
while line := f.readline():
if line.lower().startswith("loop_"):
in_atom_site = False
continue
# if line.lower().startswith("_atom_site_"):
if "_atom_site_" in line.lower():
in_atom_site = True
continue
if in_atom_site:
if line.startswith("_"):
in_atom_site = False
continue
elif line:
natoms += 1
return natoms
# Override the collation method in `pytorch_geometric.data.InMemoryDataset`
def collate(data_list):
keys = data_list[0].keys
data = data_list[0].__class__()
for key in keys:
data[key] = []
slices = {key: [0] for key in keys}
for item, key in product(data_list, keys):
data[key].append(item[key])
if torch.is_tensor(item[key]):
s = slices[key][-1] + item[key].size(item.__cat_dim__(key, item[key]))
elif isinstance(item[key], (int, float)):
s = slices[key][-1] + 1
else:
raise ValueError("Unsupported attribute type")
slices[key].append(s)
if hasattr(data_list[0], "__num_nodes__"):
data.__num_nodes__ = []
for item in data_list:
data.__num_nodes__.append(item.num_nodes)
for key in keys:
if torch.is_tensor(data_list[0][key]):
data[key] = torch.cat(
data[key], dim=data.__cat_dim__(key, data_list[0][key])
)
else:
data[key] = torch.tensor(data[key])
slices[key] = torch.tensor(slices[key], dtype=torch.long)
return data, slices
def data_list_collater(
data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False
) -> BaseData | dict[str, torch.Tensor]:
batch = Batch.from_data_list(data_list)
if not otf_graph:
try:
n_neighbors = []
for _, data in enumerate(data_list):
n_index = data.edge_index[1, :]
n_neighbors.append(n_index.shape[0])
batch.neighbors = torch.tensor(n_neighbors)
except (NotImplementedError, TypeError):
logging.warning(
"LMDB does not contain edge index information, set otf_graph=True"
)
if to_dict:
batch = dict(batch.items())
return batch
\ No newline at end of file
#!/bin/bash
export CUDA_HOME=/usr/local/cuda
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
export LIBRARY_PATH=$CUDA_HOME/lib64:$LIBRARY_PATH
export CPATH=$CUDA_HOME/include:$CPATH
\ No newline at end of file
#!/bin/bash
echo quit | nvidia-cuda-mps-control
nvidia-smi -i 0 -c DEFAULT
nvidia-smi -i 1 -c DEFAULT
nvidia-smi -i 2 -c DEFAULT
nvidia-smi -i 3 -c DEFAULT
nvidia-smi -i 4 -c DEFAULT
nvidia-smi -i 5 -c DEFAULT
nvidia-smi -i 6 -c DEFAULT
nvidia-smi -i 7 -c DEFAULT
\ No newline at end of file
#!/bin/bash
nvidia-smi -i 0 -c EXCLUSIVE_PROCESS # Set GPU 0 to exclusive mode.
nvidia-smi -i 1 -c EXCLUSIVE_PROCESS # Set GPU 1 to exclusive mode.
nvidia-smi -i 2 -c EXCLUSIVE_PROCESS # Set GPU 2 to exclusive mode.
nvidia-smi -i 3 -c EXCLUSIVE_PROCESS # Set GPU 3 to exclusive mode.
nvidia-smi -i 4 -c EXCLUSIVE_PROCESS # Set GPU 4 to exclusive mode.
nvidia-smi -i 5 -c EXCLUSIVE_PROCESS # Set GPU 5 to exclusive mode.
nvidia-smi -i 6 -c EXCLUSIVE_PROCESS # Set GPU 6 to exclusive mode.
nvidia-smi -i 7 -c EXCLUSIVE_PROCESS # Set GPU 7 to exclusive mode.
nvidia-cuda-mps-control -d # Start the daemon.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment