Commit cc6e6b7d authored by Wang, Leping's avatar Wang, Leping
Browse files

- Add config.sh with all pipeline parameters organized by category

  (molecular, crystal structure, compute, run mode, path)
- Refactor search_gen_proc.sh to source config.sh instead of
  hardcoding parameters, with optional config path argument
- Refactor structure_generate.py to load config.sh via exec(),
  replacing hardcoded values with config-driven parameters
- Remove mace-bench (the relaxation part, it will be replaced by updated seperate mace-bench project )
parent 61ec3ad9
/* ----------------------------------------------------------------------
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
https://lammps.sandia.gov/, Sandia National Laboratories
Steve Plimpton, sjplimp@sandia.gov
Copyright (2003) Sandia Corporation. Under the terms of Contract
DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
certain rights in this software. This software is distributed under
the GNU General Public License.
See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */
/* ----------------------------------------------------------------------
Contributing author: Yutack Park (SNU)
------------------------------------------------------------------------- */
#include <ATen/core/Dict.h>
#include <ATen/core/ivalue_inl.h>
#include <ATen/ops/from_blob.h>
#include <c10/core/Scalar.h>
#include <c10/core/TensorOptions.h>
#include <cstdlib>
#include <filesystem>
#include <numeric>
#include <string>
#include <torch/csrc/jit/api/module.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <cuda_runtime.h>
#include "atom.h"
#include "comm.h"
#include "comm_brick.h"
#include "error.h"
#include "force.h"
#include "memory.h"
#include "neigh_list.h"
#include "neighbor.h"
// #include "nvToolsExt.h"
#include "pair_e3gnn_parallel.h"
#include <cassert>
#ifdef OMPI_MPI_H
#include "mpi-ext.h" //This should be included after mpi.h which is included in pair.h
#endif
using namespace LAMMPS_NS;
#define INTEGER_TYPE torch::TensorOptions().dtype(torch::kInt64)
#define FLOAT_TYPE torch::TensorOptions().dtype(torch::kFloat)
DeviceBuffManager &DeviceBuffManager::getInstance() {
static DeviceBuffManager instance;
return instance;
}
void DeviceBuffManager::get_buffer(int send_size, int recv_size,
float *&buf_send_ptr, float *&buf_recv_ptr) {
if (send_size > send_buf_size) {
cudaFree(buf_send_device);
cudaError_t cuda_err =
cudaMalloc(&buf_send_device, send_size * sizeof(float));
send_buf_size = send_size;
}
if (recv_size > recv_buf_size) {
cudaFree(buf_recv_device);
cudaError_t cuda_err =
cudaMalloc(&buf_recv_device, recv_size * sizeof(float));
recv_buf_size = recv_size;
}
buf_send_ptr = buf_send_device;
buf_recv_ptr = buf_recv_device;
}
DeviceBuffManager::~DeviceBuffManager() {
cudaFree(buf_send_device);
cudaFree(buf_recv_device);
}
PairE3GNNParallel::PairE3GNNParallel(LAMMPS *lmp) : Pair(lmp) {
// constructor
const char *print_flag = std::getenv("SEVENN_PRINT_INFO");
const char *print_both_flag = std::getenv("SEVENN_PRINT_BOTH_INFO");
if (print_flag) {
world_rank = comm->me;
std::cout << "process rank: " << world_rank << " initialized" << std::endl;
print_info = (world_rank == 0) || print_both_flag;
}
std::string device_name;
const bool use_gpu = torch::cuda::is_available();
comm_forward = 0;
comm_reverse = 0;
// OpenMPI detection
#ifdef OMPI_MPI_H
#if defined(MPIX_CUDA_AWARE_SUPPORT)
if (1 == MPIX_Query_cuda_support()) {
use_cuda_mpi = true;
} else {
use_cuda_mpi = false;
}
#else
use_cuda_mpi = false;
#endif
#else
use_cuda_mpi = false;
#endif
// use_cuda_mpi = use_gpu && use_cuda_mpi;
// if (use_cuda_mpi) {
if (use_gpu) {
device = get_cuda_device();
device_name = "CUDA";
} else {
device = torch::kCPU;
device_name = "CPU";
}
if (std::getenv("OFF_E3GNN_PARALLEL_CUDA_MPI")) {
use_cuda_mpi = false;
}
if (lmp->screen) {
if (use_gpu && !use_cuda_mpi) {
device_comm = torch::kCPU;
fprintf(lmp->screen,
"cuda-aware mpi not found, communicate via host device\n");
} else {
device_comm = device;
}
fprintf(lmp->screen, "PairE3GNNParallel using device : %s\n",
device_name.c_str());
fprintf(lmp->screen, "PairE3GNNParallel cuda-aware mpi: %s\n",
use_cuda_mpi ? "True" : "False");
}
if (lmp->logfile) {
if (use_gpu && !use_cuda_mpi) {
device_comm = torch::kCPU;
fprintf(lmp->logfile,
"cuda-aware mpi not found, communicate via host device\n");
} else {
device_comm = device;
}
fprintf(lmp->logfile, "PairE3GNNParallel using device : %s\n",
device_name.c_str());
fprintf(lmp->logfile, "PairE3GNNParallel cuda-aware mpi: %s\n",
use_cuda_mpi ? "True" : "False");
}
}
torch::Device PairE3GNNParallel::get_cuda_device() {
char *cuda_visible = std::getenv("CUDA_VISIBLE_DEVICES");
int num_gpus;
int idx;
int rank = comm->me;
num_gpus = torch::cuda::device_count();
idx = rank % num_gpus;
if (print_info)
std::cout << world_rank << " Available # of GPUs found: " << num_gpus
<< std::endl;
cudaError_t cuda_err = cudaSetDevice(idx);
if (cuda_err != cudaSuccess) {
std::cerr << "E3GNN: Failed to set CUDA device: "
<< cudaGetErrorString(cuda_err) << std::endl;
}
return torch::Device(torch::kCUDA, idx);
}
PairE3GNNParallel::~PairE3GNNParallel() {
if (allocated) {
memory->destroy(setflag);
memory->destroy(cutsq);
memory->destroy(map);
}
}
int PairE3GNNParallel::get_x_dim() { return x_dim; }
bool PairE3GNNParallel::use_cuda_mpi_() { return use_cuda_mpi; }
bool PairE3GNNParallel::is_comm_preprocess_done() {
return comm_preprocess_done;
}
void PairE3GNNParallel::compute(int eflag, int vflag) {
/*
Graph build on cpu
*/
if (eflag || vflag)
ev_setup(eflag, vflag);
else
evflag = vflag_fdotr = 0;
if (vflag_atom) {
error->all(FLERR, "atomic stress is not supported\n");
}
if (atom->tag_consecutive() == 0) {
error->all(FLERR, "Pair e3gnn requires consecutive atom IDs");
}
double **x = atom->x;
double **f = atom->f;
int *type = atom->type;
int nlocal = list->inum; // same as nlocal
int nghost = atom->nghost;
int ntotal = nlocal + nghost;
int *ilist = list->ilist;
int inum = list->inum;
CommBrick *comm_brick = dynamic_cast<CommBrick *>(comm);
if (comm_brick == nullptr) {
error->all(FLERR, "e3gnn/parallel: comm style should be brick & from "
"modified code of comm_brick");
}
bigint natoms = atom->natoms;
// tag ignore PBC
tagint *tag = atom->tag;
// store graph_idx from local to known ghost atoms(ghost atoms inside cutoff)
int tag_to_graph_idx[natoms + 1]; // tag starts from 1 not 0
std::fill_n(tag_to_graph_idx, natoms + 1, -1);
// to access tag_to_graph_idx from comm
tag_to_graph_idx_ptr = tag_to_graph_idx;
int graph_indexer = nlocal;
int graph_index_to_i[ntotal];
int *numneigh = list->numneigh; // j loop cond
int **firstneigh = list->firstneigh; // j list
const int nedges_upper_bound =
std::accumulate(numneigh, numneigh + nlocal, 0);
std::vector<long> node_type;
std::vector<long> node_type_ghost;
float edge_vec[nedges_upper_bound][3];
long edge_idx_src[nedges_upper_bound];
long edge_idx_dst[nedges_upper_bound];
int nedges = 0;
for (int ii = 0; ii < inum; ii++) {
// populate tag_to_graph_idx of local atoms
const int i = ilist[ii];
const int itag = tag[i];
const int itype = type[i];
tag_to_graph_idx[itag] = ii;
graph_index_to_i[ii] = i;
node_type.push_back(map[itype]);
}
// loop over neighbors, build graph
for (int ii = 0; ii < inum; ii++) {
const int i = ilist[ii];
const int i_graph_idx = ii;
const int *jlist = firstneigh[i];
const int jnum = numneigh[i];
for (int jj = 0; jj < jnum; jj++) {
int j = jlist[jj];
const int jtag = tag[j];
j &= NEIGHMASK;
const int jtype = type[j];
// we have to calculate Rij to check cutoff in lammps side
const double delij[3] = {x[j][0] - x[i][0], x[j][1] - x[i][1],
x[j][2] - x[i][2]};
const double Rij =
delij[0] * delij[0] + delij[1] * delij[1] + delij[2] * delij[2];
int j_graph_idx;
if (Rij < cutoff_square) {
// if given j is not local atom and inside cutoff
if (tag_to_graph_idx[jtag] == -1) {
// if j is ghost atom inside cutoff but first seen
tag_to_graph_idx[jtag] = graph_indexer;
graph_index_to_i[graph_indexer] = j;
node_type_ghost.push_back(map[jtype]);
graph_indexer++;
}
j_graph_idx = tag_to_graph_idx[jtag];
edge_idx_src[nedges] = i_graph_idx;
edge_idx_dst[nedges] = j_graph_idx;
edge_vec[nedges][0] = delij[0];
edge_vec[nedges][1] = delij[1];
edge_vec[nedges][2] = delij[2];
nedges++;
}
} // j loop end
} // i loop end
// member variable
graph_size = graph_indexer;
const int ghost_node_num = graph_size - nlocal;
// convert data to Tensor
auto inp_node_type = torch::from_blob(node_type.data(), nlocal, INTEGER_TYPE);
auto inp_node_type_ghost =
torch::from_blob(node_type_ghost.data(), ghost_node_num, INTEGER_TYPE);
long num_nodes[1] = {long(nlocal)};
auto inp_num_atoms = torch::from_blob(num_nodes, {1}, INTEGER_TYPE);
auto edge_idx_src_tensor =
torch::from_blob(edge_idx_src, {nedges}, INTEGER_TYPE);
auto edge_idx_dst_tensor =
torch::from_blob(edge_idx_dst, {nedges}, INTEGER_TYPE);
auto inp_edge_index =
torch::stack({edge_idx_src_tensor, edge_idx_dst_tensor});
auto inp_edge_vec = torch::from_blob(edge_vec, {nedges, 3}, FLOAT_TYPE);
if (print_info) {
std::cout << world_rank << " Nlocal: " << nlocal << std::endl;
std::cout << world_rank << " Graph_size: " << graph_size << std::endl;
std::cout << world_rank << " Ghost_node_num: " << ghost_node_num
<< std::endl;
std::cout << world_rank << " Nedges: " << nedges << "\n" << std::endl;
}
// r_original requires grad True
inp_edge_vec.set_requires_grad(true);
torch::Dict<std::string, torch::Tensor> input_dict;
input_dict.insert("x", inp_node_type.to(device));
input_dict.insert("x_ghost", inp_node_type_ghost.to(device));
input_dict.insert("edge_index", inp_edge_index.to(device));
input_dict.insert("edge_vec", inp_edge_vec.to(device));
input_dict.insert("num_atoms", inp_num_atoms.to(device));
input_dict.insert("nlocal", inp_num_atoms.to(torch::kCPU));
std::list<std::vector<torch::Tensor>> wrt_tensors;
wrt_tensors.push_back({input_dict.at("edge_vec")});
auto model_part = model_list.front();
auto output = model_part.forward({input_dict}).toGenericDict();
comm_preprocess();
// extra_graph_idx_map is set from comm_preprocess();
// last one is for trash values. See pack_forward_init
const int extra_size =
ghost_node_num + static_cast<int>(extra_graph_idx_map.size()) + 1;
torch::Tensor x_local;
torch::Tensor x_ghost;
for (auto it = model_list.begin(); it != model_list.end(); ++it) {
if (it == model_list.begin())
continue;
model_part = *it;
x_local = output.at("x").toTensor().detach().to(device);
x_dim = x_local.size(1); // length of per atom vector(node feature)
auto ghost_and_extra_x = torch::zeros({ghost_node_num + extra_size, x_dim},
FLOAT_TYPE.device(device));
x_comm = torch::cat({x_local, ghost_and_extra_x}, 0).to(device_comm);
comm_brick->forward_comm(this); // populate x_ghost by communication
// What we got from forward_comm (node feature of ghosts)
x_ghost = torch::split_with_sizes(
x_comm, {nlocal, ghost_node_num, extra_size}, 0)[1];
x_ghost.set_requires_grad(true);
// prepare next input (output > next input)
output.insert_or_assign("x_ghost", x_ghost.to(device));
// make another edge_vec to discriminate grad calculation with other
// edge_vecs(maybe redundant?)
output.insert_or_assign("edge_vec",
output.at("edge_vec").toTensor().clone());
// save tensors for backprop
wrt_tensors.push_back({output.at("edge_vec").toTensor(),
output.at("x").toTensor(),
output.at("self_cont_tmp").toTensor(),
output.at("x_ghost").toTensor()});
output = model_part.forward({output}).toGenericDict();
}
torch::Tensor energy_tensor =
output.at("inferred_total_energy").toTensor().squeeze();
torch::Tensor dE_dr =
torch::zeros({nedges, 3}, FLOAT_TYPE.device(device)); // create on device
torch::Tensor x_local_save; // holds grad info of x_local (it loses its grad
// when sends to CPU)
torch::Tensor self_conn_grads;
std::vector<torch::Tensor> grads;
std::vector<torch::Tensor> of_tensor;
// TODO: most values of self_conn_grads were zero because we use only scalars
// for energy
for (auto rit = wrt_tensors.rbegin(); rit != wrt_tensors.rend(); ++rit) {
// edge_vec, x, x_ghost order
auto wrt_tensor = *rit;
if (rit == wrt_tensors.rbegin()) {
grads = torch::autograd::grad({energy_tensor}, wrt_tensor);
} else {
x_local_save.copy_(x_local);
// of wrt grads_output
grads = torch::autograd::grad(of_tensor, wrt_tensor,
{x_local_save, self_conn_grads});
}
dE_dr = dE_dr + grads.at(0); // accumulate force
if (std::distance(rit, wrt_tensors.rend()) == 1)
continue; // if last iteration
of_tensor.clear();
of_tensor.push_back(wrt_tensor[1]); // x
of_tensor.push_back(wrt_tensor[2]); // self_cont_tmp
x_local_save = grads.at(1); // for grads_output
x_local = x_local_save.detach(); // grad_outputs & communication
x_dim = x_local.size(1);
self_conn_grads = grads.at(2); // no communication, for grads_output
x_ghost = grads.at(3).detach(); // yes communication, not for grads_output
auto extra_x = torch::zeros({extra_size, x_dim}, FLOAT_TYPE.device(device));
x_comm = torch::cat({x_local, x_ghost, extra_x}, 0).to(device_comm);
comm_brick->reverse_comm(this); // completes x_local
// now x_local is complete (dE_dx), become next grads_output(with
// self_conn_grads)
x_local = torch::split_with_sizes(
x_comm, {nlocal, ghost_node_num, extra_size}, 0)[0];
}
// postprocessing
if (print_info) {
size_t free, tot;
cudaMemGetInfo(&free, &tot);
std::cout << world_rank << " MEM use after backward(MB)" << std::endl;
double Mfree = static_cast<double>(free) / (1024 * 1024);
double Mtot = static_cast<double>(tot) / (1024 * 1024);
std::cout << world_rank << " Total: " << Mtot << std::endl;
std::cout << world_rank << " Free: " << Mfree << std::endl;
std::cout << world_rank << " Used: " << Mtot - Mfree << std::endl;
double Mused = Mtot - Mfree;
std::cout << world_rank << " Used/Nedges: " << Mused / nedges << std::endl;
std::cout << world_rank << " Used/Nlocal: " << Mused / nlocal << std::endl;
std::cout << world_rank << " Used/GraphSize: " << Mused / graph_size << "\n"
<< std::endl;
}
eng_vdwl += energy_tensor.item<float>(); // accumulate energy
dE_dr = dE_dr.to(torch::kCPU);
torch::Tensor force_tensor = torch::zeros({graph_indexer, 3});
auto _edge_idx_src_tensor =
edge_idx_src_tensor.repeat_interleave(3).view({nedges, 3});
auto _edge_idx_dst_tensor =
edge_idx_dst_tensor.repeat_interleave(3).view({nedges, 3});
force_tensor.scatter_reduce_(0, _edge_idx_src_tensor, dE_dr, "sum");
force_tensor.scatter_reduce_(0, _edge_idx_dst_tensor, torch::neg(dE_dr),
"sum");
auto forces = force_tensor.accessor<float, 2>();
for (int graph_idx = 0; graph_idx < graph_indexer; graph_idx++) {
int i = graph_index_to_i[graph_idx];
f[i][0] += forces[graph_idx][0];
f[i][1] += forces[graph_idx][1];
f[i][2] += forces[graph_idx][2];
}
if (vflag) {
auto diag = inp_edge_vec * dE_dr;
auto s12 = inp_edge_vec.select(1, 0) * dE_dr.select(1, 1);
auto s23 = inp_edge_vec.select(1, 1) * dE_dr.select(1, 2);
auto s31 = inp_edge_vec.select(1, 2) * dE_dr.select(1, 0);
std::vector<torch::Tensor> voigt_list = {
diag, s12.unsqueeze(-1), s23.unsqueeze(-1), s31.unsqueeze(-1)};
auto voigt = torch::cat(voigt_list, 1);
torch::Tensor per_atom_stress_tensor = torch::zeros({graph_indexer, 6});
auto _edge_idx_dst6_tensor =
edge_idx_dst_tensor.repeat_interleave(6).view({nedges, 6});
per_atom_stress_tensor.scatter_reduce_(0, _edge_idx_dst6_tensor, voigt,
"sum");
auto virial_stress_tensor =
torch::neg(torch::sum(per_atom_stress_tensor, 0));
auto virial_stress = virial_stress_tensor.accessor<float, 1>();
virial[0] += virial_stress[0];
virial[1] += virial_stress[1];
virial[2] += virial_stress[2];
virial[3] += virial_stress[3];
virial[4] += virial_stress[5];
virial[5] += virial_stress[4];
}
if (eflag_atom) {
torch::Tensor atomic_energy_tensor =
output.at("atomic_energy").toTensor().cpu().squeeze();
auto atomic_energy = atomic_energy_tensor.accessor<float, 1>();
for (int graph_idx = 0; graph_idx < nlocal; graph_idx++) {
int i = graph_index_to_i[graph_idx];
eatom[i] += atomic_energy[graph_idx];
}
}
// clean up comm preprocess variables
comm_preprocess_done = false;
for (int i = 0; i < 6; i++) {
// array of vector<long>
comm_index_pack_forward[i].clear();
comm_index_unpack_forward[i].clear();
comm_index_unpack_reverse[i].clear();
}
extra_graph_idx_map.clear();
}
// allocate arrays (called from coeff)
void PairE3GNNParallel::allocate() {
allocated = 1;
int n = atom->ntypes;
memory->create(setflag, n + 1, n + 1, "pair:setflag");
memory->create(cutsq, n + 1, n + 1, "pair:cutsq");
memory->create(map, n + 1, "pair:map");
}
// global settings for pair_style
void PairE3GNNParallel::settings(int narg, char **arg) {
if (narg != 0) {
error->all(FLERR, "Illegal pair_style command");
}
}
void PairE3GNNParallel::coeff(int narg, char **arg) {
if (allocated) {
error->all(FLERR, "pair_e3gnn coeff called twice");
}
allocate();
if (strcmp(arg[0], "*") != 0 || strcmp(arg[1], "*") != 0) {
error->all(FLERR,
"e3gnn: first and second input of pair_coeff should be '*'");
}
// expected input : pair_coeff * * pot.pth type_name1 type_name2 ...
std::unordered_map<std::string, std::string> meta_dict = {
{"chemical_symbols_to_index", ""},
{"cutoff", ""},
{"num_species", ""},
{"model_type", ""},
{"version", ""},
{"dtype", ""},
{"time", ""},
{"comm_size", ""}};
// model loading from input
int n_model = std::stoi(arg[2]);
int chem_arg_i = 4;
std::vector<std::string> model_fnames;
if (std::filesystem::exists(arg[3])) {
if (std::filesystem::is_directory(arg[3])) {
auto headf = std::string(arg[3]);
for (int i = 0; i < n_model; i++) {
auto stri = std::to_string(i);
model_fnames.push_back(headf + "/deployed_parallel_" + stri + ".pt");
}
} else if (std::filesystem::is_regular_file(arg[3])) {
for (int i = 3; i < n_model + 3; i++) {
model_fnames.push_back(std::string(arg[i]));
}
chem_arg_i = n_model + 3;
} else {
error->all(FLERR, "No such file or directory:" + std::string(arg[3]));
}
}
for (const auto &modelf : model_fnames) {
if (!std::filesystem::is_regular_file(modelf)) {
error->all(FLERR, "Expected this is a regular file:" + modelf);
}
model_list.push_back(torch::jit::load(modelf, device, meta_dict));
}
torch::jit::setGraphExecutorOptimize(false);
torch::jit::FusionStrategy strategy;
// strategy = {{torch::jit::FusionBehavior::DYNAMIC, 3}};
strategy = {{torch::jit::FusionBehavior::STATIC, 0}};
torch::jit::setFusionStrategy(strategy);
cutoff = std::stod(meta_dict["cutoff"]);
// maximum possible size of per atom x before last convolution
int comm_size = std::stod(meta_dict["comm_size"]);
// to initialize buffer size for communication
comm_forward = comm_size;
comm_reverse = comm_size;
cutoff_square = cutoff * cutoff;
if (meta_dict["model_type"].compare("E3_equivariant_model") != 0) {
error->all(FLERR, "given model type is not E3_equivariant_model");
}
std::string chem_str = meta_dict["chemical_symbols_to_index"];
int ntypes = atom->ntypes;
auto delim = " ";
char *tok = std::strtok(const_cast<char *>(chem_str.c_str()), delim);
std::vector<std::string> chem_vec;
while (tok != nullptr) {
chem_vec.push_back(std::string(tok));
tok = std::strtok(nullptr, delim);
}
// what if unknown chemical specie is in arg? should I abort? is there any use
// case for that?
bool found_flag = false;
int n_chem = narg - chem_arg_i;
for (int i = 0; i < n_chem; i++) {
found_flag = false;
for (int j = 0; j < chem_vec.size(); j++) {
if (chem_vec[j].compare(arg[i + chem_arg_i]) == 0) {
map[i + 1] = j; // store from 1, (not 0)
found_flag = true;
if (lmp->logfile) {
fprintf(lmp->logfile, "Chemical specie '%s' is assigned to type %d\n",
arg[i + chem_arg_i], i + 1);
break;
}
}
}
if (!found_flag) {
error->all(FLERR, "Unknown chemical specie is given or the number of "
"potential files is not consistent");
}
}
for (int i = 1; i <= ntypes; i++) {
for (int j = 1; j <= ntypes; j++) {
if ((map[i] >= 0) && (map[j] >= 0)) {
setflag[i][j] = 1;
cutsq[i][j] = cutoff * cutoff;
}
}
}
if (lmp->logfile) {
fprintf(lmp->logfile, "from sevenn version '%s' ",
meta_dict["version"].c_str());
fprintf(lmp->logfile, "%s precision model, deployed when: %s\n",
meta_dict["dtype"].c_str(), meta_dict["time"].c_str());
}
}
// init specific to this pair
void PairE3GNNParallel::init_style() {
// full neighbor list & newton on
if (force->newton_pair == 0) {
error->all(FLERR, "Pair style e3gnn/parallel requires newton pair on");
}
neighbor->add_request(this, NeighConst::REQ_FULL);
}
double PairE3GNNParallel::init_one(int i, int j) { return cutoff; }
void PairE3GNNParallel::notify_proc_ids(const int *sendproc, const int *recvproc) {
for (int iswap = 0; iswap < 6; iswap++) {
this->sendproc[iswap] = sendproc[iswap];
this->recvproc[iswap]= recvproc[iswap];
}
}
void PairE3GNNParallel::comm_preprocess() {
assert(!comm_preprocess_done);
CommBrick *comm_brick = dynamic_cast<CommBrick *>(comm);
// fake lammps communication call to preprocess index
// gives complete comm_index_pack, unpack_forward, and extra_graph_idx_map
comm_brick->forward_comm(this);
std::map<int, std::set<int>> already_met_map;
for (int comm_phase = 0; comm_phase < 6; comm_phase++) {
const int n = comm_index_pack_forward[comm_phase].size();
int sproc = this->sendproc[comm_phase];
if (already_met_map.count(sproc) == 0) {
already_met_map.insert({sproc, std::set<int>()});
}
// for unpack_reverse, Ignore duplicated index by 'already_met'
std::vector<long> &idx_map_forward = comm_index_pack_forward[comm_phase];
std::vector<long> &idx_map_reverse = comm_index_unpack_reverse[comm_phase];
std::set<int>& already_met = already_met_map[sproc];
// the last index of x_comm is used to trash unnecessary values
const int trash_index =
graph_size + static_cast<int>(extra_graph_idx_map.size()); //+ 1;
for (int i = 0; i < n; i++) {
const int idx = idx_map_forward[i];
if (idx < graph_size) {
if (already_met.count(idx) == 1) {
idx_map_reverse.push_back(trash_index);
} else {
idx_map_reverse.push_back(idx);
already_met.insert(idx);
}
} else {
idx_map_reverse.push_back(idx);
}
}
if (use_cuda_mpi) {
comm_index_pack_forward_tensor[comm_phase] = torch::from_blob(idx_map_forward.data(), idx_map_forward.size(), INTEGER_TYPE).to(device);
auto upmap = comm_index_unpack_forward[comm_phase];
comm_index_unpack_forward_tensor[comm_phase] = torch::from_blob(upmap.data(), upmap.size(), INTEGER_TYPE).to(device);
comm_index_unpack_reverse_tensor[comm_phase] = torch::from_blob(idx_map_reverse.data(), idx_map_reverse.size(), INTEGER_TYPE).to(device);
}
}
comm_preprocess_done = true;
}
// called from comm_brick if comm_preprocess_done is false
void PairE3GNNParallel::pack_forward_init(int n, int *list_send,
int comm_phase) {
std::vector<long> &idx_map = comm_index_pack_forward[comm_phase];
idx_map.reserve(n);
int i, j;
int nlocal = list->inum;
tagint *tag = atom->tag;
for (i = 0; i < n; i++) {
int list_i = list_send[i];
int graph_idx = tag_to_graph_idx_ptr[tag[list_i]];
if (graph_idx != -1) {
// known atom (local atom + ghost atom inside cutoff)
idx_map.push_back(graph_idx);
} else {
// unknown atom, these are not used in computation in this process
// instead, this process is used to hand over these atoms to other proecss
// hold them in continuous manner for flexible tensor operations later
if (extra_graph_idx_map.find(list_i) != extra_graph_idx_map.end()) {
idx_map.push_back(extra_graph_idx_map[list_i]);
} else {
// unknown atom at pack forward, ghost atom outside cutoff?
extra_graph_idx_map[i] = graph_size + extra_graph_idx_map.size();
idx_map.push_back(extra_graph_idx_map[i]); // same as list_i in pack
}
}
}
}
// called from comm_brick if comm_preprocess_done is false
void PairE3GNNParallel::unpack_forward_init(int n, int first, int comm_phase) {
std::vector<long> &idx_map = comm_index_unpack_forward[comm_phase];
idx_map.reserve(n);
int i, j, last;
last = first + n;
int nlocal = list->inum;
tagint *tag = atom->tag;
for (i = first; i < last; i++) {
int graph_idx = tag_to_graph_idx_ptr[tag[i]];
if (graph_idx != -1) {
idx_map.push_back(graph_idx);
} else {
extra_graph_idx_map[i] = graph_size + extra_graph_idx_map.size();
idx_map.push_back(extra_graph_idx_map[i]); // same as list_i in pack
}
}
}
int PairE3GNNParallel::pack_forward_comm_gnn(float *buf, int comm_phase) {
std::vector<long> &idx_map = comm_index_pack_forward[comm_phase];
const int n = static_cast<int>(idx_map.size());
if (use_cuda_mpi && n != 0) {
torch::Tensor &idx_map_tensor = comm_index_pack_forward_tensor[comm_phase];
auto selected = x_comm.index_select(0, idx_map_tensor); // its size is x_dim * n
cudaError_t cuda_err =
cudaMemcpy(buf, selected.data_ptr<float>(), (x_dim * n) * sizeof(float),
cudaMemcpyDeviceToDevice);
} else {
int i, j, m;
m = 0;
for (i = 0; i < n; i++) {
const int idx = static_cast<int>(idx_map.at(i));
float *from = x_comm[idx].data_ptr<float>();
for (j = 0; j < x_dim; j++) {
buf[m++] = from[j];
}
}
}
if (print_info) {
std::cout << world_rank << " comm_phase: " << comm_phase << std::endl;
std::cout << world_rank << " pack_forward x_dim: " << x_dim << std::endl;
std::cout << world_rank << " pack_forward n: " << n << std::endl;
std::cout << world_rank << " pack_forward x_dim*n: " << x_dim * n
<< std::endl;
double Msend = static_cast<double>(x_dim * n * 4) / (1024 * 1024);
std::cout << world_rank << " send size(MB): " << Msend << "\n" << std::endl;
}
return x_dim * n;
}
void PairE3GNNParallel::unpack_forward_comm_gnn(float *buf, int comm_phase) {
std::vector<long> &idx_map = comm_index_unpack_forward[comm_phase];
const int n = static_cast<int>(idx_map.size());
if (use_cuda_mpi && n != 0) {
torch::Tensor &idx_map_tensor = comm_index_unpack_forward_tensor[comm_phase];
auto buf_tensor =
torch::from_blob(buf, {n, x_dim}, FLOAT_TYPE.device(device));
x_comm.scatter_(0, idx_map_tensor.repeat_interleave(x_dim).view({n, x_dim}),
buf_tensor);
} else {
int i, j, m;
m = 0;
for (i = 0; i < n; i++) {
const int idx = static_cast<int>(idx_map.at(i));
float *to = x_comm[idx].data_ptr<float>();
for (j = 0; j < x_dim; j++) {
to[j] = buf[m++];
}
}
}
}
int PairE3GNNParallel::pack_reverse_comm_gnn(float *buf, int comm_phase) {
std::vector<long> &idx_map = comm_index_unpack_forward[comm_phase];
const int n = static_cast<int>(idx_map.size());
if (use_cuda_mpi && n != 0) {
torch::Tensor &idx_map_tensor = comm_index_unpack_forward_tensor[comm_phase];
auto selected = x_comm.index_select(0, idx_map_tensor);
cudaError_t cuda_err = cudaMemcpy(buf, selected.data_ptr<float>(), (x_dim * n) * sizeof(float), cudaMemcpyDeviceToDevice);
} else {
int i, j, m;
m = 0;
for (i = 0; i < n; i++) {
const int idx = static_cast<int>(idx_map.at(i));
float *from = x_comm[idx].data_ptr<float>();
for (j = 0; j < x_dim; j++) {
buf[m++] = from[j];
}
}
}
if (print_info) {
std::cout << world_rank << " comm_phase: " << comm_phase << std::endl;
std::cout << world_rank << " pack_reverse x_dim: " << x_dim << std::endl;
std::cout << world_rank << " pack_reverse n: " << n << std::endl;
std::cout << world_rank << " pack_reverse x_dim*n: " << x_dim * n
<< std::endl;
double Msend = static_cast<double>(x_dim * n * 4) / (1024 * 1024);
}
return x_dim * n;
}
void PairE3GNNParallel::unpack_reverse_comm_gnn(float *buf, int comm_phase) {
std::vector<long> &idx_map = comm_index_unpack_reverse[comm_phase];
const int n = static_cast<int>(idx_map.size());
if (use_cuda_mpi && n != 0) {
torch::Tensor &idx_map_tensor = comm_index_unpack_reverse_tensor[comm_phase];
auto buf_tensor =
torch::from_blob(buf, {n, x_dim}, FLOAT_TYPE.device(device));
x_comm.scatter_(0, idx_map_tensor.repeat_interleave(x_dim).view({n, x_dim}),
buf_tensor, "add");
} else {
int i, j, m;
m = 0;
for (i = 0; i < n; i++) {
const int idx = static_cast<int>(idx_map.at(i));
if (idx == -1) {
m += x_dim;
continue;
}
float *to = x_comm[idx].data_ptr<float>();
for (j = 0; j < x_dim; j++) {
to[j] += buf[m++];
}
}
}
}
/* -*- c++ -*- ----------------------------------------------------------
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
http://lammps.sandia.gov, Sandia National Laboratories
Steve Plimpton, sjplimp@sandia.gov
Copyright (2003) Sandia Corporation. Under the terms of Contract
DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
certain rights in this software. This software is distributed under
the GNU General Public License.
See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */
#ifdef PAIR_CLASS
PairStyle(e3gnn/parallel, PairE3GNNParallel)
#else
#ifndef LMP_PAIR_E3GNN_PARALLEL
#define LMP_PAIR_E3GNN_PARALLEL
#include "pair.h"
#include <torch/torch.h>
#include <vector>
namespace LAMMPS_NS {
class PairE3GNNParallel : public Pair {
private:
double cutoff;
double cutoff_square;
std::vector<torch::jit::Module> model_list;
torch::Device device = torch::kCPU;
torch::Device device_comm = torch::kCPU;
torch::Device get_cuda_device();
bool use_cuda_mpi;
// for communication
// Most of these variables for communication is temporary and valid for only
// one MD step.
int x_dim; // to determine per atom data size
int graph_size;
torch::Tensor x_comm; // x_local + x_ghost + x_comm_extra
void comm_preprocess();
bool comm_preprocess_done = false;
// temporary variables holds for each compute step
std::unordered_map<int, long> extra_graph_idx_map;
// To use scatter, store long instead of int
// array of vector
std::vector<long> comm_index_pack_forward[6];
std::vector<long> comm_index_unpack_forward[6];
std::vector<long> comm_index_unpack_reverse[6];
// its size is 6 and initialized at comm_preprocess()
torch::Tensor comm_index_pack_forward_tensor[6];
torch::Tensor comm_index_unpack_forward_tensor[6];
torch::Tensor comm_index_unpack_reverse_tensor[6];
// to use tag_to_graph_idx inside comm methods
int *tag_to_graph_idx_ptr = nullptr;
int sendproc[6];
int recvproc[6];
public:
PairE3GNNParallel(class LAMMPS *);
~PairE3GNNParallel();
// TODO: keep encapsulation..
void compute(int, int) override;
void settings(int, char **) override;
// read Atom type string from input script & related coeff
void coeff(int, char **) override;
void allocate();
void pack_forward_init(int n, int *list, int comm_phase);
void unpack_forward_init(int n, int first, int comm_phase);
int pack_forward_comm_gnn(float *buf, int comm_phase);
void unpack_forward_comm_gnn(float *buf, int comm_phase);
int pack_reverse_comm_gnn(float *buf, int comm_phase);
void unpack_reverse_comm_gnn(float *buf, int comm_phase);
void init_style() override;
double init_one(int, int) override;
int get_x_dim();
bool use_cuda_mpi_();
bool is_comm_preprocess_done();
void notify_proc_ids(const int *sendproc, const int *recvproc);
bool print_info = false;
int world_rank;
};
class DeviceBuffManager {
private:
DeviceBuffManager() {}
DeviceBuffManager(const DeviceBuffManager &);
DeviceBuffManager &operator=(const DeviceBuffManager &);
float *buf_send_device = nullptr;
float *buf_recv_device = nullptr;
int send_buf_size = 0;
int recv_buf_size = 0;
public:
static DeviceBuffManager &getInstance();
void get_buffer(int, int, float *&, float *&);
~DeviceBuffManager();
};
} // namespace LAMMPS_NS
#endif
#endif
#!/bin/bash
lammps_root=$1
cxx_standard=$2 # 14, 17
d3_support=$3 # 1, 0
SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
###########################################
# Check if the given arguments are valid #
###########################################
# Check the number of arguments
if [ "$#" -ne 3 ]; then
echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support}"
echo " {lammps_root}: Root directory of LAMMPS source"
echo " {cxx_standard}: C++ standard (14, 17)"
echo " {d3_support}: Support for pair_d3 (1, 0)"
exit 1
fi
# Check if the lammps_root directory exists
if [ ! -d "$lammps_root" ]; then
echo "Error: No such directory: $lammps_root"
exit 1
fi
# Check if the given directory is the root of LAMMPS source
if [ ! -d "$lammps_root/cmake" ] && [ ! -d "$lammps_root/potentials" ]; then
echo "Error: Given $lammps_root is not a root of LAMMPS source"
exit 1
fi
# Check if the script is being run from the root of SevenNet
if [ ! -f "${SCRIPT_DIR}/pair_e3gnn.cpp" ]; then
echo "Error: Script executed in a wrong directory"
exit 1
fi
# Check if the patch is already applied
if [ -f "$lammps_root/src/pair_e3gnn.cpp" ]; then
echo "----------------------------------------------------------"
echo "Seems like given LAMMPS is already patched."
echo "Try again after removing src/pair_e3gnn.cpp to force patch"
echo "----------------------------------------------------------"
echo "Example build commands, under LAMMPS root"
echo " mkdir build; cd build"
echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')"
echo " make -j 4"
exit 0
fi
# Check if OpenMPI exists and if it is CUDA-aware
if command -v ompi_info &> /dev/null; then
cuda_support=$(ompi_info --parsable --all | grep mpi_built_with_cuda_support:value)
if [[ -z "$cuda_support" ]]; then
echo "OpenMPI not found, parallel performance is not optimal"
elif [[ "$cuda_support" == *"true" ]]; then
echo "OpenMPI is CUDA aware"
else
echo "This system's OpenMPI is not 'CUDA aware', parallel performance is not optimal"
fi
else
echo "OpenMPI not found, parallel performance is not optimal"
fi
# Extract LAMMPS version and update
lammps_version=$(grep "#define LAMMPS_VERSION" $lammps_root/src/version.h | awk '{print $3, $4, $5}' | tr -d '"')
# Combine version and update
detected_version="$lammps_version"
required_version="2 Aug 2023" # Example required version
# Check if the detected version is compatible
if [[ "$detected_version" != "$required_version" ]]; then
echo "Warning: Detected LAMMPS version ($detected_version) may not be compatible. Required version: $required_version"
fi
###########################################
# Backup original LAMMPS source code #
###########################################
# Create a backup directory if it doesn't exist
backup_dir="$lammps_root/_backups"
mkdir -p $backup_dir
# Copy comm_* from original LAMMPS source as backup
cp $lammps_root/src/comm_brick.cpp $backup_dir/
cp $lammps_root/src/comm_brick.h $backup_dir/
# Copy cmake/CMakeLists.txt from original source as backup
cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt
###########################################
# Patch LAMMPS source code: e3gnn #
###########################################
# 1. Copy pair_e3gnn files to LAMMPS source
cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/
cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/
# 2. Patch cmake/CMakeLists.txt
sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD $cxx_standard)/" $lammps_root/cmake/CMakeLists.txt
cat >> $lammps_root/cmake/CMakeLists.txt << "EOF"
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}")
EOF
###########################################
# Patch LAMMPS source code: d3 #
###########################################
if [ "$d3_support" -ne 0 ]; then
# 1. Copy pair_d3 files to LAMMPS source
cp $SCRIPT_DIR/pair_d3.cu $lammps_root/src/
cp $SCRIPT_DIR/pair_d3.h $lammps_root/src/
cp $SCRIPT_DIR/pair_d3_pars.h $lammps_root/src/
# 2. Patch cmake/CMakeLists.txt
sed -i "s/project(lammps CXX)/project(lammps CXX CUDA)/" $lammps_root/cmake/CMakeLists.txt
sed -i "s/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp \${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cu/" $lammps_root/cmake/CMakeLists.txt
cat >> $lammps_root/cmake/CMakeLists.txt << "EOF"
find_package(CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fmad=false -O3")
string(REPLACE "-gencode arch=compute_50,code=sm_50" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
target_link_libraries(lammps PUBLIC ${CUDA_LIBRARIES} cuda)
EOF
fi
###########################################
# Print changes and backup file locations #
###########################################
# Print changes and backup file locations
echo "Changes made:"
echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups"
echo " - Copied contents of pair_e3gnn to $lammps_root/src/"
echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard"
if [ "$d3_support" -ne 0 ]; then
echo " - Copied contents of pair_d3 to $lammps_root/src/"
echo " - Patched CMakeLists.txt: include CUDA"
fi
# Provide example cmake command to the user
echo "Example build commands, under LAMMPS root"
echo " mkdir build; cd build"
echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')"
echo " make -j 4"
exit 0
import glob
import os
import warnings
from typing import Any, Callable, Dict
import torch
import yaml
import sevenn._const as _const
import sevenn._keys as KEY
import sevenn.util as util
def config_initialize(
key: str,
config: Dict,
default: Any,
conditions: Dict,
):
# default value exist & no user input -> return default
if key not in config.keys():
return default
# No validation method exist => accept user input
user_input = config[key]
if key in conditions:
condition = conditions[key]
else:
return user_input
if type(default) is dict and isinstance(condition, dict):
for i_key, val in default.items():
user_input[i_key] = config_initialize(
i_key, user_input, val, condition
)
return user_input
elif isinstance(condition, type):
if isinstance(user_input, condition):
return user_input
else:
try:
return condition(user_input) # try type casting
except ValueError:
raise ValueError(
f"Expect '{user_input}' for '{key}' is {condition}"
)
elif isinstance(condition, Callable) and condition(user_input):
return user_input
else:
raise ValueError(
f"Given input '{user_input}' for '{key}' is not valid"
)
def init_model_config(config: Dict):
# defaults = _const.model_defaults(config)
model_meta = {}
# init complicated ones
if KEY.CHEMICAL_SPECIES not in config.keys():
raise ValueError('required key chemical_species not exist')
input_chem = config[KEY.CHEMICAL_SPECIES]
if isinstance(input_chem, str) and input_chem.lower() == 'auto':
model_meta[KEY.CHEMICAL_SPECIES] = 'auto'
model_meta[KEY.NUM_SPECIES] = 'auto'
model_meta[KEY.TYPE_MAP] = 'auto'
elif isinstance(input_chem, str) and 'univ' in input_chem.lower():
model_meta.update(util.chemical_species_preprocess([], universal=True))
else:
if isinstance(input_chem, list) and all(
isinstance(x, str) for x in input_chem
):
pass
elif isinstance(input_chem, str):
input_chem = (
input_chem.replace('-', ',').replace(' ', ',').split(',')
)
input_chem = [chem for chem in input_chem if len(chem) != 0]
else:
raise ValueError(f'given {KEY.CHEMICAL_SPECIES} input is strange')
model_meta.update(util.chemical_species_preprocess(input_chem))
# deprecation warnings
if KEY.AVG_NUM_NEIGH in config:
warnings.warn(
"key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'."
' We use the default, the average number of neighbors in the'
' dataset, if not provided.',
UserWarning,
)
config.pop(KEY.AVG_NUM_NEIGH)
if KEY.TRAIN_AVG_NUM_NEIGH in config:
warnings.warn(
"key 'train_avg_num_neigh' is deprecated. Please use"
" 'train_denominator'. We overwrite train_denominator as given"
' train_avg_num_neigh',
UserWarning,
)
config[KEY.TRAIN_DENOMINTAOR] = config[KEY.TRAIN_AVG_NUM_NEIGH]
config.pop(KEY.TRAIN_AVG_NUM_NEIGH)
if KEY.OPTIMIZE_BY_REDUCE in config:
warnings.warn(
"key 'optimize_by_reduce' is deprecated. Always true",
UserWarning,
)
config.pop(KEY.OPTIMIZE_BY_REDUCE)
# init simpler ones
for key, default in _const.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG.items():
model_meta[key] = config_initialize(
key, config, default, _const.MODEL_CONFIG_CONDITION
)
unknown_keys = [
key for key in config.keys() if key not in model_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected model keys: {unknown_keys} will be ignored',
UserWarning,
)
return model_meta
def init_train_config(config: Dict):
train_meta = {}
# defaults = _const.train_defaults(config)
try:
device_input = config[KEY.DEVICE]
train_meta[KEY.DEVICE] = torch.device(device_input)
except KeyError:
train_meta[KEY.DEVICE] = (
torch.device('cuda')
if torch.cuda.is_available()
else torch.device('cpu')
)
train_meta[KEY.DEVICE] = str(train_meta[KEY.DEVICE])
# init simpler ones
for key, default in _const.DEFAULT_TRAINING_CONFIG.items():
train_meta[key] = config_initialize(
key, config, default, _const.TRAINING_CONFIG_CONDITION
)
if KEY.CONTINUE in config.keys():
cnt_dct = config[KEY.CONTINUE]
if KEY.CHECKPOINT not in cnt_dct.keys():
raise ValueError('no checkpoint is given in continue')
checkpoint = cnt_dct[KEY.CHECKPOINT]
if os.path.isfile(checkpoint):
checkpoint_file = checkpoint
else:
checkpoint_file = util.pretrained_name_to_path(checkpoint)
train_meta[KEY.CONTINUE].update({KEY.CHECKPOINT: checkpoint_file})
unknown_keys = [
key for key in config.keys() if key not in train_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected train keys: {unknown_keys} will be ignored',
UserWarning,
)
return train_meta
def init_data_config(config: Dict):
data_meta = {}
# defaults = _const.data_defaults(config)
load_data_keys = []
for k in config:
if k.startswith('load_') and k.endswith('_path'):
load_data_keys.append(k)
for load_data_key in load_data_keys:
if load_data_key in config.keys():
inp = config[load_data_key]
extended = []
if type(inp) not in [str, list]:
raise ValueError(f'unexpected input {inp} for sturcture_list')
if type(inp) is str:
extended = glob.glob(inp)
elif type(inp) is list:
for i in inp:
if isinstance(i, str):
extended.extend(glob.glob(i))
elif isinstance(i, dict):
extended.append(i)
if len(extended) == 0:
raise ValueError(
f'Cannot find {inp} for {load_data_key}'
+ ' or path is not given'
)
data_meta[load_data_key] = extended
else:
data_meta[load_data_key] = False
for key, default in _const.DEFAULT_DATA_CONFIG.items():
data_meta[key] = config_initialize(
key, config, default, _const.DATA_CONFIG_CONDITION
)
unknown_keys = [
key for key in config.keys() if key not in data_meta.keys()
]
if len(unknown_keys) != 0:
warnings.warn(
f'Unexpected data keys: {unknown_keys} will be ignored',
UserWarning,
)
return data_meta
def read_config_yaml(filename: str, return_separately: bool = False):
with open(filename, 'r') as fstream:
inputs = yaml.safe_load(fstream)
model_meta, train_meta, data_meta = {}, {}, {}
for key, config in inputs.items():
if key == 'model':
model_meta = init_model_config(config)
elif key == 'train':
train_meta = init_train_config(config)
elif key == 'data':
data_meta = init_data_config(config)
else:
raise ValueError(f'Unexpected input {key} given')
if return_separately:
return model_meta, train_meta, data_meta
else:
model_meta.update(train_meta)
model_meta.update(data_meta)
return model_meta
def main():
filename = './input.yaml'
read_config_yaml(filename)
if __name__ == '__main__':
main()
model:
chemical_species: 'univ' # Ready for 119 elements
cutoff: 5.0
channel: 128
is_parity: False
lmax: 2
num_convolution_layer: 5
irreps_manual:
- "128x0e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e"
weight_nn_hidden_neurons: [64, 64]
radial_basis:
radial_basis_name: 'bessel'
bessel_basis_num: 8
cutoff_function:
cutoff_function_name: 'XPLOR'
cutoff_on: 4.5
act_gate: {'e': 'silu', 'o': 'tanh'}
act_scalar: {'e': 'silu', 'o': 'tanh'}
conv_denominator: 'avg_num_neigh'
train_shift_scale: False
train_denominator: False
self_connection_type: 'linear'
# Following are used to specify which part of the model would utilize fidelity-dependent parameters for multi-fidelity training.
# For detailed architecture, please refer to https://arxiv.org/abs/2409.07947
# Parts using fidelity-dependent weights are indicated as `Modified linear` layers in Figure 1.
use_modal_node_embedding: False # If true, use modified linear layer in atom-type embedding layer.
use_modal_self_inter_intro: True # If true, use modified linear layers in self-interaction block before the convolution in the interaction blocks.
use_modal_self_inter_outro: True # If true, use modified linear layers in self-interaction block after the convolution in the interaction blocks.
use_modal_output_block: True # If true, use modified linear layer in the output block.
train:
train_shuffle: True
random_seed: 777
is_train_stress : True
epoch: 200
loss: 'Huber'
loss_param:
delta: 0.01
optimizer: 'adam'
optim_param:
lr: 0.01
scheduler: 'linearlr'
scheduler_param:
start_factor: 1.0
total_iters: 200
end_factor: 0.0001
force_loss_weight : 1.00
stress_loss_weight: 0.01
error_record:
- ['Energy', 'MAE']
- ['Force', 'MAE']
- ['Stress', 'MAE']
- ['Energy', 'Loss']
- ['Force', 'Loss']
- ['Stress', 'Loss']
- ['TotalLoss', 'None']
per_epoch: 10
use_modality: True
use_weight: True
data:
batch_size: 64
shift: 'elemwise_reference_energies'
scale: 1.73
use_modal_wise_shift: True # If true, use different atomic energy shift for each database
use_modal_wise_scale: False # If true, use different atomic energy scale for each database
load_trainset_path:
- data_modality: pbe # Name of database
file_list:
- file: "**path to PBE database**" # ASE readable or .pt file (graph.pt)
data_weight:
energy: 1.0
force: 1.0 # This weight would be additionally multiplied to `force_loss_weight` for this database
stress: 1.0 # This weight would be additionally multiplied to `stress_loss_weight` for this database
- data_modality: r2scan
file_list:
- file: "**path to r2SCAN database**"
data_weight:
energy: 7.0
force: 7.0
stress: 7.0
load_pbe_validset_path: # any name starts with 'load' and ends with 'set_path'
- data_modality: pbe # modality must be given for mm valid set
file_list:
- file: "**path to PBE test set**"
load_scan_validset_path:
- data_modality: r2scan
file_list:
- file: "**path to r2SCAN test set**"
# Example input.yaml for training SevenNet.
# '*' signifies default. You can check log.sevenn for defaults.
model:
chemical_species: 'Auto' # Elements model should know. [ 'Univ' | 'Auto' | manual_user_input ]
cutoff: 5.0 # Cutoff radius in Angstroms. If two atoms are within the cutoff, they are connected.
channel: 32 # The multiplicity(channel) of node features.
lmax: 2 # Maximum order of irreducible representations (rotation order).
num_convolution_layer: 3 # The number of message passing layers.
#irreps_manual: # Manually set irreps of the model in each layer
#- "128x0e"
#- "128x0e+64x1e+32x2e"
#- "128x0e+64x1e+32x2e"
#- "128x0e+64x1e+32x2e"
#- "128x0e+64x1e+32x2e"
#- "128x0e"
weight_nn_hidden_neurons: [64, 64] # Hidden neurons in convolution weight neural network
radial_basis: # Function and its parameters to encode radial distance
radial_basis_name: 'bessel' # Only 'bessel' is currently supported
bessel_basis_num: 8
cutoff_function: # Envelop function, multiplied to radial_basis functions to init edge features
cutoff_function_name: 'poly_cut' # {'poly_cut' and 'poly_cut_p_value'} or {'XPLOR' and 'cutoff_on'}
poly_cut_p_value: 6
act_gate: {'e': 'silu', 'o': 'tanh'} # Equivalent to 'nonlinearity_gates' in nequip
act_scalar: {'e': 'silu', 'o': 'tanh'} # Equivalent to 'nonlinearity_scalars' in nequip
is_parity: False # Pairy True (E(3) group) or False (to SE(3) group)
self_connection_type: 'nequip' # Default is 'nequip'. 'linear' is used for SevenNet-0. I recommend 'linear' for 'Univ' chemical_species
conv_denominator: "avg_num_neigh" # Valid options are "avg_num_neigh*", "sqrt_avg_num_neigh", or float
train_denominator: False # Enable training for denominator in convolution layer
train_shift_scale: False # Enable training for shift & scale in output layer
train:
random_seed: 1
is_train_stress: True # Includes stress in the loss function
epoch: 200 # Ends training after this number of epochs
#loss: 'Huber' # Default is 'mse' (mean squared error)
#loss_param:
#delta: 0.01
# Each optimizer and scheduler have different available parameters.
# You can refer to sevenn/train/optim.py for supporting optimizer & schedulers
optimizer: 'adam' # Options available are 'sgd', 'adagrad', 'adam', 'adamw', 'radam'
optim_param:
lr: 0.005
scheduler: 'exponentiallr' # 'steplr', 'multisteplr', 'exponentiallr', 'cosineannealinglr', 'reducelronplateau', 'linearlr'
scheduler_param:
gamma: 0.99
force_loss_weight: 0.1 # Coefficient for force loss
stress_loss_weight: 1e-06 # Coefficient for stress loss (to kbar unit)
per_epoch: 10 # Generate checkpoints every this epoch
# ['target y', 'metric']
# Target y: TotalEnergy, Energy, Force, Stress, Stress_GPa, TotalLoss
# Metric : RMSE, MAE, or Loss
error_record:
- ['Energy', 'RMSE']
- ['Force', 'RMSE']
- ['Stress', 'RMSE']
- ['TotalLoss', 'None']
# Continue training model from given checkpoint, or pre-trained model checkpoint for fine-tuning
#continue:
#checkpoint: 'checkpoint_best.pth' # Checkpoint of pre-trained model or a model want to continue training.
#reset_optimizer: False # Set True for fine-tuning
#reset_scheduler: False # Set True for fine-tuning
data:
batch_size: 4 # Per GPU batch size.
shift: 'per_atom_energy_mean' # One of 'per_atom_energy_mean*', 'elemwise_reference_energies', float
scale: 'force_rms' # One of 'force_rms*', 'per_atom_energy_std', float
# SevenNet automatically matches data format from its filename.
# For those not `structure_list` or `.pt` files, assumes it is ASE readable
# In this case, below arguments are directly passed to `ase.io.read`
data_format_args:
index: ':' # see `https://wiki.fysik.dtu.dk/ase/ase/io/io.html` for more valid arguments
# validset is needed if you want '_best.pth' during training. If not, both validset and testset is optional.
load_trainset_path: ['./train_*.extxyz'] # Example of using ase as data_format, support multiple files and expansion(*)
load_validset_path: ['./valid.extxyz']
load_testset_path: ['./sevenn_data/mydata.pt'] # Graph can be preprocessed using `sevenn_graph_build` and accessible like this
# Example input.yaml for fine-tuning sevennet-0
# '*' signifies default. You can check log.sevenn for defaults.
model: # model keys should be consistent except for train_* keys
chemical_species: 'Auto'
cutoff: 5.0
channel: 128
is_parity: False
lmax: 2
num_convolution_layer: 5
irreps_manual:
- "128x0e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e"
weight_nn_hidden_neurons: [64, 64]
radial_basis:
radial_basis_name: 'bessel'
bessel_basis_num: 8
cutoff_function:
cutoff_function_name: 'XPLOR'
cutoff_on: 4.5
self_connection_type: 'linear'
train_shift_scale: False # customizable (True | False)
train_denominator: False # customizable (True | False)
train: # Customizable
random_seed: 1
is_train_stress: True
epoch: 100
loss: 'Huber' # keeping original loss function give better ft result
loss_param:
delta: 0.01
optimizer: 'adam'
optim_param:
lr: 0.004
scheduler: 'exponentiallr'
scheduler_param:
gamma: 0.99
force_loss_weight: 1.0
stress_loss_weight: 0.01
per_epoch: 10 # Generate checkpoints every this epoch
# ['target y', 'metric']
# Target y: TotalEnergy, Energy, Force, Stress, Stress_GPa, TotalLoss
# Metric : RMSE, MAE, or Loss
error_record:
- ['Energy', 'RMSE']
- ['Force', 'RMSE']
- ['Stress', 'RMSE']
- ['TotalLoss', 'None']
continue:
reset_optimizer: True
reset_scheduler: True
reset_epoch: True
checkpoint: 'SevenNet-0_11July2024'
data: # Customizable
batch_size: 4
data_divide_ratio: 0.1
# SevenNet automatically matches data format from its filename.
# For those not `structure_list` or `.pt` files, assumes it is ASE readable
# In this case, below arguments are directly passed to `ase.io.read`
data_format_args:
index: ':' # see `https://wiki.fysik.dtu.dk/ase/ase/io/io.html` for more valid arguments
# validset is needed if you want '_best.pth' during training. If not, both validset and testset is optional.
load_trainset_path: ['./train_*.extxyz'] # Example of using ase as data_format, support multiple files and expansion(*)
load_validset_path: ['./valid.extxyz']
load_testset_path: ['./sevenn_data/mydata.pt'] # Graph can be preprocessed using `sevenn_graph_build` and accessible like this
# Application of 7net-0 on liquid electrolyte system via fine-tuning
# Paper: https://arxiv.org/abs/2501.05211
model:
# parameters of SevenNet-0, should not be changed
chemical_species: 'auto'
cutoff: 5.0
channel: 128
is_parity: False
lmax: 2
num_convolution_layer: 5
irreps_manual:
- "128x0e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e"
weight_nn_hidden_neurons: [64, 64]
radial_basis:
radial_basis_name: 'bessel'
bessel_basis_num: 8
cutoff_function:
cutoff_function_name: 'XPLOR'
cutoff_on: 4.5
act_gate: {'e': 'silu', 'o': 'tanh'}
act_scalar: {'e': 'silu', 'o': 'tanh'}
self_connection_type: 'linear'
# useful for fine-tuning
train_shift_scale: True
train_avg_num_neigh: True
train:
random_seed: 1
is_train_stress: True
epoch: 100 # we went through 100 epochs and chose checkpoint at 50 epoch where the error have reached plateau.
loss: 'Huber'
loss_param:
delta: 0.01
optimizer: 'adam'
optim_param:
lr: 0.0001
scheduler: 'linearlr'
scheduler_param:
start_factor: 1.0
total_iters: 600
end_factor: 0.000001
force_loss_weight: 1.00
stress_loss_weight: 1.00 # 7net-0 quantitatively lacked accuracy on pressure histograms compared to DFT, so we increased stress loss weight
error_record:
- ['Energy', 'RMSE']
- ['Force', 'RMSE']
- ['Stress', 'RMSE']
- ['Energy', 'MAE']
- ['Force', 'MAE']
- ['Stress', 'MAE']
- ['Energy', 'Loss']
- ['Force', 'Loss']
- ['Stress', 'Loss']
- ['TotalLoss', 'None']
per_epoch: 10 # Generate epoch every this number of times
continue:
use_statistic_values_of_checkpoint: True
checkpoint: '7net-0' # fine-tuning from 7net-0
reset_optimizer: True
reset_scheduler: True
data:
batch_size: 1 # our fine-tuning dataset had ~360 atoms per structure, so we used batch size of 1 to avoid GPU OOM error.
shift: 'elemwise_reference_energies'
scale: 1.858
data_format: 'ase'
data_divide_ratio: 0.05
load_dataset_path: ["./data/total.extxyz"]
model:
chemical_species: 'univ' # Ready for 119 elements
cutoff: 6.0
channel: 128
is_parity: False
lmax: 3
num_convolution_layer: 3
irreps_manual:
- "128x0e"
- "128x0e+64x1e+32x2e+16x3e"
- "128x0e+64x1e+32x2e+16x3e"
- "128x0e"
weight_nn_hidden_neurons: [64, 64]
radial_basis:
radial_basis_name: 'bessel'
bessel_basis_num: 8
cutoff_function:
cutoff_function_name: 'XPLOR'
cutoff_on: 5.5
act_gate: {'e': 'silu', 'o': 'tanh'}
act_scalar: {'e': 'silu', 'o': 'tanh'}
conv_denominator: 'avg_num_neigh'
train_shift_scale: True
train_denominator: False
self_connection_type: 'linear'
# Following are used to specify which part of the model would utilize fidelity-dependent parameters for multi-fidelity training.
# For detailed architecture, please refer to https://arxiv.org/abs/2409.07947
# Parts using fidelity-dependent weights are indicated as `Modified linear` layers in Figure 1.
use_modal_node_embedding: False # If true, use modified linear layer in atom-type embedding layer.
use_modal_self_inter_intro: True # If true, use modified linear layers in self-interaction block before the convolution in the interaction blocks.
use_modal_self_inter_outro: True # If true, use modified linear layers in self-interaction block after the convolution in the interaction blocks.
use_modal_output_block: True # If true, use modified linear layer in the output block.
train:
train_shuffle: True
random_seed: 777
is_train_stress : True
epoch: 200
loss: 'Huber'
loss_param:
delta: 0.01
optimizer: 'adam'
optim_param:
lr: 0.01
scheduler: 'linearlr'
scheduler_param:
start_factor: 1.0
total_iters: 200
end_factor: 0.0001
force_loss_weight : 1.00
stress_loss_weight: 0.01
error_record:
- ['Energy', 'MAE']
- ['Force', 'MAE']
- ['Stress', 'MAE']
- ['Energy', 'Loss']
- ['Force', 'Loss']
- ['Stress', 'Loss']
- ['TotalLoss', 'None']
per_epoch: 10
use_modality: True
use_weight: True
data:
batch_size: 16
shift: 'elemwise_reference_energies'
scale: 'force_rms'
use_modal_wise_shift: True # If true, use different atomic energy shift for each database
use_modal_wise_scale: False # If true, use different atomic energy scale for each database
load_trainset_path:
- data_modality: pbe # Name of database
file_list:
- file: "path to pbe dataset" # ASE readable or .pt file (graph.pt)
data_weight:
energy: 1.0
force: 0.1 # This weight would be additionally multiplied to `force_loss_weight` for this database
stress: 1.0 # This weight would be additionally multiplied to `stress_loss_weight` for this database
- data_modality: scan
file_list:
- file: "path to scan dataset"
data_weight:
energy: 1.0
force: 10.0
stress: 1.0
load_pbe_validset_path: # any name starts with 'load' and ends with 'set_path'
- data_modality: pbe # modality must be given for mm valid set
file_list:
- file: "path to pbe validset"
load_scan_validset_path:
- data_modality: scan
file_list:
- file: "path to scan validset"
# SevenNet-0, should be run with `sevenn -m train_v1` as it uses old routine
model:
chemical_species: 'auto'
cutoff: 5.0
channel: 128
is_parity: False
lmax: 2
num_convolution_layer: 5
irreps_manual:
- "128x0e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e"
weight_nn_hidden_neurons: [64, 64]
radial_basis:
radial_basis_name: 'bessel'
bessel_basis_num: 8
cutoff_function:
cutoff_function_name: 'XPLOR'
cutoff_on: 4.5
act_gate: {'e': 'silu', 'o': 'tanh'}
act_scalar: {'e': 'silu', 'o': 'tanh'}
conv_denominator: 'avg_num_neigh'
train_shift_scale: False
train_denominator: False
self_connection_type: 'linear'
train:
train_shuffle: False
random_seed: 1
is_train_stress : True
epoch: 600
loss: 'Huber'
loss_param:
delta: 0.01
optimizer: 'adam'
optim_param:
lr: 0.01
scheduler: 'linearlr'
scheduler_param:
start_factor: 1.0
total_iters: 600
end_factor: 0.0001
force_loss_weight : 1.00
stress_loss_weight: 0.01
error_record:
- ['Energy', 'RMSE']
- ['Force', 'RMSE']
- ['Stress', 'RMSE']
- ['Energy', 'MAE']
- ['Force', 'MAE']
- ['Stress', 'MAE']
- ['Energy', 'Loss']
- ['Force', 'Loss']
- ['Stress', 'Loss']
- ['TotalLoss', 'None']
per_epoch: 10
# continue:
# checkpoint: './checkpoint_last.pth'
# reset_optimizer: False
# reset_scheduler: False
data:
batch_size: 128 # per GPU batch size, as the model trained with 32 GPUs, the effective batch size equals 4096.
scale: 'per_atom_energy_std'
shift: 'elemwise_reference_energies'
data_format: 'ase'
save_by_train_valid: False
load_dataset_path: ["path_to_MPtrj_total.sevenn_data"]
load_validset_path: ["validaset.sevenn_data"]
model:
chemical_species: auto
cutoff: 5.0
irreps_manual:
- 128x0e
- 128x0e+64x1e+32x2e+32x3e
- 128x0e+64x1e+32x2e+32x3e
- 128x0e+64x1e+32x2e+32x3e
- 128x0e+64x1e+32x2e+32x3e
- 128x0e
channel: 128
lmax: 3
num_convolution_layer: 5
is_parity: false
radial_basis:
radial_basis_name: bessel
bessel_basis_num: 8
cutoff_function:
cutoff_function_name: poly_cut
poly_cut_p_value: 6
act_radial: silu
weight_nn_hidden_neurons:
- 64
- 64
act_scalar:
e: silu
o: tanh
act_gate:
e: silu
o: tanh
train_denominator: false
train_shift_scale: false
use_bias_in_linear: false
readout_as_fcn: false
self_connection_type: linear
interaction_type: nequip
train:
random_seed: 1
epoch: 600
loss: Huber
loss_param:
delta: 0.01
optimizer: adam
optim_param:
lr: 0.01
scheduler: linearlr
scheduler_param:
start_factor: 1.0
total_iters: 600
end_factor: 0.0001
force_loss_weight: 1.0
stress_loss_weight: 0.01
per_epoch: 10
is_train_stress: true
train_shuffle: true
error_record:
- - Energy
- MAE
- - Energy
- RMSE
- - Force
- MAE
- - Force
- RMSE
- - Stress
- MAE
- - Stress
- RMSE
- - Energy
- Loss
- - Force
- Loss
- - Stress
- Loss
- - TotalLoss
- None
best_metric: TotalLoss
data:
data_format: ase
data_format_args: {}
batch_size: 1024 # global batch size, should be divided by the number of GPUs
load_trainset_path: '**path_to_trainset**'
load_validset_path: '**path_to_validset**'
shift: 'elemwise_reference_energies'
scale: 'force_rms'
"""
Debt
keep old pre-trained checkpoints unchanged.
"""
import copy
import torch
import sevenn._keys as KEY
def version_tuple(v1):
v1 = tuple(map(int, v1.split('.')))
return v1
def patch_old_config(config):
version = config.get('version', None)
if not version:
raise ValueError('No version found in config')
major, minor, _ = version.split('.')[:3]
major, minor = int(major), int(minor)
if major == 0 and minor <= 9:
if config[KEY.CUTOFF_FUNCTION][KEY.CUTOFF_FUNCTION_NAME] == 'XPLOR':
config[KEY.CUTOFF_FUNCTION].pop('poly_cut_p_value', None)
if KEY.TRAIN_DENOMINTAOR not in config:
config[KEY.TRAIN_DENOMINTAOR] = config.pop('train_avg_num_neigh', False)
_opt = config.pop('optimize_by_reduce', None)
if _opt is False:
raise ValueError(
'This checkpoint(optimize_by_reduce: False) is no longer supported'
)
if KEY.CONV_DENOMINATOR not in config:
config[KEY.CONV_DENOMINATOR] = 0.0
if KEY._NORMALIZE_SPH not in config:
config[KEY._NORMALIZE_SPH] = False
return config
def map_old_model(old_model_state_dict):
"""
For compatibility with old namings (before 'correct' branch merged 2404XX)
Map old model's module names to new model's module names
"""
_old_module_name_mapping = {
'EdgeEmbedding': 'edge_embedding',
'reducing nn input to hidden': 'reduce_input_to_hidden',
'reducing nn hidden to energy': 'reduce_hidden_to_energy',
'rescale atomic energy': 'rescale_atomic_energy',
}
for i in range(10):
_old_module_name_mapping[f'{i} self connection intro'] = (
f'{i}_self_connection_intro'
)
_old_module_name_mapping[f'{i} convolution'] = f'{i}_convolution'
_old_module_name_mapping[f'{i} self interaction 2'] = (
f'{i}_self_interaction_2'
)
_old_module_name_mapping[f'{i} equivariant gate'] = f'{i}_equivariant_gate'
new_model_state_dict = {}
for k, v in old_model_state_dict.items():
key_name = k.split('.')[0]
follower = '.'.join(k.split('.')[1:])
if 'denumerator' in follower:
follower = follower.replace('denumerator', 'denominator')
if key_name in _old_module_name_mapping:
new_key_name = _old_module_name_mapping[key_name] + '.' + follower
new_model_state_dict[new_key_name] = v
else:
new_model_state_dict[k] = v
return new_model_state_dict
def sort_old_convolution(model_now, state_dict):
from e3nn.o3 import wigner_3j
"""
Reason1: we have to sort instructions of convolution to be compatible with
cuEquivariance. (therefore, sort weight)
Reason2: some of old convolution module's w3j coeff has flipped sign. This also
has to be fixed to be compatible with cuEquivarinace.
"""
def patch(stct):
inst_old = copy.copy(conv._instructions_before_sort)
inst_old = [(inst[0], inst[1], inst[2]) for inst in inst_old]
del conv._instructions_before_sort
conv_args = conv.convolution_kwargs
irreps_in1 = conv_args['irreps_in1']
irreps_in2 = conv_args['irreps_in2']
irreps_out = conv_args.get('irreps_out', conv_args.get('filter_irreps_out'))
inst_sorted = sorted(inst_old, key=lambda x: x[2])
inst_sorted = [
# in1, in2, out, weights
(inst[0], inst[1], inst[2], irreps_in1[inst[0]].mul)
for inst in inst_sorted
]
n = len(weight_nn.hs) - 2
ww_key = f'{conv_key}.weight_nn.layer{n}.weight'
ww = stct[ww_key]
ww_sorted = [None] * len(inst_old)
_prev_idx = 0
for ist_src in inst_old:
for j, ist_dst in enumerate(inst_sorted):
if not all(ist_src[ii] == ist_dst[ii] for ii in range(3)):
continue
numel = ist_dst[3] # weight num
ww_src = ww[:, _prev_idx : _prev_idx + numel]
l1, l2, l3 = (
irreps_in1[ist_src[0]].ir.l,
irreps_in2[ist_src[1]].ir.l,
irreps_out[ist_src[2]].ir.l,
)
if l1 > 0 and l2 > 0 and l3 > 0:
w3j_key = f'_w3j_{l1}_{l2}_{l3}'
conv_w3j_key = (
f'{conv_key}.convolution._compiled_main_left_right.{w3j_key}'
)
w3j_old = stct[conv_w3j_key]
w3j_now = wigner_3j(l1, l2, l3)
if not torch.allclose(w3j_old.to(w3j_now.device), w3j_now):
assert torch.allclose(
w3j_old.to(w3j_now.device), -1 * w3j_now
)
ww_src = -1 * ww_src
stct[conv_w3j_key] *= -1 # stct updated
_prev_idx += numel
ww_sorted[j] = ww_src
ww_sorted = torch.cat(ww_sorted, dim=1) # type: ignore
stct[ww_key] = ww_sorted.clone() # stct updated
conv_dicts = {}
for k, v in state_dict.items():
key_name = k.split('.')[0]
if key_name.split('_')[1] == 'convolution':
if key_name not in conv_dicts:
conv_dicts[key_name] = {}
conv_dicts[key_name].update({k: v})
new_state_dict = {}
new_state_dict.update(state_dict)
for conv_key, conv_state_dict in conv_dicts.items():
conv = model_now._modules[conv_key]
weight_nn = conv.weight_nn
patch(conv_state_dict)
new_state_dict.update(conv_state_dict)
return new_state_dict
def patch_state_dict_if_old(state_dict, config_cp, now_model):
version = config_cp.get('version', None)
if not version:
raise ValueError('No version found in config')
vs = version.split('.')
vsuffix = ''
if len(vs) == 4:
vsuffix = vs[-1]
vs = version_tuple('.'.join(vs[:3]))
else:
vs = version_tuple('.'.join(vs))
if vs < version_tuple('0.10.0'):
state_dict = map_old_model(state_dict)
# TODO: change version criteria before release!!!
# it causes problem if model is sorted but this function is called
# ... more robust way? idk
if vs < version_tuple('0.11.0') or (
vs == version_tuple('0.11.0') and vsuffix == 'dev0'
):
state_dict = sort_old_convolution(now_model, state_dict)
return state_dict
import math
from typing import List
import torch
import torch.nn as nn
from e3nn.o3 import Irreps, Linear
import sevenn._keys as KEY
from sevenn.model_build import build_E3_equivariant_model
modal_module_dict = {
KEY.USE_MODAL_NODE_EMBEDDING: 'onehot_to_feature_x',
KEY.USE_MODAL_SELF_INTER_INTRO: 'self_interaction_1',
KEY.USE_MODAL_SELF_INTER_OUTRO: 'self_interaction_2',
KEY.USE_MODAL_OUTPUT_BLOCK: 'reduce_input_to_hidden',
}
def _get_scalar_index(irreps: Irreps):
scalar_indices = []
for idx, (_, (l, p)) in enumerate(irreps): # noqa
if (
l == 0 and p == 1
): # get index of parameter for scalar (0e), which is used for modality
scalar_indices.append(idx)
return scalar_indices
def _reshape_weight_of_linear(
irreps_in: Irreps, irreps_out: Irreps, weight: torch.Tensor
) -> List[torch.Tensor]:
linear = Linear(irreps_in, irreps_out)
linear.weight = nn.Parameter(weight)
return list(linear.weight_views())
def _erase_linear_modal_params(
model_state_dct: dict,
erase_modal_indices: List[int],
key: str,
irreps_in: Irreps,
irreps_out: Irreps,
):
orig_input_dim = irreps_in.count('0e')
new_input_dim = orig_input_dim - len(erase_modal_indices)
orig_weight = model_state_dct[key + '.linear.weight']
scalar_idx = _get_scalar_index(irreps_in)
linear_weight_list = _reshape_weight_of_linear(
irreps_in, irreps_out, orig_weight
)
new_weight_list = []
for idx, l_p_weight in enumerate(linear_weight_list[:-1]):
new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze()
if idx in scalar_idx:
new_weight = new_weight * math.sqrt(new_input_dim / orig_input_dim)
new_weight_list.append(new_weight)
"""
Following works for normalization = `path`, which is not used in SEVENNet
for l_p_weight in linear_weight_list[:-1]:
new_weight_list.append(torch.reshape(l_p_weight, (1, -1)).squeeze())
"""
flattened_weight = torch.cat(new_weight_list)
return flattened_weight
def _get_modal_weight_as_bias(
model_state_dct: dict,
key: str,
ref_index: int,
irreps_in: Irreps,
irreps_out: Irreps,
):
assert ref_index != -1
input_dim = irreps_in.count('0e')
output_dim = irreps_out.count('0e')
orig_weight = model_state_dct[key + '.linear.weight']
orig_bias = model_state_dct[key + '.linear.bias']
if len(orig_bias) == 0:
orig_bias = torch.zeros(output_dim, dtype=orig_weight.dtype)
modal_weight = _reshape_weight_of_linear(
irreps_in, irreps_out, orig_weight
)[-1]
new_bias = orig_bias + modal_weight[ref_index] / math.sqrt(input_dim)
return new_bias
def _append_modal_weight(
model_state_dct: dict, # state dict to be targeted
key: str, # linear weight modune name
irreps_in: Irreps, # irreps_in before modality append
irreps_out: Irreps,
append_number: int,
):
# This works for normalization = `element`, default in SEVENNet.
# (normalization = `path` is curruently deprecated in SEVENNet.)
input_dim = irreps_in.count('0e')
output_dim = irreps_out.count('0e')
new_input_dim = input_dim + append_number
orig_weight = model_state_dct[key + '.linear.weight']
scalar_idx = _get_scalar_index(irreps_in)
linear_weight_list = _reshape_weight_of_linear(
irreps_in, irreps_out, orig_weight
)
new_weight_list = []
# TODO: combine following as function with _erase_linear_modal_params
for idx, l_p_weight in enumerate(linear_weight_list):
new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze()
if idx in scalar_idx:
new_weight = new_weight * math.sqrt(new_input_dim / input_dim)
new_weight_list.append(new_weight)
flattened_weight_list = []
for l_p_weight in new_weight_list:
flattened_weight_list.append(
torch.reshape(l_p_weight, (1, -1)).squeeze()
)
flattened_weight = torch.cat(flattened_weight_list)
append_weight = torch.cat([
flattened_weight,
torch.zeros(append_number * output_dim, dtype=flattened_weight.dtype),
]) # zeros: starting from common model
return append_weight
def get_single_modal_model_dct(
model_state_dct: dict,
config: dict,
ref_modal: str,
from_processing_cp: bool = False,
is_deploy: bool = False,
):
"""
Convert multimodal model state dictionary to single modal model.
Modal is selected by `ref_modal`
`model_state_dct`: model state dictionary from multimodal checkpoint file
`config`: dictionary containing configuration of the checkpoint model
`ref_modal`: modal that are going to be converted
`from_processing_cp`: if True, use modal_map of the checkpoint file
`is_deploy`: if True, model is build with single-modal shift and scale
"""
if (
not from_processing_cp and not config[KEY.USE_MODALITY]
): # model is already single modal
return model_state_dct
config[KEY.USE_BIAS_IN_LINEAR] = True
config['_deploy'] = is_deploy
model = build_E3_equivariant_model(config)
del config['_deploy']
key_add = '_cp' if from_processing_cp else ''
modal_type_dict = config[KEY.MODAL_MAP + key_add]
erase_modal_indices = range(len(modal_type_dict.keys())) # starts with 0
if ref_modal != 'common':
try:
ref_modal_index = modal_type_dict[ref_modal]
except:
raise KeyError(
f'{ref_modal} not in modal type. Use one of'
f' {modal_type_dict.keys()}.'
)
for module_key in model._modules.keys():
for (
use_modal_module_key,
modal_module_name,
) in modal_module_dict.items():
irreps_out = Irreps(model.get_irreps_in(module_key, 'irreps_out'))
# TODO: directly using "irreps_in" might not be compatible
# when changing `nn/linear.py`
output_dim = irreps_out.count('0e')
if (
config[use_modal_module_key]
and modal_module_name in module_key
): # this module is used for giving modality
irreps_in = Irreps(
model.get_irreps_in(module_key, 'irreps_in')
)
new_bias = (
torch.zeros(output_dim)
if ref_modal == 'common'
else _get_modal_weight_as_bias(
model_state_dct,
module_key,
ref_modal_index,
irreps_in, # type: ignore
irreps_out, # type: ignore
)
)
erased_modal_weight = _erase_linear_modal_params(
model_state_dct,
erase_modal_indices,
module_key,
irreps_in, # type: ignore
irreps_out, # type: ignore
)
model_state_dct[module_key + '.linear.weight'] = (
erased_modal_weight
)
model_state_dct[module_key + '.linear.bias'] = new_bias
elif modal_module_name in module_key:
model_state_dct[module_key + '.linear.bias'] = torch.zeros(
output_dim,
dtype=model_state_dct[module_key + '.linear.weight'].dtype,
)
final_block_key = 'reduce_hidden_to_energy'
model_state_dct[final_block_key + '.linear.bias'] = torch.tensor(
[0], dtype=model_state_dct[final_block_key + '.linear.weight'].dtype
)
if config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SHIFT]:
rescaler_names = []
if config[KEY.USE_MODAL_WISE_SHIFT]:
rescaler_names.append('shift')
if config[KEY.USE_MODAL_WISE_SCALE]:
rescaler_names.append('scale')
config[KEY.USE_MODAL_WISE_SHIFT] = False
config[KEY.USE_MODAL_WISE_SCALE] = False
for rescaler_name in rescaler_names:
rescaler_key = 'rescale_atomic_energy.' + rescaler_name
rescaler = model_state_dct[rescaler_key][ref_modal_index]
model_state_dct.update({rescaler_key: rescaler})
config.update({rescaler_name: rescaler})
config[KEY.USE_MODALITY] = False
return model_state_dct
def append_modality_to_model_dct(
model_state_dct: dict,
config: dict,
orig_num_modal: int,
append_modal_length: int,
):
"""
Append modal-wise parameters to the original linear layers.
This enables expanding modal to single/multi modal model checkpoint.
`model_state_dct`: model state dictionary from multimodal checkpoint file
`config`: dictionary containing configuration of the checkpoint model
+ modality appended
`orig_num_modal`: Number of modality used in original checkpoint
`append_modal_length`: Number of modality to be appended in new checkpoint.
"""
config_num_modal = config[KEY.NUM_MODALITIES]
config.update({KEY.NUM_MODALITIES: orig_num_modal, KEY.USE_MODALITY: True})
model = build_E3_equivariant_model(config)
for module_key in model._modules.keys():
for (
use_modal_module_key,
modal_module_name,
) in modal_module_dict.items():
if (
config[use_modal_module_key]
and modal_module_name in module_key
): # this module is used for giving modality
irreps_in = model.get_irreps_in(
module_key, 'irreps_in'
)
# TODO: directly using "irreps_in" might not be compatible
# when changing `nn/linear.py`
irreps_out = model.get_irreps_in(module_key, 'irreps_out')
irreps_in, irreps_out = Irreps(irreps_in), Irreps(irreps_out)
append_weight = _append_modal_weight(
model_state_dct,
module_key,
irreps_in, # type: ignore
irreps_out, # type: ignore
append_modal_length,
)
model_state_dct[module_key + '.linear.weight'] = append_weight
config[KEY.NUM_MODALITIES] = config_num_modal
return model_state_dct
import os
from datetime import datetime
from typing import Optional
import e3nn.util.jit
import torch
import torch.nn
from ase.data import chemical_symbols
import sevenn._keys as KEY
from sevenn import __version__
from sevenn.model_build import build_E3_equivariant_model
from sevenn.util import load_checkpoint
def deploy(checkpoint, fname='deployed_serial.pt', modal: Optional[str] = None):
"""
This method is messy to avoid changes in pair_e3gnn.cpp, while
refactoring python part.
If changes the behavior, and accordingly pair_e3gnn.cpp,
we have to recompile LAMMPS (which I always want to procrastinate)
"""
from sevenn.nn.edge_embedding import EdgePreprocess
from sevenn.nn.force_output import ForceStressOutput
cp = load_checkpoint(checkpoint)
model, config = cp.build_model('e3nn'), cp.config
model.prepand_module('edge_preprocess', EdgePreprocess(True))
grad_module = ForceStressOutput()
model.replace_module('force_output', grad_module)
new_grad_key = grad_module.get_grad_key()
model.key_grad = new_grad_key
if hasattr(model, 'eval_type_map'):
setattr(model, 'eval_type_map', False)
if modal:
model.prepare_modal_deploy(modal)
elif model.modal_map is not None and len(model.modal_map) >= 1:
raise ValueError(
f'Modal is not given. It has: {list(model.modal_map.keys())}'
)
model.set_is_batch_data(False)
model.eval()
model = e3nn.util.jit.script(model)
model = torch.jit.freeze(model)
# make some config need for md
md_configs = {}
type_map = config[KEY.TYPE_MAP]
chem_list = ''
for Z in type_map.keys():
chem_list += chemical_symbols[Z] + ' '
chem_list.strip()
md_configs.update({'chemical_symbols_to_index': chem_list})
md_configs.update({'cutoff': str(config[KEY.CUTOFF])})
md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])})
md_configs.update(
{'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')}
)
md_configs.update({'version': __version__})
md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')})
md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')})
if fname.endswith('.pt') is False:
fname += '.pt'
torch.jit.save(model, fname, _extra_files=md_configs)
# TODO: build model only once
def deploy_parallel(
checkpoint, fname='deployed_parallel', modal: Optional[str] = None
):
# Additional layer for ghost atom (and copy parameters from original)
GHOST_LAYERS_KEYS = ['onehot_to_feature_x', '0_self_interaction_1']
cp = load_checkpoint(checkpoint)
model, config = cp.build_model('e3nn'), cp.config
config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': False}
model_state_dct = model.state_dict()
model_list = build_E3_equivariant_model(config, parallel=True)
dct_temp = {}
copy_counter = {gk: 0 for gk in GHOST_LAYERS_KEYS}
for ghost_layer_key in GHOST_LAYERS_KEYS:
for key, val in model_state_dct.items():
if not key.startswith(ghost_layer_key):
continue
dct_temp.update({f'ghost_{key}': val})
copy_counter[ghost_layer_key] += 1
# Ensure reference weights are copied from state dict
assert all(x > 0 for x in copy_counter.values())
model_state_dct.update(dct_temp)
for model_part in model_list:
missing, _ = model_part.load_state_dict(model_state_dct, strict=False)
if hasattr(model_part, 'eval_type_map'):
setattr(model_part, 'eval_type_map', False)
# Ensure all values are inserted
assert len(missing) == 0, missing
if modal:
model_list[0].prepare_modal_deploy(modal)
elif model_list[0].modal_map is not None:
raise ValueError(
f'Modal is not given. It has: {list(model_list[0].modal_map.keys())}'
)
# prepare some extra information for MD
md_configs = {}
type_map = config[KEY.TYPE_MAP]
chem_list = ''
for Z in type_map.keys():
chem_list += chemical_symbols[Z] + ' '
chem_list.strip()
comm_size = max(
[
seg._modules[f'{t}_convolution']._comm_size # type: ignore
for t, seg in enumerate(model_list)
]
)
md_configs.update({'chemical_symbols_to_index': chem_list})
md_configs.update({'cutoff': str(config[KEY.CUTOFF])})
md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])})
md_configs.update({'comm_size': str(comm_size)})
md_configs.update(
{'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')}
)
md_configs.update({'version': __version__})
md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')})
md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')})
os.makedirs(fname)
for idx, model in enumerate(model_list):
fname_full = f'{fname}/deployed_parallel_{idx}.pt'
model.set_is_batch_data(False)
model.eval()
model = e3nn.util.jit.script(model)
model = torch.jit.freeze(model)
torch.jit.save(model, fname_full, _extra_files=md_configs)
import os
from typing import List, Optional
from sevenn.logger import Logger
from sevenn.train.dataset import AtomGraphDataset
from sevenn.util import unique_filepath
def build_sevennet_graph_dataset(
source: List[str],
cutoff: float,
num_cores: int,
out: str,
filename: str,
metadata: Optional[dict] = None,
**fmt_kwargs,
):
from sevenn.train.graph_dataset import SevenNetGraphDataset
log = Logger()
if metadata is None:
metadata = {}
log.timer_start('graph_build')
db = SevenNetGraphDataset(
cutoff=cutoff,
root=out,
files=source,
processed_name=filename,
process_num_cores=num_cores,
**fmt_kwargs,
)
log.timer_end('graph_build', 'graph build time')
log.writeline(f'Graph saved: {db.processed_paths[0]}')
log.bar()
for k, v in metadata.items():
log.format_k_v(k, v, write=True)
log.bar()
log.writeline('Distribution:')
log.statistic_write(db.statistics)
log.format_k_v('# atoms (node)', db.natoms, write=True)
log.format_k_v('# structures (graph)', len(db), write=True)
def dataset_finalize(dataset, metadata, out):
"""
Deprecated
"""
natoms = dataset.get_natoms()
species = dataset.get_species()
metadata = {
**metadata,
'natoms': natoms,
'species': species,
}
dataset.meta = metadata
if os.path.isdir(out):
out = os.path.join(out, 'graph_built.sevenn_data')
elif out.endswith('.sevenn_data') is False:
out = out + '.sevenn_data'
out = unique_filepath(out)
log = Logger()
log.writeline('The metadata of the dataset is...')
for k, v in metadata.items():
log.format_k_v(k, v, write=True)
dataset.save(out)
log.writeline(f'dataset is saved to {out}')
return dataset
def build_script(
source: List[str],
cutoff: float,
num_cores: int,
out: str,
metadata: Optional[dict] = None,
**fmt_kwargs,
):
"""
Deprecated
"""
from sevenn.train.dataload import file_to_dataset, match_reader
if metadata is None:
metadata = {}
log = Logger()
dataset = AtomGraphDataset({}, cutoff)
common_args = {
'cutoff': cutoff,
'cores': num_cores,
'label': 'graph_build',
}
log.timer_start('graph_build')
for path in source:
if os.path.isdir(path):
continue
log.writeline(f'Read: {path}')
basename = os.path.basename(path)
if 'structure_list' in basename:
fmt = 'structure_list'
else:
fmt = 'ase'
reader, rmeta = match_reader(fmt, **fmt_kwargs)
metadata.update(**rmeta)
dataset.augment(
file_to_dataset(
file=path,
reader=reader,
**common_args,
)
)
log.timer_end('graph_build', 'graph build time')
dataset_finalize(dataset, metadata, out)
import csv
import os
from typing import Iterable, List, Optional, Union
import numpy as np
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import sevenn._keys as KEY
import sevenn.util as util
from sevenn.atom_graph_data import AtomGraphData
from sevenn.train.graph_dataset import SevenNetGraphDataset
from sevenn.train.modal_dataset import SevenNetMultiModalDataset
def write_inference_csv(output_list, out):
for i, output in enumerate(output_list):
output = output.fit_dimension()
output[KEY.STRESS] = output[KEY.STRESS] * 1602.1766208
output[KEY.PRED_STRESS] = output[KEY.PRED_STRESS] * 1602.1766208
output_list[i] = output.to_numpy_dict()
per_graph_keys = [
KEY.NUM_ATOMS,
KEY.USER_LABEL,
KEY.ENERGY,
KEY.PRED_TOTAL_ENERGY,
KEY.STRESS,
KEY.PRED_STRESS,
]
per_atom_keys = [
KEY.ATOMIC_NUMBERS,
KEY.ATOMIC_ENERGY,
KEY.POS,
KEY.FORCE,
KEY.PRED_FORCE,
]
def unfold_dct_val(dct, keys, suffix_list=None):
res = {}
if suffix_list is None:
suffix_list = range(100)
for k in keys:
if k not in dct:
res[k] = '-'
elif isinstance(dct[k], np.ndarray) and dct[k].ndim != 0:
res.update(
{f'{k}_{suffix_list[i]}': v for i, v in enumerate(dct[k])}
)
else:
res[k] = dct[k]
return res
def per_atom_dct_list(dct, keys):
sfx_list = ['x', 'y', 'z']
res = []
natoms = dct[KEY.NUM_ATOMS]
extracted = {k: dct[k] for k in keys}
for i in range(natoms):
raw = {}
raw.update({k: v[i] for k, v in extracted.items()})
per_atom_dct = unfold_dct_val(raw, keys, suffix_list=sfx_list)
res.append(per_atom_dct)
return res
try:
with open(f'{out}/info.csv', 'w', newline='') as f:
header = output_list[0][KEY.INFO].keys()
writer = csv.DictWriter(f, fieldnames=header)
writer.writeheader()
for output in output_list:
writer.writerow(output[KEY.INFO])
except (KeyError, TypeError, AttributeError, csv.Error) as e:
print(e)
print('failed to write meta data, info.csv is not written')
with open(f'{out}/per_graph.csv', 'w', newline='') as f:
sfx_list = ['xx', 'yy', 'zz', 'xy', 'yz', 'zx'] # for stress
writer = None
for output in output_list:
cell_dct = {KEY.CELL: output[KEY.CELL]}
cell_dct = unfold_dct_val(cell_dct, [KEY.CELL], ['a', 'b', 'c'])
data = {
**unfold_dct_val(output, per_graph_keys, sfx_list),
**cell_dct,
}
if writer is None:
writer = csv.DictWriter(f, fieldnames=data.keys())
writer.writeheader()
writer.writerow(data)
with open(f'{out}/per_atom.csv', 'w', newline='') as f:
writer = None
for i, output in enumerate(output_list):
list_of_dct = per_atom_dct_list(output, per_atom_keys)
for j, dct in enumerate(list_of_dct):
idx_dct = {'stct_id': i, 'atom_id': j}
data = {**idx_dct, **dct}
if writer is None:
writer = csv.DictWriter(f, fieldnames=data.keys())
writer.writeheader()
writer.writerow(data)
def _patch_data_info(
graph_list: Iterable[AtomGraphData], full_file_list: List[str]
) -> None:
keys = set()
for graph, path in zip(graph_list, full_file_list):
if KEY.INFO not in graph:
graph[KEY.INFO] = {}
graph[KEY.INFO].update({'file': os.path.abspath(path)})
keys.update(graph[KEY.INFO].keys())
# save only safe subset of info (for batching)
for graph in graph_list:
info_dict = graph[KEY.INFO]
info_dict.update({k: '' for k in keys if k not in info_dict})
def inference(
checkpoint: str,
targets: Union[str, List[str]],
output_dir: str,
num_workers: int = 1,
device: str = 'cpu',
batch_size: int = 4,
save_graph: bool = False,
allow_unlabeled: bool = False,
modal: Optional[str] = None,
**data_kwargs,
) -> None:
"""
Inference model on the target dataset, writes
per_graph, per_atom inference results in csv format
to the output_dir
If a given target doesn't have EFS key, it puts dummy
values.
Args:
checkpoint: model checkpoint path,
target: path, or list of path to evaluate. Supports
ASE readable, sevenn_data/*.pt, .sevenn_data, and
structure_list
output_dir: directory to write results
num_workers: number of workers to build graph
device: device to evaluate, defaults to 'auto'
batch_size: batch size for inference
save_grpah: if True, save preprocessed graph to output dir
data_kwargs: keyword arguments used when reading targets,
for example, given index='-1', only the last snapshot
will be evaluated if it was ASE readable.
While this function can handle different types of targets
at once, it will not work smoothly with data_kwargs
"""
model, _ = util.model_from_checkpoint(checkpoint)
cutoff = model.cutoff
if modal:
if model.modal_map is None:
raise ValueError('Modality given, but model has no modal_map')
if modal not in model.modal_map:
_modals = list(model.modal_map.keys())
raise ValueError(f'Unknown modal {modal} (not in {_modals})')
if isinstance(targets, str):
targets = [targets]
full_file_list = []
if save_graph:
dataset = SevenNetGraphDataset(
cutoff=cutoff,
root=output_dir,
files=targets,
process_num_cores=num_workers,
processed_name='saved_graph.pt',
**data_kwargs,
)
full_file_list = dataset.full_file_list # TODO: not used currently
else:
dataset = []
for file in targets:
tmplist = SevenNetGraphDataset.file_to_graph_list(
file,
cutoff=cutoff,
num_cores=num_workers,
allow_unlabeled=allow_unlabeled,
**data_kwargs,
)
dataset.extend(tmplist)
full_file_list.extend([os.path.abspath(file)] * len(tmplist))
if (
full_file_list is not None
and len(full_file_list) == len(dataset)
and not isinstance(dataset, SevenNetGraphDataset)
):
_patch_data_info(dataset, full_file_list) # type: ignore
if modal:
dataset = SevenNetMultiModalDataset({modal: dataset}) # type: ignore
loader = DataLoader(dataset, batch_size, shuffle=False) # type: ignore
model.to(device)
model.set_is_batch_data(True)
model.eval()
rec = util.get_error_recorder()
output_list = []
for batch in tqdm(loader):
batch = batch.to(device)
output = model(batch).detach().cpu()
rec.update(output)
output_list.extend(util.to_atom_graph_list(output))
errors = rec.epoch_forward()
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(os.path.join(output_dir, 'errors.txt'), 'w', encoding='utf-8') as f:
for key, val in errors.items():
f.write(f'{key}: {val}\n')
write_inference_csv(output_list, output_dir)
import os
import warnings
import torch
import sevenn._keys as KEY
import sevenn.util as util
from sevenn.logger import Logger
from sevenn.scripts.convert_model_modality import (
append_modality_to_model_dct,
get_single_modal_model_dct,
)
def processing_continue_v2(config): # simpler
"""
Replacement of processing_continue,
Skips model compatibility
"""
log = Logger()
continue_dct = config[KEY.CONTINUE]
log.write('\nContinue found, loading checkpoint\n')
checkpoint = util.load_checkpoint(continue_dct[KEY.CHECKPOINT])
model_cp = checkpoint.build_model()
config_cp = checkpoint.config
model_state_dict_cp = model_cp.state_dict()
optimizer_state_dict_cp = (
checkpoint.optimizer_state_dict
if not continue_dct[KEY.RESET_OPTIMIZER]
else None
)
scheduler_state_dict_cp = (
checkpoint.scheduler_state_dict
if not continue_dct[KEY.RESET_SCHEDULER]
else None
)
# use_statistic_value_of_checkpoint always True
# Overwrite config from model state dict, so graph_dataset.from_config
# will not put statistic values to shift, scale, and conv_denominator
config[KEY.SHIFT] = model_state_dict_cp['rescale_atomic_energy.shift'].tolist()
config[KEY.SCALE] = model_state_dict_cp['rescale_atomic_energy.scale'].tolist()
conv_denom = []
for i in range(config_cp[KEY.NUM_CONVOLUTION]):
conv_denom.append(model_state_dict_cp[f'{i}_convolution.denominator'].item())
config[KEY.CONV_DENOMINATOR] = conv_denom
log.writeline(
f'{KEY.SHIFT}, {KEY.SCALE}, and {KEY.CONV_DENOMINATOR} are '
+ 'overwritten by model_state_dict of checkpoint'
)
chem_keys = [
KEY.TYPE_MAP,
KEY.NUM_SPECIES,
KEY.CHEMICAL_SPECIES,
KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER,
]
config.update({k: config_cp[k] for k in chem_keys})
log.writeline(
'chemical_species are overwritten by checkpoint. '
+ f'This model knows {config[KEY.NUM_SPECIES]} species'
)
if config_cp.get(KEY.USE_MODALITY, False) != config.get(KEY.USE_MODALITY):
raise ValueError('use_modality is not same. Check sevenn_cp')
modal_map = config_cp.get(KEY.MODAL_MAP, None) # dict | None
if modal_map and len(modal_map) > 0:
modalities = list(modal_map.keys())
log.writeline(f'Multimodal model found: {modalities}')
log.writeline('use_modality: True')
config[KEY.USE_MODALITY] = True
from_epoch = checkpoint.epoch or 0
log.writeline(f'Checkpoint previous epoch was: {from_epoch}')
epoch = 1 if continue_dct[KEY.RESET_EPOCH] else from_epoch + 1
log.writeline(f'epoch start from {epoch}')
log.writeline('checkpoint loading successful')
state_dicts = [
model_state_dict_cp,
optimizer_state_dict_cp,
scheduler_state_dict_cp,
]
return state_dicts, epoch
def check_config_compatible(config, config_cp):
# TODO: check more
SHOULD_BE_SAME = [
KEY.NODE_FEATURE_MULTIPLICITY,
KEY.LMAX,
KEY.IS_PARITY,
KEY.RADIAL_BASIS,
KEY.CUTOFF_FUNCTION,
KEY.CUTOFF,
KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS,
KEY.NUM_CONVOLUTION,
KEY.USE_BIAS_IN_LINEAR,
KEY.SELF_CONNECTION_TYPE,
]
for sbs in SHOULD_BE_SAME:
if config[sbs] == config_cp[sbs]:
continue
if sbs == KEY.SELF_CONNECTION_TYPE and config_cp[sbs] == 'MACE':
warnings.warn(
'We do not support this version of checkpoints to continue '
"Please use self_connection_type='linear' in input.yaml "
'and train from scratch',
UserWarning,
)
raise ValueError(
f'Value of {sbs} should be same. {config[sbs]} != {config_cp[sbs]}'
)
try:
cntdct = config[KEY.CONTINUE]
except KeyError:
return
TRAINABLE_CONFIGS = [KEY.TRAIN_DENOMINTAOR, KEY.TRAIN_SHIFT_SCALE]
if (
any((not cntdct[KEY.RESET_SCHEDULER], not cntdct[KEY.RESET_OPTIMIZER]))
and all(config[k] == config_cp[k] for k in TRAINABLE_CONFIGS) is False
):
raise ValueError(
'reset optimizer and scheduler if you want to change '
+ 'trainable configs'
)
# TODO add conition for changed optim/scheduler but not reset
def processing_continue(config):
log = Logger()
continue_dct = config[KEY.CONTINUE]
log.write('\nContinue found, loading checkpoint\n')
checkpoint = torch.load(
continue_dct[KEY.CHECKPOINT], map_location='cpu', weights_only=False
)
config_cp = checkpoint['config']
model_cp, config_cp = util.model_from_checkpoint(checkpoint)
model_state_dict_cp = model_cp.state_dict()
# it will raise error if not compatible
check_config_compatible(config, config_cp)
log.write('Checkpoint config is compatible\n')
# for backward compat.
config.update({KEY._NORMALIZE_SPH: config_cp[KEY._NORMALIZE_SPH]})
from_epoch = checkpoint['epoch']
optimizer_state_dict_cp = (
checkpoint['optimizer_state_dict']
if not continue_dct[KEY.RESET_OPTIMIZER]
else None
)
scheduler_state_dict_cp = (
checkpoint['scheduler_state_dict']
if not continue_dct[KEY.RESET_SCHEDULER]
else None
)
# These could be changed based on given continue_input.yaml
# ex) adapt to statistics of fine-tuning dataset
shift_cp = model_state_dict_cp['rescale_atomic_energy.shift'].numpy()
del model_state_dict_cp['rescale_atomic_energy.shift']
scale_cp = model_state_dict_cp['rescale_atomic_energy.scale'].numpy()
del model_state_dict_cp['rescale_atomic_energy.scale']
conv_denominators = []
for i in range(config_cp[KEY.NUM_CONVOLUTION]):
conv_denominators.append(
(model_state_dict_cp[f'{i}_convolution.denominator']).item()
)
del model_state_dict_cp[f'{i}_convolution.denominator']
# Further handled by processing_dataset.py
config.update({
KEY.SHIFT + '_cp': shift_cp,
KEY.SCALE + '_cp': scale_cp,
KEY.CONV_DENOMINATOR + '_cp': conv_denominators,
})
chem_keys = [
KEY.TYPE_MAP,
KEY.NUM_SPECIES,
KEY.CHEMICAL_SPECIES,
KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER,
]
config.update({k: config_cp[k] for k in chem_keys})
if (
KEY.USE_MODALITY in config_cp.keys() and config_cp[KEY.USE_MODALITY]
): # checkpoint model is multimodal
config.update({
KEY.MODAL_MAP + '_cp': config_cp[KEY.MODAL_MAP],
KEY.USE_MODALITY + '_cp': True,
KEY.NUM_MODALITIES + '_cp': len(config_cp[KEY.MODAL_MAP]),
})
else:
config.update({
KEY.MODAL_MAP + '_cp': {},
KEY.USE_MODALITY + '_cp': False,
KEY.NUM_MODALITIES + '_cp': 0,
})
log.write(f'checkpoint previous epoch was: {from_epoch}\n')
# decide start epoch
reset_epoch = continue_dct[KEY.RESET_EPOCH]
if reset_epoch:
start_epoch = 1
log.write('epoch reset to 1\n')
else:
start_epoch = from_epoch + 1
log.write(f'epoch start from {start_epoch}\n')
# decide csv file to continue
init_csv = True
csv_fname = config_cp[KEY.CSV_LOG]
if os.path.isfile(csv_fname):
# I hope python compare dict well
if config_cp[KEY.ERROR_RECORD] == config[KEY.ERROR_RECORD]:
log.writeline('Same metric, csv file will be appended')
init_csv = False
else:
log.writeline(f'{csv_fname} file not found, new csv file will be created')
log.writeline('checkpoint loading was successful')
state_dicts = [
model_state_dict_cp,
optimizer_state_dict_cp,
scheduler_state_dict_cp,
]
return state_dicts, start_epoch, init_csv
def convert_modality_of_checkpoint_state_dct(config, state_dicts):
# TODO: this requires updating model state dict after seeing dataset
model_state_dict_cp, optimizer_state_dict_cp, scheduler_state_dict_cp = (
state_dicts
)
if config[KEY.USE_MODALITY]: # current model is multimodal
num_modalities_cp = len(config[KEY.MODAL_MAP + '_cp'])
append_modal_length = config[KEY.NUM_MODALITIES] - num_modalities_cp
model_state_dict_cp = append_modality_to_model_dct(
model_state_dict_cp, config, num_modalities_cp, append_modal_length
)
else: # current model is single modal
if config[KEY.USE_MODALITY + '_cp']: # checkpoint model is multimodal
# change model state dict to single modal, default = "common"
model_state_dict_cp = get_single_modal_model_dct(
model_state_dict_cp,
config,
config[KEY.DEFAULT_MODAL],
from_processing_cp=True,
)
state_dicts = (
model_state_dict_cp,
optimizer_state_dict_cp,
scheduler_state_dict_cp,
)
return state_dicts
import os
import torch
import torch.distributed as dist
import sevenn._const as CONST
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.train.dataload import file_to_dataset, match_reader
from sevenn.train.dataset import AtomGraphDataset
from sevenn.util import chemical_species_preprocess, onehot_to_chem
def dataset_load(file: str, config):
"""
Wrapping of dataload.file_to_dataset to suppert
graph prebuilt sevenn_data
"""
log = Logger()
log.write(f'Loading {file}\n')
log.timer_start('loading dataset')
if file.endswith('.sevenn_data'):
dataset = torch.load(file, map_location='cpu', weights_only=False)
else:
reader, _ = match_reader(
config[KEY.DATA_FORMAT], **config[KEY.DATA_FORMAT_ARGS]
)
dataset = file_to_dataset(
file,
config[KEY.CUTOFF],
config[KEY.PREPROCESS_NUM_CORES],
reader=reader,
use_modality=config[KEY.USE_MODALITY],
use_weight=config[KEY.USE_WEIGHT],
)
log.format_k_v('loaded dataset size is', dataset.len(), write=True)
log.timer_end('loading dataset', 'data set loading time')
return dataset
def calculate_shift_or_scale_from_key(
train_set: AtomGraphDataset, key_given, n_chem
):
_expand = True
use_species_wise_shift_scale = False
if key_given == 'per_atom_energy_mean':
shift_or_scale = train_set.get_per_atom_energy_mean()
elif key_given == 'elemwise_reference_energies':
shift_or_scale = train_set.get_species_ref_energy_by_linear_comb(n_chem)
_expand = False
use_species_wise_shift_scale = True
elif key_given == 'force_rms':
shift_or_scale = train_set.get_force_rms()
elif key_given == 'per_atom_energy_std':
shift_or_scale = train_set.get_statistics(KEY.PER_ATOM_ENERGY)['Total'][
'std'
]
elif key_given == 'elemwise_force_rms':
shift_or_scale = train_set.get_species_wise_force_rms(n_chem)
_expand = False
use_species_wise_shift_scale = True
return shift_or_scale, _expand, use_species_wise_shift_scale
def handle_shift_scale(config, train_set: AtomGraphDataset, checkpoint_given):
"""
Priority (first comes later to overwrite):
1. Float given in yaml
2. Use statistic values of checkpoint == True
3. Plain options (provided as string)
"""
log = Logger()
shift, scale, conv_denominator = None, None, None
type_map = config[KEY.TYPE_MAP]
n_chem = len(type_map)
chem_strs = onehot_to_chem(list(range(n_chem)), type_map)
log.writeline('\nCalculating statistic values from dataset')
shift_given = config[KEY.SHIFT]
scale_given = config[KEY.SCALE]
_expand_shift = True
_expand_scale = True
use_species_wise_shift = False
use_species_wise_scale = False
use_modal_wise_shift = config[KEY.USE_MODAL_WISE_SHIFT]
use_modal_wise_scale = config[KEY.USE_MODAL_WISE_SCALE]
if shift_given in CONST.IMPLEMENTED_SHIFT:
shift, _expand_shift, use_species_wise_shift = (
calculate_shift_or_scale_from_key(train_set, shift_given, n_chem)
)
if scale_given in CONST.IMPLEMENTED_SCALE:
scale, _expand_scale, use_species_wise_scale = (
calculate_shift_or_scale_from_key(train_set, scale_given, n_chem)
)
if use_modal_wise_shift or use_modal_wise_scale:
atomdata_dict_sort_by_modal = train_set.get_dict_sort_by_modality()
modal_map = config[KEY.MODAL_MAP]
n_modal = len(modal_map)
cutoff = config[KEY.CUTOFF]
if use_modal_wise_shift:
shift = torch.zeros((n_modal, n_chem))
if use_modal_wise_scale:
scale = torch.zeros((n_modal, n_chem))
for modal_key, data_list in atomdata_dict_sort_by_modal.items():
modal_set = AtomGraphDataset(data_list, cutoff, x_is_one_hot_idx=True)
if use_modal_wise_shift:
if shift_given == 'elemwise_reference_energies':
modal_shift, _expand_shift, use_species_wise_shift = (
calculate_shift_or_scale_from_key(
modal_set, shift_given, n_chem
)
)
shift[modal_map[modal_key]] = torch.tensor(
modal_shift
) # this is np.array
elif shift_given in CONST.IMPLEMENTED_SHIFT:
raise NotImplementedError(
'Currently, modal-wise shift implemented for'
'species-dependent case only.'
)
if use_modal_wise_scale:
if scale_given == 'elemwise_force_rms':
modal_scale, _expand_scale, use_species_wise_scale = (
calculate_shift_or_scale_from_key(
modal_set, scale_given, n_chem
)
)
scale[modal_map[modal_key]] = modal_scale
elif scale_given in CONST.IMPLEMENTED_SCALE:
raise NotImplementedError(
'Currently, modal-wise scale implemented for'
'species-dependent case only.'
)
avg_num_neigh = train_set.get_avg_num_neigh()
log.format_k_v('Average # of neighbors', f'{avg_num_neigh:.6f}', write=True)
if config[KEY.CONV_DENOMINATOR] == 'avg_num_neigh':
conv_denominator = avg_num_neigh
elif config[KEY.CONV_DENOMINATOR] == 'sqrt_avg_num_neigh':
conv_denominator = avg_num_neigh ** (0.5)
if (
checkpoint_given
and config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT]
):
log.writeline(
'Overwrite shift, scale, conv_denominator from model checkpoint'
)
# TODO: This needs refactoring
conv_denominator = config[KEY.CONV_DENOMINATOR + '_cp']
if not (use_modal_wise_shift or use_modal_wise_scale):
# Values extracted from checkpoint in processing_continue.py
if len(list(shift)) > 1:
use_species_wise_shift = True
use_species_wise_scale = True
_expand_shift = _expand_scale = False
else:
shift = shift.item()
scale = scale.item()
else:
# Case of modal wise shift scale
shift_cp = config[KEY.SHIFT + '_cp']
scale_cp = config[KEY.SCALE + '_cp']
if not use_modal_wise_shift:
shift = shift_cp
if not use_modal_wise_scale:
scale = scale_cp
modal_map = config[KEY.MODAL_MAP]
modal_map_cp = config[KEY.MODAL_MAP + '_cp']
# Extracting shift, scale for modal in checkpoint model.
if config[KEY.USE_MODALITY + '_cp']: # cp model is multimodal
for modal_key_cp, modal_idx_cp in modal_map_cp.items():
modal_idx = modal_map[modal_key_cp]
if use_modal_wise_shift:
shift[modal_idx] = torch.tensor(shift_cp[modal_idx_cp])
if use_modal_wise_scale:
scale[modal_idx] = torch.tensor(scale_cp[modal_idx_cp])
else: # cp model is single modal
try:
modal_idx = modal_map[config[KEY.DEFAULT_MODAL]]
except:
raise KeyError(
f'{config[KEY.DEFAULT_MODAL]} should be one of'
f' {modal_map.keys()}'
)
if use_modal_wise_shift:
shift[modal_idx] = torch.tensor(shift_cp)
if use_modal_wise_scale:
scale[modal_idx] = torch.tensor(scale_cp)
if not config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY]:
# Also overwrite values of new modal to reference value
# For multimodal, set reference modal with KEY.DEFAULT_MODAL
shift_ref = shift_cp
scale_ref = scale_cp
if config[KEY.USE_MODALITY + '_cp']:
try:
modal_idx_cp = modal_map_cp[config[KEY.DEFAULT_MODAL]]
except:
raise KeyError(
f'{config[KEY.DEFAULT_MODAL]} should be one of'
f' {modal_map_cp.keys()}'
)
shift_ref = shift_cp[modal_idx_cp]
scale_ref = scale_cp[modal_idx_cp]
for modal_key, modal_idx in modal_map.items():
if modal_key not in modal_map_cp.keys():
if use_modal_wise_shift:
shift[modal_idx] = shift_ref
if use_modal_wise_scale:
scale[modal_idx] = scale_ref
# overwrite shift scale anyway if defined in yaml.
if type(shift_given) in [list, float]:
log.writeline('Overwrite shift to value(s) given in yaml')
_expand_shift = isinstance(shift_given, float)
shift = shift_given
if type(scale_given) in [list, float]:
log.writeline('Overwrite scale to value(s) given in yaml')
_expand_scale = isinstance(scale_given, float)
scale = scale_given
if isinstance(config[KEY.CONV_DENOMINATOR], float):
log.writeline('Overwrite conv_denominator to value given in yaml')
conv_denominator = config[KEY.CONV_DENOMINATOR]
if isinstance(conv_denominator, float):
conv_denominator = [conv_denominator] * config[KEY.NUM_CONVOLUTION]
use_species_wise_shift_scale = use_species_wise_shift or use_species_wise_scale
if use_species_wise_shift_scale:
chem_strs = onehot_to_chem(list(range(n_chem)), type_map)
if _expand_shift:
if use_modal_wise_shift:
shift = torch.full((n_modal, n_chem), shift)
else:
shift = [shift] * n_chem
if _expand_scale:
if use_modal_wise_scale:
scale = torch.full((n_modal, n_chem), scale)
else:
scale = [scale] * n_chem
Logger().write('Use element-wise shift, scale\n')
if use_modal_wise_shift or use_modal_wise_scale:
for modal_key, modal_idx in modal_map.items():
Logger().writeline(f'For modal = {modal_key}')
print_shift = shift[modal_idx] if use_modal_wise_shift else shift
print_scale = scale[modal_idx] if use_modal_wise_scale else scale
for cstr, sh, sc in zip(chem_strs, print_shift, print_scale):
Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True)
else:
for cstr, sh, sc in zip(chem_strs, shift, scale):
Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True)
else:
log.write('Use global shift, scale\n')
log.format_k_v('shift, scale', f'{shift:.6f}, {scale:.6f}', write=True)
assert isinstance(conv_denominator, list) and all(
isinstance(deno, float) for deno in conv_denominator
)
log.format_k_v(
'(1st) conv_denominator is', f'{conv_denominator[0]:.6f}', write=True
)
config[KEY.USE_SPECIES_WISE_SHIFT_SCALE] = use_species_wise_shift_scale
return shift, scale, conv_denominator
# TODO: This is too long
def processing_dataset(config, working_dir):
log = Logger()
prefix = f'{os.path.abspath(working_dir)}/'
is_stress = config[KEY.IS_TRAIN_STRESS]
checkpoint_given = config[KEY.CONTINUE][KEY.CHECKPOINT] is not False
cutoff = config[KEY.CUTOFF]
log.write('\nInitializing dataset...\n')
dataset = AtomGraphDataset({}, cutoff)
load_dataset = config[KEY.LOAD_DATASET]
if type(load_dataset) is str:
load_dataset = [load_dataset]
for file in load_dataset:
dataset.augment(dataset_load(file, config))
dataset.group_by_key() # apply labels inside original datapoint
dataset.unify_dtypes() # unify dtypes of all data points
# TODO: I think manual chemical species input is redundant
chem_in_db = dataset.get_species()
if config[KEY.CHEMICAL_SPECIES] == 'auto' and not checkpoint_given:
log.writeline('Auto detect chemical species from dataset')
config.update(chemical_species_preprocess(chem_in_db))
elif config[KEY.CHEMICAL_SPECIES] == 'auto' and checkpoint_given:
pass # copied from checkpoint in processing_continue.py
elif config[KEY.CHEMICAL_SPECIES] != 'auto' and not checkpoint_given:
pass # processed in parse_input.py
else: # config[KEY.CHEMICAL_SPECIES] != "auto" and checkpoint_given
log.writeline('Ignore chemical species in yaml, use checkpoint')
# already processed in processing_continue.py
# basic dataset compatibility check with previous model
if checkpoint_given:
chem_from_cp = config[KEY.CHEMICAL_SPECIES]
if not all(chem in chem_from_cp for chem in chem_in_db):
raise ValueError('Chemical species in checkpoint is not compatible')
# check what modalities are used in dataset
if config[KEY.USE_MODALITY]:
modalities = dataset.get_modalities()
num_modalities = len(modalities)
if num_modalities < 2:
Logger().writeline('Only one modal is given, ignore modality')
config.uptate({KEY.USE_MODALITY: False})
else:
modal_map_cp = config[KEY.MODAL_MAP + '_cp'] if checkpoint_given else {}
modal_map = modal_map_cp.copy()
current_idx = len(modal_map_cp)
for modal_key in modalities:
if modal_key not in modal_map.keys():
modal_map[modal_key] = current_idx
current_idx += 1
if config[KEY.IS_DDP]:
# Synchronize modal_map
torch.cuda.set_device(config[KEY.LOCAL_RANK])
modal_map_bcast = [modal_map]
dist.broadcast_object_list(modal_map_bcast, src=0)
modal_map = modal_map_bcast[0]
config.update(
{
KEY.NUM_MODALITIES: len(modal_map),
KEY.MODAL_MAP: modal_map,
KEY.MODAL_LIST: list(modal_map.keys()),
}
)
dataset.write_modal_attr(
modal_map,
config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE],
)
# --------------- save dataset regardless of train/valid--------------#
save_dataset = config[KEY.SAVE_DATASET]
save_by_label = config[KEY.SAVE_BY_LABEL]
if save_dataset:
if save_dataset.endswith('.sevenn_data') is False:
save_dataset += '.sevenn_data'
if (save_dataset.startswith('.') or save_dataset.startswith('/')) is False:
save_dataset = prefix + save_dataset # save_data set is plain file name
dataset.save(save_dataset)
log.format_k_v('Dataset saved to', save_dataset, write=True)
# log.write(f"Loaded full dataset saved to : {save_dataset}\n")
if save_by_label:
dataset.save(prefix, by_label=True)
log.format_k_v('Dataset saved by label', prefix, write=True)
# --------------------------------------------------------------------#
# TODO: testset is not used
ignore_test = not config.get(KEY.USE_TESTSET, False)
if KEY.LOAD_VALIDSET in config and config[KEY.LOAD_VALIDSET]:
train_set = dataset
test_set = AtomGraphDataset([], config[KEY.CUTOFF])
log.write('Loading validset from load_validset\n')
valid_set = AtomGraphDataset({}, cutoff)
for file in config[KEY.LOAD_VALIDSET]:
valid_set.augment(dataset_load(file, config))
valid_set.group_by_key()
valid_set.unify_dtypes()
# condition: validset labels should be subset of trainset labels
valid_labels = valid_set.user_labels
train_labels = train_set.user_labels
if set(valid_labels).issubset(set(train_labels)) is False:
valid_set = AtomGraphDataset(valid_set.to_list(), cutoff)
valid_set.rewrite_labels_to_data()
train_set = AtomGraphDataset(train_set.to_list(), cutoff)
train_set.rewrite_labels_to_data()
Logger().write('WARNING! validset labels is not subset of trainset\n')
Logger().write('We overwrite all the train, valid labels to default.\n')
Logger().write('Please create validset by sevenn_graph_build with -l\n')
Logger().write('the validset loaded, load_dataset is now train_set\n')
Logger().write('the ratio will be ignored\n')
# condition: validset modalities should be subset of trainset modalities
if config[KEY.USE_MODALITY]:
config_modality = config[KEY.MODAL_LIST]
valid_modality = valid_set.get_modalities()
if set(valid_modality).issubset(set(config_modality)) is False:
raise ValueError('validset modality is not subset of trainset')
valid_set.write_modal_attr(
config[KEY.MODAL_MAP],
config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE],
)
else:
train_set, valid_set, test_set = dataset.divide_dataset(
config[KEY.RATIO], ignore_test=ignore_test
)
log.write(f'The dataset divided into train, valid by {KEY.RATIO}\n')
log.format_k_v('\nloaded trainset size is', train_set.len(), write=True)
log.format_k_v('\nloaded validset size is', valid_set.len(), write=True)
log.write('Dataset initialization was successful\n')
log.write('\nNumber of atoms in the train_set:\n')
log.natoms_write(train_set.get_natoms(config[KEY.TYPE_MAP]))
log.bar()
log.write('Per atom energy(eV/atom) distribution:\n')
log.statistic_write(train_set.get_statistics(KEY.PER_ATOM_ENERGY))
log.bar()
log.write('Force(eV/Angstrom) distribution:\n')
log.statistic_write(train_set.get_statistics(KEY.FORCE))
log.bar()
log.write('Stress(eV/Angstrom^3) distribution:\n')
try:
log.statistic_write(train_set.get_statistics(KEY.STRESS))
except KeyError:
log.write('\n Stress is not included in the train_set\n')
if is_stress:
is_stress = False
log.write('Turn off stress training\n')
log.bar()
# saved data must have atomic numbers as X not one hot idx
if config[KEY.SAVE_BY_TRAIN_VALID]:
train_set.save(prefix + 'train')
valid_set.save(prefix + 'valid')
log.format_k_v('Dataset saved by train, valid', prefix, write=True)
# inconsistent .info dict give error when collate
_, _ = train_set.separate_info()
_, _ = valid_set.separate_info()
if train_set.x_is_one_hot_idx is False:
train_set.x_to_one_hot_idx(config[KEY.TYPE_MAP])
if valid_set.x_is_one_hot_idx is False:
valid_set.x_to_one_hot_idx(config[KEY.TYPE_MAP])
log.format_k_v('training_set size', train_set.len(), write=True)
log.format_k_v('validation_set size', valid_set.len(), write=True)
shift, scale, conv_denominator = handle_shift_scale(
config, train_set, checkpoint_given
)
config.update(
{
KEY.SHIFT: shift,
KEY.SCALE: scale,
KEY.CONV_DENOMINATOR: conv_denominator,
}
)
data_lists = (train_set.to_list(), valid_set.to_list(), test_set.to_list())
return data_lists
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment