Commit cc6e6b7d authored by Wang, Leping's avatar Wang, Leping
Browse files

- Add config.sh with all pipeline parameters organized by category

  (molecular, crystal structure, compute, run mode, path)
- Refactor search_gen_proc.sh to source config.sh instead of
  hardcoding parameters, with optional config path argument
- Refactor structure_generate.py to load config.sh via exec(),
  replacing hardcoded values with config-driven parameters
- Remove mace-bench (the relaxation part, it will be replaced by updated seperate mace-bench project )
parent 61ec3ad9
import os
from copy import deepcopy
from typing import Optional
import torch
from torch.utils.data.distributed import DistributedSampler
import sevenn._keys as KEY
from sevenn.error_recorder import ErrorRecorder
from sevenn.logger import Logger
from sevenn.train.trainer import Trainer
def processing_epoch_v2(
config: dict,
trainer: Trainer,
loaders: dict, # dict[str, Dataset]
start_epoch: int = 1,
train_loader_key: str = 'trainset',
error_recorder: Optional[ErrorRecorder] = None,
total_epoch: Optional[int] = None,
per_epoch: Optional[int] = None,
best_metric_loader_key: str = 'validset',
best_metric: Optional[str] = None,
write_csv: bool = True,
working_dir: Optional[str] = None,
):
from sevenn.util import unique_filepath
log = Logger()
write_csv = write_csv and log.rank == 0
working_dir = working_dir or os.getcwd()
prefix = f'{os.path.abspath(working_dir)}/'
total_epoch = total_epoch or config[KEY.EPOCH]
per_epoch = per_epoch or config.get(KEY.PER_EPOCH, 10)
best_metric = best_metric or config.get(KEY.BEST_METRIC, 'TotalLoss')
recorder = error_recorder or ErrorRecorder.from_config(
config, trainer.loss_functions
)
recorders = {k: deepcopy(recorder) for k in loaders}
best_val = float('inf')
best_key = None
if best_metric_loader_key in recorders:
best_key = recorders[best_metric_loader_key].get_key_str(best_metric)
if best_key is None:
log.writeline(
f'Failed to get error recorder key: {best_metric} or '
+ f'{best_metric_loader_key} is missing. There will be no best '
+ 'checkpoint.'
)
csv_path = unique_filepath(f'{prefix}/lc.csv')
if write_csv:
head = ['epoch', 'lr']
for k, rec in recorders.items():
head.extend(list(rec.get_dct(prefix=k)))
with open(csv_path, 'w') as f:
f.write(','.join(head) + '\n')
if start_epoch == 1:
path = f'{prefix}/checkpoint_0.pth' # save first epoch
trainer.write_checkpoint(path, config=config, epoch=0)
for epoch in range(start_epoch, total_epoch + 1): # one indexing
log.timer_start('epoch')
lr = trainer.get_lr()
log.bar()
log.write(f'Epoch {epoch}/{total_epoch} lr: {lr:8f}\n')
log.bar()
csv_dct = {'epoch': str(epoch), 'lr': f'{lr:8f}'}
errors = {}
for k, loader in loaders.items():
is_train = k == train_loader_key
if (
trainer.distributed
and isinstance(loader.sampler, DistributedSampler)
and is_train
and config.get('train_shuffle', True)
):
loader.sampler.set_epoch(epoch)
rec = recorders[k]
trainer.run_one_epoch(loader, is_train, rec)
csv_dct.update(rec.get_dct(prefix=k))
errors[k] = rec.epoch_forward()
log.write_full_table(list(errors.values()), list(errors))
trainer.scheduler_step(best_val)
if write_csv:
with open(csv_path, 'a') as f:
f.write(','.join(list(csv_dct.values())) + '\n')
if best_key and errors[best_metric_loader_key][best_key] < best_val:
path = f'{prefix}/checkpoint_best.pth'
trainer.write_checkpoint(path, config=config, epoch=epoch)
best_val = errors[best_metric_loader_key][best_key]
log.writeline('Best checkpoint written')
if epoch % per_epoch == 0:
path = f'{prefix}/checkpoint_{epoch}.pth'
trainer.write_checkpoint(path, config=config, epoch=epoch)
log.timer_end('epoch', message=f'Epoch {epoch} elapsed')
return trainer
def processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir):
log = Logger()
prefix = f'{os.path.abspath(working_dir)}/'
train_loader, valid_loader = loaders
is_distributed = config[KEY.IS_DDP]
rank = config[KEY.RANK]
total_epoch = config[KEY.EPOCH]
per_epoch = config[KEY.PER_EPOCH]
train_recorder = ErrorRecorder.from_config(config)
valid_recorder = ErrorRecorder.from_config(config)
best_metric = config[KEY.BEST_METRIC]
csv_fname = f'{prefix}{config[KEY.CSV_LOG]}'
current_best = float('inf')
if init_csv:
csv_header = ['Epoch', 'Learning_rate']
# Assume train valid have the same metrics
for metric in train_recorder.get_metric_dict().keys():
csv_header.append(f'Train_{metric}')
csv_header.append(f'Valid_{metric}')
log.init_csv(csv_fname, csv_header)
def write_checkpoint(epoch, is_best=False):
if is_distributed and rank != 0:
return
suffix = '_best' if is_best else f'_{epoch}'
checkpoint = trainer.get_checkpoint_dict()
checkpoint.update({'config': config, 'epoch': epoch})
torch.save(checkpoint, f'{prefix}/checkpoint{suffix}.pth')
fin_epoch = total_epoch + start_epoch
for epoch in range(start_epoch, fin_epoch):
lr = trainer.get_lr()
log.timer_start('epoch')
log.bar()
log.write(f'Epoch {epoch}/{fin_epoch - 1} lr: {lr:8f}\n')
log.bar()
trainer.run_one_epoch(
train_loader, is_train=True, error_recorder=train_recorder
)
train_err = train_recorder.epoch_forward()
trainer.run_one_epoch(valid_loader, error_recorder=valid_recorder)
valid_err = valid_recorder.epoch_forward()
csv_values = [epoch, lr]
for metric in train_err:
csv_values.append(train_err[metric])
csv_values.append(valid_err[metric])
log.append_csv(csv_fname, csv_values)
log.write_full_table([train_err, valid_err], ['Train', 'Valid'])
val = None
for metric in valid_err:
# loose string comparison,
# e.g. "Energy" in "TotalEnergy" or "Energy_Loss"
if best_metric in metric:
val = valid_err[metric]
break
assert val is not None, f'Metric {best_metric} not found in {valid_err}'
trainer.scheduler_step(val)
log.timer_end('epoch', message=f'Epoch {epoch} elapsed')
if val < current_best:
current_best = val
write_checkpoint(epoch, is_best=True)
log.writeline('Best checkpoint written')
if epoch % per_epoch == 0:
write_checkpoint(epoch)
from typing import List, Optional
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch_geometric.loader import DataLoader
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.model_build import build_E3_equivariant_model
from sevenn.scripts.processing_continue import (
convert_modality_of_checkpoint_state_dct,
)
from sevenn.train.trainer import Trainer
def loader_from_config(config, dataset, is_train=False):
batch_size = config[KEY.BATCH_SIZE]
shuffle = is_train and config[KEY.TRAIN_SHUFFLE]
sampler = None
loader_args = {
'dataset': dataset,
'batch_size': batch_size,
'shuffle': shuffle
}
if KEY.NUM_WORKERS in config and config[KEY.NUM_WORKERS] > 0:
loader_args.update({'num_workers': config[KEY.NUM_WORKERS]})
if config[KEY.IS_DDP]:
dist.barrier()
sampler = DistributedSampler(
dataset, dist.get_world_size(), dist.get_rank(), shuffle=shuffle
)
loader_args.update({'sampler': sampler})
loader_args.pop('shuffle') # sampler is mutually exclusive with shuffle
return DataLoader(**loader_args)
def train_v2(config, working_dir: str):
"""
Main program flow, since v0.9.6
"""
import sevenn.train.atoms_dataset as atoms_dataset
import sevenn.train.graph_dataset as graph_dataset
import sevenn.train.modal_dataset as modal_dataset
from .processing_continue import processing_continue_v2
from .processing_epoch import processing_epoch_v2
log = Logger()
log.timer_start('total')
if KEY.LOAD_TRAINSET not in config and KEY.LOAD_DATASET in config:
log.writeline('***************************************************')
log.writeline('For train_v2, please use load_trainset_path instead')
log.writeline('I will assign load_trainset as load_dataset')
log.writeline('***************************************************')
config[KEY.LOAD_TRAINSET] = config.pop(KEY.LOAD_DATASET)
# config updated
start_epoch = 1
state_dicts: Optional[List[dict]] = None
if config[KEY.CONTINUE][KEY.CHECKPOINT]:
state_dicts, start_epoch = processing_continue_v2(config)
if config.get(KEY.USE_MODALITY, False):
datasets = modal_dataset.from_config(config, working_dir)
elif config[KEY.DATASET_TYPE] == 'graph':
datasets = graph_dataset.from_config(config, working_dir)
elif config[KEY.DATASET_TYPE] == 'atoms':
datasets = atoms_dataset.from_config(config, working_dir)
else:
raise ValueError(f'Unknown dataset type: {config[KEY.DATASET_TYPE]}')
loaders = {
k: loader_from_config(config, v, is_train=(k == 'trainset'))
for k, v in datasets.items()
}
log.write('\nModel building...\n')
model = build_E3_equivariant_model(config)
log.print_model_info(model, config)
trainer = Trainer.from_config(model, config)
if state_dicts:
trainer.load_state_dicts(*state_dicts, strict=False)
processing_epoch_v2(
config, trainer, loaders, start_epoch, working_dir=working_dir
)
log.timer_end('total', message='Total wall time')
def train(config, working_dir: str):
"""
Main program flow, until v0.9.5
"""
from .processing_continue import processing_continue
from .processing_dataset import processing_dataset
from .processing_epoch import processing_epoch
log = Logger()
log.timer_start('total')
# config updated
state_dicts: Optional[List[dict]] = None
if config[KEY.CONTINUE][KEY.CHECKPOINT]:
state_dicts, start_epoch, init_csv = processing_continue(config)
else:
start_epoch, init_csv = 1, True
# config updated
train, valid, _ = processing_dataset(config, working_dir)
datasets = {'dataset': train, 'validset': valid}
loaders = {
k: loader_from_config(config, v, is_train=(k == 'dataset'))
for k, v in datasets.items()
}
loaders = list(loaders.values())
log.write('\nModel building...\n')
model = build_E3_equivariant_model(config)
log.write('Model building was successful\n')
trainer = Trainer.from_config(model, config)
if state_dicts:
state_dicts = convert_modality_of_checkpoint_state_dct(
config, state_dicts
)
trainer.load_state_dicts(*state_dicts, strict=False)
log.print_model_info(model, config)
Logger().write('Trainer initialized, ready to training\n')
Logger().bar()
log.write('Trainer initialized, ready to training\n')
log.bar()
processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir)
log.timer_end('total', message='Total wall time')
import warnings
from .logger import * # noqa: F403
warnings.warn('Please use sevenn.logger instead of sevenn.sevenn_logger',
DeprecationWarning, stacklevel=2)
import warnings
from .calculator import * # noqa: F403
warnings.warn('Please use sevenn.calculator instead of sevenn.sevennet_calculator',
DeprecationWarning, stacklevel=2)
import os
import random
import warnings
from collections import Counter
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch.utils.data
from ase.atoms import Atoms
from ase.data import chemical_symbols
from ase.io import write
from tqdm import tqdm
import sevenn._keys as KEY
import sevenn.train.dataload as dataload
import sevenn.util as util
from sevenn._const import NUM_UNIV_ELEMENT
from sevenn.atom_graph_data import AtomGraphData
_warn_avg_num_neigh = """SevenNetAtomsDataset does not provide correct avg_num_neigh
as it does not build graph. We will compute only random 10000 structures graph to
approximate this value. If you want more precise avg_num_neigh,
use SevenNetGraphDataset. If it is not viable due to memory limit, you
need online algorithm to do this , which is not yet implemented in the SevenNet"""
class SevenNetAtomsDataset(torch.utils.data.Dataset):
"""
Args:
cutoff: edge cutoff of given AtomGraphData
files: list of filenames or dict describing how to parse the file
ASE readable (with proper extension), structure_list, .sevenn_data,
dict containing file_list (see dict_reader of train/dataload.py)
info_dict_copy_keys: patch these keys from KEY.INFO to graph when accessing.
default is KEY.DATA_WEIGHT and KEY.DATA_MODALITY, which may accessed
while training.
**process_kwargs: keyword arguments that will be passed into ase.io.read
"""
def __init__(
self,
cutoff: float,
files: Union[str, List[str]],
atoms_filter: Optional[Callable] = None,
atoms_transform: Optional[Callable] = None,
transform: Optional[Callable] = None,
use_data_weight: bool = False,
**process_kwargs,
):
self.cutoff = cutoff
if isinstance(files, str):
files = [files] # user convenience
files = [os.path.abspath(file) for file in files]
self._files = files
self.atoms_filter = atoms_filter
self.atoms_transform = atoms_transform
self.transform = transform
self.use_data_weight = use_data_weight
self._scanned = False
self._avg_num_neigh_approx = None
self.statistics = {}
atoms_list = []
for file in files:
atoms_list.extend(
SevenNetAtomsDataset.file_to_atoms_list(file, **process_kwargs)
)
self._atoms_list = atoms_list
super().__init__()
@staticmethod
def file_to_atoms_list(file: Union[str, dict], **kwargs) -> List[Atoms]:
if isinstance(file, dict):
atoms_list = dataload.dict_reader(file)
elif 'structure_list' in file:
atoms_dct = dataload.structure_list_reader(file)
atoms_list = []
for lst in atoms_dct.values():
atoms_list.extend(lst)
else:
atoms_list = dataload.ase_reader(file, **kwargs)
return atoms_list
def save(self, path):
# Save atoms list as extxyz
write(path, self._atoms_list, format='extxyz')
def _graph_build(self, atoms):
return dataload.atoms_to_graph(
atoms, self.cutoff, transfer_info=False, y_from_calc=False
)
def __len__(self):
return len(self._atoms_list)
def __getitem__(self, index):
atoms = self._atoms_list[index]
if self.atoms_transform is not None:
atoms = self.atoms_transform(atoms)
graph = self._graph_build(atoms)
if self.transform is not None:
graph = self.transform(graph)
if self.use_data_weight:
weight = graph[KEY.INFO].pop(
KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0}
)
graph[KEY.DATA_WEIGHT] = weight
return AtomGraphData.from_numpy_dict(graph)
@property
def species(self):
self.run_stat()
return [z for z in self.statistics['_natoms'].keys() if z != 'total']
@property
def natoms(self):
self.run_stat()
return self.statistics['_natoms']
@property
def per_atom_energy_mean(self):
self.run_stat()
return self.statistics[KEY.PER_ATOM_ENERGY]['mean']
@property
def elemwise_reference_energies(self):
from sklearn.linear_model import Ridge
c = self.statistics['_composition']
y = self.statistics[KEY.ENERGY]['_array']
zero_indices = np.all(c == 0, axis=0)
c_reduced = c[:, ~zero_indices]
# will not 100% reproduce, as it is sorted by Z
# train/dataset.py was sorted by alphabets of chemical species
coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_
full_coeff = np.zeros(NUM_UNIV_ELEMENT)
full_coeff[~zero_indices] = coef_reduced
return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy
@property
def force_rms(self):
self.run_stat()
mean = self.statistics[KEY.FORCE]['mean']
std = self.statistics[KEY.FORCE]['std']
return float((mean**2 + std**2) ** (0.5))
@property
def per_atom_energy_std(self):
self.run_stat()
return self.statistics['per_atom_energy']['std']
@property
def avg_num_neigh(self, n_sample=10000):
if self._avg_num_neigh_approx is None:
if len(self) > n_sample:
warnings.warn(_warn_avg_num_neigh)
n_sample = min(len(self), n_sample)
indices = random.sample(range(len(self)), n_sample)
n_neigh = []
for i in indices:
graph = self[i]
_, nn = np.unique(graph[KEY.EDGE_IDX][0], return_counts=True)
n_neigh.append(nn)
n_neigh = np.concatenate(n_neigh)
self._avg_num_neigh_approx = np.mean(n_neigh)
return self._avg_num_neigh_approx
@property
def sqrt_avg_num_neigh(self):
self.run_stat()
return self.avg_num_neigh**0.5
def run_stat(self):
"""
Loop over dataset and init any statistics might need
Unlink SevenNetGraphDataset, neighbors count is not computed as
it requires to build graph
"""
if self._scanned is True:
return # statistics already computed
y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS]
natoms_counter = Counter()
composition = np.zeros((len(self), NUM_UNIV_ELEMENT))
stats: Dict[str, Dict[str, Any]] = {y: {'_array': []} for y in y_keys}
for i, atoms in tqdm(
enumerate(self._atoms_list), desc='run_stat', total=len(self)
):
z = atoms.get_atomic_numbers()
natoms_counter.update(z.tolist())
composition[i] = np.bincount(z, minlength=NUM_UNIV_ELEMENT)
for y, dct in stats.items():
if y == KEY.ENERGY:
dct['_array'].append(atoms.info['y_energy'])
elif y == KEY.PER_ATOM_ENERGY:
dct['_array'].append(atoms.info['y_energy'] / len(atoms))
elif y == KEY.FORCE:
dct['_array'].append(atoms.arrays['y_force'].reshape(-1))
elif y == KEY.STRESS:
dct['_array'].append(atoms.info['y_stress'].reshape(-1))
for y, dct in stats.items():
if y == KEY.FORCE:
array = np.concatenate(dct['_array'])
else:
array = np.array(dct['_array']).reshape(-1)
dct.update(
{
'mean': float(np.mean(array)),
'std': float(np.std(array)),
'median': float(np.quantile(array, q=0.5)),
'max': float(np.max(array)),
'min': float(np.min(array)),
'_array': array,
}
)
natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()}
natoms['total'] = sum(list(natoms.values()))
self.statistics.update(
{
'_composition': composition,
'_natoms': natoms,
**stats,
}
)
self._scanned = True
# script, return dict of SevenNetAtomsDataset
def from_config(
config: Dict[str, Any],
working_dir: str = os.getcwd(),
dataset_keys: Optional[List[str]] = None,
):
from sevenn.logger import Logger
log = Logger()
if dataset_keys is None:
dataset_keys = []
for k in config:
if k.startswith('load_') and k.endswith('_path'):
dataset_keys.append(k)
if KEY.LOAD_TRAINSET not in dataset_keys:
raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config')
# initialize arguments for loading dataset
dataset_args = {
'cutoff': config[KEY.CUTOFF],
'use_data_weight': config.get(KEY.USE_WEIGHT, False),
**config[KEY.DATA_FORMAT_ARGS],
}
datasets = {}
for dk in dataset_keys:
if not (paths := config[dk]):
continue
if isinstance(paths, str):
paths = [paths]
name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]])
dataset_args.update({'files': paths})
datasets[name] = SevenNetAtomsDataset(**dataset_args)
if not config[KEY.COMPUTE_STATISTICS]:
log.writeline(
(
'Computing statistics is skipped, note that if any of other'
'configurations requires statistics (shift, scale, avg_num_neigh,'
'chemical_species as auto), SevenNet eventually raise an error!'
)
)
return datasets
train_set = datasets['trainset']
chem_species = set(train_set.species)
# print statistics of each dataset
for name, dataset in datasets.items():
dataset.run_stat()
log.bar()
log.writeline(f'{name} distribution:')
log.statistic_write(dataset.statistics)
log.format_k_v('# atoms (node)', dataset.natoms, write=True)
log.format_k_v('# structures (graph)', len(dataset), write=True)
chem_species.update(dataset.species)
log.bar()
# initialize known species from dataset if 'auto'
# sorted to alphabetical order (which is same as before)
chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP]
if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py
log.writeline('Known species are obtained from the dataset')
config.update(util.chemical_species_preprocess(sorted(list(chem_species))))
# retrieve shift, scale, conv_denominaotrs from user input (keyword)
init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR]
for k in init_from_stats:
input = config[k] # statistic key or numbers
# If it is not 'str', 1: It is 'continue' training
# 2: User manually inserted numbers
if isinstance(input, str) and hasattr(train_set, input):
var = getattr(train_set, input)
config.update({k: var})
log.writeline(f'{k} is obtained from statistics')
elif isinstance(input, str) and not hasattr(train_set, input):
raise NotImplementedError(input)
return datasets
from typing import Any, List, Optional, Sequence
from ase.atoms import Atoms
from torch_geometric.loader.dataloader import Collater
from sevenn.atom_graph_data import AtomGraphData
from .dataload import atoms_to_graph
class AtomsToGraphCollater(Collater):
def __init__(
self,
dataset: Sequence[Atoms],
cutoff: float,
transfer_info: bool = False,
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
y_from_calc: bool = True,
):
# quite original collator's type mismatch with []
super().__init__([], follow_batch, exclude_keys)
self.dataset = dataset
self.cutoff = cutoff
self.transfer_info = transfer_info
self.y_from_calc = y_from_calc
def __call__(self, batch: List[Any]) -> Any:
# build list of graph
graph_list = []
for stct in batch:
graph = atoms_to_graph(
stct,
self.cutoff,
transfer_info=self.transfer_info,
y_from_calc=self.y_from_calc,
)
graph = AtomGraphData.from_numpy_dict(graph)
graph_list.append(graph)
return super().__call__(graph_list)
import copy
import os.path
from functools import partial
from itertools import chain, islice
from typing import Callable, Dict, List, Optional
import ase
import ase.io
import numpy as np
import torch.multiprocessing as mp
from ase.io.vasp_parsers.vasp_outcar_parsers import (
Cell,
DefaultParsersContainer,
Energy,
OutcarChunkParser,
PositionsAndForces,
Stress,
outcarchunks,
)
from ase.neighborlist import primitive_neighbor_list
from ase.utils import string2index
from braceexpand import braceexpand
from tqdm import tqdm
import sevenn._keys as KEY
from sevenn._const import LossType
from sevenn.atom_graph_data import AtomGraphData
from .dataset import AtomGraphDataset
def _graph_build_matscipy(cutoff: float, pbc, cell, pos):
pbc_x = pbc[0]
pbc_y = pbc[1]
pbc_z = pbc[2]
identity = np.identity(3, dtype=float)
max_positions = np.max(np.absolute(pos)) + 1
# Extend cell in non-periodic directions
# For models with more than 5 layers,
# the multiplicative constant needs to be increased.
if not pbc_x:
cell[0, :] = max_positions * 5 * cutoff * identity[0, :]
if not pbc_y:
cell[1, :] = max_positions * 5 * cutoff * identity[1, :]
if not pbc_z:
cell[2, :] = max_positions * 5 * cutoff * identity[2, :]
# it does not have self-interaction
edge_src, edge_dst, edge_vec, shifts = neighbour_list(
quantities='ijDS',
pbc=pbc,
cell=cell,
positions=pos,
cutoff=cutoff,
)
# dtype issue
edge_src = edge_src.astype(np.int64)
edge_dst = edge_dst.astype(np.int64)
return edge_src, edge_dst, edge_vec, shifts
def _graph_build_ase(cutoff: float, pbc, cell, pos):
# building neighbor list
edge_src, edge_dst, edge_vec, shifts = primitive_neighbor_list(
'ijDS', pbc, cell, pos, cutoff, self_interaction=True
)
is_zero_idx = np.all(edge_vec == 0, axis=1)
is_self_idx = edge_src == edge_dst
non_trivials = ~(is_zero_idx & is_self_idx)
shifts = np.array(shifts[non_trivials])
edge_vec = edge_vec[non_trivials]
edge_src = edge_src[non_trivials]
edge_dst = edge_dst[non_trivials]
return edge_src, edge_dst, edge_vec, shifts
_graph_build_f = _graph_build_ase
try:
from matscipy.neighbours import neighbour_list
_graph_build_f = _graph_build_matscipy
except ImportError:
pass
def _correct_scalar(v):
if isinstance(v, np.ndarray):
v = v.squeeze()
assert v.ndim == 0, f'given {v} is not a scalar'
return v
elif isinstance(v, (int, float, np.integer, np.floating)):
return np.array(v)
else:
assert False, f'{type(v)} is not expected'
def unlabeled_atoms_to_graph(atoms: ase.Atoms, cutoff: float):
pos = atoms.get_positions()
cell = np.array(atoms.get_cell())
pbc = atoms.get_pbc()
edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos)
edge_idx = np.array([edge_src, edge_dst])
atomic_numbers = atoms.get_atomic_numbers()
cell = np.array(cell)
vol = _correct_scalar(atoms.cell.volume)
if vol == 0:
vol = np.array(np.finfo(float).eps)
data = {
KEY.NODE_FEATURE: atomic_numbers,
KEY.ATOMIC_NUMBERS: atomic_numbers,
KEY.POS: pos,
KEY.EDGE_IDX: edge_idx,
KEY.EDGE_VEC: edge_vec,
KEY.CELL: cell,
KEY.CELL_SHIFT: shifts,
KEY.CELL_VOLUME: vol,
KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)),
}
data[KEY.INFO] = {}
return data
def atoms_to_graph(
atoms: ase.Atoms,
cutoff: float,
transfer_info: bool = True,
y_from_calc: bool = False,
allow_unlabeled: bool = False,
):
"""
From ase atoms, return AtomGraphData as graph based on cutoff radius
Except for energy, force and stress labels must be numpy array type
as other cases are not tested.
Returns 'np.nan' with consistent shape for unlabeled data
(ex. stress of non-pbc system)
Args:
atoms (Atoms): ase atoms
cutoff (float): cutoff radius
transfer_info (bool): if True, transfer ".info" from atoms to graph,
defaults to True
y_from_calc: if True, get ref values from calculator, defaults to False
Returns:
numpy dict that can be used to initialize AtomGraphData
by AtomGraphData(**atoms_to_graph(atoms, cutoff))
, for scalar, its shape is (), and types are np.ndarray
Requires grad is handled by 'dataset' not here.
"""
if not y_from_calc:
y_energy = atoms.info['y_energy']
y_force = atoms.arrays['y_force']
y_stress = atoms.info.get('y_stress', np.full((6,), np.nan))
if y_stress.shape == (3, 3):
y_stress = np.array(
[
y_stress[0][0],
y_stress[1][1],
y_stress[2][2],
y_stress[0][1],
y_stress[1][2],
y_stress[2][0],
]
)
else:
y_stress = y_stress.squeeze()
else:
from_calc = _y_from_calc(atoms)
y_energy = from_calc['energy']
y_force = from_calc['force']
y_stress = from_calc['stress']
assert y_stress.shape == (6,), 'If you see this, please raise a issue'
if not allow_unlabeled and (np.isnan(y_energy) or np.isnan(y_force).any()):
raise ValueError('Unlabeled E or F found, set allow_unlabeled True')
pos = atoms.get_positions()
cell = np.array(atoms.get_cell())
pbc = atoms.get_pbc()
edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos)
edge_idx = np.array([edge_src, edge_dst])
atomic_numbers = atoms.get_atomic_numbers()
cell = np.array(cell)
vol = _correct_scalar(atoms.cell.volume)
if vol == 0:
vol = np.array(np.finfo(float).eps)
data = {
KEY.NODE_FEATURE: atomic_numbers,
KEY.ATOMIC_NUMBERS: atomic_numbers,
KEY.POS: pos,
KEY.EDGE_IDX: edge_idx,
KEY.EDGE_VEC: edge_vec,
KEY.ENERGY: _correct_scalar(y_energy),
KEY.FORCE: y_force,
KEY.STRESS: y_stress.reshape(1, 6), # to make batch have (n_node, 6)
KEY.CELL: cell,
KEY.CELL_SHIFT: shifts,
KEY.CELL_VOLUME: vol,
KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)),
KEY.PER_ATOM_ENERGY: _correct_scalar(y_energy / len(pos)),
}
if transfer_info and atoms.info is not None:
info = copy.deepcopy(atoms.info)
# save only metadata
info.pop('y_energy', None)
info.pop('y_force', None)
info.pop('y_stress', None)
data[KEY.INFO] = info
else:
data[KEY.INFO] = {}
return data
def graph_build(
atoms_list: List,
cutoff: float,
num_cores: int = 1,
transfer_info: bool = True,
y_from_calc: bool = False,
allow_unlabeled: bool = False,
) -> List[AtomGraphData]:
"""
parallel version of graph_build
build graph from atoms_list and return list of AtomGraphData
Args:
atoms_list (List): list of ASE atoms
cutoff (float): cutoff radius of graph
num_cores (int): number of cores to use
transfer_info (bool): if True, copy info from atoms to graph,
defaults to True
y_from_calc (bool): Get reference y labels from calculator, defaults to False
Returns:
List[AtomGraphData]: list of AtomGraphData
"""
serial = num_cores == 1
inputs = [
(atoms, cutoff, transfer_info, y_from_calc, allow_unlabeled)
for atoms in atoms_list
]
if not serial:
pool = mp.Pool(num_cores)
graph_list = pool.starmap(
atoms_to_graph,
tqdm(inputs, total=len(atoms_list), desc=f'graph_build ({num_cores})'),
)
pool.close()
pool.join()
else:
graph_list = [
atoms_to_graph(*input_)
for input_ in tqdm(inputs, desc='graph_build (1)')
]
graph_list = [AtomGraphData.from_numpy_dict(g) for g in graph_list]
return graph_list
def _y_from_calc(atoms: ase.Atoms):
ret = {
'energy': np.nan,
'force': np.full((len(atoms), 3), np.nan),
'stress': np.full((6,), np.nan),
}
if atoms.calc is None:
return ret
try:
ret['energy'] = atoms.get_potential_energy(force_consistent=True)
except NotImplementedError:
ret['energy'] = atoms.get_potential_energy()
try:
ret['force'] = atoms.get_forces(apply_constraint=False)
except NotImplementedError:
pass
try:
y_stress = -1 * atoms.get_stress() # it ensures correct shape
ret['stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]])
except RuntimeError:
pass
return ret
def _set_atoms_y(
atoms_list: List[ase.Atoms],
energy_key: Optional[str] = None,
force_key: Optional[str] = None,
stress_key: Optional[str] = None,
) -> List[ase.Atoms]:
"""
Define how SevenNet reads ASE.atoms object for its y label
If energy_key, force_key, or stress_key is given, the corresponding
label is obtained from .info dict of Atoms object. These values should
have eV, eV/Angstrom, and eV/Angstrom^3 for energy, force, and stress,
respectively. (stress in Voigt notation)
Args:
atoms_list (list[ase.Atoms]): target atoms to set y_labels
energy_key (str, optional): key to get energy. Defaults to None.
force_key (str, optional): key to get force. Defaults to None.
stress_key (str, optional): key to get stress. Defaults to None.
Returns:
list[ase.Atoms]: list of ase.Atoms
Raises:
RuntimeError: if ase atoms are somewhat imperfect
Use free_energy: atoms.get_potential_energy(force_consistent=True)
If it is not available, use atoms.get_potential_energy()
If stress is available, initialize stress tensor
Ignore constraints like selective dynamics
"""
for atoms in atoms_list:
from_calc = _y_from_calc(atoms)
if energy_key is not None:
atoms.info['y_energy'] = atoms.info.pop(energy_key)
else:
atoms.info['y_energy'] = from_calc['energy']
if force_key is not None:
atoms.arrays['y_force'] = atoms.arrays.pop(force_key)
else:
atoms.arrays['y_force'] = from_calc['force']
if stress_key is not None:
y_stress = -1 * atoms.info.pop(stress_key)
atoms.info['y_stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]])
else:
atoms.info['y_stress'] = from_calc['stress']
return atoms_list
def ase_reader(
filename: str,
energy_key: Optional[str] = None,
force_key: Optional[str] = None,
stress_key: Optional[str] = None,
index: str = ':',
**kwargs,
) -> List[ase.Atoms]:
"""
Wrapper of ase.io.read
"""
atoms_list = ase.io.read(filename, index=index, **kwargs)
if not isinstance(atoms_list, list):
atoms_list = [atoms_list]
return _set_atoms_y(atoms_list, energy_key, force_key, stress_key)
# Reader
def structure_list_reader(filename: str, format_outputs: Optional[str] = None):
"""
Read from structure_list using braceexpand and ASE
Args:
fname : filename of structure_list
Returns:
dictionary of lists of ASE structures.
key is title of training data (user-define)
"""
parsers = DefaultParsersContainer(
PositionsAndForces, Stress, Energy, Cell
).make_parsers()
ocp = OutcarChunkParser(parsers=parsers)
def parse_label(line):
line = line.strip()
if line.startswith('[') is False:
return False
elif line.endswith(']') is False:
raise ValueError('wrong structure_list title format')
return line[1:-1]
def parse_fileline(line):
line = line.strip().split()
if len(line) == 1:
line.append(':')
elif len(line) != 2:
raise ValueError('wrong structure_list format')
return line[0], line[1]
structure_list_file = open(filename, 'r')
lines = structure_list_file.readlines()
raw_str_dict = {}
label = 'Default'
for line in lines:
if line.strip() == '':
continue
tmp_label = parse_label(line)
if tmp_label:
label = tmp_label
raw_str_dict[label] = []
continue
elif label in raw_str_dict:
files_expr, index_expr = parse_fileline(line)
raw_str_dict[label].append((files_expr, index_expr))
else:
raise ValueError('wrong structure_list format')
structure_list_file.close()
structures_dict = {}
info_dct = {'data_from': 'user_OUTCAR'}
for title, file_lines in raw_str_dict.items():
stct_lists = []
for file_line in file_lines:
files_expr, index_expr = file_line
index = string2index(index_expr)
for expanded_filename in list(braceexpand(files_expr)):
f_stream = open(expanded_filename, 'r')
# generator of all outcar ionic steps
gen_all = outcarchunks(f_stream, ocp)
try: # TODO: index may not slice, it can be integer
it_atoms = islice(gen_all, index.start, index.stop, index.step)
except ValueError:
# TODO: support
# negative index
raise ValueError('Negative index is not supported yet')
info_dct_f = {
**info_dct,
'file': os.path.abspath(expanded_filename),
}
for idx, o in enumerate(it_atoms):
try:
it_atoms = islice(
gen_all, index.start, index.stop, index.step
)
except ValueError:
# TODO: support
# negative index
raise ValueError('Negative index is not supported yet')
info_dct_f = {
**info_dct,
'file': os.path.abspath(expanded_filename),
}
for idx, o in enumerate(it_atoms):
try:
istep = index.start + idx * index.step # type: ignore
atoms = o.build()
atoms.info = {**info_dct_f, 'ionic_step': istep}.copy()
except TypeError: # it is not slice of ionic steps
atoms = o.build()
atoms.info = info_dct_f.copy()
stct_lists.append(atoms)
f_stream.close()
else:
stct_lists += ase.io.read(
expanded_filename,
index=index_expr,
parallel=False,
)
structures_dict[title] = stct_lists
return {k: _set_atoms_y(v) for k, v in structures_dict.items()}
def dict_reader(data_dict: Dict):
data_dict_cp = copy.deepcopy(data_dict)
ret = []
file_list = data_dict_cp.pop('file_list', None)
if file_list is None:
raise KeyError('file_list is not found')
data_weight_default = {
'energy': 1.0,
'force': 1.0,
'stress': 1.0,
}
data_weight = data_weight_default.copy()
data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {}))
for file_dct in file_list:
ftype = file_dct.pop('data_format', 'ase')
files = list(braceexpand(file_dct.pop('file')))
if ftype == 'ase':
ret.extend(chain(*[ase_reader(f, **file_dct) for f in files]))
elif ftype == 'graph':
continue
else:
raise ValueError(f'{ftype} yet')
for atoms in ret:
atoms.info.update(data_dict_cp)
atoms.info.update({KEY.DATA_WEIGHT: data_weight})
return _set_atoms_y(ret)
def match_reader(reader_name: str, **kwargs):
reader = None
metadata = {}
if reader_name == 'structure_list':
reader = partial(structure_list_reader, **kwargs)
metadata.update({'origin': 'structure_list'})
else:
reader = partial(ase_reader, **kwargs)
metadata.update({'origin': 'ase_reader'})
return reader, metadata
def file_to_dataset(
file: str,
cutoff: float,
cores: int = 1,
reader: Callable = ase_reader,
label: Optional[str] = None,
transfer_info: bool = True,
use_weight: bool = False,
use_modality: bool = False,
):
"""
Deprecated
Read file by reader > get list of atoms or dict of atoms
"""
# expect label: atoms_list dct or atoms or list of atoms
atoms = reader(file)
if type(atoms) is list:
if label is None:
label = KEY.LABEL_NONE
atoms_dct = {label: atoms}
elif isinstance(atoms, ase.Atoms):
if label is None:
label = KEY.LABEL_NONE
atoms_dct = {label: [atoms]}
elif isinstance(atoms, dict):
atoms_dct = atoms
else:
raise TypeError('The return of reader is not list or dict')
graph_dct = {}
for label, atoms_list in atoms_dct.items():
graph_list = graph_build(
atoms_list=atoms_list,
cutoff=cutoff,
num_cores=cores,
transfer_info=transfer_info,
y_from_calc=False,
)
label_info = label.split(':')
for graph in graph_list:
graph[KEY.USER_LABEL] = label_info[0].strip()
if use_weight:
find_weight = False
for info in label_info[1:]:
if 'w=' in info.lower():
weights = info.split('=')[1]
try:
if ',' in weights:
weight_list = list(map(float, weights.split(',')))
else:
weight_list = [float(weights)] * 3
weight_dict = {}
for idx, loss_type in enumerate(LossType):
weight_dict[loss_type.value] = (
weight_list[idx] if idx < len(weight_list) else 1
)
graph[KEY.DATA_WEIGHT] = weight_dict
find_weight = True
break
except:
raise ValueError(
'Weight must be a real number, but'
f' {weights} is given for {label}'
)
if not find_weight:
weight_dict = {}
for loss_type in LossType:
weight_dict[loss_type.value] = 1
graph[KEY.DATA_WEIGHT] = weight_dict
if use_modality:
find_modality = False
for info in label_info[1:]:
if 'm=' in info.lower():
graph[KEY.DATA_MODALITY] = (info.split('=')[1]).strip()
find_modality = True
break
if not find_modality:
raise ValueError(f'Modality not given for {label}')
graph_dct[label_info[0].strip()] = graph_list
db = AtomGraphDataset(graph_dct, cutoff)
return db
import itertools
import random
from collections import Counter
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from ase.data import chemical_symbols
from sklearn.linear_model import Ridge
import sevenn._keys as KEY
import sevenn.util as util
class AtomGraphDataset:
"""
Deprecated
class representing dataset of AtomGraphData
the dataset is handled as dict, {label: data}
if given data is List, it stores data as {KEY_DEFAULT: data}
cutoff is for metadata of the graphs not used for some calc
Every data expected to have one unique cutoff
No validity or check of the condition is done inside the object
attribute:
dataset (Dict[str, List]): key is data label(str), value is list of data
user_labels (List[str]): list of user labels same as dataset.keys()
meta (Dict, Optional): metadata of dataset
for now, metadata 'might' have following keys:
KEY.CUTOFF (float), KEY.CHEMICAL_SPECIES (Dict)
"""
DATA_KEY_X = (
KEY.NODE_FEATURE
) # atomic_number > one_hot_idx > one_hot_vector
DATA_KEY_ENERGY = KEY.ENERGY
DATA_KEY_FORCE = KEY.FORCE
KEY_DEFAULT = KEY.LABEL_NONE
def __init__(
self,
dataset: Union[Dict[str, List], List],
cutoff: float,
metadata: Optional[Dict] = None,
x_is_one_hot_idx: bool = False,
):
"""
Default constructor of AtomGraphDataset
Args:
dataset (Union[Dict[str, List], List]: dataset as dict or pure list
metadata (Dict, Optional): metadata of data
cutoff (float): cutoff radius of graphs inside the dataset
x_is_one_hot_idx (bool): if True, x is one_hot_idx, else 'Z'
'x' (node feature) of dataset can have 3 states, atomic_numbers,
one_hot_idx, or one_hot_vector.
atomic_numbers is general but cannot directly used for input
one_hot_idx is can be input of the model but requires 'type_map'
"""
self.cutoff = cutoff
self.x_is_one_hot_idx = x_is_one_hot_idx
if metadata is None:
metadata = {KEY.CUTOFF: cutoff}
self.meta = metadata
if type(dataset) is list:
self.dataset = {self.KEY_DEFAULT: dataset}
else:
self.dataset = dataset
self.user_labels = list(self.dataset.keys())
# group_by_key here? or not?
def rewrite_labels_to_data(self):
"""
Based on self.dataset dict's keys
write data[KEY.USER_LABEL] to correspond to dict's keys
Most of times, it is already correctly written
But required to rewrite if someone rearrange dataset by their own way
"""
for label, data_list in self.dataset.items():
for data in data_list:
data[KEY.USER_LABEL] = label
def group_by_key(self, data_key: str = KEY.USER_LABEL):
"""
group dataset list by given key and save it as dict
and change in-place
Args:
data_key (str): data key to group by
original use is USER_LABEL, but it can be used for other keys
if someone established it from data[KEY.INFO]
"""
data_list = self.to_list()
self.dataset = {}
for datum in data_list:
key = datum[data_key]
if key not in self.dataset:
self.dataset[key] = []
self.dataset[key].append(datum)
self.user_labels = list(self.dataset.keys())
def separate_info(self, data_key: str = KEY.INFO):
"""
Separate info from data and save it as list of dict
to make it compatible with torch_geometric and later training
"""
data_list = self.to_list()
info_list = []
for datum in data_list:
if data_key in datum is False:
continue
info_list.append(datum[data_key])
del datum[data_key] # It does change the self.dataset
datum[data_key] = len(info_list) - 1
self.info_list = info_list
return (data_list, info_list)
def get_species(self):
"""
You can also use get_natoms and extract keys from there instead of this
(And it is more efficient)
get chemical species of dataset
return list of SORTED chemical species (as str)
"""
if hasattr(self, 'type_map'):
natoms = self.get_natoms(self.type_map)
else:
natoms = self.get_natoms()
species = set()
for natom_dct in natoms.values():
species.update(natom_dct.keys())
species = sorted(list(species))
return species
def get_modalities(self):
modalities = set()
for data_list in self.dataset.values():
datum = data_list[0].to_dict()
if KEY.DATA_MODALITY in datum.keys():
modalities.add(datum[KEY.DATA_MODALITY])
else:
return []
return list(modalities)
def write_modal_attr(
self, modal_type_mapper: dict, write_modal_type: bool = False
):
num_modalities = len(modal_type_mapper)
for data_list in self.dataset.values():
for data in data_list:
tmp_tensor = torch.zeros(num_modalities)
if data[KEY.DATA_MODALITY] != 'common':
modal_idx = modal_type_mapper[data[KEY.DATA_MODALITY]]
tmp_tensor[modal_idx] = 1.0
if write_modal_type:
data[KEY.MODAL_TYPE] = modal_idx
data[KEY.MODAL_ATTR] = tmp_tensor
def get_dict_sort_by_modality(self):
dict_sort_by_modality = {}
for data_list in self.dataset.values():
try:
modal_key = data_list[0].to_dict()[KEY.DATA_MODALITY]
except: # Dataset is not modal
raise ValueError('This dataset has no modality.')
if modal_key not in dict_sort_by_modality.keys():
dict_sort_by_modality[modal_key] = []
dict_sort_by_modality[modal_key].extend(data_list)
return dict_sort_by_modality
def len(self):
if (
len(self.dataset.keys()) == 1
and list(self.dataset.keys())[0] == AtomGraphDataset.KEY_DEFAULT
):
return len(self.dataset[AtomGraphDataset.KEY_DEFAULT])
else:
return {k: len(v) for k, v in self.dataset.items()}
def get(self, idx: int, key: Optional[str] = None):
if key is None:
key = self.KEY_DEFAULT
return self.dataset[key][idx]
def items(self):
return self.dataset.items()
def to_dict(self):
dct_dataset = {}
for label, data_list in self.dataset.items():
dct_dataset[label] = [datum.to_dict() for datum in data_list]
self.dataset = dct_dataset
return self
def x_to_one_hot_idx(self, type_map: Dict[int, int]):
"""
type_map is dict of {atomic_number: one_hot_idx}
after this process, the dataset has dependency on type_map
or chemical species user want to consider
"""
assert self.x_is_one_hot_idx is False
for data_list in self.dataset.values():
for datum in data_list:
datum[self.DATA_KEY_X] = torch.LongTensor(
[type_map[z.item()] for z in datum[self.DATA_KEY_X]]
)
self.type_map = type_map
self.x_is_one_hot_idx = True
def toggle_requires_grad_of_data(
self, key: str, requires_grad_value: bool
):
"""
set requires_grad of specific key of data(pos, edge_vec, ...)
"""
for data_list in self.dataset.values():
for datum in data_list:
datum[key].requires_grad_(requires_grad_value)
def divide_dataset(
self,
ratio: float,
constant_ratio_btw_labels: bool = True,
ignore_test: bool = True
):
"""
divide dataset into 1-2*ratio : ratio : ratio
return divided AtomGraphDataset
returned value lost its dict key and became {KEY_DEFAULT: datalist}
but KEY.USER_LABEL of each data is preserved
"""
def divide(ratio: float, data_list: List, ignore_test=True):
if ratio > 0.5:
raise ValueError('Ratio must not exceed 0.5')
data_len = len(data_list)
random.shuffle(data_list)
n_validation = int(data_len * ratio)
if n_validation == 0:
raise ValueError(
'# of validation set is 0, increase your dataset'
)
if ignore_test:
test_list = []
n_train = data_len - n_validation
train_list = data_list[0:n_train]
valid_list = data_list[n_train:]
else:
n_train = data_len - 2 * n_validation
train_list = data_list[0:n_train]
valid_list = data_list[n_train : n_train + n_validation]
test_list = data_list[n_train + n_validation : data_len]
return train_list, valid_list, test_list
lists = ([], [], []) # train, valid, test
if constant_ratio_btw_labels:
for data_list in self.dataset.values():
for store, divided in zip(lists, divide(ratio, data_list)):
store.extend(divided)
else:
lists = divide(ratio, self.to_list())
dbs = tuple(
AtomGraphDataset(data, self.cutoff, self.meta) for data in lists
)
for db in dbs:
db.group_by_key()
return dbs
def to_list(self):
return list(itertools.chain(*self.dataset.values()))
def get_natoms(self, type_map: Optional[Dict[int, int]] = None):
"""
if x_is_one_hot_idx, type_map is required
type_map: Z->one_hot_index(node_feature)
return Dict{label: {symbol, natom}]}
"""
assert not (self.x_is_one_hot_idx is True and type_map is None)
natoms = {}
for label, data in self.dataset.items():
natoms[label] = Counter()
for datum in data:
if self.x_is_one_hot_idx and type_map is not None:
Zs = util.onehot_to_chem(datum[self.DATA_KEY_X], type_map)
else:
Zs = [
chemical_symbols[z]
for z in datum[self.DATA_KEY_X].tolist()
]
cnt = Counter(Zs)
natoms[label] += cnt
natoms[label] = dict(natoms[label])
return natoms
def get_per_atom_mean(self, key: str, key_num_atoms: str = KEY.NUM_ATOMS):
"""
return per_atom mean of given data key
"""
eng_list = torch.Tensor(
[x[key] / x[key_num_atoms] for x in self.to_list()]
)
return float(torch.mean(eng_list))
def get_per_atom_energy_mean(self):
"""
alias for get_per_atom_mean(KEY.ENERGY)
"""
return self.get_per_atom_mean(self.DATA_KEY_ENERGY)
def get_species_ref_energy_by_linear_comb(self, num_chem_species: int):
"""
Total energy as y, composition as c_i,
solve linear regression of y = c_i*X
sklearn LinearRegression as solver
x should be one-hot-indexed
give num_chem_species if possible
"""
assert self.x_is_one_hot_idx is True
data_list = self.to_list()
c = torch.zeros((len(data_list), num_chem_species))
for idx, datum in enumerate(data_list):
c[idx] = torch.bincount(
datum[self.DATA_KEY_X], minlength=num_chem_species
)
y = torch.Tensor([x[self.DATA_KEY_ENERGY] for x in data_list])
c = c.numpy()
y = y.numpy()
# tweak to fine tune training from many-element to small element
zero_indices = np.all(c == 0, axis=0)
c_reduced = c[:, ~zero_indices]
full_coeff = np.zeros(num_chem_species)
coef_reduced = (
Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_
)
full_coeff[~zero_indices] = coef_reduced
return full_coeff
def get_force_rms(self):
force_list = []
for x in self.to_list():
force_list.extend(
x[self.DATA_KEY_FORCE]
.reshape(
-1,
)
.tolist()
)
force_list = torch.Tensor(force_list)
return float(torch.sqrt(torch.mean(torch.pow(force_list, 2))))
def get_species_wise_force_rms(self, num_chem_species: int):
"""
Return force rms for each species
Averaged by each components (x, y, z)
"""
assert self.x_is_one_hot_idx is True
data_list = self.to_list()
atomx = torch.concat([d[self.DATA_KEY_X] for d in data_list])
force = torch.concat([d[self.DATA_KEY_FORCE] for d in data_list])
index = atomx.repeat_interleave(3, 0).reshape(force.shape)
rms = torch.zeros(
(num_chem_species, 3),
dtype=force.dtype,
device=force.device
)
rms.scatter_reduce_(
0, index, force.square(),
reduce='mean', include_self=False
)
return torch.sqrt(rms.mean(dim=1))
def get_avg_num_neigh(self):
n_neigh = []
for _, data_list in self.dataset.items():
for data in data_list:
n_neigh.extend(
np.unique(data[KEY.EDGE_IDX][0], return_counts=True)[1]
)
avg_num_neigh = np.average(n_neigh)
return avg_num_neigh
def get_statistics(self, key: str):
"""
return dict of statistics of given key (energy, force, stress)
key of dict is its label and _total for total statistics
value of dict is dict of statistics (mean, std, median, max, min)
"""
def _get_statistic_dict(tensor_list):
data_list = torch.cat(
[
tensor.reshape(
-1,
)
for tensor in tensor_list
]
)
data_list = data_list[~torch.isnan(data_list)]
return {
'mean': float(torch.mean(data_list)),
'std': float(torch.std(data_list)),
'median': float(torch.median(data_list)),
'max': (
torch.nan
if data_list.numel() == 0
else float(torch.max(data_list))
),
'min': (
torch.nan
if data_list.numel() == 0
else float(torch.min(data_list))
),
}
res = {}
for label, values in self.dataset.items():
# flatten list of torch.Tensor (values)
tensor_list = [x[key] for x in values]
res[label] = _get_statistic_dict(tensor_list)
tensor_list = [x[key] for x in self.to_list()]
res['Total'] = _get_statistic_dict(tensor_list)
return res
def augment(self, dataset, validator: Optional[Callable] = None):
"""check meta compatibility here
dataset(AtomGraphDataset): data to augment
validator(Callable, Optional): function(self, dataset) -> bool
if validator is None, by default it checks
whether cutoff & chemical_species are same before augment
check consistent data type, float, double, long integer etc
"""
def default_validator(db1, db2):
cut_consis = db1.cutoff == db2.cutoff
# compare unordered lists
x_is_not_onehot = (not db1.x_is_one_hot_idx) and (
not db2.x_is_one_hot_idx
)
return cut_consis and x_is_not_onehot
if validator is None:
validator = default_validator
if not validator(self, dataset):
raise ValueError('given datasets are not compatible check cutoffs')
for key, val in dataset.items():
if key in self.dataset:
self.dataset[key].extend(val)
else:
self.dataset.update({key: val})
self.user_labels = list(self.dataset.keys())
def unify_dtypes(
self,
float_dtype: torch.dtype = torch.float32,
int_dtype: torch.dtype = torch.int64
):
data_list = self.to_list()
for datum in data_list:
for k, v in list(datum.items()):
datum[k] = util.dtype_correct(v, float_dtype, int_dtype)
def delete_data_key(self, key: str):
for data in self.to_list():
del data[key]
# TODO: this by_label is not straightforward
def save(self, path: str, by_label: bool = False):
if by_label:
for label, data in self.dataset.items():
torch.save(
AtomGraphDataset(
{label: data}, self.cutoff, metadata=self.meta
),
f'{path}/{label}.sevenn_data',
)
else:
if path.endswith('.sevenn_data') is False:
path += '.sevenn_data'
torch.save(self, path)
import os
import warnings
from collections import Counter
from copy import deepcopy
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.serialization
import torch.utils.data
import yaml
from ase.data import chemical_symbols
from torch_geometric.data import Data
from torch_geometric.data.in_memory_dataset import InMemoryDataset
from tqdm import tqdm
import sevenn._keys as KEY
import sevenn.train.dataload as dataload
import sevenn.util as util
from sevenn import __version__
from sevenn._const import NUM_UNIV_ELEMENT
from sevenn.atom_graph_data import AtomGraphData
from sevenn.logger import Logger
if torch.__version__.split()[0] >= '2.4.0':
# load graph without error
torch.serialization.add_safe_globals([AtomGraphData])
# warning from PyG, for later torch versions
warnings.filterwarnings(
'ignore',
message='You are using `torch.load` with `weights_only=False`',
)
def _tag_graphs(graph_list: List[AtomGraphData], tag: str):
"""
WIP: To be used
"""
for g in graph_list:
g[KEY.TAG] = tag
return graph_list
def pt_to_args(pt_filename: str):
"""
Return arg dict of root and processed_name from path to .pt
Usage:
dataset = SevenNetGraphDataset(
**pt_to_args({path}/sevenn_data/dataset.pt)
)
"""
processed_dir, basename = os.path.split(pt_filename)
return {
'root': os.path.dirname(processed_dir),
'processed_name': os.path.basename(basename),
}
def _run_stat(
graph_list,
y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS],
) -> Dict[str, Any]:
"""
Loop over dataset and init any statistics might need
"""
n_neigh = []
natoms_counter = Counter()
composition = torch.zeros((len(graph_list), NUM_UNIV_ELEMENT))
stats: Dict[str, Any] = {y: {'_array': []} for y in y_keys}
for i, graph in tqdm(
enumerate(graph_list), desc='run_stat', total=len(graph_list)
):
z_tensor = graph[KEY.ATOMIC_NUMBERS]
natoms_counter.update(z_tensor.tolist())
composition[i] = torch.bincount(z_tensor, minlength=NUM_UNIV_ELEMENT)
n_neigh.append(torch.unique(graph[KEY.EDGE_IDX][0], return_counts=True)[1])
for y, dct in stats.items():
dct['_array'].append(
graph[y].reshape(
-1,
)
)
stats.update({'num_neighbor': {'_array': n_neigh}})
for y, dct in stats.items():
array = torch.cat(dct['_array'])
if array.dtype == torch.int64: # because of n_neigh
array = array.to(torch.float)
try:
median = torch.quantile(array, q=0.5)
except RuntimeError:
warnings.warn(f'skip median due to too large tensor size: {y}')
median = torch.nan
dct.update(
{
'mean': float(torch.mean(array)),
'std': float(torch.std(array, correction=0)),
'median': float(median),
'max': float(torch.max(array)),
'min': float(torch.min(array)),
'count': array.numel(),
'_array': array,
}
)
natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()}
natoms['total'] = sum(list(natoms.values()))
stats.update({'_composition': composition, 'natoms': natoms})
return stats
def _elemwise_reference_energies(composition: np.ndarray, energies: np.ndarray):
from sklearn.linear_model import Ridge
c = composition
y = energies
zero_indices = np.all(c == 0, axis=0)
c_reduced = c[:, ~zero_indices]
# will not 100% reproduce, as it is sorted by Z
# train/dataset.py was sorted by alphabets of chemical species
coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_
full_coeff = np.zeros(NUM_UNIV_ELEMENT)
full_coeff[~zero_indices] = coef_reduced
return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy
class SevenNetGraphDataset(InMemoryDataset):
"""
Replacement of AtomGraphDataset. (and .sevenn_data)
Extends InMemoryDataset of PyG. From given 'files', and 'cutoff',
build graphs for training SevenNet model. Preprocessed graphs are saved to
f'{root}/sevenn_data/{processed_name}.pt
TODO: Save meta info (cutoff) by overriding .save and .load
TODO: 'tag' is not used yet, but initialized
'tag' is replacement for 'label', and each datapoint has it as integer
'tag' is usually parsed from if the structure_list of load_dataset
Args:
root: path to save/load processed PyG dataset
cutoff: edge cutoff of given AtomGraphData
files: list of filenames or dict describing how to parse the file
ASE readable (with proper extension), structure_list, .sevenn_data,
dict containing file_list (see dict_reader of train/dataload.py)
process_num_cores: # of cpu cores to build graph
processed_name: save as {root}/sevenn_data/{processed_name}.pt
pre_transfrom: optional transform for each graph: def (graph) -> graph
pre_filter: optional filtering function for each graph: def (graph) -> graph
force_reload: if True, reload dataset from files even if there exist
{root}/sevenn_data/{processed_name}
**process_kwargs: keyword arguments that will be passed into ase.io.read
"""
def __init__(
self,
cutoff: float,
root: Optional[str] = None,
files: Optional[Union[str, List[Any]]] = None,
process_num_cores: int = 1,
processed_name: str = 'graph.pt',
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
use_data_weight: bool = False,
log: bool = True,
force_reload: bool = False,
drop_info: bool = True,
**process_kwargs,
):
self.cutoff = cutoff
if files is None:
files = []
elif isinstance(files, str):
files = [files] # user convenience
_files = []
for f in files:
if isinstance(f, str):
f = os.path.abspath(f)
_files.append(f)
self._files = _files
self._full_file_list = []
if not processed_name.endswith('.pt'):
processed_name += '.pt'
self._processed_names = [
processed_name, # {root}/sevenn_data/{name}.pt
processed_name.replace('.pt', '.yaml'),
]
root = root or './'
_pdir = os.path.join(root, 'sevenn_data')
_pt = os.path.join(_pdir, self._processed_names[0])
if not os.path.exists(_pt) and len(self._files) == 0:
raise ValueError(
(
f'{_pt} not found and no files to process. '
+ 'If you copied only .pt file, please copy '
+ 'whole sevenn_data dir without changing its name.'
+ ' They all work together.'
)
)
_yam = os.path.join(_pdir, self._processed_names[1])
if not os.path.exists(_yam) and len(self._files) == 0:
raise ValueError(f'{_yam} not found and no files to process')
self.process_num_cores = process_num_cores
self.process_kwargs = process_kwargs
self.use_data_weight = use_data_weight
self.drop_info = drop_info
self.tag_map = {}
self.statistics = {}
self.finalized = False
super().__init__(
root,
transform,
pre_transform,
pre_filter,
log=log,
force_reload=force_reload,
) # Internally calls 'process'
self.load(self.processed_paths[0]) # load pt, saved after process
def load(self, path: str, data_cls=Data) -> None:
super().load(path, data_cls)
if len(self) == 0:
warnings.warn(f'No graphs found {self.processed_paths[0]}')
if len(self.statistics) == 0:
# dataset is loaded from existing pt file.
self._load_meta()
def _load_meta(self) -> None:
with open(self.processed_paths[1], 'r') as f:
meta = yaml.safe_load(f)
if meta['sevennet_version'] == '0.10.0':
self._save_meta(list(self))
with open(self.processed_paths[1], 'r') as f:
meta = yaml.safe_load(f)
cutoff = float(meta['cutoff'])
if float(meta['cutoff']) != self.cutoff:
warnings.warn(
(
'Loaded dataset is built with different cutoff length: '
+ f'{cutoff} != {self.cutoff}, dataset cutoff will be'
+ f' overwritten to {cutoff}'
)
)
self.cutoff = cutoff
self._files = meta['files']
self.statistics = meta['statistics']
def __getitem__(self, idx):
graph = super().__getitem__(idx)
if self.drop_info:
graph.pop(KEY.INFO, None) # type: ignore
return graph
@property
def raw_file_names(self) -> List[Any]:
return self._files
@property
def processed_file_names(self) -> List[str]:
return self._processed_names
@property
def processed_dir(self) -> str:
return os.path.join(self.root, 'sevenn_data')
@property
def full_file_list(self) -> Union[List[str], None]:
return self._full_file_list
def process(self):
graph_list: List[AtomGraphData] = []
for file in self.raw_file_names:
tmplist = SevenNetGraphDataset.file_to_graph_list(
file=file,
cutoff=self.cutoff,
num_cores=self.process_num_cores,
**self.process_kwargs,
)
if isinstance(file, str) and self._full_file_list is not None:
self._full_file_list.extend([os.path.abspath(file)] * len(tmplist))
else:
self._full_file_list = None
graph_list.extend(tmplist)
processed_graph_list = []
for data in graph_list:
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
if self.use_data_weight:
# pop data weight from info, and assign to graph
weight = data[KEY.INFO].pop(
KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0}
)
data[KEY.DATA_WEIGHT] = weight
processed_graph_list.append(data)
if len(processed_graph_list) == 0:
# Can not save at all if there is no graph (error in PyG), raise an error
raise ValueError('Zero graph found after filtering')
# save graphs, handled by torch_geometrics
self.save(processed_graph_list, self.processed_paths[0])
self._save_meta(processed_graph_list)
if self.log:
Logger().writeline(f'Dataset is saved: {self.processed_paths[0]}')
def _save_meta(self, graph_list) -> None:
stats = _run_stat(graph_list)
stats['elemwise_reference_energies'] = _elemwise_reference_energies(
stats['_composition'].numpy(), stats[KEY.ENERGY]['_array'].numpy()
)
self.statistics = stats
stats_save = {}
for label, dct in self.statistics.items():
if label.startswith('_'):
continue
stats_save[label] = {}
if not isinstance(dct, dict):
stats_save[label] = dct
else:
for k, v in dct.items():
if k.startswith('_'):
continue
stats_save[label][k] = v
meta = {
'sevennet_version': __version__,
'cutoff': self.cutoff,
'when': datetime.now().strftime('%Y-%m-%d %H:%M'),
'files': self._files,
'statistics': stats_save,
'species': self.species,
'num_graphs': self.statistics[KEY.ENERGY]['count'],
'per_atom_energy_mean': self.per_atom_energy_mean,
'force_rms': self.force_rms,
'per_atom_energy_std': self.per_atom_energy_std,
'avg_num_neigh': self.avg_num_neigh,
'sqrt_avg_num_neigh': self.sqrt_avg_num_neigh,
}
with open(self.processed_paths[1], 'w') as f:
yaml.dump(meta, f, default_flow_style=False)
@property
def species(self):
return [z for z in self.statistics['natoms'].keys() if z != 'total']
@property
def natoms(self):
return self.statistics['natoms']
@property
def per_atom_energy_mean(self):
return self.statistics[KEY.PER_ATOM_ENERGY]['mean']
@property
def elemwise_reference_energies(self):
return self.statistics['elemwise_reference_energies']
@property
def force_rms(self):
mean = self.statistics[KEY.FORCE]['mean']
std = self.statistics[KEY.FORCE]['std']
return float((mean**2 + std**2) ** (0.5))
@property
def per_atom_energy_std(self):
return self.statistics['per_atom_energy']['std']
@property
def avg_num_neigh(self):
return self.statistics['num_neighbor']['mean']
@property
def sqrt_avg_num_neigh(self):
return self.avg_num_neigh**0.5
@staticmethod
def _read_sevenn_data(filename: str) -> Tuple[List[AtomGraphData], float]:
# backward compatibility
from sevenn.train.dataset import AtomGraphDataset
dataset = torch.load(filename, map_location='cpu', weights_only=False)
if isinstance(dataset, AtomGraphDataset):
graph_list = []
for _, graphs in dataset.dataset.items(): # type: ignore
# TODO: transfer label to tag (who gonna need this?)
graph_list.extend(graphs)
return graph_list, dataset.cutoff
else:
raise ValueError(f'Not sevenn_data type: {type(dataset)}')
@staticmethod
def _read_structure_list(
filename: str, cutoff: float, num_cores: int = 1
) -> List[AtomGraphData]:
datadct = dataload.structure_list_reader(filename)
graph_list = []
for tag, atoms_list in datadct.items():
tmp = dataload.graph_build(atoms_list, cutoff, num_cores)
graph_list.extend(_tag_graphs(tmp, tag))
return graph_list
@staticmethod
def _read_ase_readable(
filename: str,
cutoff: float,
num_cores: int = 1,
tag: str = '',
transfer_info: bool = True,
allow_unlabeled: bool = False,
**ase_kwargs,
) -> List[AtomGraphData]:
pbc_override = ase_kwargs.pop('pbc', None)
atoms_list = dataload.ase_reader(filename, **ase_kwargs)
for atoms in atoms_list:
if pbc_override is not None:
atoms.pbc = pbc_override
graph_list = dataload.graph_build(
atoms_list,
cutoff,
num_cores,
transfer_info=transfer_info,
allow_unlabeled=allow_unlabeled,
)
if tag != '':
graph_list = _tag_graphs(graph_list, tag)
return graph_list
@staticmethod
def _read_graph_dataset(
filename: str, cutoff: float, **kwargs
) -> List[AtomGraphData]:
meta_f = filename.replace('.pt', '.yaml')
orig_cutoff = cutoff
if not os.path.exists(filename):
raise FileNotFoundError(f'No such file: {filename}')
if not os.path.exists(meta_f):
warnings.warn('No meta info found, beware of cutoff...')
else:
with open(meta_f, 'r') as f:
meta = yaml.safe_load(f)
orig_cutoff = float(meta['cutoff'])
if orig_cutoff != cutoff:
warnings.warn(
f'{filename} has different cutoff length: '
+ f'{cutoff} != {orig_cutoff}'
)
ds_args: dict[str, Any] = dict({'cutoff': orig_cutoff})
ds_args.update(pt_to_args(filename))
ds_args.update(kwargs)
dataset = SevenNetGraphDataset(**ds_args)
# TODO: hard coded. consult with inference.py
glist = [g.fit_dimension() for g in dataset] # type: ignore
for g in glist:
if KEY.STRESS in g:
# (1, 6) is what we want
g[KEY.STRESS] = g[KEY.STRESS].unsqueeze(0)
return glist
@staticmethod
def _read_dict(
data_dict: dict,
cutoff: float,
num_cores: int = 1,
):
# logic same as the dataload dict_reader, but handles graphs
data_dict_cp = deepcopy(data_dict)
file_list = data_dict_cp.get('file_list', None)
if file_list is None:
raise KeyError('file_list is not found')
data_weight_default = {
'energy': 1.0,
'force': 1.0,
'stress': 1.0,
}
data_weight = data_weight_default.copy()
data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {}))
graph_list = []
for file_dct in file_list:
ftype = file_dct.pop('data_format', 'ase')
if ftype != 'graph':
continue
graph_list.extend(
SevenNetGraphDataset._read_graph_dataset(
file_dct.get('file'), cutoff=cutoff
)
)
for graph in graph_list:
if KEY.INFO not in graph:
graph[KEY.INFO] = {}
graph[KEY.INFO].update(data_dict_cp)
graph[KEY.INFO].update({KEY.DATA_WEIGHT: data_weight})
atoms_list = dataload.dict_reader(data_dict)
graph_list.extend(dataload.graph_build(atoms_list, cutoff, num_cores))
return graph_list
@staticmethod
def file_to_graph_list(
file: Union[str, dict], cutoff: float, num_cores: int = 1, **kwargs
) -> List[AtomGraphData]:
"""
kwargs: if file is ase readable, passed to ase.io.read
"""
if isinstance(file, str) and not os.path.isfile(file):
raise ValueError(f'No such file: {file}')
graph_list: List[AtomGraphData]
if isinstance(file, dict):
graph_list = SevenNetGraphDataset._read_dict(
file, cutoff, num_cores, **kwargs
)
elif file.endswith('.pt'):
graph_list = SevenNetGraphDataset._read_graph_dataset(file, cutoff)
elif file.endswith('.sevenn_data'):
graph_list, cutoff_other = SevenNetGraphDataset._read_sevenn_data(file)
if cutoff_other != cutoff:
warnings.warn(f'Given {file} has different {cutoff_other}!')
cutoff = cutoff_other
elif 'structure_list' in file:
graph_list = SevenNetGraphDataset._read_structure_list(
file, cutoff, num_cores
)
else:
graph_list = SevenNetGraphDataset._read_ase_readable(
file, cutoff, num_cores, **kwargs
)
return graph_list
def from_single_path(
path: Union[str, List], override_data_weight: bool = True, **dataset_kwargs
) -> Union[SevenNetGraphDataset, None]:
"""
Convenient routine for loading a single .pt dataset.
If given dict and it has data_weight, apply it using transform
"""
data_weight = {'energy': 1.0, 'force': 1.0, 'stress': 1.0}
spath = _extract_single_path(path)
if spath is None:
return None
if isinstance(spath, str):
if not spath.endswith('.pt'):
return None
dataset_kwargs.update(pt_to_args(spath))
elif isinstance(spath, dict):
file = _extract_file_from_dict(spath)
if file is None or not file.endswith('.pt'):
return None
dataset_kwargs.update(pt_to_args(file))
data_weight_user = spath.get(KEY.DATA_WEIGHT, None)
if data_weight_user is not None:
data_weight.update(data_weight_user)
else:
return None
if override_data_weight:
dataset_kwargs['transform'] = _chain_data_weight_override(
dataset_kwargs.get('transform'), data_weight
)
return SevenNetGraphDataset(**dataset_kwargs)
def _extract_single_path(path: Union[str, List]) -> Union[str, dict, None]:
"""Extracts a single path from the input,
ensuring it's either a single string or list with one item."""
if isinstance(path, list):
return path[0] if len(path) == 1 else None
return path if isinstance(path, (str, dict)) else None
def _extract_file_from_dict(path_dict: dict) -> Union[str, None]:
"""Extracts a single file path from the dictionary, ensuring it's valid."""
file_list = path_dict.get('file_list', None)
if file_list and len(file_list) == 1:
file = file_list[0].get('file', None)
return file if isinstance(file, str) else None
return None
def _chain_data_weight_override(transform_func, data_weight):
"""Creates a transform function that overrides the data weight."""
def chained_transform(graph):
graph = transform_func(graph) if transform_func is not None else graph
graph[KEY.INFO].pop(KEY.DATA_WEIGHT, None)
graph[KEY.DATA_WEIGHT] = data_weight
return graph
return chained_transform
# script, return dict of SevenNetGraphDataset
def from_config(
config: Dict[str, Any],
working_dir: str = os.getcwd(),
dataset_keys: Optional[List[str]] = None,
):
log = Logger()
if dataset_keys is None:
dataset_keys = []
for k in config:
if k.startswith('load_') and k.endswith('_path'):
dataset_keys.append(k)
if KEY.LOAD_TRAINSET not in dataset_keys:
raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config')
# initialize arguments for loading dataset
dataset_args = {
'cutoff': config[KEY.CUTOFF],
'root': working_dir,
'process_num_cores': config.get(KEY.PREPROCESS_NUM_CORES, 1),
'use_data_weight': config.get(KEY.USE_WEIGHT, False),
**config.get(KEY.DATA_FORMAT_ARGS, {}),
}
datasets = {}
for dk in dataset_keys:
if not (paths := config[dk]):
continue
if isinstance(paths, str):
paths = [paths]
name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]])
if (dataset := from_single_path(paths, **dataset_args)) is not None:
datasets[name] = dataset
else:
dataset_args.update({'files': paths, 'processed_name': name})
dataset_path = os.path.join(working_dir, 'sevenn_data', f'{name}.pt')
if os.path.exists(dataset_path) and 'force_reload' not in dataset_args:
log.writeline(
f'Dataset will be loaded from {dataset_path}, without update. '
+ 'If you have changed your files to read, put force_reload=True'
+ ' under the data_format_args key'
)
datasets[name] = SevenNetGraphDataset(**dataset_args)
train_set = datasets['trainset']
chem_species = set(train_set.species)
# print statistics of each dataset
for name, dataset in datasets.items():
log.bar()
log.writeline(f'{name} distribution:')
log.statistic_write(dataset.statistics)
log.format_k_v('# structures (graph)', len(dataset), write=True)
chem_species.update(dataset.species)
log.bar()
# initialize known species from dataset if 'auto'
# sorted to alphabetical order (which is same as before)
chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP]
if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py
log.writeline('Known species are obtained from the dataset')
config.update(util.chemical_species_preprocess(sorted(list(chem_species))))
# retrieve shift, scale, conv_denominaotrs from user input (keyword)
init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR]
for k in init_from_stats:
input = config[k] # statistic key or numbers
# If it is not 'str', 1: It is 'continue' training
# 2: User manually inserted numbers
if isinstance(input, str) and hasattr(train_set, input):
var = getattr(train_set, input)
config.update({k: var})
log.writeline(f'{k} is obtained from statistics')
elif isinstance(input, str) and not hasattr(train_set, input):
raise NotImplementedError(input)
if 'validset' not in datasets and config.get(KEY.RATIO, 0.0) > 0.0:
log.writeline('Use validation set as random split from the training set')
log.writeline(
'Note that statistics, shift, scale, and conv_denominator are '
+ 'computed before random split.\n If you want these after random '
+ 'split, please preprocess dataset and set it as load_trainset_path '
+ 'and load_validset_path explicitly.'
)
ratio = float(config[KEY.RATIO])
train, valid = torch.utils.data.random_split(
datasets['trainset'], (1.0 - ratio, ratio)
)
datasets['trainset'] = train
datasets['validset'] = valid
return datasets
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment