Unverified Commit ca86f720 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent b75ed73c
import math
from typing import List
import torch
import torch.nn as nn
from e3nn.o3 import Irreps, Linear
import sevenn._keys as KEY
from sevenn.model_build import build_E3_equivariant_model
modal_module_dict = {
KEY.USE_MODAL_NODE_EMBEDDING: 'onehot_to_feature_x',
KEY.USE_MODAL_SELF_INTER_INTRO: 'self_interaction_1',
KEY.USE_MODAL_SELF_INTER_OUTRO: 'self_interaction_2',
KEY.USE_MODAL_OUTPUT_BLOCK: 'reduce_input_to_hidden',
}
def _get_scalar_index(irreps: Irreps):
scalar_indices = []
for idx, (_, (l, p)) in enumerate(irreps): # noqa
if (
l == 0 and p == 1
): # get index of parameter for scalar (0e), which is used for modality
scalar_indices.append(idx)
return scalar_indices
def _reshape_weight_of_linear(
irreps_in: Irreps, irreps_out: Irreps, weight: torch.Tensor
) -> List[torch.Tensor]:
linear = Linear(irreps_in, irreps_out)
linear.weight = nn.Parameter(weight)
return list(linear.weight_views())
def _erase_linear_modal_params(
model_state_dct: dict,
erase_modal_indices: List[int],
key: str,
irreps_in: Irreps,
irreps_out: Irreps,
):
orig_input_dim = irreps_in.count('0e')
new_input_dim = orig_input_dim - len(erase_modal_indices)
orig_weight = model_state_dct[key + '.linear.weight']
scalar_idx = _get_scalar_index(irreps_in)
linear_weight_list = _reshape_weight_of_linear(
irreps_in, irreps_out, orig_weight
)
new_weight_list = []
for idx, l_p_weight in enumerate(linear_weight_list[:-1]):
new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze()
if idx in scalar_idx:
new_weight = new_weight * math.sqrt(new_input_dim / orig_input_dim)
new_weight_list.append(new_weight)
"""
Following works for normalization = `path`, which is not used in SEVENNet
for l_p_weight in linear_weight_list[:-1]:
new_weight_list.append(torch.reshape(l_p_weight, (1, -1)).squeeze())
"""
flattened_weight = torch.cat(new_weight_list)
return flattened_weight
def _get_modal_weight_as_bias(
model_state_dct: dict,
key: str,
ref_index: int,
irreps_in: Irreps,
irreps_out: Irreps,
):
assert ref_index != -1
input_dim = irreps_in.count('0e')
output_dim = irreps_out.count('0e')
orig_weight = model_state_dct[key + '.linear.weight']
orig_bias = model_state_dct[key + '.linear.bias']
if len(orig_bias) == 0:
orig_bias = torch.zeros(output_dim, dtype=orig_weight.dtype)
modal_weight = _reshape_weight_of_linear(
irreps_in, irreps_out, orig_weight
)[-1]
new_bias = orig_bias + modal_weight[ref_index] / math.sqrt(input_dim)
return new_bias
def _append_modal_weight(
model_state_dct: dict, # state dict to be targeted
key: str, # linear weight modune name
irreps_in: Irreps, # irreps_in before modality append
irreps_out: Irreps,
append_number: int,
):
# This works for normalization = `element`, default in SEVENNet.
# (normalization = `path` is curruently deprecated in SEVENNet.)
input_dim = irreps_in.count('0e')
output_dim = irreps_out.count('0e')
new_input_dim = input_dim + append_number
orig_weight = model_state_dct[key + '.linear.weight']
scalar_idx = _get_scalar_index(irreps_in)
linear_weight_list = _reshape_weight_of_linear(
irreps_in, irreps_out, orig_weight
)
new_weight_list = []
# TODO: combine following as function with _erase_linear_modal_params
for idx, l_p_weight in enumerate(linear_weight_list):
new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze()
if idx in scalar_idx:
new_weight = new_weight * math.sqrt(new_input_dim / input_dim)
new_weight_list.append(new_weight)
flattened_weight_list = []
for l_p_weight in new_weight_list:
flattened_weight_list.append(
torch.reshape(l_p_weight, (1, -1)).squeeze()
)
flattened_weight = torch.cat(flattened_weight_list)
append_weight = torch.cat([
flattened_weight,
torch.zeros(append_number * output_dim, dtype=flattened_weight.dtype),
]) # zeros: starting from common model
return append_weight
def get_single_modal_model_dct(
model_state_dct: dict,
config: dict,
ref_modal: str,
from_processing_cp: bool = False,
is_deploy: bool = False,
):
"""
Convert multimodal model state dictionary to single modal model.
Modal is selected by `ref_modal`
`model_state_dct`: model state dictionary from multimodal checkpoint file
`config`: dictionary containing configuration of the checkpoint model
`ref_modal`: modal that are going to be converted
`from_processing_cp`: if True, use modal_map of the checkpoint file
`is_deploy`: if True, model is build with single-modal shift and scale
"""
if (
not from_processing_cp and not config[KEY.USE_MODALITY]
): # model is already single modal
return model_state_dct
config[KEY.USE_BIAS_IN_LINEAR] = True
config['_deploy'] = is_deploy
model = build_E3_equivariant_model(config)
del config['_deploy']
key_add = '_cp' if from_processing_cp else ''
modal_type_dict = config[KEY.MODAL_MAP + key_add]
erase_modal_indices = range(len(modal_type_dict.keys())) # starts with 0
if ref_modal != 'common':
try:
ref_modal_index = modal_type_dict[ref_modal]
except:
raise KeyError(
f'{ref_modal} not in modal type. Use one of'
f' {modal_type_dict.keys()}.'
)
for module_key in model._modules.keys():
for (
use_modal_module_key,
modal_module_name,
) in modal_module_dict.items():
irreps_out = Irreps(model.get_irreps_in(module_key, 'irreps_out'))
# TODO: directly using "irreps_in" might not be compatible
# when changing `nn/linear.py`
output_dim = irreps_out.count('0e')
if (
config[use_modal_module_key]
and modal_module_name in module_key
): # this module is used for giving modality
irreps_in = Irreps(
model.get_irreps_in(module_key, 'irreps_in')
)
new_bias = (
torch.zeros(output_dim)
if ref_modal == 'common'
else _get_modal_weight_as_bias(
model_state_dct,
module_key,
ref_modal_index,
irreps_in, # type: ignore
irreps_out, # type: ignore
)
)
erased_modal_weight = _erase_linear_modal_params(
model_state_dct,
erase_modal_indices,
module_key,
irreps_in, # type: ignore
irreps_out, # type: ignore
)
model_state_dct[module_key + '.linear.weight'] = (
erased_modal_weight
)
model_state_dct[module_key + '.linear.bias'] = new_bias
elif modal_module_name in module_key:
model_state_dct[module_key + '.linear.bias'] = torch.zeros(
output_dim,
dtype=model_state_dct[module_key + '.linear.weight'].dtype,
)
final_block_key = 'reduce_hidden_to_energy'
model_state_dct[final_block_key + '.linear.bias'] = torch.tensor(
[0], dtype=model_state_dct[final_block_key + '.linear.weight'].dtype
)
if config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SHIFT]:
rescaler_names = []
if config[KEY.USE_MODAL_WISE_SHIFT]:
rescaler_names.append('shift')
if config[KEY.USE_MODAL_WISE_SCALE]:
rescaler_names.append('scale')
config[KEY.USE_MODAL_WISE_SHIFT] = False
config[KEY.USE_MODAL_WISE_SCALE] = False
for rescaler_name in rescaler_names:
rescaler_key = 'rescale_atomic_energy.' + rescaler_name
rescaler = model_state_dct[rescaler_key][ref_modal_index]
model_state_dct.update({rescaler_key: rescaler})
config.update({rescaler_name: rescaler})
config[KEY.USE_MODALITY] = False
return model_state_dct
def append_modality_to_model_dct(
model_state_dct: dict,
config: dict,
orig_num_modal: int,
append_modal_length: int,
):
"""
Append modal-wise parameters to the original linear layers.
This enables expanding modal to single/multi modal model checkpoint.
`model_state_dct`: model state dictionary from multimodal checkpoint file
`config`: dictionary containing configuration of the checkpoint model
+ modality appended
`orig_num_modal`: Number of modality used in original checkpoint
`append_modal_length`: Number of modality to be appended in new checkpoint.
"""
config_num_modal = config[KEY.NUM_MODALITIES]
config.update({KEY.NUM_MODALITIES: orig_num_modal, KEY.USE_MODALITY: True})
model = build_E3_equivariant_model(config)
for module_key in model._modules.keys():
for (
use_modal_module_key,
modal_module_name,
) in modal_module_dict.items():
if (
config[use_modal_module_key]
and modal_module_name in module_key
): # this module is used for giving modality
irreps_in = model.get_irreps_in(
module_key, 'irreps_in'
)
# TODO: directly using "irreps_in" might not be compatible
# when changing `nn/linear.py`
irreps_out = model.get_irreps_in(module_key, 'irreps_out')
irreps_in, irreps_out = Irreps(irreps_in), Irreps(irreps_out)
append_weight = _append_modal_weight(
model_state_dct,
module_key,
irreps_in, # type: ignore
irreps_out, # type: ignore
append_modal_length,
)
model_state_dct[module_key + '.linear.weight'] = append_weight
config[KEY.NUM_MODALITIES] = config_num_modal
return model_state_dct
import os
from datetime import datetime
from typing import Optional
import e3nn.util.jit
import torch
import torch.nn
from ase.data import chemical_symbols
import sevenn._keys as KEY
from sevenn import __version__
from sevenn.model_build import build_E3_equivariant_model
from sevenn.util import load_checkpoint
def deploy(checkpoint, fname='deployed_serial.pt', modal: Optional[str] = None):
"""
This method is messy to avoid changes in pair_e3gnn.cpp, while
refactoring python part.
If changes the behavior, and accordingly pair_e3gnn.cpp,
we have to recompile LAMMPS (which I always want to procrastinate)
"""
from sevenn.nn.edge_embedding import EdgePreprocess
from sevenn.nn.force_output import ForceStressOutput
cp = load_checkpoint(checkpoint)
model, config = cp.build_model('e3nn'), cp.config
model.prepand_module('edge_preprocess', EdgePreprocess(True))
grad_module = ForceStressOutput()
model.replace_module('force_output', grad_module)
new_grad_key = grad_module.get_grad_key()
model.key_grad = new_grad_key
if hasattr(model, 'eval_type_map'):
setattr(model, 'eval_type_map', False)
if modal:
model.prepare_modal_deploy(modal)
elif model.modal_map is not None and len(model.modal_map) >= 1:
raise ValueError(
f'Modal is not given. It has: {list(model.modal_map.keys())}'
)
model.set_is_batch_data(False)
model.eval()
model = e3nn.util.jit.script(model)
model = torch.jit.freeze(model)
# make some config need for md
md_configs = {}
type_map = config[KEY.TYPE_MAP]
chem_list = ''
for Z in type_map.keys():
chem_list += chemical_symbols[Z] + ' '
chem_list.strip()
md_configs.update({'chemical_symbols_to_index': chem_list})
md_configs.update({'cutoff': str(config[KEY.CUTOFF])})
md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])})
md_configs.update(
{'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')}
)
md_configs.update({'version': __version__})
md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')})
md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')})
if fname.endswith('.pt') is False:
fname += '.pt'
torch.jit.save(model, fname, _extra_files=md_configs)
# TODO: build model only once
def deploy_parallel(
checkpoint, fname='deployed_parallel', modal: Optional[str] = None
):
# Additional layer for ghost atom (and copy parameters from original)
GHOST_LAYERS_KEYS = ['onehot_to_feature_x', '0_self_interaction_1']
cp = load_checkpoint(checkpoint)
model, config = cp.build_model('e3nn'), cp.config
config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': False}
model_state_dct = model.state_dict()
model_list = build_E3_equivariant_model(config, parallel=True)
dct_temp = {}
copy_counter = {gk: 0 for gk in GHOST_LAYERS_KEYS}
for ghost_layer_key in GHOST_LAYERS_KEYS:
for key, val in model_state_dct.items():
if not key.startswith(ghost_layer_key):
continue
dct_temp.update({f'ghost_{key}': val})
copy_counter[ghost_layer_key] += 1
# Ensure reference weights are copied from state dict
assert all(x > 0 for x in copy_counter.values())
model_state_dct.update(dct_temp)
for model_part in model_list:
missing, _ = model_part.load_state_dict(model_state_dct, strict=False)
if hasattr(model_part, 'eval_type_map'):
setattr(model_part, 'eval_type_map', False)
# Ensure all values are inserted
assert len(missing) == 0, missing
if modal:
model_list[0].prepare_modal_deploy(modal)
elif model_list[0].modal_map is not None:
raise ValueError(
f'Modal is not given. It has: {list(model_list[0].modal_map.keys())}'
)
# prepare some extra information for MD
md_configs = {}
type_map = config[KEY.TYPE_MAP]
chem_list = ''
for Z in type_map.keys():
chem_list += chemical_symbols[Z] + ' '
chem_list.strip()
comm_size = max(
[
seg._modules[f'{t}_convolution']._comm_size # type: ignore
for t, seg in enumerate(model_list)
]
)
md_configs.update({'chemical_symbols_to_index': chem_list})
md_configs.update({'cutoff': str(config[KEY.CUTOFF])})
md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])})
md_configs.update({'comm_size': str(comm_size)})
md_configs.update(
{'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')}
)
md_configs.update({'version': __version__})
md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')})
md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')})
os.makedirs(fname)
for idx, model in enumerate(model_list):
fname_full = f'{fname}/deployed_parallel_{idx}.pt'
model.set_is_batch_data(False)
model.eval()
model = e3nn.util.jit.script(model)
model = torch.jit.freeze(model)
torch.jit.save(model, fname_full, _extra_files=md_configs)
import os
from typing import List, Optional
from sevenn.logger import Logger
from sevenn.train.dataset import AtomGraphDataset
from sevenn.util import unique_filepath
def build_sevennet_graph_dataset(
source: List[str],
cutoff: float,
num_cores: int,
out: str,
filename: str,
metadata: Optional[dict] = None,
**fmt_kwargs,
):
from sevenn.train.graph_dataset import SevenNetGraphDataset
log = Logger()
if metadata is None:
metadata = {}
log.timer_start('graph_build')
db = SevenNetGraphDataset(
cutoff=cutoff,
root=out,
files=source,
processed_name=filename,
process_num_cores=num_cores,
**fmt_kwargs,
)
log.timer_end('graph_build', 'graph build time')
log.writeline(f'Graph saved: {db.processed_paths[0]}')
log.bar()
for k, v in metadata.items():
log.format_k_v(k, v, write=True)
log.bar()
log.writeline('Distribution:')
log.statistic_write(db.statistics)
log.format_k_v('# atoms (node)', db.natoms, write=True)
log.format_k_v('# structures (graph)', len(db), write=True)
def dataset_finalize(dataset, metadata, out):
"""
Deprecated
"""
natoms = dataset.get_natoms()
species = dataset.get_species()
metadata = {
**metadata,
'natoms': natoms,
'species': species,
}
dataset.meta = metadata
if os.path.isdir(out):
out = os.path.join(out, 'graph_built.sevenn_data')
elif out.endswith('.sevenn_data') is False:
out = out + '.sevenn_data'
out = unique_filepath(out)
log = Logger()
log.writeline('The metadata of the dataset is...')
for k, v in metadata.items():
log.format_k_v(k, v, write=True)
dataset.save(out)
log.writeline(f'dataset is saved to {out}')
return dataset
def build_script(
source: List[str],
cutoff: float,
num_cores: int,
out: str,
metadata: Optional[dict] = None,
**fmt_kwargs,
):
"""
Deprecated
"""
from sevenn.train.dataload import file_to_dataset, match_reader
if metadata is None:
metadata = {}
log = Logger()
dataset = AtomGraphDataset({}, cutoff)
common_args = {
'cutoff': cutoff,
'cores': num_cores,
'label': 'graph_build',
}
log.timer_start('graph_build')
for path in source:
if os.path.isdir(path):
continue
log.writeline(f'Read: {path}')
basename = os.path.basename(path)
if 'structure_list' in basename:
fmt = 'structure_list'
else:
fmt = 'ase'
reader, rmeta = match_reader(fmt, **fmt_kwargs)
metadata.update(**rmeta)
dataset.augment(
file_to_dataset(
file=path,
reader=reader,
**common_args,
)
)
log.timer_end('graph_build', 'graph build time')
dataset_finalize(dataset, metadata, out)
import csv
import os
from typing import Iterable, List, Optional, Union
import numpy as np
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import sevenn._keys as KEY
import sevenn.util as util
from sevenn.atom_graph_data import AtomGraphData
from sevenn.train.graph_dataset import SevenNetGraphDataset
from sevenn.train.modal_dataset import SevenNetMultiModalDataset
def write_inference_csv(output_list, out):
for i, output in enumerate(output_list):
output = output.fit_dimension()
output[KEY.STRESS] = output[KEY.STRESS] * 1602.1766208
output[KEY.PRED_STRESS] = output[KEY.PRED_STRESS] * 1602.1766208
output_list[i] = output.to_numpy_dict()
per_graph_keys = [
KEY.NUM_ATOMS,
KEY.USER_LABEL,
KEY.ENERGY,
KEY.PRED_TOTAL_ENERGY,
KEY.STRESS,
KEY.PRED_STRESS,
]
per_atom_keys = [
KEY.ATOMIC_NUMBERS,
KEY.ATOMIC_ENERGY,
KEY.POS,
KEY.FORCE,
KEY.PRED_FORCE,
]
def unfold_dct_val(dct, keys, suffix_list=None):
res = {}
if suffix_list is None:
suffix_list = range(100)
for k in keys:
if k not in dct:
res[k] = '-'
elif isinstance(dct[k], np.ndarray) and dct[k].ndim != 0:
res.update(
{f'{k}_{suffix_list[i]}': v for i, v in enumerate(dct[k])}
)
else:
res[k] = dct[k]
return res
def per_atom_dct_list(dct, keys):
sfx_list = ['x', 'y', 'z']
res = []
natoms = dct[KEY.NUM_ATOMS]
extracted = {k: dct[k] for k in keys}
for i in range(natoms):
raw = {}
raw.update({k: v[i] for k, v in extracted.items()})
per_atom_dct = unfold_dct_val(raw, keys, suffix_list=sfx_list)
res.append(per_atom_dct)
return res
try:
with open(f'{out}/info.csv', 'w', newline='') as f:
header = output_list[0][KEY.INFO].keys()
writer = csv.DictWriter(f, fieldnames=header)
writer.writeheader()
for output in output_list:
writer.writerow(output[KEY.INFO])
except (KeyError, TypeError, AttributeError, csv.Error) as e:
print(e)
print('failed to write meta data, info.csv is not written')
with open(f'{out}/per_graph.csv', 'w', newline='') as f:
sfx_list = ['xx', 'yy', 'zz', 'xy', 'yz', 'zx'] # for stress
writer = None
for output in output_list:
cell_dct = {KEY.CELL: output[KEY.CELL]}
cell_dct = unfold_dct_val(cell_dct, [KEY.CELL], ['a', 'b', 'c'])
data = {
**unfold_dct_val(output, per_graph_keys, sfx_list),
**cell_dct,
}
if writer is None:
writer = csv.DictWriter(f, fieldnames=data.keys())
writer.writeheader()
writer.writerow(data)
with open(f'{out}/per_atom.csv', 'w', newline='') as f:
writer = None
for i, output in enumerate(output_list):
list_of_dct = per_atom_dct_list(output, per_atom_keys)
for j, dct in enumerate(list_of_dct):
idx_dct = {'stct_id': i, 'atom_id': j}
data = {**idx_dct, **dct}
if writer is None:
writer = csv.DictWriter(f, fieldnames=data.keys())
writer.writeheader()
writer.writerow(data)
def _patch_data_info(
graph_list: Iterable[AtomGraphData], full_file_list: List[str]
) -> None:
keys = set()
for graph, path in zip(graph_list, full_file_list):
if KEY.INFO not in graph:
graph[KEY.INFO] = {}
graph[KEY.INFO].update({'file': os.path.abspath(path)})
keys.update(graph[KEY.INFO].keys())
# save only safe subset of info (for batching)
for graph in graph_list:
info_dict = graph[KEY.INFO]
info_dict.update({k: '' for k in keys if k not in info_dict})
def inference(
checkpoint: str,
targets: Union[str, List[str]],
output_dir: str,
num_workers: int = 1,
device: str = 'cpu',
batch_size: int = 4,
save_graph: bool = False,
allow_unlabeled: bool = False,
modal: Optional[str] = None,
**data_kwargs,
) -> None:
"""
Inference model on the target dataset, writes
per_graph, per_atom inference results in csv format
to the output_dir
If a given target doesn't have EFS key, it puts dummy
values.
Args:
checkpoint: model checkpoint path,
target: path, or list of path to evaluate. Supports
ASE readable, sevenn_data/*.pt, .sevenn_data, and
structure_list
output_dir: directory to write results
num_workers: number of workers to build graph
device: device to evaluate, defaults to 'auto'
batch_size: batch size for inference
save_grpah: if True, save preprocessed graph to output dir
data_kwargs: keyword arguments used when reading targets,
for example, given index='-1', only the last snapshot
will be evaluated if it was ASE readable.
While this function can handle different types of targets
at once, it will not work smoothly with data_kwargs
"""
model, _ = util.model_from_checkpoint(checkpoint)
cutoff = model.cutoff
if modal:
if model.modal_map is None:
raise ValueError('Modality given, but model has no modal_map')
if modal not in model.modal_map:
_modals = list(model.modal_map.keys())
raise ValueError(f'Unknown modal {modal} (not in {_modals})')
if isinstance(targets, str):
targets = [targets]
full_file_list = []
if save_graph:
dataset = SevenNetGraphDataset(
cutoff=cutoff,
root=output_dir,
files=targets,
process_num_cores=num_workers,
processed_name='saved_graph.pt',
**data_kwargs,
)
full_file_list = dataset.full_file_list # TODO: not used currently
else:
dataset = []
for file in targets:
tmplist = SevenNetGraphDataset.file_to_graph_list(
file,
cutoff=cutoff,
num_cores=num_workers,
allow_unlabeled=allow_unlabeled,
**data_kwargs,
)
dataset.extend(tmplist)
full_file_list.extend([os.path.abspath(file)] * len(tmplist))
if (
full_file_list is not None
and len(full_file_list) == len(dataset)
and not isinstance(dataset, SevenNetGraphDataset)
):
_patch_data_info(dataset, full_file_list) # type: ignore
if modal:
dataset = SevenNetMultiModalDataset({modal: dataset}) # type: ignore
loader = DataLoader(dataset, batch_size, shuffle=False) # type: ignore
model.to(device)
model.set_is_batch_data(True)
model.eval()
rec = util.get_error_recorder()
output_list = []
for batch in tqdm(loader):
batch = batch.to(device)
output = model(batch).detach().cpu()
rec.update(output)
output_list.extend(util.to_atom_graph_list(output))
errors = rec.epoch_forward()
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(os.path.join(output_dir, 'errors.txt'), 'w', encoding='utf-8') as f:
for key, val in errors.items():
f.write(f'{key}: {val}\n')
write_inference_csv(output_list, output_dir)
import os
import warnings
import torch
import sevenn._keys as KEY
import sevenn.util as util
from sevenn.logger import Logger
from sevenn.scripts.convert_model_modality import (
append_modality_to_model_dct,
get_single_modal_model_dct,
)
def processing_continue_v2(config): # simpler
"""
Replacement of processing_continue,
Skips model compatibility
"""
log = Logger()
continue_dct = config[KEY.CONTINUE]
log.write('\nContinue found, loading checkpoint\n')
checkpoint = util.load_checkpoint(continue_dct[KEY.CHECKPOINT])
model_cp = checkpoint.build_model()
config_cp = checkpoint.config
model_state_dict_cp = model_cp.state_dict()
optimizer_state_dict_cp = (
checkpoint.optimizer_state_dict
if not continue_dct[KEY.RESET_OPTIMIZER]
else None
)
scheduler_state_dict_cp = (
checkpoint.scheduler_state_dict
if not continue_dct[KEY.RESET_SCHEDULER]
else None
)
# use_statistic_value_of_checkpoint always True
# Overwrite config from model state dict, so graph_dataset.from_config
# will not put statistic values to shift, scale, and conv_denominator
config[KEY.SHIFT] = model_state_dict_cp['rescale_atomic_energy.shift'].tolist()
config[KEY.SCALE] = model_state_dict_cp['rescale_atomic_energy.scale'].tolist()
conv_denom = []
for i in range(config_cp[KEY.NUM_CONVOLUTION]):
conv_denom.append(model_state_dict_cp[f'{i}_convolution.denominator'].item())
config[KEY.CONV_DENOMINATOR] = conv_denom
log.writeline(
f'{KEY.SHIFT}, {KEY.SCALE}, and {KEY.CONV_DENOMINATOR} are '
+ 'overwritten by model_state_dict of checkpoint'
)
chem_keys = [
KEY.TYPE_MAP,
KEY.NUM_SPECIES,
KEY.CHEMICAL_SPECIES,
KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER,
]
config.update({k: config_cp[k] for k in chem_keys})
log.writeline(
'chemical_species are overwritten by checkpoint. '
+ f'This model knows {config[KEY.NUM_SPECIES]} species'
)
if config_cp.get(KEY.USE_MODALITY, False) != config.get(KEY.USE_MODALITY):
raise ValueError('use_modality is not same. Check sevenn_cp')
modal_map = config_cp.get(KEY.MODAL_MAP, None) # dict | None
if modal_map and len(modal_map) > 0:
modalities = list(modal_map.keys())
log.writeline(f'Multimodal model found: {modalities}')
log.writeline('use_modality: True')
config[KEY.USE_MODALITY] = True
from_epoch = checkpoint.epoch or 0
log.writeline(f'Checkpoint previous epoch was: {from_epoch}')
epoch = 1 if continue_dct[KEY.RESET_EPOCH] else from_epoch + 1
log.writeline(f'epoch start from {epoch}')
log.writeline('checkpoint loading successful')
state_dicts = [
model_state_dict_cp,
optimizer_state_dict_cp,
scheduler_state_dict_cp,
]
return state_dicts, epoch
def check_config_compatible(config, config_cp):
# TODO: check more
SHOULD_BE_SAME = [
KEY.NODE_FEATURE_MULTIPLICITY,
KEY.LMAX,
KEY.IS_PARITY,
KEY.RADIAL_BASIS,
KEY.CUTOFF_FUNCTION,
KEY.CUTOFF,
KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS,
KEY.NUM_CONVOLUTION,
KEY.USE_BIAS_IN_LINEAR,
KEY.SELF_CONNECTION_TYPE,
]
for sbs in SHOULD_BE_SAME:
if config[sbs] == config_cp[sbs]:
continue
if sbs == KEY.SELF_CONNECTION_TYPE and config_cp[sbs] == 'MACE':
warnings.warn(
'We do not support this version of checkpoints to continue '
"Please use self_connection_type='linear' in input.yaml "
'and train from scratch',
UserWarning,
)
raise ValueError(
f'Value of {sbs} should be same. {config[sbs]} != {config_cp[sbs]}'
)
try:
cntdct = config[KEY.CONTINUE]
except KeyError:
return
TRAINABLE_CONFIGS = [KEY.TRAIN_DENOMINTAOR, KEY.TRAIN_SHIFT_SCALE]
if (
any((not cntdct[KEY.RESET_SCHEDULER], not cntdct[KEY.RESET_OPTIMIZER]))
and all(config[k] == config_cp[k] for k in TRAINABLE_CONFIGS) is False
):
raise ValueError(
'reset optimizer and scheduler if you want to change '
+ 'trainable configs'
)
# TODO add conition for changed optim/scheduler but not reset
def processing_continue(config):
log = Logger()
continue_dct = config[KEY.CONTINUE]
log.write('\nContinue found, loading checkpoint\n')
checkpoint = torch.load(
continue_dct[KEY.CHECKPOINT], map_location='cpu', weights_only=False
)
config_cp = checkpoint['config']
model_cp, config_cp = util.model_from_checkpoint(checkpoint)
model_state_dict_cp = model_cp.state_dict()
# it will raise error if not compatible
check_config_compatible(config, config_cp)
log.write('Checkpoint config is compatible\n')
# for backward compat.
config.update({KEY._NORMALIZE_SPH: config_cp[KEY._NORMALIZE_SPH]})
from_epoch = checkpoint['epoch']
optimizer_state_dict_cp = (
checkpoint['optimizer_state_dict']
if not continue_dct[KEY.RESET_OPTIMIZER]
else None
)
scheduler_state_dict_cp = (
checkpoint['scheduler_state_dict']
if not continue_dct[KEY.RESET_SCHEDULER]
else None
)
# These could be changed based on given continue_input.yaml
# ex) adapt to statistics of fine-tuning dataset
shift_cp = model_state_dict_cp['rescale_atomic_energy.shift'].numpy()
del model_state_dict_cp['rescale_atomic_energy.shift']
scale_cp = model_state_dict_cp['rescale_atomic_energy.scale'].numpy()
del model_state_dict_cp['rescale_atomic_energy.scale']
conv_denominators = []
for i in range(config_cp[KEY.NUM_CONVOLUTION]):
conv_denominators.append(
(model_state_dict_cp[f'{i}_convolution.denominator']).item()
)
del model_state_dict_cp[f'{i}_convolution.denominator']
# Further handled by processing_dataset.py
config.update({
KEY.SHIFT + '_cp': shift_cp,
KEY.SCALE + '_cp': scale_cp,
KEY.CONV_DENOMINATOR + '_cp': conv_denominators,
})
chem_keys = [
KEY.TYPE_MAP,
KEY.NUM_SPECIES,
KEY.CHEMICAL_SPECIES,
KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER,
]
config.update({k: config_cp[k] for k in chem_keys})
if (
KEY.USE_MODALITY in config_cp.keys() and config_cp[KEY.USE_MODALITY]
): # checkpoint model is multimodal
config.update({
KEY.MODAL_MAP + '_cp': config_cp[KEY.MODAL_MAP],
KEY.USE_MODALITY + '_cp': True,
KEY.NUM_MODALITIES + '_cp': len(config_cp[KEY.MODAL_MAP]),
})
else:
config.update({
KEY.MODAL_MAP + '_cp': {},
KEY.USE_MODALITY + '_cp': False,
KEY.NUM_MODALITIES + '_cp': 0,
})
log.write(f'checkpoint previous epoch was: {from_epoch}\n')
# decide start epoch
reset_epoch = continue_dct[KEY.RESET_EPOCH]
if reset_epoch:
start_epoch = 1
log.write('epoch reset to 1\n')
else:
start_epoch = from_epoch + 1
log.write(f'epoch start from {start_epoch}\n')
# decide csv file to continue
init_csv = True
csv_fname = config_cp[KEY.CSV_LOG]
if os.path.isfile(csv_fname):
# I hope python compare dict well
if config_cp[KEY.ERROR_RECORD] == config[KEY.ERROR_RECORD]:
log.writeline('Same metric, csv file will be appended')
init_csv = False
else:
log.writeline(f'{csv_fname} file not found, new csv file will be created')
log.writeline('checkpoint loading was successful')
state_dicts = [
model_state_dict_cp,
optimizer_state_dict_cp,
scheduler_state_dict_cp,
]
return state_dicts, start_epoch, init_csv
def convert_modality_of_checkpoint_state_dct(config, state_dicts):
# TODO: this requires updating model state dict after seeing dataset
model_state_dict_cp, optimizer_state_dict_cp, scheduler_state_dict_cp = (
state_dicts
)
if config[KEY.USE_MODALITY]: # current model is multimodal
num_modalities_cp = len(config[KEY.MODAL_MAP + '_cp'])
append_modal_length = config[KEY.NUM_MODALITIES] - num_modalities_cp
model_state_dict_cp = append_modality_to_model_dct(
model_state_dict_cp, config, num_modalities_cp, append_modal_length
)
else: # current model is single modal
if config[KEY.USE_MODALITY + '_cp']: # checkpoint model is multimodal
# change model state dict to single modal, default = "common"
model_state_dict_cp = get_single_modal_model_dct(
model_state_dict_cp,
config,
config[KEY.DEFAULT_MODAL],
from_processing_cp=True,
)
state_dicts = (
model_state_dict_cp,
optimizer_state_dict_cp,
scheduler_state_dict_cp,
)
return state_dicts
import os
import torch
import torch.distributed as dist
import sevenn._const as CONST
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.train.dataload import file_to_dataset, match_reader
from sevenn.train.dataset import AtomGraphDataset
from sevenn.util import chemical_species_preprocess, onehot_to_chem
def dataset_load(file: str, config):
"""
Wrapping of dataload.file_to_dataset to suppert
graph prebuilt sevenn_data
"""
log = Logger()
log.write(f'Loading {file}\n')
log.timer_start('loading dataset')
if file.endswith('.sevenn_data'):
dataset = torch.load(file, map_location='cpu', weights_only=False)
else:
reader, _ = match_reader(
config[KEY.DATA_FORMAT], **config[KEY.DATA_FORMAT_ARGS]
)
dataset = file_to_dataset(
file,
config[KEY.CUTOFF],
config[KEY.PREPROCESS_NUM_CORES],
reader=reader,
use_modality=config[KEY.USE_MODALITY],
use_weight=config[KEY.USE_WEIGHT],
)
log.format_k_v('loaded dataset size is', dataset.len(), write=True)
log.timer_end('loading dataset', 'data set loading time')
return dataset
def calculate_shift_or_scale_from_key(
train_set: AtomGraphDataset, key_given, n_chem
):
_expand = True
use_species_wise_shift_scale = False
if key_given == 'per_atom_energy_mean':
shift_or_scale = train_set.get_per_atom_energy_mean()
elif key_given == 'elemwise_reference_energies':
shift_or_scale = train_set.get_species_ref_energy_by_linear_comb(n_chem)
_expand = False
use_species_wise_shift_scale = True
elif key_given == 'force_rms':
shift_or_scale = train_set.get_force_rms()
elif key_given == 'per_atom_energy_std':
shift_or_scale = train_set.get_statistics(KEY.PER_ATOM_ENERGY)['Total'][
'std'
]
elif key_given == 'elemwise_force_rms':
shift_or_scale = train_set.get_species_wise_force_rms(n_chem)
_expand = False
use_species_wise_shift_scale = True
return shift_or_scale, _expand, use_species_wise_shift_scale
def handle_shift_scale(config, train_set: AtomGraphDataset, checkpoint_given):
"""
Priority (first comes later to overwrite):
1. Float given in yaml
2. Use statistic values of checkpoint == True
3. Plain options (provided as string)
"""
log = Logger()
shift, scale, conv_denominator = None, None, None
type_map = config[KEY.TYPE_MAP]
n_chem = len(type_map)
chem_strs = onehot_to_chem(list(range(n_chem)), type_map)
log.writeline('\nCalculating statistic values from dataset')
shift_given = config[KEY.SHIFT]
scale_given = config[KEY.SCALE]
_expand_shift = True
_expand_scale = True
use_species_wise_shift = False
use_species_wise_scale = False
use_modal_wise_shift = config[KEY.USE_MODAL_WISE_SHIFT]
use_modal_wise_scale = config[KEY.USE_MODAL_WISE_SCALE]
if shift_given in CONST.IMPLEMENTED_SHIFT:
shift, _expand_shift, use_species_wise_shift = (
calculate_shift_or_scale_from_key(train_set, shift_given, n_chem)
)
if scale_given in CONST.IMPLEMENTED_SCALE:
scale, _expand_scale, use_species_wise_scale = (
calculate_shift_or_scale_from_key(train_set, scale_given, n_chem)
)
if use_modal_wise_shift or use_modal_wise_scale:
atomdata_dict_sort_by_modal = train_set.get_dict_sort_by_modality()
modal_map = config[KEY.MODAL_MAP]
n_modal = len(modal_map)
cutoff = config[KEY.CUTOFF]
if use_modal_wise_shift:
shift = torch.zeros((n_modal, n_chem))
if use_modal_wise_scale:
scale = torch.zeros((n_modal, n_chem))
for modal_key, data_list in atomdata_dict_sort_by_modal.items():
modal_set = AtomGraphDataset(data_list, cutoff, x_is_one_hot_idx=True)
if use_modal_wise_shift:
if shift_given == 'elemwise_reference_energies':
modal_shift, _expand_shift, use_species_wise_shift = (
calculate_shift_or_scale_from_key(
modal_set, shift_given, n_chem
)
)
shift[modal_map[modal_key]] = torch.tensor(
modal_shift
) # this is np.array
elif shift_given in CONST.IMPLEMENTED_SHIFT:
raise NotImplementedError(
'Currently, modal-wise shift implemented for'
'species-dependent case only.'
)
if use_modal_wise_scale:
if scale_given == 'elemwise_force_rms':
modal_scale, _expand_scale, use_species_wise_scale = (
calculate_shift_or_scale_from_key(
modal_set, scale_given, n_chem
)
)
scale[modal_map[modal_key]] = modal_scale
elif scale_given in CONST.IMPLEMENTED_SCALE:
raise NotImplementedError(
'Currently, modal-wise scale implemented for'
'species-dependent case only.'
)
avg_num_neigh = train_set.get_avg_num_neigh()
log.format_k_v('Average # of neighbors', f'{avg_num_neigh:.6f}', write=True)
if config[KEY.CONV_DENOMINATOR] == 'avg_num_neigh':
conv_denominator = avg_num_neigh
elif config[KEY.CONV_DENOMINATOR] == 'sqrt_avg_num_neigh':
conv_denominator = avg_num_neigh ** (0.5)
if (
checkpoint_given
and config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT]
):
log.writeline(
'Overwrite shift, scale, conv_denominator from model checkpoint'
)
# TODO: This needs refactoring
conv_denominator = config[KEY.CONV_DENOMINATOR + '_cp']
if not (use_modal_wise_shift or use_modal_wise_scale):
# Values extracted from checkpoint in processing_continue.py
if len(list(shift)) > 1:
use_species_wise_shift = True
use_species_wise_scale = True
_expand_shift = _expand_scale = False
else:
shift = shift.item()
scale = scale.item()
else:
# Case of modal wise shift scale
shift_cp = config[KEY.SHIFT + '_cp']
scale_cp = config[KEY.SCALE + '_cp']
if not use_modal_wise_shift:
shift = shift_cp
if not use_modal_wise_scale:
scale = scale_cp
modal_map = config[KEY.MODAL_MAP]
modal_map_cp = config[KEY.MODAL_MAP + '_cp']
# Extracting shift, scale for modal in checkpoint model.
if config[KEY.USE_MODALITY + '_cp']: # cp model is multimodal
for modal_key_cp, modal_idx_cp in modal_map_cp.items():
modal_idx = modal_map[modal_key_cp]
if use_modal_wise_shift:
shift[modal_idx] = torch.tensor(shift_cp[modal_idx_cp])
if use_modal_wise_scale:
scale[modal_idx] = torch.tensor(scale_cp[modal_idx_cp])
else: # cp model is single modal
try:
modal_idx = modal_map[config[KEY.DEFAULT_MODAL]]
except:
raise KeyError(
f'{config[KEY.DEFAULT_MODAL]} should be one of'
f' {modal_map.keys()}'
)
if use_modal_wise_shift:
shift[modal_idx] = torch.tensor(shift_cp)
if use_modal_wise_scale:
scale[modal_idx] = torch.tensor(scale_cp)
if not config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY]:
# Also overwrite values of new modal to reference value
# For multimodal, set reference modal with KEY.DEFAULT_MODAL
shift_ref = shift_cp
scale_ref = scale_cp
if config[KEY.USE_MODALITY + '_cp']:
try:
modal_idx_cp = modal_map_cp[config[KEY.DEFAULT_MODAL]]
except:
raise KeyError(
f'{config[KEY.DEFAULT_MODAL]} should be one of'
f' {modal_map_cp.keys()}'
)
shift_ref = shift_cp[modal_idx_cp]
scale_ref = scale_cp[modal_idx_cp]
for modal_key, modal_idx in modal_map.items():
if modal_key not in modal_map_cp.keys():
if use_modal_wise_shift:
shift[modal_idx] = shift_ref
if use_modal_wise_scale:
scale[modal_idx] = scale_ref
# overwrite shift scale anyway if defined in yaml.
if type(shift_given) in [list, float]:
log.writeline('Overwrite shift to value(s) given in yaml')
_expand_shift = isinstance(shift_given, float)
shift = shift_given
if type(scale_given) in [list, float]:
log.writeline('Overwrite scale to value(s) given in yaml')
_expand_scale = isinstance(scale_given, float)
scale = scale_given
if isinstance(config[KEY.CONV_DENOMINATOR], float):
log.writeline('Overwrite conv_denominator to value given in yaml')
conv_denominator = config[KEY.CONV_DENOMINATOR]
if isinstance(conv_denominator, float):
conv_denominator = [conv_denominator] * config[KEY.NUM_CONVOLUTION]
use_species_wise_shift_scale = use_species_wise_shift or use_species_wise_scale
if use_species_wise_shift_scale:
chem_strs = onehot_to_chem(list(range(n_chem)), type_map)
if _expand_shift:
if use_modal_wise_shift:
shift = torch.full((n_modal, n_chem), shift)
else:
shift = [shift] * n_chem
if _expand_scale:
if use_modal_wise_scale:
scale = torch.full((n_modal, n_chem), scale)
else:
scale = [scale] * n_chem
Logger().write('Use element-wise shift, scale\n')
if use_modal_wise_shift or use_modal_wise_scale:
for modal_key, modal_idx in modal_map.items():
Logger().writeline(f'For modal = {modal_key}')
print_shift = shift[modal_idx] if use_modal_wise_shift else shift
print_scale = scale[modal_idx] if use_modal_wise_scale else scale
for cstr, sh, sc in zip(chem_strs, print_shift, print_scale):
Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True)
else:
for cstr, sh, sc in zip(chem_strs, shift, scale):
Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True)
else:
log.write('Use global shift, scale\n')
log.format_k_v('shift, scale', f'{shift:.6f}, {scale:.6f}', write=True)
assert isinstance(conv_denominator, list) and all(
isinstance(deno, float) for deno in conv_denominator
)
log.format_k_v(
'(1st) conv_denominator is', f'{conv_denominator[0]:.6f}', write=True
)
config[KEY.USE_SPECIES_WISE_SHIFT_SCALE] = use_species_wise_shift_scale
return shift, scale, conv_denominator
# TODO: This is too long
def processing_dataset(config, working_dir):
log = Logger()
prefix = f'{os.path.abspath(working_dir)}/'
is_stress = config[KEY.IS_TRAIN_STRESS]
checkpoint_given = config[KEY.CONTINUE][KEY.CHECKPOINT] is not False
cutoff = config[KEY.CUTOFF]
log.write('\nInitializing dataset...\n')
dataset = AtomGraphDataset({}, cutoff)
load_dataset = config[KEY.LOAD_DATASET]
if type(load_dataset) is str:
load_dataset = [load_dataset]
for file in load_dataset:
dataset.augment(dataset_load(file, config))
dataset.group_by_key() # apply labels inside original datapoint
dataset.unify_dtypes() # unify dtypes of all data points
# TODO: I think manual chemical species input is redundant
chem_in_db = dataset.get_species()
if config[KEY.CHEMICAL_SPECIES] == 'auto' and not checkpoint_given:
log.writeline('Auto detect chemical species from dataset')
config.update(chemical_species_preprocess(chem_in_db))
elif config[KEY.CHEMICAL_SPECIES] == 'auto' and checkpoint_given:
pass # copied from checkpoint in processing_continue.py
elif config[KEY.CHEMICAL_SPECIES] != 'auto' and not checkpoint_given:
pass # processed in parse_input.py
else: # config[KEY.CHEMICAL_SPECIES] != "auto" and checkpoint_given
log.writeline('Ignore chemical species in yaml, use checkpoint')
# already processed in processing_continue.py
# basic dataset compatibility check with previous model
if checkpoint_given:
chem_from_cp = config[KEY.CHEMICAL_SPECIES]
if not all(chem in chem_from_cp for chem in chem_in_db):
raise ValueError('Chemical species in checkpoint is not compatible')
# check what modalities are used in dataset
if config[KEY.USE_MODALITY]:
modalities = dataset.get_modalities()
num_modalities = len(modalities)
if num_modalities < 2:
Logger().writeline('Only one modal is given, ignore modality')
config.uptate({KEY.USE_MODALITY: False})
else:
modal_map_cp = config[KEY.MODAL_MAP + '_cp'] if checkpoint_given else {}
modal_map = modal_map_cp.copy()
current_idx = len(modal_map_cp)
for modal_key in modalities:
if modal_key not in modal_map.keys():
modal_map[modal_key] = current_idx
current_idx += 1
if config[KEY.IS_DDP]:
# Synchronize modal_map
torch.cuda.set_device(config[KEY.LOCAL_RANK])
modal_map_bcast = [modal_map]
dist.broadcast_object_list(modal_map_bcast, src=0)
modal_map = modal_map_bcast[0]
config.update(
{
KEY.NUM_MODALITIES: len(modal_map),
KEY.MODAL_MAP: modal_map,
KEY.MODAL_LIST: list(modal_map.keys()),
}
)
dataset.write_modal_attr(
modal_map,
config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE],
)
# --------------- save dataset regardless of train/valid--------------#
save_dataset = config[KEY.SAVE_DATASET]
save_by_label = config[KEY.SAVE_BY_LABEL]
if save_dataset:
if save_dataset.endswith('.sevenn_data') is False:
save_dataset += '.sevenn_data'
if (save_dataset.startswith('.') or save_dataset.startswith('/')) is False:
save_dataset = prefix + save_dataset # save_data set is plain file name
dataset.save(save_dataset)
log.format_k_v('Dataset saved to', save_dataset, write=True)
# log.write(f"Loaded full dataset saved to : {save_dataset}\n")
if save_by_label:
dataset.save(prefix, by_label=True)
log.format_k_v('Dataset saved by label', prefix, write=True)
# --------------------------------------------------------------------#
# TODO: testset is not used
ignore_test = not config.get(KEY.USE_TESTSET, False)
if KEY.LOAD_VALIDSET in config and config[KEY.LOAD_VALIDSET]:
train_set = dataset
test_set = AtomGraphDataset([], config[KEY.CUTOFF])
log.write('Loading validset from load_validset\n')
valid_set = AtomGraphDataset({}, cutoff)
for file in config[KEY.LOAD_VALIDSET]:
valid_set.augment(dataset_load(file, config))
valid_set.group_by_key()
valid_set.unify_dtypes()
# condition: validset labels should be subset of trainset labels
valid_labels = valid_set.user_labels
train_labels = train_set.user_labels
if set(valid_labels).issubset(set(train_labels)) is False:
valid_set = AtomGraphDataset(valid_set.to_list(), cutoff)
valid_set.rewrite_labels_to_data()
train_set = AtomGraphDataset(train_set.to_list(), cutoff)
train_set.rewrite_labels_to_data()
Logger().write('WARNING! validset labels is not subset of trainset\n')
Logger().write('We overwrite all the train, valid labels to default.\n')
Logger().write('Please create validset by sevenn_graph_build with -l\n')
Logger().write('the validset loaded, load_dataset is now train_set\n')
Logger().write('the ratio will be ignored\n')
# condition: validset modalities should be subset of trainset modalities
if config[KEY.USE_MODALITY]:
config_modality = config[KEY.MODAL_LIST]
valid_modality = valid_set.get_modalities()
if set(valid_modality).issubset(set(config_modality)) is False:
raise ValueError('validset modality is not subset of trainset')
valid_set.write_modal_attr(
config[KEY.MODAL_MAP],
config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE],
)
else:
train_set, valid_set, test_set = dataset.divide_dataset(
config[KEY.RATIO], ignore_test=ignore_test
)
log.write(f'The dataset divided into train, valid by {KEY.RATIO}\n')
log.format_k_v('\nloaded trainset size is', train_set.len(), write=True)
log.format_k_v('\nloaded validset size is', valid_set.len(), write=True)
log.write('Dataset initialization was successful\n')
log.write('\nNumber of atoms in the train_set:\n')
log.natoms_write(train_set.get_natoms(config[KEY.TYPE_MAP]))
log.bar()
log.write('Per atom energy(eV/atom) distribution:\n')
log.statistic_write(train_set.get_statistics(KEY.PER_ATOM_ENERGY))
log.bar()
log.write('Force(eV/Angstrom) distribution:\n')
log.statistic_write(train_set.get_statistics(KEY.FORCE))
log.bar()
log.write('Stress(eV/Angstrom^3) distribution:\n')
try:
log.statistic_write(train_set.get_statistics(KEY.STRESS))
except KeyError:
log.write('\n Stress is not included in the train_set\n')
if is_stress:
is_stress = False
log.write('Turn off stress training\n')
log.bar()
# saved data must have atomic numbers as X not one hot idx
if config[KEY.SAVE_BY_TRAIN_VALID]:
train_set.save(prefix + 'train')
valid_set.save(prefix + 'valid')
log.format_k_v('Dataset saved by train, valid', prefix, write=True)
# inconsistent .info dict give error when collate
_, _ = train_set.separate_info()
_, _ = valid_set.separate_info()
if train_set.x_is_one_hot_idx is False:
train_set.x_to_one_hot_idx(config[KEY.TYPE_MAP])
if valid_set.x_is_one_hot_idx is False:
valid_set.x_to_one_hot_idx(config[KEY.TYPE_MAP])
log.format_k_v('training_set size', train_set.len(), write=True)
log.format_k_v('validation_set size', valid_set.len(), write=True)
shift, scale, conv_denominator = handle_shift_scale(
config, train_set, checkpoint_given
)
config.update(
{
KEY.SHIFT: shift,
KEY.SCALE: scale,
KEY.CONV_DENOMINATOR: conv_denominator,
}
)
data_lists = (train_set.to_list(), valid_set.to_list(), test_set.to_list())
return data_lists
import os
from copy import deepcopy
from typing import Optional
import torch
from torch.utils.data.distributed import DistributedSampler
import sevenn._keys as KEY
from sevenn.error_recorder import ErrorRecorder
from sevenn.logger import Logger
from sevenn.train.trainer import Trainer
def processing_epoch_v2(
config: dict,
trainer: Trainer,
loaders: dict, # dict[str, Dataset]
start_epoch: int = 1,
train_loader_key: str = 'trainset',
error_recorder: Optional[ErrorRecorder] = None,
total_epoch: Optional[int] = None,
per_epoch: Optional[int] = None,
best_metric_loader_key: str = 'validset',
best_metric: Optional[str] = None,
write_csv: bool = True,
working_dir: Optional[str] = None,
):
from sevenn.util import unique_filepath
log = Logger()
write_csv = write_csv and log.rank == 0
working_dir = working_dir or os.getcwd()
prefix = f'{os.path.abspath(working_dir)}/'
total_epoch = total_epoch or config[KEY.EPOCH]
per_epoch = per_epoch or config.get(KEY.PER_EPOCH, 10)
best_metric = best_metric or config.get(KEY.BEST_METRIC, 'TotalLoss')
recorder = error_recorder or ErrorRecorder.from_config(
config, trainer.loss_functions
)
recorders = {k: deepcopy(recorder) for k in loaders}
best_val = float('inf')
best_key = None
if best_metric_loader_key in recorders:
best_key = recorders[best_metric_loader_key].get_key_str(best_metric)
if best_key is None:
log.writeline(
f'Failed to get error recorder key: {best_metric} or '
+ f'{best_metric_loader_key} is missing. There will be no best '
+ 'checkpoint.'
)
csv_path = unique_filepath(f'{prefix}/lc.csv')
if write_csv:
head = ['epoch', 'lr']
for k, rec in recorders.items():
head.extend(list(rec.get_dct(prefix=k)))
with open(csv_path, 'w') as f:
f.write(','.join(head) + '\n')
if start_epoch == 1:
path = f'{prefix}/checkpoint_0.pth' # save first epoch
trainer.write_checkpoint(path, config=config, epoch=0)
for epoch in range(start_epoch, total_epoch + 1): # one indexing
log.timer_start('epoch')
lr = trainer.get_lr()
log.bar()
log.write(f'Epoch {epoch}/{total_epoch} lr: {lr:8f}\n')
log.bar()
csv_dct = {'epoch': str(epoch), 'lr': f'{lr:8f}'}
errors = {}
for k, loader in loaders.items():
is_train = k == train_loader_key
if (
trainer.distributed
and isinstance(loader.sampler, DistributedSampler)
and is_train
and config.get('train_shuffle', True)
):
loader.sampler.set_epoch(epoch)
rec = recorders[k]
trainer.run_one_epoch(loader, is_train, rec)
csv_dct.update(rec.get_dct(prefix=k))
errors[k] = rec.epoch_forward()
log.write_full_table(list(errors.values()), list(errors))
trainer.scheduler_step(best_val)
if write_csv:
with open(csv_path, 'a') as f:
f.write(','.join(list(csv_dct.values())) + '\n')
if best_key and errors[best_metric_loader_key][best_key] < best_val:
path = f'{prefix}/checkpoint_best.pth'
trainer.write_checkpoint(path, config=config, epoch=epoch)
best_val = errors[best_metric_loader_key][best_key]
log.writeline('Best checkpoint written')
if epoch % per_epoch == 0:
path = f'{prefix}/checkpoint_{epoch}.pth'
trainer.write_checkpoint(path, config=config, epoch=epoch)
log.timer_end('epoch', message=f'Epoch {epoch} elapsed')
return trainer
def processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir):
log = Logger()
prefix = f'{os.path.abspath(working_dir)}/'
train_loader, valid_loader = loaders
is_distributed = config[KEY.IS_DDP]
rank = config[KEY.RANK]
total_epoch = config[KEY.EPOCH]
per_epoch = config[KEY.PER_EPOCH]
train_recorder = ErrorRecorder.from_config(config)
valid_recorder = ErrorRecorder.from_config(config)
best_metric = config[KEY.BEST_METRIC]
csv_fname = f'{prefix}{config[KEY.CSV_LOG]}'
current_best = float('inf')
if init_csv:
csv_header = ['Epoch', 'Learning_rate']
# Assume train valid have the same metrics
for metric in train_recorder.get_metric_dict().keys():
csv_header.append(f'Train_{metric}')
csv_header.append(f'Valid_{metric}')
log.init_csv(csv_fname, csv_header)
def write_checkpoint(epoch, is_best=False):
if is_distributed and rank != 0:
return
suffix = '_best' if is_best else f'_{epoch}'
checkpoint = trainer.get_checkpoint_dict()
checkpoint.update({'config': config, 'epoch': epoch})
torch.save(checkpoint, f'{prefix}/checkpoint{suffix}.pth')
fin_epoch = total_epoch + start_epoch
for epoch in range(start_epoch, fin_epoch):
lr = trainer.get_lr()
log.timer_start('epoch')
log.bar()
log.write(f'Epoch {epoch}/{fin_epoch - 1} lr: {lr:8f}\n')
log.bar()
trainer.run_one_epoch(
train_loader, is_train=True, error_recorder=train_recorder
)
train_err = train_recorder.epoch_forward()
trainer.run_one_epoch(valid_loader, error_recorder=valid_recorder)
valid_err = valid_recorder.epoch_forward()
csv_values = [epoch, lr]
for metric in train_err:
csv_values.append(train_err[metric])
csv_values.append(valid_err[metric])
log.append_csv(csv_fname, csv_values)
log.write_full_table([train_err, valid_err], ['Train', 'Valid'])
val = None
for metric in valid_err:
# loose string comparison,
# e.g. "Energy" in "TotalEnergy" or "Energy_Loss"
if best_metric in metric:
val = valid_err[metric]
break
assert val is not None, f'Metric {best_metric} not found in {valid_err}'
trainer.scheduler_step(val)
log.timer_end('epoch', message=f'Epoch {epoch} elapsed')
if val < current_best:
current_best = val
write_checkpoint(epoch, is_best=True)
log.writeline('Best checkpoint written')
if epoch % per_epoch == 0:
write_checkpoint(epoch)
from typing import List, Optional
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch_geometric.loader import DataLoader
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.model_build import build_E3_equivariant_model
from sevenn.scripts.processing_continue import (
convert_modality_of_checkpoint_state_dct,
)
from sevenn.train.trainer import Trainer
def loader_from_config(config, dataset, is_train=False):
batch_size = config[KEY.BATCH_SIZE]
shuffle = is_train and config[KEY.TRAIN_SHUFFLE]
sampler = None
loader_args = {
'dataset': dataset,
'batch_size': batch_size,
'shuffle': shuffle
}
if KEY.NUM_WORKERS in config and config[KEY.NUM_WORKERS] > 0:
loader_args.update({'num_workers': config[KEY.NUM_WORKERS]})
if config[KEY.IS_DDP]:
dist.barrier()
sampler = DistributedSampler(
dataset, dist.get_world_size(), dist.get_rank(), shuffle=shuffle
)
loader_args.update({'sampler': sampler})
loader_args.pop('shuffle') # sampler is mutually exclusive with shuffle
return DataLoader(**loader_args)
def train_v2(config, working_dir: str):
"""
Main program flow, since v0.9.6
"""
import sevenn.train.atoms_dataset as atoms_dataset
import sevenn.train.graph_dataset as graph_dataset
import sevenn.train.modal_dataset as modal_dataset
from .processing_continue import processing_continue_v2
from .processing_epoch import processing_epoch_v2
log = Logger()
log.timer_start('total')
if KEY.LOAD_TRAINSET not in config and KEY.LOAD_DATASET in config:
log.writeline('***************************************************')
log.writeline('For train_v2, please use load_trainset_path instead')
log.writeline('I will assign load_trainset as load_dataset')
log.writeline('***************************************************')
config[KEY.LOAD_TRAINSET] = config.pop(KEY.LOAD_DATASET)
# config updated
start_epoch = 1
state_dicts: Optional[List[dict]] = None
if config[KEY.CONTINUE][KEY.CHECKPOINT]:
state_dicts, start_epoch = processing_continue_v2(config)
if config.get(KEY.USE_MODALITY, False):
datasets = modal_dataset.from_config(config, working_dir)
elif config[KEY.DATASET_TYPE] == 'graph':
datasets = graph_dataset.from_config(config, working_dir)
elif config[KEY.DATASET_TYPE] == 'atoms':
datasets = atoms_dataset.from_config(config, working_dir)
else:
raise ValueError(f'Unknown dataset type: {config[KEY.DATASET_TYPE]}')
loaders = {
k: loader_from_config(config, v, is_train=(k == 'trainset'))
for k, v in datasets.items()
}
log.write('\nModel building...\n')
model = build_E3_equivariant_model(config)
log.print_model_info(model, config)
trainer = Trainer.from_config(model, config)
if state_dicts:
trainer.load_state_dicts(*state_dicts, strict=False)
processing_epoch_v2(
config, trainer, loaders, start_epoch, working_dir=working_dir
)
log.timer_end('total', message='Total wall time')
def train(config, working_dir: str):
"""
Main program flow, until v0.9.5
"""
from .processing_continue import processing_continue
from .processing_dataset import processing_dataset
from .processing_epoch import processing_epoch
log = Logger()
log.timer_start('total')
# config updated
state_dicts: Optional[List[dict]] = None
if config[KEY.CONTINUE][KEY.CHECKPOINT]:
state_dicts, start_epoch, init_csv = processing_continue(config)
else:
start_epoch, init_csv = 1, True
# config updated
train, valid, _ = processing_dataset(config, working_dir)
datasets = {'dataset': train, 'validset': valid}
loaders = {
k: loader_from_config(config, v, is_train=(k == 'dataset'))
for k, v in datasets.items()
}
loaders = list(loaders.values())
log.write('\nModel building...\n')
model = build_E3_equivariant_model(config)
log.write('Model building was successful\n')
trainer = Trainer.from_config(model, config)
if state_dicts:
state_dicts = convert_modality_of_checkpoint_state_dct(
config, state_dicts
)
trainer.load_state_dicts(*state_dicts, strict=False)
log.print_model_info(model, config)
Logger().write('Trainer initialized, ready to training\n')
Logger().bar()
log.write('Trainer initialized, ready to training\n')
log.bar()
processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir)
log.timer_end('total', message='Total wall time')
import os
import random
import warnings
from collections import Counter
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch.utils.data
from ase.atoms import Atoms
from ase.data import chemical_symbols
from ase.io import write
from tqdm import tqdm
import sevenn._keys as KEY
import sevenn.train.dataload as dataload
import sevenn.util as util
from sevenn._const import NUM_UNIV_ELEMENT
from sevenn.atom_graph_data import AtomGraphData
_warn_avg_num_neigh = """SevenNetAtomsDataset does not provide correct avg_num_neigh
as it does not build graph. We will compute only random 10000 structures graph to
approximate this value. If you want more precise avg_num_neigh,
use SevenNetGraphDataset. If it is not viable due to memory limit, you
need online algorithm to do this , which is not yet implemented in the SevenNet"""
class SevenNetAtomsDataset(torch.utils.data.Dataset):
"""
Args:
cutoff: edge cutoff of given AtomGraphData
files: list of filenames or dict describing how to parse the file
ASE readable (with proper extension), structure_list, .sevenn_data,
dict containing file_list (see dict_reader of train/dataload.py)
info_dict_copy_keys: patch these keys from KEY.INFO to graph when accessing.
default is KEY.DATA_WEIGHT and KEY.DATA_MODALITY, which may accessed
while training.
**process_kwargs: keyword arguments that will be passed into ase.io.read
"""
def __init__(
self,
cutoff: float,
files: Union[str, List[str]],
atoms_filter: Optional[Callable] = None,
atoms_transform: Optional[Callable] = None,
transform: Optional[Callable] = None,
use_data_weight: bool = False,
**process_kwargs,
):
self.cutoff = cutoff
if isinstance(files, str):
files = [files] # user convenience
files = [os.path.abspath(file) for file in files]
self._files = files
self.atoms_filter = atoms_filter
self.atoms_transform = atoms_transform
self.transform = transform
self.use_data_weight = use_data_weight
self._scanned = False
self._avg_num_neigh_approx = None
self.statistics = {}
atoms_list = []
for file in files:
atoms_list.extend(
SevenNetAtomsDataset.file_to_atoms_list(file, **process_kwargs)
)
self._atoms_list = atoms_list
super().__init__()
@staticmethod
def file_to_atoms_list(file: Union[str, dict], **kwargs) -> List[Atoms]:
if isinstance(file, dict):
atoms_list = dataload.dict_reader(file)
elif 'structure_list' in file:
atoms_dct = dataload.structure_list_reader(file)
atoms_list = []
for lst in atoms_dct.values():
atoms_list.extend(lst)
else:
atoms_list = dataload.ase_reader(file, **kwargs)
return atoms_list
def save(self, path):
# Save atoms list as extxyz
write(path, self._atoms_list, format='extxyz')
def _graph_build(self, atoms):
return dataload.atoms_to_graph(
atoms, self.cutoff, transfer_info=False, y_from_calc=False
)
def __len__(self):
return len(self._atoms_list)
def __getitem__(self, index):
atoms = self._atoms_list[index]
if self.atoms_transform is not None:
atoms = self.atoms_transform(atoms)
graph = self._graph_build(atoms)
if self.transform is not None:
graph = self.transform(graph)
if self.use_data_weight:
weight = graph[KEY.INFO].pop(
KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0}
)
graph[KEY.DATA_WEIGHT] = weight
return AtomGraphData.from_numpy_dict(graph)
@property
def species(self):
self.run_stat()
return [z for z in self.statistics['_natoms'].keys() if z != 'total']
@property
def natoms(self):
self.run_stat()
return self.statistics['_natoms']
@property
def per_atom_energy_mean(self):
self.run_stat()
return self.statistics[KEY.PER_ATOM_ENERGY]['mean']
@property
def elemwise_reference_energies(self):
from sklearn.linear_model import Ridge
c = self.statistics['_composition']
y = self.statistics[KEY.ENERGY]['_array']
zero_indices = np.all(c == 0, axis=0)
c_reduced = c[:, ~zero_indices]
# will not 100% reproduce, as it is sorted by Z
# train/dataset.py was sorted by alphabets of chemical species
coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_
full_coeff = np.zeros(NUM_UNIV_ELEMENT)
full_coeff[~zero_indices] = coef_reduced
return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy
@property
def force_rms(self):
self.run_stat()
mean = self.statistics[KEY.FORCE]['mean']
std = self.statistics[KEY.FORCE]['std']
return float((mean**2 + std**2) ** (0.5))
@property
def per_atom_energy_std(self):
self.run_stat()
return self.statistics['per_atom_energy']['std']
@property
def avg_num_neigh(self, n_sample=10000):
if self._avg_num_neigh_approx is None:
if len(self) > n_sample:
warnings.warn(_warn_avg_num_neigh)
n_sample = min(len(self), n_sample)
indices = random.sample(range(len(self)), n_sample)
n_neigh = []
for i in indices:
graph = self[i]
_, nn = np.unique(graph[KEY.EDGE_IDX][0], return_counts=True)
n_neigh.append(nn)
n_neigh = np.concatenate(n_neigh)
self._avg_num_neigh_approx = np.mean(n_neigh)
return self._avg_num_neigh_approx
@property
def sqrt_avg_num_neigh(self):
self.run_stat()
return self.avg_num_neigh**0.5
def run_stat(self):
"""
Loop over dataset and init any statistics might need
Unlink SevenNetGraphDataset, neighbors count is not computed as
it requires to build graph
"""
if self._scanned is True:
return # statistics already computed
y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS]
natoms_counter = Counter()
composition = np.zeros((len(self), NUM_UNIV_ELEMENT))
stats: Dict[str, Dict[str, Any]] = {y: {'_array': []} for y in y_keys}
for i, atoms in tqdm(
enumerate(self._atoms_list), desc='run_stat', total=len(self)
):
z = atoms.get_atomic_numbers()
natoms_counter.update(z.tolist())
composition[i] = np.bincount(z, minlength=NUM_UNIV_ELEMENT)
for y, dct in stats.items():
if y == KEY.ENERGY:
dct['_array'].append(atoms.info['y_energy'])
elif y == KEY.PER_ATOM_ENERGY:
dct['_array'].append(atoms.info['y_energy'] / len(atoms))
elif y == KEY.FORCE:
dct['_array'].append(atoms.arrays['y_force'].reshape(-1))
elif y == KEY.STRESS:
dct['_array'].append(atoms.info['y_stress'].reshape(-1))
for y, dct in stats.items():
if y == KEY.FORCE:
array = np.concatenate(dct['_array'])
else:
array = np.array(dct['_array']).reshape(-1)
dct.update(
{
'mean': float(np.mean(array)),
'std': float(np.std(array)),
'median': float(np.quantile(array, q=0.5)),
'max': float(np.max(array)),
'min': float(np.min(array)),
'_array': array,
}
)
natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()}
natoms['total'] = sum(list(natoms.values()))
self.statistics.update(
{
'_composition': composition,
'_natoms': natoms,
**stats,
}
)
self._scanned = True
# script, return dict of SevenNetAtomsDataset
def from_config(
config: Dict[str, Any],
working_dir: str = os.getcwd(),
dataset_keys: Optional[List[str]] = None,
):
from sevenn.logger import Logger
log = Logger()
if dataset_keys is None:
dataset_keys = []
for k in config:
if k.startswith('load_') and k.endswith('_path'):
dataset_keys.append(k)
if KEY.LOAD_TRAINSET not in dataset_keys:
raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config')
# initialize arguments for loading dataset
dataset_args = {
'cutoff': config[KEY.CUTOFF],
'use_data_weight': config.get(KEY.USE_WEIGHT, False),
**config[KEY.DATA_FORMAT_ARGS],
}
datasets = {}
for dk in dataset_keys:
if not (paths := config[dk]):
continue
if isinstance(paths, str):
paths = [paths]
name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]])
dataset_args.update({'files': paths})
datasets[name] = SevenNetAtomsDataset(**dataset_args)
if not config[KEY.COMPUTE_STATISTICS]:
log.writeline(
(
'Computing statistics is skipped, note that if any of other'
'configurations requires statistics (shift, scale, avg_num_neigh,'
'chemical_species as auto), SevenNet eventually raise an error!'
)
)
return datasets
train_set = datasets['trainset']
chem_species = set(train_set.species)
# print statistics of each dataset
for name, dataset in datasets.items():
dataset.run_stat()
log.bar()
log.writeline(f'{name} distribution:')
log.statistic_write(dataset.statistics)
log.format_k_v('# atoms (node)', dataset.natoms, write=True)
log.format_k_v('# structures (graph)', len(dataset), write=True)
chem_species.update(dataset.species)
log.bar()
# initialize known species from dataset if 'auto'
# sorted to alphabetical order (which is same as before)
chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP]
if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py
log.writeline('Known species are obtained from the dataset')
config.update(util.chemical_species_preprocess(sorted(list(chem_species))))
# retrieve shift, scale, conv_denominaotrs from user input (keyword)
init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR]
for k in init_from_stats:
input = config[k] # statistic key or numbers
# If it is not 'str', 1: It is 'continue' training
# 2: User manually inserted numbers
if isinstance(input, str) and hasattr(train_set, input):
var = getattr(train_set, input)
config.update({k: var})
log.writeline(f'{k} is obtained from statistics')
elif isinstance(input, str) and not hasattr(train_set, input):
raise NotImplementedError(input)
return datasets
from typing import Any, List, Optional, Sequence
from ase.atoms import Atoms
from torch_geometric.loader.dataloader import Collater
from sevenn.atom_graph_data import AtomGraphData
from .dataload import atoms_to_graph
class AtomsToGraphCollater(Collater):
def __init__(
self,
dataset: Sequence[Atoms],
cutoff: float,
transfer_info: bool = False,
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
y_from_calc: bool = True,
):
# quite original collator's type mismatch with []
super().__init__([], follow_batch, exclude_keys)
self.dataset = dataset
self.cutoff = cutoff
self.transfer_info = transfer_info
self.y_from_calc = y_from_calc
def __call__(self, batch: List[Any]) -> Any:
# build list of graph
graph_list = []
for stct in batch:
graph = atoms_to_graph(
stct,
self.cutoff,
transfer_info=self.transfer_info,
y_from_calc=self.y_from_calc,
)
graph = AtomGraphData.from_numpy_dict(graph)
graph_list.append(graph)
return super().__call__(graph_list)
import copy
import os.path
from functools import partial
from itertools import chain, islice
from typing import Callable, Dict, List, Optional
import ase
import ase.io
import numpy as np
import torch.multiprocessing as mp
from ase.io.vasp_parsers.vasp_outcar_parsers import (
Cell,
DefaultParsersContainer,
Energy,
OutcarChunkParser,
PositionsAndForces,
Stress,
outcarchunks,
)
from ase.neighborlist import primitive_neighbor_list
from ase.utils import string2index
from braceexpand import braceexpand
from tqdm import tqdm
import sevenn._keys as KEY
from sevenn._const import LossType
from sevenn.atom_graph_data import AtomGraphData
from .dataset import AtomGraphDataset
def _graph_build_matscipy(cutoff: float, pbc, cell, pos):
pbc_x = pbc[0]
pbc_y = pbc[1]
pbc_z = pbc[2]
identity = np.identity(3, dtype=float)
max_positions = np.max(np.absolute(pos)) + 1
# Extend cell in non-periodic directions
# For models with more than 5 layers,
# the multiplicative constant needs to be increased.
if not pbc_x:
cell[0, :] = max_positions * 5 * cutoff * identity[0, :]
if not pbc_y:
cell[1, :] = max_positions * 5 * cutoff * identity[1, :]
if not pbc_z:
cell[2, :] = max_positions * 5 * cutoff * identity[2, :]
# it does not have self-interaction
edge_src, edge_dst, edge_vec, shifts = neighbour_list(
quantities='ijDS',
pbc=pbc,
cell=cell,
positions=pos,
cutoff=cutoff,
)
# dtype issue
edge_src = edge_src.astype(np.int64)
edge_dst = edge_dst.astype(np.int64)
return edge_src, edge_dst, edge_vec, shifts
def _graph_build_ase(cutoff: float, pbc, cell, pos):
# building neighbor list
edge_src, edge_dst, edge_vec, shifts = primitive_neighbor_list(
'ijDS', pbc, cell, pos, cutoff, self_interaction=True
)
is_zero_idx = np.all(edge_vec == 0, axis=1)
is_self_idx = edge_src == edge_dst
non_trivials = ~(is_zero_idx & is_self_idx)
shifts = np.array(shifts[non_trivials])
edge_vec = edge_vec[non_trivials]
edge_src = edge_src[non_trivials]
edge_dst = edge_dst[non_trivials]
return edge_src, edge_dst, edge_vec, shifts
_graph_build_f = _graph_build_ase
try:
from matscipy.neighbours import neighbour_list
_graph_build_f = _graph_build_matscipy
except ImportError:
pass
def _correct_scalar(v):
if isinstance(v, np.ndarray):
v = v.squeeze()
assert v.ndim == 0, f'given {v} is not a scalar'
return v
elif isinstance(v, (int, float, np.integer, np.floating)):
return np.array(v)
else:
assert False, f'{type(v)} is not expected'
def unlabeled_atoms_to_graph(atoms: ase.Atoms, cutoff: float):
pos = atoms.get_positions()
cell = np.array(atoms.get_cell())
pbc = atoms.get_pbc()
edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos)
edge_idx = np.array([edge_src, edge_dst])
atomic_numbers = atoms.get_atomic_numbers()
cell = np.array(cell)
vol = _correct_scalar(atoms.cell.volume)
if vol == 0:
vol = np.array(np.finfo(float).eps)
data = {
KEY.NODE_FEATURE: atomic_numbers,
KEY.ATOMIC_NUMBERS: atomic_numbers,
KEY.POS: pos,
KEY.EDGE_IDX: edge_idx,
KEY.EDGE_VEC: edge_vec,
KEY.CELL: cell,
KEY.CELL_SHIFT: shifts,
KEY.CELL_VOLUME: vol,
KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)),
}
data[KEY.INFO] = {}
return data
def atoms_to_graph(
atoms: ase.Atoms,
cutoff: float,
transfer_info: bool = True,
y_from_calc: bool = False,
allow_unlabeled: bool = False,
):
"""
From ase atoms, return AtomGraphData as graph based on cutoff radius
Except for energy, force and stress labels must be numpy array type
as other cases are not tested.
Returns 'np.nan' with consistent shape for unlabeled data
(ex. stress of non-pbc system)
Args:
atoms (Atoms): ase atoms
cutoff (float): cutoff radius
transfer_info (bool): if True, transfer ".info" from atoms to graph,
defaults to True
y_from_calc: if True, get ref values from calculator, defaults to False
Returns:
numpy dict that can be used to initialize AtomGraphData
by AtomGraphData(**atoms_to_graph(atoms, cutoff))
, for scalar, its shape is (), and types are np.ndarray
Requires grad is handled by 'dataset' not here.
"""
if not y_from_calc:
y_energy = atoms.info['y_energy']
y_force = atoms.arrays['y_force']
y_stress = atoms.info.get('y_stress', np.full((6,), np.nan))
if y_stress.shape == (3, 3):
y_stress = np.array(
[
y_stress[0][0],
y_stress[1][1],
y_stress[2][2],
y_stress[0][1],
y_stress[1][2],
y_stress[2][0],
]
)
else:
y_stress = y_stress.squeeze()
else:
from_calc = _y_from_calc(atoms)
y_energy = from_calc['energy']
y_force = from_calc['force']
y_stress = from_calc['stress']
assert y_stress.shape == (6,), 'If you see this, please raise a issue'
if not allow_unlabeled and (np.isnan(y_energy) or np.isnan(y_force).any()):
raise ValueError('Unlabeled E or F found, set allow_unlabeled True')
pos = atoms.get_positions()
cell = np.array(atoms.get_cell())
pbc = atoms.get_pbc()
edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos)
edge_idx = np.array([edge_src, edge_dst])
atomic_numbers = atoms.get_atomic_numbers()
cell = np.array(cell)
vol = _correct_scalar(atoms.cell.volume)
if vol == 0:
vol = np.array(np.finfo(float).eps)
data = {
KEY.NODE_FEATURE: atomic_numbers,
KEY.ATOMIC_NUMBERS: atomic_numbers,
KEY.POS: pos,
KEY.EDGE_IDX: edge_idx,
KEY.EDGE_VEC: edge_vec,
KEY.ENERGY: _correct_scalar(y_energy),
KEY.FORCE: y_force,
KEY.STRESS: y_stress.reshape(1, 6), # to make batch have (n_node, 6)
KEY.CELL: cell,
KEY.CELL_SHIFT: shifts,
KEY.CELL_VOLUME: vol,
KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)),
KEY.PER_ATOM_ENERGY: _correct_scalar(y_energy / len(pos)),
}
if transfer_info and atoms.info is not None:
info = copy.deepcopy(atoms.info)
# save only metadata
info.pop('y_energy', None)
info.pop('y_force', None)
info.pop('y_stress', None)
data[KEY.INFO] = info
else:
data[KEY.INFO] = {}
return data
def graph_build(
atoms_list: List,
cutoff: float,
num_cores: int = 1,
transfer_info: bool = True,
y_from_calc: bool = False,
allow_unlabeled: bool = False,
) -> List[AtomGraphData]:
"""
parallel version of graph_build
build graph from atoms_list and return list of AtomGraphData
Args:
atoms_list (List): list of ASE atoms
cutoff (float): cutoff radius of graph
num_cores (int): number of cores to use
transfer_info (bool): if True, copy info from atoms to graph,
defaults to True
y_from_calc (bool): Get reference y labels from calculator, defaults to False
Returns:
List[AtomGraphData]: list of AtomGraphData
"""
serial = num_cores == 1
inputs = [
(atoms, cutoff, transfer_info, y_from_calc, allow_unlabeled)
for atoms in atoms_list
]
if not serial:
pool = mp.Pool(num_cores)
graph_list = pool.starmap(
atoms_to_graph,
tqdm(inputs, total=len(atoms_list), desc=f'graph_build ({num_cores})'),
)
pool.close()
pool.join()
else:
graph_list = [
atoms_to_graph(*input_)
for input_ in tqdm(inputs, desc='graph_build (1)')
]
graph_list = [AtomGraphData.from_numpy_dict(g) for g in graph_list]
return graph_list
def _y_from_calc(atoms: ase.Atoms):
ret = {
'energy': np.nan,
'force': np.full((len(atoms), 3), np.nan),
'stress': np.full((6,), np.nan),
}
if atoms.calc is None:
return ret
try:
ret['energy'] = atoms.get_potential_energy(force_consistent=True)
except NotImplementedError:
ret['energy'] = atoms.get_potential_energy()
try:
ret['force'] = atoms.get_forces(apply_constraint=False)
except NotImplementedError:
pass
try:
y_stress = -1 * atoms.get_stress() # it ensures correct shape
ret['stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]])
except RuntimeError:
pass
return ret
def _set_atoms_y(
atoms_list: List[ase.Atoms],
energy_key: Optional[str] = None,
force_key: Optional[str] = None,
stress_key: Optional[str] = None,
) -> List[ase.Atoms]:
"""
Define how SevenNet reads ASE.atoms object for its y label
If energy_key, force_key, or stress_key is given, the corresponding
label is obtained from .info dict of Atoms object. These values should
have eV, eV/Angstrom, and eV/Angstrom^3 for energy, force, and stress,
respectively. (stress in Voigt notation)
Args:
atoms_list (list[ase.Atoms]): target atoms to set y_labels
energy_key (str, optional): key to get energy. Defaults to None.
force_key (str, optional): key to get force. Defaults to None.
stress_key (str, optional): key to get stress. Defaults to None.
Returns:
list[ase.Atoms]: list of ase.Atoms
Raises:
RuntimeError: if ase atoms are somewhat imperfect
Use free_energy: atoms.get_potential_energy(force_consistent=True)
If it is not available, use atoms.get_potential_energy()
If stress is available, initialize stress tensor
Ignore constraints like selective dynamics
"""
for atoms in atoms_list:
from_calc = _y_from_calc(atoms)
if energy_key is not None:
atoms.info['y_energy'] = atoms.info.pop(energy_key)
else:
atoms.info['y_energy'] = from_calc['energy']
if force_key is not None:
atoms.arrays['y_force'] = atoms.arrays.pop(force_key)
else:
atoms.arrays['y_force'] = from_calc['force']
if stress_key is not None:
y_stress = -1 * atoms.info.pop(stress_key)
atoms.info['y_stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]])
else:
atoms.info['y_stress'] = from_calc['stress']
return atoms_list
def ase_reader(
filename: str,
energy_key: Optional[str] = None,
force_key: Optional[str] = None,
stress_key: Optional[str] = None,
index: str = ':',
**kwargs,
) -> List[ase.Atoms]:
"""
Wrapper of ase.io.read
"""
atoms_list = ase.io.read(filename, index=index, **kwargs)
if not isinstance(atoms_list, list):
atoms_list = [atoms_list]
return _set_atoms_y(atoms_list, energy_key, force_key, stress_key)
# Reader
def structure_list_reader(filename: str, format_outputs: Optional[str] = None):
"""
Read from structure_list using braceexpand and ASE
Args:
fname : filename of structure_list
Returns:
dictionary of lists of ASE structures.
key is title of training data (user-define)
"""
parsers = DefaultParsersContainer(
PositionsAndForces, Stress, Energy, Cell
).make_parsers()
ocp = OutcarChunkParser(parsers=parsers)
def parse_label(line):
line = line.strip()
if line.startswith('[') is False:
return False
elif line.endswith(']') is False:
raise ValueError('wrong structure_list title format')
return line[1:-1]
def parse_fileline(line):
line = line.strip().split()
if len(line) == 1:
line.append(':')
elif len(line) != 2:
raise ValueError('wrong structure_list format')
return line[0], line[1]
structure_list_file = open(filename, 'r')
lines = structure_list_file.readlines()
raw_str_dict = {}
label = 'Default'
for line in lines:
if line.strip() == '':
continue
tmp_label = parse_label(line)
if tmp_label:
label = tmp_label
raw_str_dict[label] = []
continue
elif label in raw_str_dict:
files_expr, index_expr = parse_fileline(line)
raw_str_dict[label].append((files_expr, index_expr))
else:
raise ValueError('wrong structure_list format')
structure_list_file.close()
structures_dict = {}
info_dct = {'data_from': 'user_OUTCAR'}
for title, file_lines in raw_str_dict.items():
stct_lists = []
for file_line in file_lines:
files_expr, index_expr = file_line
index = string2index(index_expr)
for expanded_filename in list(braceexpand(files_expr)):
f_stream = open(expanded_filename, 'r')
# generator of all outcar ionic steps
gen_all = outcarchunks(f_stream, ocp)
try: # TODO: index may not slice, it can be integer
it_atoms = islice(gen_all, index.start, index.stop, index.step)
except ValueError:
# TODO: support
# negative index
raise ValueError('Negative index is not supported yet')
info_dct_f = {
**info_dct,
'file': os.path.abspath(expanded_filename),
}
for idx, o in enumerate(it_atoms):
try:
it_atoms = islice(
gen_all, index.start, index.stop, index.step
)
except ValueError:
# TODO: support
# negative index
raise ValueError('Negative index is not supported yet')
info_dct_f = {
**info_dct,
'file': os.path.abspath(expanded_filename),
}
for idx, o in enumerate(it_atoms):
try:
istep = index.start + idx * index.step # type: ignore
atoms = o.build()
atoms.info = {**info_dct_f, 'ionic_step': istep}.copy()
except TypeError: # it is not slice of ionic steps
atoms = o.build()
atoms.info = info_dct_f.copy()
stct_lists.append(atoms)
f_stream.close()
else:
stct_lists += ase.io.read(
expanded_filename,
index=index_expr,
parallel=False,
)
structures_dict[title] = stct_lists
return {k: _set_atoms_y(v) for k, v in structures_dict.items()}
def dict_reader(data_dict: Dict):
data_dict_cp = copy.deepcopy(data_dict)
ret = []
file_list = data_dict_cp.pop('file_list', None)
if file_list is None:
raise KeyError('file_list is not found')
data_weight_default = {
'energy': 1.0,
'force': 1.0,
'stress': 1.0,
}
data_weight = data_weight_default.copy()
data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {}))
for file_dct in file_list:
ftype = file_dct.pop('data_format', 'ase')
files = list(braceexpand(file_dct.pop('file')))
if ftype == 'ase':
ret.extend(chain(*[ase_reader(f, **file_dct) for f in files]))
elif ftype == 'graph':
continue
else:
raise ValueError(f'{ftype} yet')
for atoms in ret:
atoms.info.update(data_dict_cp)
atoms.info.update({KEY.DATA_WEIGHT: data_weight})
return _set_atoms_y(ret)
def match_reader(reader_name: str, **kwargs):
reader = None
metadata = {}
if reader_name == 'structure_list':
reader = partial(structure_list_reader, **kwargs)
metadata.update({'origin': 'structure_list'})
else:
reader = partial(ase_reader, **kwargs)
metadata.update({'origin': 'ase_reader'})
return reader, metadata
def file_to_dataset(
file: str,
cutoff: float,
cores: int = 1,
reader: Callable = ase_reader,
label: Optional[str] = None,
transfer_info: bool = True,
use_weight: bool = False,
use_modality: bool = False,
):
"""
Deprecated
Read file by reader > get list of atoms or dict of atoms
"""
# expect label: atoms_list dct or atoms or list of atoms
atoms = reader(file)
if type(atoms) is list:
if label is None:
label = KEY.LABEL_NONE
atoms_dct = {label: atoms}
elif isinstance(atoms, ase.Atoms):
if label is None:
label = KEY.LABEL_NONE
atoms_dct = {label: [atoms]}
elif isinstance(atoms, dict):
atoms_dct = atoms
else:
raise TypeError('The return of reader is not list or dict')
graph_dct = {}
for label, atoms_list in atoms_dct.items():
graph_list = graph_build(
atoms_list=atoms_list,
cutoff=cutoff,
num_cores=cores,
transfer_info=transfer_info,
y_from_calc=False,
)
label_info = label.split(':')
for graph in graph_list:
graph[KEY.USER_LABEL] = label_info[0].strip()
if use_weight:
find_weight = False
for info in label_info[1:]:
if 'w=' in info.lower():
weights = info.split('=')[1]
try:
if ',' in weights:
weight_list = list(map(float, weights.split(',')))
else:
weight_list = [float(weights)] * 3
weight_dict = {}
for idx, loss_type in enumerate(LossType):
weight_dict[loss_type.value] = (
weight_list[idx] if idx < len(weight_list) else 1
)
graph[KEY.DATA_WEIGHT] = weight_dict
find_weight = True
break
except:
raise ValueError(
'Weight must be a real number, but'
f' {weights} is given for {label}'
)
if not find_weight:
weight_dict = {}
for loss_type in LossType:
weight_dict[loss_type.value] = 1
graph[KEY.DATA_WEIGHT] = weight_dict
if use_modality:
find_modality = False
for info in label_info[1:]:
if 'm=' in info.lower():
graph[KEY.DATA_MODALITY] = (info.split('=')[1]).strip()
find_modality = True
break
if not find_modality:
raise ValueError(f'Modality not given for {label}')
graph_dct[label_info[0].strip()] = graph_list
db = AtomGraphDataset(graph_dct, cutoff)
return db
import itertools
import random
from collections import Counter
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from ase.data import chemical_symbols
from sklearn.linear_model import Ridge
import sevenn._keys as KEY
import sevenn.util as util
class AtomGraphDataset:
"""
Deprecated
class representing dataset of AtomGraphData
the dataset is handled as dict, {label: data}
if given data is List, it stores data as {KEY_DEFAULT: data}
cutoff is for metadata of the graphs not used for some calc
Every data expected to have one unique cutoff
No validity or check of the condition is done inside the object
attribute:
dataset (Dict[str, List]): key is data label(str), value is list of data
user_labels (List[str]): list of user labels same as dataset.keys()
meta (Dict, Optional): metadata of dataset
for now, metadata 'might' have following keys:
KEY.CUTOFF (float), KEY.CHEMICAL_SPECIES (Dict)
"""
DATA_KEY_X = (
KEY.NODE_FEATURE
) # atomic_number > one_hot_idx > one_hot_vector
DATA_KEY_ENERGY = KEY.ENERGY
DATA_KEY_FORCE = KEY.FORCE
KEY_DEFAULT = KEY.LABEL_NONE
def __init__(
self,
dataset: Union[Dict[str, List], List],
cutoff: float,
metadata: Optional[Dict] = None,
x_is_one_hot_idx: bool = False,
):
"""
Default constructor of AtomGraphDataset
Args:
dataset (Union[Dict[str, List], List]: dataset as dict or pure list
metadata (Dict, Optional): metadata of data
cutoff (float): cutoff radius of graphs inside the dataset
x_is_one_hot_idx (bool): if True, x is one_hot_idx, else 'Z'
'x' (node feature) of dataset can have 3 states, atomic_numbers,
one_hot_idx, or one_hot_vector.
atomic_numbers is general but cannot directly used for input
one_hot_idx is can be input of the model but requires 'type_map'
"""
self.cutoff = cutoff
self.x_is_one_hot_idx = x_is_one_hot_idx
if metadata is None:
metadata = {KEY.CUTOFF: cutoff}
self.meta = metadata
if type(dataset) is list:
self.dataset = {self.KEY_DEFAULT: dataset}
else:
self.dataset = dataset
self.user_labels = list(self.dataset.keys())
# group_by_key here? or not?
def rewrite_labels_to_data(self):
"""
Based on self.dataset dict's keys
write data[KEY.USER_LABEL] to correspond to dict's keys
Most of times, it is already correctly written
But required to rewrite if someone rearrange dataset by their own way
"""
for label, data_list in self.dataset.items():
for data in data_list:
data[KEY.USER_LABEL] = label
def group_by_key(self, data_key: str = KEY.USER_LABEL):
"""
group dataset list by given key and save it as dict
and change in-place
Args:
data_key (str): data key to group by
original use is USER_LABEL, but it can be used for other keys
if someone established it from data[KEY.INFO]
"""
data_list = self.to_list()
self.dataset = {}
for datum in data_list:
key = datum[data_key]
if key not in self.dataset:
self.dataset[key] = []
self.dataset[key].append(datum)
self.user_labels = list(self.dataset.keys())
def separate_info(self, data_key: str = KEY.INFO):
"""
Separate info from data and save it as list of dict
to make it compatible with torch_geometric and later training
"""
data_list = self.to_list()
info_list = []
for datum in data_list:
if data_key in datum is False:
continue
info_list.append(datum[data_key])
del datum[data_key] # It does change the self.dataset
datum[data_key] = len(info_list) - 1
self.info_list = info_list
return (data_list, info_list)
def get_species(self):
"""
You can also use get_natoms and extract keys from there instead of this
(And it is more efficient)
get chemical species of dataset
return list of SORTED chemical species (as str)
"""
if hasattr(self, 'type_map'):
natoms = self.get_natoms(self.type_map)
else:
natoms = self.get_natoms()
species = set()
for natom_dct in natoms.values():
species.update(natom_dct.keys())
species = sorted(list(species))
return species
def get_modalities(self):
modalities = set()
for data_list in self.dataset.values():
datum = data_list[0].to_dict()
if KEY.DATA_MODALITY in datum.keys():
modalities.add(datum[KEY.DATA_MODALITY])
else:
return []
return list(modalities)
def write_modal_attr(
self, modal_type_mapper: dict, write_modal_type: bool = False
):
num_modalities = len(modal_type_mapper)
for data_list in self.dataset.values():
for data in data_list:
tmp_tensor = torch.zeros(num_modalities)
if data[KEY.DATA_MODALITY] != 'common':
modal_idx = modal_type_mapper[data[KEY.DATA_MODALITY]]
tmp_tensor[modal_idx] = 1.0
if write_modal_type:
data[KEY.MODAL_TYPE] = modal_idx
data[KEY.MODAL_ATTR] = tmp_tensor
def get_dict_sort_by_modality(self):
dict_sort_by_modality = {}
for data_list in self.dataset.values():
try:
modal_key = data_list[0].to_dict()[KEY.DATA_MODALITY]
except: # Dataset is not modal
raise ValueError('This dataset has no modality.')
if modal_key not in dict_sort_by_modality.keys():
dict_sort_by_modality[modal_key] = []
dict_sort_by_modality[modal_key].extend(data_list)
return dict_sort_by_modality
def len(self):
if (
len(self.dataset.keys()) == 1
and list(self.dataset.keys())[0] == AtomGraphDataset.KEY_DEFAULT
):
return len(self.dataset[AtomGraphDataset.KEY_DEFAULT])
else:
return {k: len(v) for k, v in self.dataset.items()}
def get(self, idx: int, key: Optional[str] = None):
if key is None:
key = self.KEY_DEFAULT
return self.dataset[key][idx]
def items(self):
return self.dataset.items()
def to_dict(self):
dct_dataset = {}
for label, data_list in self.dataset.items():
dct_dataset[label] = [datum.to_dict() for datum in data_list]
self.dataset = dct_dataset
return self
def x_to_one_hot_idx(self, type_map: Dict[int, int]):
"""
type_map is dict of {atomic_number: one_hot_idx}
after this process, the dataset has dependency on type_map
or chemical species user want to consider
"""
assert self.x_is_one_hot_idx is False
for data_list in self.dataset.values():
for datum in data_list:
datum[self.DATA_KEY_X] = torch.LongTensor(
[type_map[z.item()] for z in datum[self.DATA_KEY_X]]
)
self.type_map = type_map
self.x_is_one_hot_idx = True
def toggle_requires_grad_of_data(
self, key: str, requires_grad_value: bool
):
"""
set requires_grad of specific key of data(pos, edge_vec, ...)
"""
for data_list in self.dataset.values():
for datum in data_list:
datum[key].requires_grad_(requires_grad_value)
def divide_dataset(
self,
ratio: float,
constant_ratio_btw_labels: bool = True,
ignore_test: bool = True
):
"""
divide dataset into 1-2*ratio : ratio : ratio
return divided AtomGraphDataset
returned value lost its dict key and became {KEY_DEFAULT: datalist}
but KEY.USER_LABEL of each data is preserved
"""
def divide(ratio: float, data_list: List, ignore_test=True):
if ratio > 0.5:
raise ValueError('Ratio must not exceed 0.5')
data_len = len(data_list)
random.shuffle(data_list)
n_validation = int(data_len * ratio)
if n_validation == 0:
raise ValueError(
'# of validation set is 0, increase your dataset'
)
if ignore_test:
test_list = []
n_train = data_len - n_validation
train_list = data_list[0:n_train]
valid_list = data_list[n_train:]
else:
n_train = data_len - 2 * n_validation
train_list = data_list[0:n_train]
valid_list = data_list[n_train : n_train + n_validation]
test_list = data_list[n_train + n_validation : data_len]
return train_list, valid_list, test_list
lists = ([], [], []) # train, valid, test
if constant_ratio_btw_labels:
for data_list in self.dataset.values():
for store, divided in zip(lists, divide(ratio, data_list)):
store.extend(divided)
else:
lists = divide(ratio, self.to_list())
dbs = tuple(
AtomGraphDataset(data, self.cutoff, self.meta) for data in lists
)
for db in dbs:
db.group_by_key()
return dbs
def to_list(self):
return list(itertools.chain(*self.dataset.values()))
def get_natoms(self, type_map: Optional[Dict[int, int]] = None):
"""
if x_is_one_hot_idx, type_map is required
type_map: Z->one_hot_index(node_feature)
return Dict{label: {symbol, natom}]}
"""
assert not (self.x_is_one_hot_idx is True and type_map is None)
natoms = {}
for label, data in self.dataset.items():
natoms[label] = Counter()
for datum in data:
if self.x_is_one_hot_idx and type_map is not None:
Zs = util.onehot_to_chem(datum[self.DATA_KEY_X], type_map)
else:
Zs = [
chemical_symbols[z]
for z in datum[self.DATA_KEY_X].tolist()
]
cnt = Counter(Zs)
natoms[label] += cnt
natoms[label] = dict(natoms[label])
return natoms
def get_per_atom_mean(self, key: str, key_num_atoms: str = KEY.NUM_ATOMS):
"""
return per_atom mean of given data key
"""
eng_list = torch.Tensor(
[x[key] / x[key_num_atoms] for x in self.to_list()]
)
return float(torch.mean(eng_list))
def get_per_atom_energy_mean(self):
"""
alias for get_per_atom_mean(KEY.ENERGY)
"""
return self.get_per_atom_mean(self.DATA_KEY_ENERGY)
def get_species_ref_energy_by_linear_comb(self, num_chem_species: int):
"""
Total energy as y, composition as c_i,
solve linear regression of y = c_i*X
sklearn LinearRegression as solver
x should be one-hot-indexed
give num_chem_species if possible
"""
assert self.x_is_one_hot_idx is True
data_list = self.to_list()
c = torch.zeros((len(data_list), num_chem_species))
for idx, datum in enumerate(data_list):
c[idx] = torch.bincount(
datum[self.DATA_KEY_X], minlength=num_chem_species
)
y = torch.Tensor([x[self.DATA_KEY_ENERGY] for x in data_list])
c = c.numpy()
y = y.numpy()
# tweak to fine tune training from many-element to small element
zero_indices = np.all(c == 0, axis=0)
c_reduced = c[:, ~zero_indices]
full_coeff = np.zeros(num_chem_species)
coef_reduced = (
Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_
)
full_coeff[~zero_indices] = coef_reduced
return full_coeff
def get_force_rms(self):
force_list = []
for x in self.to_list():
force_list.extend(
x[self.DATA_KEY_FORCE]
.reshape(
-1,
)
.tolist()
)
force_list = torch.Tensor(force_list)
return float(torch.sqrt(torch.mean(torch.pow(force_list, 2))))
def get_species_wise_force_rms(self, num_chem_species: int):
"""
Return force rms for each species
Averaged by each components (x, y, z)
"""
assert self.x_is_one_hot_idx is True
data_list = self.to_list()
atomx = torch.concat([d[self.DATA_KEY_X] for d in data_list])
force = torch.concat([d[self.DATA_KEY_FORCE] for d in data_list])
index = atomx.repeat_interleave(3, 0).reshape(force.shape)
rms = torch.zeros(
(num_chem_species, 3),
dtype=force.dtype,
device=force.device
)
rms.scatter_reduce_(
0, index, force.square(),
reduce='mean', include_self=False
)
return torch.sqrt(rms.mean(dim=1))
def get_avg_num_neigh(self):
n_neigh = []
for _, data_list in self.dataset.items():
for data in data_list:
n_neigh.extend(
np.unique(data[KEY.EDGE_IDX][0], return_counts=True)[1]
)
avg_num_neigh = np.average(n_neigh)
return avg_num_neigh
def get_statistics(self, key: str):
"""
return dict of statistics of given key (energy, force, stress)
key of dict is its label and _total for total statistics
value of dict is dict of statistics (mean, std, median, max, min)
"""
def _get_statistic_dict(tensor_list):
data_list = torch.cat(
[
tensor.reshape(
-1,
)
for tensor in tensor_list
]
)
data_list = data_list[~torch.isnan(data_list)]
return {
'mean': float(torch.mean(data_list)),
'std': float(torch.std(data_list)),
'median': float(torch.median(data_list)),
'max': (
torch.nan
if data_list.numel() == 0
else float(torch.max(data_list))
),
'min': (
torch.nan
if data_list.numel() == 0
else float(torch.min(data_list))
),
}
res = {}
for label, values in self.dataset.items():
# flatten list of torch.Tensor (values)
tensor_list = [x[key] for x in values]
res[label] = _get_statistic_dict(tensor_list)
tensor_list = [x[key] for x in self.to_list()]
res['Total'] = _get_statistic_dict(tensor_list)
return res
def augment(self, dataset, validator: Optional[Callable] = None):
"""check meta compatibility here
dataset(AtomGraphDataset): data to augment
validator(Callable, Optional): function(self, dataset) -> bool
if validator is None, by default it checks
whether cutoff & chemical_species are same before augment
check consistent data type, float, double, long integer etc
"""
def default_validator(db1, db2):
cut_consis = db1.cutoff == db2.cutoff
# compare unordered lists
x_is_not_onehot = (not db1.x_is_one_hot_idx) and (
not db2.x_is_one_hot_idx
)
return cut_consis and x_is_not_onehot
if validator is None:
validator = default_validator
if not validator(self, dataset):
raise ValueError('given datasets are not compatible check cutoffs')
for key, val in dataset.items():
if key in self.dataset:
self.dataset[key].extend(val)
else:
self.dataset.update({key: val})
self.user_labels = list(self.dataset.keys())
def unify_dtypes(
self,
float_dtype: torch.dtype = torch.float32,
int_dtype: torch.dtype = torch.int64
):
data_list = self.to_list()
for datum in data_list:
for k, v in list(datum.items()):
datum[k] = util.dtype_correct(v, float_dtype, int_dtype)
def delete_data_key(self, key: str):
for data in self.to_list():
del data[key]
# TODO: this by_label is not straightforward
def save(self, path: str, by_label: bool = False):
if by_label:
for label, data in self.dataset.items():
torch.save(
AtomGraphDataset(
{label: data}, self.cutoff, metadata=self.meta
),
f'{path}/{label}.sevenn_data',
)
else:
if path.endswith('.sevenn_data') is False:
path += '.sevenn_data'
torch.save(self, path)
import os
import warnings
from collections import Counter
from copy import deepcopy
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.serialization
import torch.utils.data
import yaml
from ase.data import chemical_symbols
from torch_geometric.data import Data
from torch_geometric.data.in_memory_dataset import InMemoryDataset
from tqdm import tqdm
import sevenn._keys as KEY
import sevenn.train.dataload as dataload
import sevenn.util as util
from sevenn import __version__
from sevenn._const import NUM_UNIV_ELEMENT
from sevenn.atom_graph_data import AtomGraphData
from sevenn.logger import Logger
if torch.__version__.split()[0] >= '2.4.0':
# load graph without error
torch.serialization.add_safe_globals([AtomGraphData])
# warning from PyG, for later torch versions
warnings.filterwarnings(
'ignore',
message='You are using `torch.load` with `weights_only=False`',
)
def _tag_graphs(graph_list: List[AtomGraphData], tag: str):
"""
WIP: To be used
"""
for g in graph_list:
g[KEY.TAG] = tag
return graph_list
def pt_to_args(pt_filename: str):
"""
Return arg dict of root and processed_name from path to .pt
Usage:
dataset = SevenNetGraphDataset(
**pt_to_args({path}/sevenn_data/dataset.pt)
)
"""
processed_dir, basename = os.path.split(pt_filename)
return {
'root': os.path.dirname(processed_dir),
'processed_name': os.path.basename(basename),
}
def _run_stat(
graph_list,
y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS],
) -> Dict[str, Any]:
"""
Loop over dataset and init any statistics might need
"""
n_neigh = []
natoms_counter = Counter()
composition = torch.zeros((len(graph_list), NUM_UNIV_ELEMENT))
stats: Dict[str, Any] = {y: {'_array': []} for y in y_keys}
for i, graph in tqdm(
enumerate(graph_list), desc='run_stat', total=len(graph_list)
):
z_tensor = graph[KEY.ATOMIC_NUMBERS]
natoms_counter.update(z_tensor.tolist())
composition[i] = torch.bincount(z_tensor, minlength=NUM_UNIV_ELEMENT)
n_neigh.append(torch.unique(graph[KEY.EDGE_IDX][0], return_counts=True)[1])
for y, dct in stats.items():
dct['_array'].append(
graph[y].reshape(
-1,
)
)
stats.update({'num_neighbor': {'_array': n_neigh}})
for y, dct in stats.items():
array = torch.cat(dct['_array'])
if array.dtype == torch.int64: # because of n_neigh
array = array.to(torch.float)
try:
median = torch.quantile(array, q=0.5)
except RuntimeError:
warnings.warn(f'skip median due to too large tensor size: {y}')
median = torch.nan
dct.update(
{
'mean': float(torch.mean(array)),
'std': float(torch.std(array, correction=0)),
'median': float(median),
'max': float(torch.max(array)),
'min': float(torch.min(array)),
'count': array.numel(),
'_array': array,
}
)
natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()}
natoms['total'] = sum(list(natoms.values()))
stats.update({'_composition': composition, 'natoms': natoms})
return stats
def _elemwise_reference_energies(composition: np.ndarray, energies: np.ndarray):
from sklearn.linear_model import Ridge
c = composition
y = energies
zero_indices = np.all(c == 0, axis=0)
c_reduced = c[:, ~zero_indices]
# will not 100% reproduce, as it is sorted by Z
# train/dataset.py was sorted by alphabets of chemical species
coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_
full_coeff = np.zeros(NUM_UNIV_ELEMENT)
full_coeff[~zero_indices] = coef_reduced
return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy
class SevenNetGraphDataset(InMemoryDataset):
"""
Replacement of AtomGraphDataset. (and .sevenn_data)
Extends InMemoryDataset of PyG. From given 'files', and 'cutoff',
build graphs for training SevenNet model. Preprocessed graphs are saved to
f'{root}/sevenn_data/{processed_name}.pt
TODO: Save meta info (cutoff) by overriding .save and .load
TODO: 'tag' is not used yet, but initialized
'tag' is replacement for 'label', and each datapoint has it as integer
'tag' is usually parsed from if the structure_list of load_dataset
Args:
root: path to save/load processed PyG dataset
cutoff: edge cutoff of given AtomGraphData
files: list of filenames or dict describing how to parse the file
ASE readable (with proper extension), structure_list, .sevenn_data,
dict containing file_list (see dict_reader of train/dataload.py)
process_num_cores: # of cpu cores to build graph
processed_name: save as {root}/sevenn_data/{processed_name}.pt
pre_transfrom: optional transform for each graph: def (graph) -> graph
pre_filter: optional filtering function for each graph: def (graph) -> graph
force_reload: if True, reload dataset from files even if there exist
{root}/sevenn_data/{processed_name}
**process_kwargs: keyword arguments that will be passed into ase.io.read
"""
def __init__(
self,
cutoff: float,
root: Optional[str] = None,
files: Optional[Union[str, List[Any]]] = None,
process_num_cores: int = 1,
processed_name: str = 'graph.pt',
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
use_data_weight: bool = False,
log: bool = True,
force_reload: bool = False,
drop_info: bool = True,
**process_kwargs,
):
self.cutoff = cutoff
if files is None:
files = []
elif isinstance(files, str):
files = [files] # user convenience
_files = []
for f in files:
if isinstance(f, str):
f = os.path.abspath(f)
_files.append(f)
self._files = _files
self._full_file_list = []
if not processed_name.endswith('.pt'):
processed_name += '.pt'
self._processed_names = [
processed_name, # {root}/sevenn_data/{name}.pt
processed_name.replace('.pt', '.yaml'),
]
root = root or './'
_pdir = os.path.join(root, 'sevenn_data')
_pt = os.path.join(_pdir, self._processed_names[0])
if not os.path.exists(_pt) and len(self._files) == 0:
raise ValueError(
(
f'{_pt} not found and no files to process. '
+ 'If you copied only .pt file, please copy '
+ 'whole sevenn_data dir without changing its name.'
+ ' They all work together.'
)
)
_yam = os.path.join(_pdir, self._processed_names[1])
if not os.path.exists(_yam) and len(self._files) == 0:
raise ValueError(f'{_yam} not found and no files to process')
self.process_num_cores = process_num_cores
self.process_kwargs = process_kwargs
self.use_data_weight = use_data_weight
self.drop_info = drop_info
self.tag_map = {}
self.statistics = {}
self.finalized = False
super().__init__(
root,
transform,
pre_transform,
pre_filter,
log=log,
force_reload=force_reload,
) # Internally calls 'process'
self.load(self.processed_paths[0]) # load pt, saved after process
def load(self, path: str, data_cls=Data) -> None:
super().load(path, data_cls)
if len(self) == 0:
warnings.warn(f'No graphs found {self.processed_paths[0]}')
if len(self.statistics) == 0:
# dataset is loaded from existing pt file.
self._load_meta()
def _load_meta(self) -> None:
with open(self.processed_paths[1], 'r') as f:
meta = yaml.safe_load(f)
if meta['sevennet_version'] == '0.10.0':
self._save_meta(list(self))
with open(self.processed_paths[1], 'r') as f:
meta = yaml.safe_load(f)
cutoff = float(meta['cutoff'])
if float(meta['cutoff']) != self.cutoff:
warnings.warn(
(
'Loaded dataset is built with different cutoff length: '
+ f'{cutoff} != {self.cutoff}, dataset cutoff will be'
+ f' overwritten to {cutoff}'
)
)
self.cutoff = cutoff
self._files = meta['files']
self.statistics = meta['statistics']
def __getitem__(self, idx):
graph = super().__getitem__(idx)
if self.drop_info:
graph.pop(KEY.INFO, None) # type: ignore
return graph
@property
def raw_file_names(self) -> List[Any]:
return self._files
@property
def processed_file_names(self) -> List[str]:
return self._processed_names
@property
def processed_dir(self) -> str:
return os.path.join(self.root, 'sevenn_data')
@property
def full_file_list(self) -> Union[List[str], None]:
return self._full_file_list
def process(self):
graph_list: List[AtomGraphData] = []
for file in self.raw_file_names:
tmplist = SevenNetGraphDataset.file_to_graph_list(
file=file,
cutoff=self.cutoff,
num_cores=self.process_num_cores,
**self.process_kwargs,
)
if isinstance(file, str) and self._full_file_list is not None:
self._full_file_list.extend([os.path.abspath(file)] * len(tmplist))
else:
self._full_file_list = None
graph_list.extend(tmplist)
processed_graph_list = []
for data in graph_list:
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
if self.use_data_weight:
# pop data weight from info, and assign to graph
weight = data[KEY.INFO].pop(
KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0}
)
data[KEY.DATA_WEIGHT] = weight
processed_graph_list.append(data)
if len(processed_graph_list) == 0:
# Can not save at all if there is no graph (error in PyG), raise an error
raise ValueError('Zero graph found after filtering')
# save graphs, handled by torch_geometrics
self.save(processed_graph_list, self.processed_paths[0])
self._save_meta(processed_graph_list)
if self.log:
Logger().writeline(f'Dataset is saved: {self.processed_paths[0]}')
def _save_meta(self, graph_list) -> None:
stats = _run_stat(graph_list)
stats['elemwise_reference_energies'] = _elemwise_reference_energies(
stats['_composition'].numpy(), stats[KEY.ENERGY]['_array'].numpy()
)
self.statistics = stats
stats_save = {}
for label, dct in self.statistics.items():
if label.startswith('_'):
continue
stats_save[label] = {}
if not isinstance(dct, dict):
stats_save[label] = dct
else:
for k, v in dct.items():
if k.startswith('_'):
continue
stats_save[label][k] = v
meta = {
'sevennet_version': __version__,
'cutoff': self.cutoff,
'when': datetime.now().strftime('%Y-%m-%d %H:%M'),
'files': self._files,
'statistics': stats_save,
'species': self.species,
'num_graphs': self.statistics[KEY.ENERGY]['count'],
'per_atom_energy_mean': self.per_atom_energy_mean,
'force_rms': self.force_rms,
'per_atom_energy_std': self.per_atom_energy_std,
'avg_num_neigh': self.avg_num_neigh,
'sqrt_avg_num_neigh': self.sqrt_avg_num_neigh,
}
with open(self.processed_paths[1], 'w') as f:
yaml.dump(meta, f, default_flow_style=False)
@property
def species(self):
return [z for z in self.statistics['natoms'].keys() if z != 'total']
@property
def natoms(self):
return self.statistics['natoms']
@property
def per_atom_energy_mean(self):
return self.statistics[KEY.PER_ATOM_ENERGY]['mean']
@property
def elemwise_reference_energies(self):
return self.statistics['elemwise_reference_energies']
@property
def force_rms(self):
mean = self.statistics[KEY.FORCE]['mean']
std = self.statistics[KEY.FORCE]['std']
return float((mean**2 + std**2) ** (0.5))
@property
def per_atom_energy_std(self):
return self.statistics['per_atom_energy']['std']
@property
def avg_num_neigh(self):
return self.statistics['num_neighbor']['mean']
@property
def sqrt_avg_num_neigh(self):
return self.avg_num_neigh**0.5
@staticmethod
def _read_sevenn_data(filename: str) -> Tuple[List[AtomGraphData], float]:
# backward compatibility
from sevenn.train.dataset import AtomGraphDataset
dataset = torch.load(filename, map_location='cpu', weights_only=False)
if isinstance(dataset, AtomGraphDataset):
graph_list = []
for _, graphs in dataset.dataset.items(): # type: ignore
# TODO: transfer label to tag (who gonna need this?)
graph_list.extend(graphs)
return graph_list, dataset.cutoff
else:
raise ValueError(f'Not sevenn_data type: {type(dataset)}')
@staticmethod
def _read_structure_list(
filename: str, cutoff: float, num_cores: int = 1
) -> List[AtomGraphData]:
datadct = dataload.structure_list_reader(filename)
graph_list = []
for tag, atoms_list in datadct.items():
tmp = dataload.graph_build(atoms_list, cutoff, num_cores)
graph_list.extend(_tag_graphs(tmp, tag))
return graph_list
@staticmethod
def _read_ase_readable(
filename: str,
cutoff: float,
num_cores: int = 1,
tag: str = '',
transfer_info: bool = True,
allow_unlabeled: bool = False,
**ase_kwargs,
) -> List[AtomGraphData]:
pbc_override = ase_kwargs.pop('pbc', None)
atoms_list = dataload.ase_reader(filename, **ase_kwargs)
for atoms in atoms_list:
if pbc_override is not None:
atoms.pbc = pbc_override
graph_list = dataload.graph_build(
atoms_list,
cutoff,
num_cores,
transfer_info=transfer_info,
allow_unlabeled=allow_unlabeled,
)
if tag != '':
graph_list = _tag_graphs(graph_list, tag)
return graph_list
@staticmethod
def _read_graph_dataset(
filename: str, cutoff: float, **kwargs
) -> List[AtomGraphData]:
meta_f = filename.replace('.pt', '.yaml')
orig_cutoff = cutoff
if not os.path.exists(filename):
raise FileNotFoundError(f'No such file: {filename}')
if not os.path.exists(meta_f):
warnings.warn('No meta info found, beware of cutoff...')
else:
with open(meta_f, 'r') as f:
meta = yaml.safe_load(f)
orig_cutoff = float(meta['cutoff'])
if orig_cutoff != cutoff:
warnings.warn(
f'{filename} has different cutoff length: '
+ f'{cutoff} != {orig_cutoff}'
)
ds_args: dict[str, Any] = dict({'cutoff': orig_cutoff})
ds_args.update(pt_to_args(filename))
ds_args.update(kwargs)
dataset = SevenNetGraphDataset(**ds_args)
# TODO: hard coded. consult with inference.py
glist = [g.fit_dimension() for g in dataset] # type: ignore
for g in glist:
if KEY.STRESS in g:
# (1, 6) is what we want
g[KEY.STRESS] = g[KEY.STRESS].unsqueeze(0)
return glist
@staticmethod
def _read_dict(
data_dict: dict,
cutoff: float,
num_cores: int = 1,
):
# logic same as the dataload dict_reader, but handles graphs
data_dict_cp = deepcopy(data_dict)
file_list = data_dict_cp.get('file_list', None)
if file_list is None:
raise KeyError('file_list is not found')
data_weight_default = {
'energy': 1.0,
'force': 1.0,
'stress': 1.0,
}
data_weight = data_weight_default.copy()
data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {}))
graph_list = []
for file_dct in file_list:
ftype = file_dct.pop('data_format', 'ase')
if ftype != 'graph':
continue
graph_list.extend(
SevenNetGraphDataset._read_graph_dataset(
file_dct.get('file'), cutoff=cutoff
)
)
for graph in graph_list:
if KEY.INFO not in graph:
graph[KEY.INFO] = {}
graph[KEY.INFO].update(data_dict_cp)
graph[KEY.INFO].update({KEY.DATA_WEIGHT: data_weight})
atoms_list = dataload.dict_reader(data_dict)
graph_list.extend(dataload.graph_build(atoms_list, cutoff, num_cores))
return graph_list
@staticmethod
def file_to_graph_list(
file: Union[str, dict], cutoff: float, num_cores: int = 1, **kwargs
) -> List[AtomGraphData]:
"""
kwargs: if file is ase readable, passed to ase.io.read
"""
if isinstance(file, str) and not os.path.isfile(file):
raise ValueError(f'No such file: {file}')
graph_list: List[AtomGraphData]
if isinstance(file, dict):
graph_list = SevenNetGraphDataset._read_dict(
file, cutoff, num_cores, **kwargs
)
elif file.endswith('.pt'):
graph_list = SevenNetGraphDataset._read_graph_dataset(file, cutoff)
elif file.endswith('.sevenn_data'):
graph_list, cutoff_other = SevenNetGraphDataset._read_sevenn_data(file)
if cutoff_other != cutoff:
warnings.warn(f'Given {file} has different {cutoff_other}!')
cutoff = cutoff_other
elif 'structure_list' in file:
graph_list = SevenNetGraphDataset._read_structure_list(
file, cutoff, num_cores
)
else:
graph_list = SevenNetGraphDataset._read_ase_readable(
file, cutoff, num_cores, **kwargs
)
return graph_list
def from_single_path(
path: Union[str, List], override_data_weight: bool = True, **dataset_kwargs
) -> Union[SevenNetGraphDataset, None]:
"""
Convenient routine for loading a single .pt dataset.
If given dict and it has data_weight, apply it using transform
"""
data_weight = {'energy': 1.0, 'force': 1.0, 'stress': 1.0}
spath = _extract_single_path(path)
if spath is None:
return None
if isinstance(spath, str):
if not spath.endswith('.pt'):
return None
dataset_kwargs.update(pt_to_args(spath))
elif isinstance(spath, dict):
file = _extract_file_from_dict(spath)
if file is None or not file.endswith('.pt'):
return None
dataset_kwargs.update(pt_to_args(file))
data_weight_user = spath.get(KEY.DATA_WEIGHT, None)
if data_weight_user is not None:
data_weight.update(data_weight_user)
else:
return None
if override_data_weight:
dataset_kwargs['transform'] = _chain_data_weight_override(
dataset_kwargs.get('transform'), data_weight
)
return SevenNetGraphDataset(**dataset_kwargs)
def _extract_single_path(path: Union[str, List]) -> Union[str, dict, None]:
"""Extracts a single path from the input,
ensuring it's either a single string or list with one item."""
if isinstance(path, list):
return path[0] if len(path) == 1 else None
return path if isinstance(path, (str, dict)) else None
def _extract_file_from_dict(path_dict: dict) -> Union[str, None]:
"""Extracts a single file path from the dictionary, ensuring it's valid."""
file_list = path_dict.get('file_list', None)
if file_list and len(file_list) == 1:
file = file_list[0].get('file', None)
return file if isinstance(file, str) else None
return None
def _chain_data_weight_override(transform_func, data_weight):
"""Creates a transform function that overrides the data weight."""
def chained_transform(graph):
graph = transform_func(graph) if transform_func is not None else graph
graph[KEY.INFO].pop(KEY.DATA_WEIGHT, None)
graph[KEY.DATA_WEIGHT] = data_weight
return graph
return chained_transform
# script, return dict of SevenNetGraphDataset
def from_config(
config: Dict[str, Any],
working_dir: str = os.getcwd(),
dataset_keys: Optional[List[str]] = None,
):
log = Logger()
if dataset_keys is None:
dataset_keys = []
for k in config:
if k.startswith('load_') and k.endswith('_path'):
dataset_keys.append(k)
if KEY.LOAD_TRAINSET not in dataset_keys:
raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config')
# initialize arguments for loading dataset
dataset_args = {
'cutoff': config[KEY.CUTOFF],
'root': working_dir,
'process_num_cores': config.get(KEY.PREPROCESS_NUM_CORES, 1),
'use_data_weight': config.get(KEY.USE_WEIGHT, False),
**config.get(KEY.DATA_FORMAT_ARGS, {}),
}
datasets = {}
for dk in dataset_keys:
if not (paths := config[dk]):
continue
if isinstance(paths, str):
paths = [paths]
name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]])
if (dataset := from_single_path(paths, **dataset_args)) is not None:
datasets[name] = dataset
else:
dataset_args.update({'files': paths, 'processed_name': name})
dataset_path = os.path.join(working_dir, 'sevenn_data', f'{name}.pt')
if os.path.exists(dataset_path) and 'force_reload' not in dataset_args:
log.writeline(
f'Dataset will be loaded from {dataset_path}, without update. '
+ 'If you have changed your files to read, put force_reload=True'
+ ' under the data_format_args key'
)
datasets[name] = SevenNetGraphDataset(**dataset_args)
train_set = datasets['trainset']
chem_species = set(train_set.species)
# print statistics of each dataset
for name, dataset in datasets.items():
log.bar()
log.writeline(f'{name} distribution:')
log.statistic_write(dataset.statistics)
log.format_k_v('# structures (graph)', len(dataset), write=True)
chem_species.update(dataset.species)
log.bar()
# initialize known species from dataset if 'auto'
# sorted to alphabetical order (which is same as before)
chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP]
if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py
log.writeline('Known species are obtained from the dataset')
config.update(util.chemical_species_preprocess(sorted(list(chem_species))))
# retrieve shift, scale, conv_denominaotrs from user input (keyword)
init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR]
for k in init_from_stats:
input = config[k] # statistic key or numbers
# If it is not 'str', 1: It is 'continue' training
# 2: User manually inserted numbers
if isinstance(input, str) and hasattr(train_set, input):
var = getattr(train_set, input)
config.update({k: var})
log.writeline(f'{k} is obtained from statistics')
elif isinstance(input, str) and not hasattr(train_set, input):
raise NotImplementedError(input)
if 'validset' not in datasets and config.get(KEY.RATIO, 0.0) > 0.0:
log.writeline('Use validation set as random split from the training set')
log.writeline(
'Note that statistics, shift, scale, and conv_denominator are '
+ 'computed before random split.\n If you want these after random '
+ 'split, please preprocess dataset and set it as load_trainset_path '
+ 'and load_validset_path explicitly.'
)
ratio = float(config[KEY.RATIO])
train, valid = torch.utils.data.random_split(
datasets['trainset'], (1.0 - ratio, ratio)
)
datasets['trainset'] = train
datasets['validset'] = valid
return datasets
from typing import Any, Callable, Dict, Optional, Tuple
import torch
import sevenn._keys as KEY
class LossDefinition:
"""
Base class for loss definition
weights are defined in outside of the class
"""
def __init__(
self,
name: str,
unit: Optional[str] = None,
criterion: Optional[Callable] = None,
ref_key: Optional[str] = None,
pred_key: Optional[str] = None,
use_weight: bool = False,
ignore_unlabeled: bool = True,
):
self.name = name
self.unit = unit
self.criterion = criterion
self.ref_key = ref_key
self.pred_key = pred_key
self.use_weight = use_weight
self.ignore_unlabeled = ignore_unlabeled
def __repr__(self):
return self.name
def assign_criteria(self, criterion: Callable):
if self.criterion is not None:
raise ValueError('Loss uses its own criterion.')
self.criterion = criterion
def _preprocess(
self, batch_data: Dict[str, Any], model: Optional[Callable] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if self.pred_key is None or self.ref_key is None:
raise NotImplementedError('LossDefinition is not implemented.')
pred = torch.reshape(batch_data[self.pred_key], (-1,))
ref = torch.reshape(batch_data[self.ref_key], (-1,))
return pred, ref, None
def _ignore_unlabeled(self, pred, ref, data_weights=None):
unlabeled = torch.isnan(ref)
pred = pred[~unlabeled]
ref = ref[~unlabeled]
if data_weights is not None:
data_weights = data_weights[~unlabeled]
return pred, ref, data_weights
def get_loss(self, batch_data: Dict[str, Any], model: Optional[Callable] = None):
"""
Function that return scalar
"""
if self.criterion is None:
raise NotImplementedError('LossDefinition has no criterion.')
pred, ref, w_tensor = self._preprocess(batch_data, model)
if self.ignore_unlabeled:
pred, ref, w_tensor = self._ignore_unlabeled(pred, ref, w_tensor)
if len(pred) == 0:
assert self.ref_key is not None
return torch.zeros(1, device=batch_data[self.ref_key].device)
loss = self.criterion(pred, ref)
if self.use_weight:
loss = torch.mean(loss * w_tensor)
return loss
class PerAtomEnergyLoss(LossDefinition):
"""
Loss for per atom energy
"""
def __init__(
self,
name: str = 'Energy',
unit: str = 'eV/atom',
criterion: Optional[Callable] = None,
ref_key: str = KEY.ENERGY,
pred_key: str = KEY.PRED_TOTAL_ENERGY,
**kwargs,
):
super().__init__(
name=name,
unit=unit,
criterion=criterion,
ref_key=ref_key,
pred_key=pred_key,
**kwargs,
)
def _preprocess(
self, batch_data: Dict[str, Any], model: Optional[Callable] = None
):
num_atoms = batch_data[KEY.NUM_ATOMS]
assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str)
pred = batch_data[self.pred_key] / num_atoms
ref = batch_data[self.ref_key] / num_atoms
w_tensor = None
if self.use_weight:
loss_type = self.name.lower()
weight = batch_data[KEY.DATA_WEIGHT][loss_type]
w_tensor = torch.repeat_interleave(weight, 1)
return pred, ref, w_tensor
class ForceLoss(LossDefinition):
"""
Loss for force
"""
def __init__(
self,
name: str = 'Force',
unit: str = 'eV/A',
criterion: Optional[Callable] = None,
ref_key: str = KEY.FORCE,
pred_key: str = KEY.PRED_FORCE,
**kwargs,
):
super().__init__(
name=name,
unit=unit,
criterion=criterion,
ref_key=ref_key,
pred_key=pred_key,
**kwargs,
)
def _preprocess(
self, batch_data: Dict[str, Any], model: Optional[Callable] = None
):
assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str)
pred = torch.reshape(batch_data[self.pred_key], (-1,))
ref = torch.reshape(batch_data[self.ref_key], (-1,))
w_tensor = None
if self.use_weight:
loss_type = self.name.lower()
weight = batch_data[KEY.DATA_WEIGHT][loss_type]
w_tensor = weight[batch_data[KEY.BATCH]]
w_tensor = torch.repeat_interleave(w_tensor, 3)
return pred, ref, w_tensor
class StressLoss(LossDefinition):
"""
Loss for stress this is kbar
"""
def __init__(
self,
name: str = 'Stress',
unit: str = 'kbar',
criterion: Optional[Callable] = None,
ref_key: str = KEY.STRESS,
pred_key: str = KEY.PRED_STRESS,
**kwargs,
):
super().__init__(
name=name,
unit=unit,
criterion=criterion,
ref_key=ref_key,
pred_key=pred_key,
**kwargs,
)
self.TO_KB = 1602.1766208 # eV/A^3 to kbar
def _preprocess(
self, batch_data: Dict[str, Any], model: Optional[Callable] = None
):
assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str)
pred = torch.reshape(batch_data[self.pred_key] * self.TO_KB, (-1,))
ref = torch.reshape(batch_data[self.ref_key] * self.TO_KB, (-1,))
w_tensor = None
if self.use_weight:
loss_type = self.name.lower()
weight = batch_data[KEY.DATA_WEIGHT][loss_type]
w_tensor = torch.repeat_interleave(weight, 6)
return pred, ref, w_tensor
def get_loss_functions_from_config(config: Dict[str, Any]):
from sevenn.train.optim import loss_dict
loss_functions = [] # list of tuples (loss_definition, weight)
loss = loss_dict[config[KEY.LOSS].lower()]
loss_param = config.get(KEY.LOSS_PARAM, {})
use_weight = config.get(KEY.USE_WEIGHT, False)
if use_weight:
loss_param['reduction'] = 'none'
criterion = loss(**loss_param)
commons = {'use_weight': use_weight}
loss_functions.append((PerAtomEnergyLoss(**commons), 1.0))
loss_functions.append((ForceLoss(**commons), config[KEY.FORCE_WEIGHT]))
if config[KEY.IS_TRAIN_STRESS]:
loss_functions.append((StressLoss(**commons), config[KEY.STRESS_WEIGHT]))
for loss_function, _ in loss_functions: # why do these?
if loss_function.criterion is None:
loss_function.assign_criteria(criterion)
return loss_functions
import bisect
import os
from copy import deepcopy
from typing import Any, Dict, List, Optional
import numpy as np
from torch.utils.data import ConcatDataset, Dataset
import sevenn._keys as KEY
import sevenn.util as util
from sevenn.logger import Logger
def _arrange_paths_by_modality(paths: List[dict]):
modal_dct = {}
for path in paths:
if isinstance(path, dict):
if KEY.DATA_MODALITY not in path:
raise ValueError(f'{KEY.DATA_MODALITY} is missing')
modal = path.pop(KEY.DATA_MODALITY)
else:
raise TypeError(f'{path} is not dict or str')
if modal not in modal_dct:
modal_dct[modal] = []
modal_dct[modal].append(path)
return modal_dct
def combined_variance(
means: np.ndarray, stds: np.ndarray, sample_sizes: np.ndarray, ddof: int = 0
) -> float:
"""
Calculate the combined variance for multiple datasets.
"""
assert len(means) == len(stds) and len(stds) == len(sample_sizes)
# Total number of samples
total_samples = np.sum(sample_sizes)
# Combined mean
combined_mean = np.sum(sample_sizes * means) / total_samples
# Combined variance calculation
variance_terms = (sample_sizes - ddof) * (stds**2)
mean_diff_terms = sample_sizes * ((means - combined_mean) ** 2)
combined_variance = (np.sum(variance_terms) + np.sum(mean_diff_terms)) / (
total_samples - ddof
)
return combined_variance
def combined_std(
means: List[float], stds: List[float], sample_sizes: List[int]
) -> float:
"""
Calculate the combined std for multiple datasets.
"""
assert len(means) == len(stds) and len(stds) == len(sample_sizes)
means_arr = np.array(means)
stds_arr = np.array(stds)
sample_sizes_arr = np.array(sample_sizes)
cv = combined_variance(means_arr, stds_arr, sample_sizes_arr)
return np.sqrt(cv)
def combined_mean(means: List[float], sample_sizes: List[int]) -> float:
"""
Calculate the combined mean for multiple datasets.
"""
assert len(means) == len(sample_sizes)
means_arr = np.array(means)
sample_sizes_arr = np.array(sample_sizes)
return np.sum(sample_sizes_arr * means_arr) / np.sum(sample_sizes_arr)
def combined_rms(
means: List[float], stds: List[float], sample_sizes: List[int]
) -> float:
"""
Calculate the combined RMS for multiple datasets.
"""
assert len(means) == len(stds) and len(stds) == len(sample_sizes)
means_arr = np.array(means)
stds_arr = np.array(stds)
sample_sizes_arr = np.array(sample_sizes)
cm = combined_mean(means, sample_sizes)
cv = combined_variance(means_arr, stds_arr, sample_sizes_arr)
# Combined RMS calculation
return np.sqrt(cm**2 + cv)
class SevenNetMultiModalDataset(ConcatDataset):
def __init__(
self,
modal_dataset_dict: Dict[str, Dataset],
):
datasets = []
modals = []
for modal, dataset in modal_dataset_dict.items():
modals.append(modal)
datasets.append(dataset)
self.modals = modals
super().__init__(datasets)
def __getitem__(self, idx):
graph = super().__getitem__(idx)
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
modality = self.modals[dataset_idx]
graph[KEY.DATA_MODALITY] = modality
return graph
def _modal_wise_property(self, attribute_name: str):
dct = {}
for modal, dataset in zip(self.modals, self.datasets):
try:
if hasattr(dataset, attribute_name):
dct[modal] = getattr(dataset, attribute_name)
except AttributeError:
dct[modal] = None
return dct
@property
def dataset_dict(self):
arr = {}
for idx, modality in enumerate(self.modals):
arr[modality] = self.datasets[idx]
return arr
@property
def species(self):
dct = self._modal_wise_property('species')
tot = set()
for sp in dct.values():
tot.update(sp)
dct['total'] = list(tot)
return dct
@property
def natoms(self):
return self._modal_wise_property('natoms')
@property
def per_atom_energy_mean(self):
dct = self._modal_wise_property('per_atom_energy_mean')
try:
means = []
sample_sizes = []
for modality, mean in dct.items():
means.append(mean)
sample_sizes.append(
self.statistics[modality][KEY.PER_ATOM_ENERGY]['count']
)
cm = combined_mean(means, sample_sizes)
dct['total'] = cm
except KeyError:
pass
return dct
@property
def elemwise_reference_energies(self):
# total is not supported (it is expensive and complex, but useless)
return self._modal_wise_property('elemwise_reference_energies')
@property
def force_rms(self):
dct = self._modal_wise_property('force_rms')
try:
means = []
sample_sizes = []
stds = []
for modality in dct:
means.append(self.statistics[modality][KEY.FORCE]['mean'])
sample_sizes.append(self.statistics[modality][KEY.FORCE]['count'])
stds.append(self.statistics[modality][KEY.FORCE]['std'])
cm = combined_rms(means, stds, sample_sizes)
dct['total'] = cm
except KeyError:
pass
return dct
@property
def per_atom_energy_std(self):
dct = self._modal_wise_property('per_atom_energy_std')
try:
means = []
sample_sizes = []
stds = []
for modality in dct:
means.append(self.statistics[modality][KEY.PER_ATOM_ENERGY]['mean'])
sample_sizes.append(
self.statistics[modality][KEY.PER_ATOM_ENERGY]['count']
)
stds.append(self.statistics[modality][KEY.PER_ATOM_ENERGY]['std'])
cm = combined_std(means, stds, sample_sizes)
dct['total'] = cm
except KeyError:
pass
return dct
@property
def avg_num_neigh(self):
dct = self._modal_wise_property('avg_num_neigh')
try:
means = []
sample_sizes = []
for modality, mean in dct.items():
means.append(mean)
sample_sizes.append(
self.statistics[modality]['num_neighbor']['count']
)
cm = combined_mean(means, sample_sizes)
dct['total'] = cm
except KeyError:
pass
return dct
@property
def sqrt_avg_num_neigh(self):
avg_nn = self.avg_num_neigh
return {k: v**0.5 for k, v in avg_nn.items()}
@property
def statistics(self):
return self._modal_wise_property('statistics')
@staticmethod
def as_graph_dataset(
paths: List[dict],
**graph_dataset_kwargs,
):
import sevenn.train.graph_dataset as gd
modal_paths = _arrange_paths_by_modality(paths)
dataset_dct = {}
for modality, paths in modal_paths.items():
kwargs = deepcopy(graph_dataset_kwargs)
if (dataset := gd.from_single_path(paths, **kwargs)) is None:
pname = kwargs.pop('processed_name', 'graph').replace('.pt', '')
dataset = gd.SevenNetGraphDataset(
files=paths,
processed_name=f'{pname}_{modality}.pt',
**kwargs,
)
dataset_dct[modality] = dataset
return SevenNetMultiModalDataset(dataset_dct)
def from_config(
config: Dict[str, Any],
working_dir: str = os.getcwd(),
dataset_keys: Optional[List[str]] = None,
):
log = Logger()
if dataset_keys is None:
dataset_keys = [
k for k in config if (k.startswith('load_') and k.endswith('_path'))
]
if KEY.LOAD_TRAINSET not in dataset_keys:
raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config')
dataset_args = {
'cutoff': config[KEY.CUTOFF],
'root': working_dir,
'process_num_cores': config.get(KEY.PREPROCESS_NUM_CORES, 1),
'use_data_weight': config.get(KEY.USE_WEIGHT, False),
**config[KEY.DATA_FORMAT_ARGS],
}
datasets = {}
for dk in dataset_keys:
if not (paths := config[dk]):
continue
if isinstance(paths, str):
paths = [paths]
name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]])
dataset_args.update({'processed_name': name})
datasets[name] = SevenNetMultiModalDataset.as_graph_dataset(
paths, # type: ignore
**dataset_args,
)
train_set = datasets['trainset']
modals_dataset = set()
chem_species = set()
# print statistics of each dataset
for name, dataset in datasets.items():
for idx, modality in enumerate(dataset.modals):
log.bar()
log.writeline(f'{name} - {modality} distribution:')
log.statistic_write(dataset.statistics[modality])
log.format_k_v(
'# structures (graph)', len(dataset.datasets[idx]), write=True
)
modals_dataset.update([modality])
chem_species.update(dataset.species['total'])
log.bar()
if (modal_map := config.get(KEY.MODAL_MAP, None)) is None:
modals = sorted(list(modals_dataset))
modal_map = {modal: i for i, modal in enumerate(modals)}
config[KEY.MODAL_MAP] = modal_map
modals = list(modal_map.keys())
if not modals_dataset.issubset(modal_map):
raise ValueError(
f'Found modalities in datasets: {modals_dataset} are not subset of'
+ f' {modals}. Use sevenn_cp tool to append/assign modality'
)
log.writeline(f'Modalities of this model: {modals}')
config[KEY.NUM_MODALITIES] = len(modal_map)
# initialize known species from dataset if 'auto'
# sorted to alphabetical order (which is same as before)
chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP]
if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py
log.writeline('Known species are obtained from the dataset')
config.update(util.chemical_species_preprocess(sorted(list(chem_species))))
# retrieve shift, scale, conv_denominaotrs from user input (keyword)
init_from_stats_candid = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR]
init_from_stats = [
k for k in init_from_stats_candid if isinstance(config[k], str)
]
for k in init_from_stats:
input = config[k]
if not hasattr(train_set, input):
raise NotImplementedError(input)
modal_stat = getattr(train_set, input)
try:
if k == KEY.CONV_DENOMINATOR and 'total' in modal_stat:
# conv_denominator is not modal-wise
var = modal_stat['total']
elif k == KEY.SHIFT and config[KEY.USE_MODAL_WISE_SHIFT]:
modal_stat.pop('total', None)
var = modal_stat
elif k == KEY.SHIFT and not config[KEY.USE_MODAL_WISE_SHIFT]:
var = modal_stat['total']
elif k == KEY.SCALE and config[KEY.USE_MODAL_WISE_SCALE]:
modal_stat.pop('total', None)
var = modal_stat
elif k == KEY.SCALE and not config[KEY.USE_MODAL_WISE_SCALE]:
var = modal_stat['total']
else:
raise NotImplementedError(f'Failed to init {k} from statistics')
except KeyError as e:
if e.args[0] == 'total':
raise NotImplementedError(
f'{k}: {input} does not support total statistics. '
+ f'Set use_modal_wise_{k} True or specify numbers manually'
)
else:
raise e
config.update({k: var})
log.writeline(f'{k} is obtained from statistics')
return datasets
import torch.nn as nn
import torch.optim.lr_scheduler as scheduler
from torch.optim import adagrad, adam, adamw, radam, sgd
optim_dict = {
'sgd': sgd.SGD,
'adagrad': adagrad.Adagrad,
'adam': adam.Adam,
'adamw': adamw.AdamW,
'radam': radam.RAdam,
}
scheduler_dict = {
'steplr': scheduler.StepLR,
'multisteplr': scheduler.MultiStepLR,
'exponentiallr': scheduler.ExponentialLR,
'cosineannealinglr': scheduler.CosineAnnealingLR,
'reducelronplateau': scheduler.ReduceLROnPlateau,
'linearlr': scheduler.LinearLR,
}
loss_dict = {'mse': nn.MSELoss, 'huber': nn.HuberLoss}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment