Commit 2409a22f authored by fanding2000's avatar fanding2000
Browse files

Format fix. More options in readme

parent ce29afea
import os
import time
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional
from ase.data import atomic_numbers
import sevenn._keys as KEY
from sevenn import __version__
CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()}
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Logger(metaclass=Singleton):
SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output
def __init__(
self, filename: Optional[str] = None, screen: bool = False, rank: int = 0
):
self.rank = rank
self._filename = filename
if rank == 0:
# if filename is not None:
# self.logfile = open(filename, 'a', buffering=1)
self.logfile = None
self.files = {}
self.screen = screen
else:
self.logfile = None
self.screen = False
self.timer_dct = {}
self.active = True
def __enter__(self):
if self.rank != 0:
return self
if self.logfile is None and self._filename is not None:
try:
self.logfile = open(
self._filename, 'a', buffering=1, encoding='utf-8'
)
except IOError as e:
print(f'Failed to re-open log file {self._filename}: {e}')
self.logfile = None
self.files = {}
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.rank != 0:
return self
try:
if self.logfile is not None:
self.logfile.close()
self.logfile = None
for f in self.files.values():
f.close()
except IOError as e:
print(f'Failed to close log files: {e}')
finally:
self.logfile = None
self.files = {}
def switch_file(self, new_filename: str):
if self.rank != 0:
return self
if self.logfile is not None:
raise ValueError('Current logfile is not yet closed')
self._filename = new_filename
return self
def write(self, content: str):
if self.rank != 0:
return
# no newline!
if self.logfile is not None and self.active:
self.logfile.write(content)
if self.screen and self.active:
print(content, end='')
def writeline(self, content: str):
content = content + '\n'
self.write(content)
def init_csv(self, filename: str, header: list):
"""
Deprecated
"""
if self.rank == 0:
self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8')
self.files[filename].write(','.join(header) + '\n')
else:
pass
def append_csv(self, filename: str, content: list, decimal: int = 6):
"""
Deprecated
"""
if self.rank == 0:
if filename not in self.files:
self.files[filename] = open(filename, 'a', buffering=1)
str_content = []
for c in content:
if isinstance(c, float):
str_content.append(f'{c:.{decimal}f}')
else:
str_content.append(str(c))
self.files[filename].write(','.join(str_content) + '\n')
else:
pass
def natoms_write(self, natoms: Dict[str, Dict]):
content = ''
total_natom = {}
for label, natom in natoms.items():
content += self.format_k_v(label, natom)
for specie, num in natom.items():
try:
total_natom[specie] += num
except KeyError:
total_natom[specie] = num
content += self.format_k_v('Total, label wise', total_natom)
content += self.format_k_v('Total', sum(total_natom.values()))
self.write(content)
def statistic_write(self, statistic: Dict[str, Dict]):
content = ''
for label, dct in statistic.items():
if label.startswith('_'):
continue
if not isinstance(dct, dict):
continue
dct_new = {}
for k, v in dct.items():
if k.startswith('_'):
continue
if isinstance(v, int):
dct_new[k] = v
else:
dct_new[k] = f'{v:.3f}'
content += self.format_k_v(label, dct_new)
self.write(content)
# TODO : refactoring!!!, this is not loss, rmse
def epoch_write_specie_wise_loss(self, train_loss, valid_loss):
lb_pad = 21
fs = 6
pad = 21 - fs
ln = '-' * fs
total_atom_type = train_loss.keys()
content = ''
for at in total_atom_type:
t_F = train_loss[at]
v_F = valid_loss[at]
at_sym = CHEM_SYMBOLS[at]
content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format(
label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs
) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format(
t_F=t_F, v_F=v_F, pad=pad, fs=fs
)
content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format(
t_S=ln, v_S=ln, pad=pad, fs=fs
)
content += '\n'
self.write(content)
def write_full_table(
self,
dict_list: List[Dict],
row_labels: List[str],
decimal_places: int = 6,
pad: int = 2,
):
"""
Assume data_list is list of dict with same keys
"""
assert len(dict_list) == len(row_labels)
label_len = max(map(len, row_labels))
# Extract the column names and create a 2D array of values
col_names = list(dict_list[0].keys())
values = [list(d.values()) for d in dict_list]
# Format the numbers with the given decimal places
formatted_values = [
[f'{value:.{decimal_places}f}' for value in row] for row in values
]
# Calculate padding lengths for each column (with extra padding)
max_col_lengths = [
max(len(str(value)) for value in col) + pad
for col in zip(col_names, *formatted_values)
]
# Create header row and separator
header = ' ' * (label_len + pad) + ' '.join(
col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths)
)
separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * (
label_len + pad
)
# Print header and separator
self.writeline(header)
self.writeline(separator)
# Print the data rows with row labels
for row_label, row in zip(row_labels, formatted_values):
data_row = ' '.join(
value.rjust(pad) for value, pad in zip(row, max_col_lengths)
)
self.writeline(f'{row_label.ljust(label_len)}{data_row}')
def format_k_v(self, key: Any, val: Any, write: bool = False):
"""
key and val should be str convertible
"""
MAX_KEY_SIZE = 20
SEPARATOR = ', '
EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3)
NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5
key = str(key)
val = str(val)
content = f'{key:<{MAX_KEY_SIZE}}: {val}'
if len(content) > NEW_LINE_LEN:
content = f'{key:<{MAX_KEY_SIZE}}: '
# septate val by separator
val_list = val.split(SEPARATOR)
current_len = len(content)
for val_compo in val_list:
current_len += len(val_compo)
if current_len > NEW_LINE_LEN:
newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}'
content += f'\\\n{newline_content}'
current_len = len(newline_content)
else:
content += f'{val_compo}{SEPARATOR}'
if content.endswith(f'{SEPARATOR}'):
content = content[: -len(SEPARATOR)]
content += '\n'
if write is False:
return content
else:
self.write(content)
return ''
def greeting(self):
LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii'
with open(LOGO_ASCII_FILE, 'r') as logo_f:
logo_ascii = logo_f.read()
content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n'
content += f'version {__version__}, {time.ctime()}\n'
self.write(content)
self.write(logo_ascii)
def bar(self):
content = '-' * Logger.SCREEN_WIDTH + '\n'
self.write(content)
def print_config(
self,
model_config: Dict[str, Any],
data_config: Dict[str, Any],
train_config: Dict[str, Any],
):
"""
print some important information from config
"""
content = 'successfully read yaml config!\n\n' + 'from model configuration\n'
for k, v in model_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom train configuration\n'
for k, v in train_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom data configuration\n'
for k, v in data_config.items():
content += self.format_k_v(k, str(v))
self.write(content)
# TODO: This is not good make own exception
def error(self, e: Exception):
content = ''
if type(e) is ValueError:
content += 'Error occurred!\n'
content += str(e) + '\n'
else:
content += 'Unknown error occurred!\n'
content += traceback.format_exc()
self.write(content)
def timer_start(self, name: str):
self.timer_dct[name] = datetime.now()
def timer_end(self, name: str, message: str, remove: bool = True):
"""
print f"{message}: {elapsed}"
"""
elapsed = str(datetime.now() - self.timer_dct[name])
# elapsed = elapsed.strftime('%H-%M-%S')
if remove:
del self.timer_dct[name]
self.write(f'{message}: {elapsed[:-4]}\n')
# TODO: print it without config
# TODO: refactoring, readout part name :(
def print_model_info(self, model, config):
from functools import partial
kv_write = partial(self.format_k_v, write=True)
self.writeline('Irreps of features')
kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out'))
for i in range(config[KEY.NUM_CONVOLUTION]):
kv_write(
f'{i}th node',
model.get_irreps_in(f'{i}_self_interaction_1'),
)
i = config[KEY.NUM_CONVOLUTION] - 1
kv_write(
'readout irreps',
model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'),
)
num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.writeline(f'# learnable parameters: {num_weights}\n')
import os
import time
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional
from ase.data import atomic_numbers
import sevenn._keys as KEY
from sevenn import __version__
CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()}
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Logger(metaclass=Singleton):
SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output
def __init__(
self, filename: Optional[str] = None, screen: bool = False, rank: int = 0
):
self.rank = rank
self._filename = filename
if rank == 0:
# if filename is not None:
# self.logfile = open(filename, 'a', buffering=1)
self.logfile = None
self.files = {}
self.screen = screen
else:
self.logfile = None
self.screen = False
self.timer_dct = {}
self.active = True
def __enter__(self):
if self.rank != 0:
return self
if self.logfile is None and self._filename is not None:
try:
self.logfile = open(
self._filename, 'a', buffering=1, encoding='utf-8'
)
except IOError as e:
print(f'Failed to re-open log file {self._filename}: {e}')
self.logfile = None
self.files = {}
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.rank != 0:
return self
try:
if self.logfile is not None:
self.logfile.close()
self.logfile = None
for f in self.files.values():
f.close()
except IOError as e:
print(f'Failed to close log files: {e}')
finally:
self.logfile = None
self.files = {}
def switch_file(self, new_filename: str):
if self.rank != 0:
return self
if self.logfile is not None:
raise ValueError('Current logfile is not yet closed')
self._filename = new_filename
return self
def write(self, content: str):
if self.rank != 0:
return
# no newline!
if self.logfile is not None and self.active:
self.logfile.write(content)
if self.screen and self.active:
print(content, end='')
def writeline(self, content: str):
content = content + '\n'
self.write(content)
def init_csv(self, filename: str, header: list):
"""
Deprecated
"""
if self.rank == 0:
self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8')
self.files[filename].write(','.join(header) + '\n')
else:
pass
def append_csv(self, filename: str, content: list, decimal: int = 6):
"""
Deprecated
"""
if self.rank == 0:
if filename not in self.files:
self.files[filename] = open(filename, 'a', buffering=1)
str_content = []
for c in content:
if isinstance(c, float):
str_content.append(f'{c:.{decimal}f}')
else:
str_content.append(str(c))
self.files[filename].write(','.join(str_content) + '\n')
else:
pass
def natoms_write(self, natoms: Dict[str, Dict]):
content = ''
total_natom = {}
for label, natom in natoms.items():
content += self.format_k_v(label, natom)
for specie, num in natom.items():
try:
total_natom[specie] += num
except KeyError:
total_natom[specie] = num
content += self.format_k_v('Total, label wise', total_natom)
content += self.format_k_v('Total', sum(total_natom.values()))
self.write(content)
def statistic_write(self, statistic: Dict[str, Dict]):
content = ''
for label, dct in statistic.items():
if label.startswith('_'):
continue
if not isinstance(dct, dict):
continue
dct_new = {}
for k, v in dct.items():
if k.startswith('_'):
continue
if isinstance(v, int):
dct_new[k] = v
else:
dct_new[k] = f'{v:.3f}'
content += self.format_k_v(label, dct_new)
self.write(content)
# TODO : refactoring!!!, this is not loss, rmse
def epoch_write_specie_wise_loss(self, train_loss, valid_loss):
lb_pad = 21
fs = 6
pad = 21 - fs
ln = '-' * fs
total_atom_type = train_loss.keys()
content = ''
for at in total_atom_type:
t_F = train_loss[at]
v_F = valid_loss[at]
at_sym = CHEM_SYMBOLS[at]
content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format(
label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs
) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format(
t_F=t_F, v_F=v_F, pad=pad, fs=fs
)
content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format(
t_S=ln, v_S=ln, pad=pad, fs=fs
)
content += '\n'
self.write(content)
def write_full_table(
self,
dict_list: List[Dict],
row_labels: List[str],
decimal_places: int = 6,
pad: int = 2,
):
"""
Assume data_list is list of dict with same keys
"""
assert len(dict_list) == len(row_labels)
label_len = max(map(len, row_labels))
# Extract the column names and create a 2D array of values
col_names = list(dict_list[0].keys())
values = [list(d.values()) for d in dict_list]
# Format the numbers with the given decimal places
formatted_values = [
[f'{value:.{decimal_places}f}' for value in row] for row in values
]
# Calculate padding lengths for each column (with extra padding)
max_col_lengths = [
max(len(str(value)) for value in col) + pad
for col in zip(col_names, *formatted_values)
]
# Create header row and separator
header = ' ' * (label_len + pad) + ' '.join(
col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths)
)
separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * (
label_len + pad
)
# Print header and separator
self.writeline(header)
self.writeline(separator)
# Print the data rows with row labels
for row_label, row in zip(row_labels, formatted_values):
data_row = ' '.join(
value.rjust(pad) for value, pad in zip(row, max_col_lengths)
)
self.writeline(f'{row_label.ljust(label_len)}{data_row}')
def format_k_v(self, key: Any, val: Any, write: bool = False):
"""
key and val should be str convertible
"""
MAX_KEY_SIZE = 20
SEPARATOR = ', '
EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3)
NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5
key = str(key)
val = str(val)
content = f'{key:<{MAX_KEY_SIZE}}: {val}'
if len(content) > NEW_LINE_LEN:
content = f'{key:<{MAX_KEY_SIZE}}: '
# septate val by separator
val_list = val.split(SEPARATOR)
current_len = len(content)
for val_compo in val_list:
current_len += len(val_compo)
if current_len > NEW_LINE_LEN:
newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}'
content += f'\\\n{newline_content}'
current_len = len(newline_content)
else:
content += f'{val_compo}{SEPARATOR}'
if content.endswith(f'{SEPARATOR}'):
content = content[: -len(SEPARATOR)]
content += '\n'
if write is False:
return content
else:
self.write(content)
return ''
def greeting(self):
LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii'
with open(LOGO_ASCII_FILE, 'r') as logo_f:
logo_ascii = logo_f.read()
content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n'
content += f'version {__version__}, {time.ctime()}\n'
self.write(content)
self.write(logo_ascii)
def bar(self):
content = '-' * Logger.SCREEN_WIDTH + '\n'
self.write(content)
def print_config(
self,
model_config: Dict[str, Any],
data_config: Dict[str, Any],
train_config: Dict[str, Any],
):
"""
print some important information from config
"""
content = 'successfully read yaml config!\n\n' + 'from model configuration\n'
for k, v in model_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom train configuration\n'
for k, v in train_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom data configuration\n'
for k, v in data_config.items():
content += self.format_k_v(k, str(v))
self.write(content)
# TODO: This is not good make own exception
def error(self, e: Exception):
content = ''
if type(e) is ValueError:
content += 'Error occurred!\n'
content += str(e) + '\n'
else:
content += 'Unknown error occurred!\n'
content += traceback.format_exc()
self.write(content)
def timer_start(self, name: str):
self.timer_dct[name] = datetime.now()
def timer_end(self, name: str, message: str, remove: bool = True):
"""
print f"{message}: {elapsed}"
"""
elapsed = str(datetime.now() - self.timer_dct[name])
# elapsed = elapsed.strftime('%H-%M-%S')
if remove:
del self.timer_dct[name]
self.write(f'{message}: {elapsed[:-4]}\n')
# TODO: print it without config
# TODO: refactoring, readout part name :(
def print_model_info(self, model, config):
from functools import partial
kv_write = partial(self.format_k_v, write=True)
self.writeline('Irreps of features')
kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out'))
for i in range(config[KEY.NUM_CONVOLUTION]):
kv_write(
f'{i}th node',
model.get_irreps_in(f'{i}_self_interaction_1'),
)
i = config[KEY.NUM_CONVOLUTION] - 1
kv_write(
'readout irreps',
model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'),
)
num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.writeline(f'# learnable parameters: {num_weights}\n')
import argparse
import os
import sys
import time
from sevenn import __version__
description = 'train a model given the input.yaml'
input_yaml_help = 'input.yaml for training'
mode_help = 'main training script to run. Default is train.'
working_dir_help = 'path to write output. Default is cwd.'
screen_help = 'print log to stdout'
distributed_help = 'set this flag if it is distributed training'
distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi'
# Metainfo will be saved to checkpoint
global_config = {
'version': __version__,
'when': time.ctime(),
'_model_type': 'E3_equivariant_model',
}
def run(args):
"""
main function of sevenn
"""
import random
import sys
import torch
import torch.distributed as dist
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.parse_input import read_config_yaml
from sevenn.scripts.train import train, train_v2
from sevenn.util import unique_filepath
input_yaml = args.input_yaml
mode = args.mode
working_dir = args.working_dir
log = args.log
screen = args.screen
distributed = args.distributed
distributed_backend = args.distributed_backend
use_cue = args.enable_cueq
if use_cue:
import sevenn.nn.cue_helper
if not sevenn.nn.cue_helper.is_cue_available():
raise ImportError('cuEquivariance not installed.')
if working_dir is None:
working_dir = os.getcwd()
elif not os.path.isdir(working_dir):
os.makedirs(working_dir, exist_ok=True)
world_size = 1
if distributed:
if distributed_backend == 'nccl':
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
elif distributed_backend == 'mpi':
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
else:
raise ValueError(f'Unknown distributed backend: {distributed_backend}')
dist.init_process_group(
backend=distributed_backend, world_size=world_size, rank=rank
)
else:
local_rank, rank, world_size = 0, 0, 1
log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}')
with Logger(filename=log_fname, screen=screen, rank=rank) as logger:
logger.greeting()
if distributed:
logger.writeline(
f'Distributed training enabled, total world size is {world_size}'
)
try:
model_config, train_config, data_config = read_config_yaml(
input_yaml, return_separately=True
)
except Exception as e:
logger.writeline('Failed to parsing input.yaml')
logger.error(e)
sys.exit(1)
train_config[KEY.IS_DDP] = distributed
train_config[KEY.DDP_BACKEND] = distributed_backend
train_config[KEY.LOCAL_RANK] = local_rank
train_config[KEY.RANK] = rank
train_config[KEY.WORLD_SIZE] = world_size
if distributed:
torch.cuda.set_device(torch.device('cuda', local_rank))
if use_cue:
if KEY.CUEQUIVARIANCE_CONFIG not in model_config:
model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True}
else:
model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True})
logger.print_config(model_config, data_config, train_config)
# don't have to distinguish configs inside program
global_config.update(model_config)
global_config.update(train_config)
global_config.update(data_config)
# Not implemented
if global_config[KEY.DTYPE] == 'double':
raise Exception('double precision is not implemented yet')
# torch.set_default_dtype(torch.double)
seed = global_config[KEY.RANDOM_SEED]
random.seed(seed)
torch.manual_seed(seed)
# run train
if mode == 'train_v1':
train(global_config, working_dir)
elif mode == 'train_v2':
train_v2(global_config, working_dir)
def cmd_parser_train(parser):
ag = parser
ag.add_argument('input_yaml', help=input_yaml_help, type=str)
ag.add_argument(
'-m',
'--mode',
choices=['train_v1', 'train_v2'],
default='train_v2',
help=mode_help,
type=str,
)
ag.add_argument(
'-cueq',
'--enable_cueq',
help='(Not stable!) use cuEquivariance for training',
action='store_true',
)
ag.add_argument(
'-w',
'--working_dir',
nargs='?',
const=os.getcwd(),
help=working_dir_help,
type=str,
)
ag.add_argument(
'-l',
'--log',
default='log.sevenn',
help='name of logfile, default is log.sevenn',
type=str,
)
ag.add_argument('-s', '--screen', help=screen_help, action='store_true')
ag.add_argument(
'-d', '--distributed', help=distributed_help, action='store_true'
)
ag.add_argument(
'--distributed_backend',
help=distributed_backend_help,
type=str,
default='nccl',
choices=['nccl', 'mpi'],
)
def add_parser(subparsers):
ag = subparsers.add_parser('train', help=description)
cmd_parser_train(ag)
def set_default_subparser(self, name, args=None, positional_args=0):
"""default subparser selection. Call after setup, just before parse_args()
name: is the name of the subparser to call by default
args: if set is the argument list handed to parse_args()
Hack copied from stack overflow
"""
subparser_found = False
for arg in sys.argv[1:]:
if arg in ['-h', '--help']: # global help if no subparser
break
else:
for x in self._subparsers._actions:
if not isinstance(x, argparse._SubParsersAction):
continue
for sp_name in x._name_parser_map.keys():
if sp_name in sys.argv[1:]:
subparser_found = True
if not subparser_found:
# insert default in last position before global positional
# arguments, this implies no global options are specified after
# first positional argument
if args is None:
sys.argv.insert(len(sys.argv) - positional_args, name)
else:
args.insert(len(args) - positional_args, name)
argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore
def main():
import sevenn.main.sevenn_cp as checkpoint_cmd
import sevenn.main.sevenn_get_model as get_model_cmd
import sevenn.main.sevenn_graph_build as graph_build_cmd
import sevenn.main.sevenn_inference as inference_cmd
import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd
import sevenn.main.sevenn_preset as preset_cmd
ag = argparse.ArgumentParser(f'SevenNet version={__version__}')
subparsers = ag.add_subparsers(dest='command', help='Sub-commands')
add_parser(subparsers) # add 'train'
checkpoint_cmd.add_parser(subparsers)
inference_cmd.add_parser(subparsers)
graph_build_cmd.add_parser(subparsers)
preset_cmd.add_parser(subparsers)
get_model_cmd.add_parser(subparsers)
patch_lammps_cmd.add_parser(subparsers)
ag.set_default_subparser('train') # type: ignore
args = ag.parse_args()
if args.command is None: # backward compatibility
args.command = 'train'
if args.command == 'train':
run(args)
elif args.command == 'preset':
preset_cmd.run(args)
if __name__ == '__main__':
main()
import argparse
import os
import sys
import time
from sevenn import __version__
description = 'train a model given the input.yaml'
input_yaml_help = 'input.yaml for training'
mode_help = 'main training script to run. Default is train.'
working_dir_help = 'path to write output. Default is cwd.'
screen_help = 'print log to stdout'
distributed_help = 'set this flag if it is distributed training'
distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi'
# Metainfo will be saved to checkpoint
global_config = {
'version': __version__,
'when': time.ctime(),
'_model_type': 'E3_equivariant_model',
}
def run(args):
"""
main function of sevenn
"""
import random
import sys
import torch
import torch.distributed as dist
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.parse_input import read_config_yaml
from sevenn.scripts.train import train, train_v2
from sevenn.util import unique_filepath
input_yaml = args.input_yaml
mode = args.mode
working_dir = args.working_dir
log = args.log
screen = args.screen
distributed = args.distributed
distributed_backend = args.distributed_backend
use_cue = args.enable_cueq
if use_cue:
import sevenn.nn.cue_helper
if not sevenn.nn.cue_helper.is_cue_available():
raise ImportError('cuEquivariance not installed.')
if working_dir is None:
working_dir = os.getcwd()
elif not os.path.isdir(working_dir):
os.makedirs(working_dir, exist_ok=True)
world_size = 1
if distributed:
if distributed_backend == 'nccl':
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
elif distributed_backend == 'mpi':
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
else:
raise ValueError(f'Unknown distributed backend: {distributed_backend}')
dist.init_process_group(
backend=distributed_backend, world_size=world_size, rank=rank
)
else:
local_rank, rank, world_size = 0, 0, 1
log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}')
with Logger(filename=log_fname, screen=screen, rank=rank) as logger:
logger.greeting()
if distributed:
logger.writeline(
f'Distributed training enabled, total world size is {world_size}'
)
try:
model_config, train_config, data_config = read_config_yaml(
input_yaml, return_separately=True
)
except Exception as e:
logger.writeline('Failed to parsing input.yaml')
logger.error(e)
sys.exit(1)
train_config[KEY.IS_DDP] = distributed
train_config[KEY.DDP_BACKEND] = distributed_backend
train_config[KEY.LOCAL_RANK] = local_rank
train_config[KEY.RANK] = rank
train_config[KEY.WORLD_SIZE] = world_size
if distributed:
torch.cuda.set_device(torch.device('cuda', local_rank))
if use_cue:
if KEY.CUEQUIVARIANCE_CONFIG not in model_config:
model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True}
else:
model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True})
logger.print_config(model_config, data_config, train_config)
# don't have to distinguish configs inside program
global_config.update(model_config)
global_config.update(train_config)
global_config.update(data_config)
# Not implemented
if global_config[KEY.DTYPE] == 'double':
raise Exception('double precision is not implemented yet')
# torch.set_default_dtype(torch.double)
seed = global_config[KEY.RANDOM_SEED]
random.seed(seed)
torch.manual_seed(seed)
# run train
if mode == 'train_v1':
train(global_config, working_dir)
elif mode == 'train_v2':
train_v2(global_config, working_dir)
def cmd_parser_train(parser):
ag = parser
ag.add_argument('input_yaml', help=input_yaml_help, type=str)
ag.add_argument(
'-m',
'--mode',
choices=['train_v1', 'train_v2'],
default='train_v2',
help=mode_help,
type=str,
)
ag.add_argument(
'-cueq',
'--enable_cueq',
help='(Not stable!) use cuEquivariance for training',
action='store_true',
)
ag.add_argument(
'-w',
'--working_dir',
nargs='?',
const=os.getcwd(),
help=working_dir_help,
type=str,
)
ag.add_argument(
'-l',
'--log',
default='log.sevenn',
help='name of logfile, default is log.sevenn',
type=str,
)
ag.add_argument('-s', '--screen', help=screen_help, action='store_true')
ag.add_argument(
'-d', '--distributed', help=distributed_help, action='store_true'
)
ag.add_argument(
'--distributed_backend',
help=distributed_backend_help,
type=str,
default='nccl',
choices=['nccl', 'mpi'],
)
def add_parser(subparsers):
ag = subparsers.add_parser('train', help=description)
cmd_parser_train(ag)
def set_default_subparser(self, name, args=None, positional_args=0):
"""default subparser selection. Call after setup, just before parse_args()
name: is the name of the subparser to call by default
args: if set is the argument list handed to parse_args()
Hack copied from stack overflow
"""
subparser_found = False
for arg in sys.argv[1:]:
if arg in ['-h', '--help']: # global help if no subparser
break
else:
for x in self._subparsers._actions:
if not isinstance(x, argparse._SubParsersAction):
continue
for sp_name in x._name_parser_map.keys():
if sp_name in sys.argv[1:]:
subparser_found = True
if not subparser_found:
# insert default in last position before global positional
# arguments, this implies no global options are specified after
# first positional argument
if args is None:
sys.argv.insert(len(sys.argv) - positional_args, name)
else:
args.insert(len(args) - positional_args, name)
argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore
def main():
import sevenn.main.sevenn_cp as checkpoint_cmd
import sevenn.main.sevenn_get_model as get_model_cmd
import sevenn.main.sevenn_graph_build as graph_build_cmd
import sevenn.main.sevenn_inference as inference_cmd
import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd
import sevenn.main.sevenn_preset as preset_cmd
ag = argparse.ArgumentParser(f'SevenNet version={__version__}')
subparsers = ag.add_subparsers(dest='command', help='Sub-commands')
add_parser(subparsers) # add 'train'
checkpoint_cmd.add_parser(subparsers)
inference_cmd.add_parser(subparsers)
graph_build_cmd.add_parser(subparsers)
preset_cmd.add_parser(subparsers)
get_model_cmd.add_parser(subparsers)
patch_lammps_cmd.add_parser(subparsers)
ag.set_default_subparser('train') # type: ignore
args = ag.parse_args()
if args.command is None: # backward compatibility
args.command = 'train'
if args.command == 'train':
run(args)
elif args.command == 'preset':
preset_cmd.run(args)
if __name__ == '__main__':
main()
import argparse
import os.path as osp
from sevenn import __version__
description = (
'tool box for sevennet checkpoints'
)
def add_parser(subparsers):
ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str)
group = ag.add_mutually_exclusive_group(required=False)
group.add_argument(
'--get_yaml',
choices=['reproduce', 'continue', 'continue_modal'],
help='create input.yaml based on the given checkpoint',
type=str,
)
group.add_argument(
'--append_modal_yaml',
help='append modality with given yaml.',
type=str,
)
ag.add_argument(
'--original_modal_name',
help=(
'when the append_modal is used and checkpoint is not multi-modal, '
+ 'used to name previously trained modality. defaults to "origin"'
),
default='origin',
type=str,
)
def run(args):
import torch
import yaml
from sevenn.parse_input import read_config_yaml
from sevenn.util import load_checkpoint
checkpoint = load_checkpoint(args.checkpoint)
if args.get_yaml:
mode = args.get_yaml
cfg = checkpoint.yaml_dict(mode)
print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False))
elif args.append_modal_yaml:
dst_yaml = args.append_modal_yaml
if not osp.exists(dst_yaml):
raise FileNotFoundError(f'No yaml file {dst_yaml}')
dst_config = read_config_yaml(dst_yaml, return_separately=False)
model_state_dict = checkpoint.append_modal(
dst_config, args.original_modal_name
)
to_save = checkpoint.get_checkpoint_dict()
to_save.update({'config': dst_config, 'model_state_dict': model_state_dict})
torch.save(to_save, 'checkpoint_modal_appended.pth')
print('checkpoint_modal_appended.pth is successfully saved.')
print(f'update continue of {dst_yaml} as blow (recommend) to continue')
cont_dct = {
'continue': {
'checkpoint': 'checkpoint_modal_appended.pth',
'reset_epoch': True,
'reset_optimizer': True,
'reset_scheduler': True,
}
}
print(
yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False)
)
else:
print(checkpoint)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os.path as osp
from sevenn import __version__
description = (
'tool box for sevennet checkpoints'
)
def add_parser(subparsers):
ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str)
group = ag.add_mutually_exclusive_group(required=False)
group.add_argument(
'--get_yaml',
choices=['reproduce', 'continue', 'continue_modal'],
help='create input.yaml based on the given checkpoint',
type=str,
)
group.add_argument(
'--append_modal_yaml',
help='append modality with given yaml.',
type=str,
)
ag.add_argument(
'--original_modal_name',
help=(
'when the append_modal is used and checkpoint is not multi-modal, '
+ 'used to name previously trained modality. defaults to "origin"'
),
default='origin',
type=str,
)
def run(args):
import torch
import yaml
from sevenn.parse_input import read_config_yaml
from sevenn.util import load_checkpoint
checkpoint = load_checkpoint(args.checkpoint)
if args.get_yaml:
mode = args.get_yaml
cfg = checkpoint.yaml_dict(mode)
print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False))
elif args.append_modal_yaml:
dst_yaml = args.append_modal_yaml
if not osp.exists(dst_yaml):
raise FileNotFoundError(f'No yaml file {dst_yaml}')
dst_config = read_config_yaml(dst_yaml, return_separately=False)
model_state_dict = checkpoint.append_modal(
dst_config, args.original_modal_name
)
to_save = checkpoint.get_checkpoint_dict()
to_save.update({'config': dst_config, 'model_state_dict': model_state_dict})
torch.save(to_save, 'checkpoint_modal_appended.pth')
print('checkpoint_modal_appended.pth is successfully saved.')
print(f'update continue of {dst_yaml} as blow (recommend) to continue')
cont_dct = {
'continue': {
'checkpoint': 'checkpoint_modal_appended.pth',
'reset_epoch': True,
'reset_optimizer': True,
'reset_scheduler': True,
}
}
print(
yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False)
)
else:
print(checkpoint)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
from sevenn import __version__
description_get_model = (
'deploy LAMMPS model from the checkpoint'
)
checkpoint_help = (
'path to the checkpoint | SevenNet-0 | 7net-0 |'
' {SevenNet-0|7net-0}_{11July2024|22May2024}'
)
output_name_help = 'filename prefix'
get_parallel_help = 'deploy parallel model'
def add_parser(subparsers):
ag = subparsers.add_parser(
'get_model', help=description_get_model, aliases=['deploy']
)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help=checkpoint_help, type=str)
ag.add_argument(
'-o', '--output_prefix', nargs='?', help=output_name_help, type=str
)
ag.add_argument(
'-p', '--get_parallel', help=get_parallel_help, action='store_true'
)
ag.add_argument(
'-m',
'--modal',
help='Modality of multi-modal model',
type=str,
)
def run(args):
import sevenn.util
from sevenn.scripts.deploy import deploy, deploy_parallel
checkpoint = args.checkpoint
output_prefix = args.output_prefix
get_parallel = args.get_parallel
get_serial = not get_parallel
modal = args.modal
if output_prefix is None:
output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial'
checkpoint_path = None
if os.path.isfile(checkpoint):
checkpoint_path = checkpoint
else:
checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint)
if get_serial:
deploy(checkpoint_path, output_prefix, modal)
else:
deploy_parallel(checkpoint_path, output_prefix, modal)
# legacy way
def main():
ag = argparse.ArgumentParser(description=description_get_model)
add_args(ag)
run(ag.parse_args())
import argparse
import os
from sevenn import __version__
description_get_model = (
'deploy LAMMPS model from the checkpoint'
)
checkpoint_help = (
'path to the checkpoint | SevenNet-0 | 7net-0 |'
' {SevenNet-0|7net-0}_{11July2024|22May2024}'
)
output_name_help = 'filename prefix'
get_parallel_help = 'deploy parallel model'
def add_parser(subparsers):
ag = subparsers.add_parser(
'get_model', help=description_get_model, aliases=['deploy']
)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help=checkpoint_help, type=str)
ag.add_argument(
'-o', '--output_prefix', nargs='?', help=output_name_help, type=str
)
ag.add_argument(
'-p', '--get_parallel', help=get_parallel_help, action='store_true'
)
ag.add_argument(
'-m',
'--modal',
help='Modality of multi-modal model',
type=str,
)
def run(args):
import sevenn.util
from sevenn.scripts.deploy import deploy, deploy_parallel
checkpoint = args.checkpoint
output_prefix = args.output_prefix
get_parallel = args.get_parallel
get_serial = not get_parallel
modal = args.modal
if output_prefix is None:
output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial'
checkpoint_path = None
if os.path.isfile(checkpoint):
checkpoint_path = checkpoint
else:
checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint)
if get_serial:
deploy(checkpoint_path, output_prefix, modal)
else:
deploy_parallel(checkpoint_path, output_prefix, modal)
# legacy way
def main():
ag = argparse.ArgumentParser(description=description_get_model)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
from datetime import datetime
from sevenn import __version__
description = 'create `sevenn_data/dataset.pt` from ase readable'
source_help = 'source data to build graph, knows *'
cutoff_help = 'cutoff radius of edges in Angstrom'
filename_help = (
'Name of the dataset, default is graph.pt. '
+ 'The dataset will be written under "sevenn_data", '
+ 'for example, {out}/sevenn_data/graph.pt.'
)
legacy_help = 'build legacy .sevenn_data'
def add_parser(subparsers):
ag = subparsers.add_parser('graph_build', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('source', help=source_help, type=str)
ag.add_argument('cutoff', help=cutoff_help, type=float)
ag.add_argument(
'-n',
'--num_cores',
help='number of cores to build graph in parallel',
default=1,
type=int,
)
ag.add_argument(
'-o',
'--out',
help='Existing path to write outputs.',
type=str,
default='./',
)
ag.add_argument(
'-f',
'--filename',
help=filename_help,
type=str,
default='graph.pt',
)
ag.add_argument(
'--legacy',
help=legacy_help,
action='store_true',
)
ag.add_argument(
'-s',
'--screen',
help='print log to the screen',
action='store_true',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to ase.io.read, or can be used to specify EFS key',
)
def run(args):
import sevenn.scripts.graph_build as graph_build
from sevenn.logger import Logger
source = glob.glob(args.source)
cutoff = args.cutoff
num_cores = args.num_cores
filename = args.filename
out = args.out
legacy = args.legacy
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if len(source) == 0:
print('Source has zero len, nothing to read')
sys.exit(0)
if not os.path.isdir(out):
raise NotADirectoryError(f'No such directory: {out}')
to_be_written = os.path.join(out, 'sevenn_data', filename)
if os.path.isfile(to_be_written):
raise FileExistsError(f'File already exist: {to_be_written}')
metadata = {
'sevenn_version': __version__,
'when': datetime.now().strftime('%Y-%m-%d'),
'cutoff': cutoff,
}
with Logger(filename=None, screen=args.screen) as logger:
logger.writeline(description)
if not legacy:
graph_build.build_sevennet_graph_dataset(
source,
cutoff,
num_cores,
out,
filename,
metadata,
**fmt_kwargs,
)
else:
out = os.path.join(out, filename.split('.')[0])
graph_build.build_script( # build .sevenn_data
source,
cutoff,
num_cores,
out,
metadata,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
from datetime import datetime
from sevenn import __version__
description = 'create `sevenn_data/dataset.pt` from ase readable'
source_help = 'source data to build graph, knows *'
cutoff_help = 'cutoff radius of edges in Angstrom'
filename_help = (
'Name of the dataset, default is graph.pt. '
+ 'The dataset will be written under "sevenn_data", '
+ 'for example, {out}/sevenn_data/graph.pt.'
)
legacy_help = 'build legacy .sevenn_data'
def add_parser(subparsers):
ag = subparsers.add_parser('graph_build', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('source', help=source_help, type=str)
ag.add_argument('cutoff', help=cutoff_help, type=float)
ag.add_argument(
'-n',
'--num_cores',
help='number of cores to build graph in parallel',
default=1,
type=int,
)
ag.add_argument(
'-o',
'--out',
help='Existing path to write outputs.',
type=str,
default='./',
)
ag.add_argument(
'-f',
'--filename',
help=filename_help,
type=str,
default='graph.pt',
)
ag.add_argument(
'--legacy',
help=legacy_help,
action='store_true',
)
ag.add_argument(
'-s',
'--screen',
help='print log to the screen',
action='store_true',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to ase.io.read, or can be used to specify EFS key',
)
def run(args):
import sevenn.scripts.graph_build as graph_build
from sevenn.logger import Logger
source = glob.glob(args.source)
cutoff = args.cutoff
num_cores = args.num_cores
filename = args.filename
out = args.out
legacy = args.legacy
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if len(source) == 0:
print('Source has zero len, nothing to read')
sys.exit(0)
if not os.path.isdir(out):
raise NotADirectoryError(f'No such directory: {out}')
to_be_written = os.path.join(out, 'sevenn_data', filename)
if os.path.isfile(to_be_written):
raise FileExistsError(f'File already exist: {to_be_written}')
metadata = {
'sevenn_version': __version__,
'when': datetime.now().strftime('%Y-%m-%d'),
'cutoff': cutoff,
}
with Logger(filename=None, screen=args.screen) as logger:
logger.writeline(description)
if not legacy:
graph_build.build_sevennet_graph_dataset(
source,
cutoff,
num_cores,
out,
filename,
metadata,
**fmt_kwargs,
)
else:
out = os.path.join(out, filename.split('.')[0])
graph_build.build_script( # build .sevenn_data
source,
cutoff,
num_cores,
out,
metadata,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
description = (
'evaluate sevenn_data/ase readable with a model (checkpoint).'
)
checkpoint_help = 'Checkpoint or pre-trained model name'
target_help = 'Target files to evaluate'
def add_parser(subparsers):
ag = subparsers.add_parser('inference', help=description, aliases=['inf'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', type=str, help=checkpoint_help)
ag.add_argument('targets', type=str, nargs='+', help=target_help)
ag.add_argument(
'-d',
'--device',
type=str,
default='auto',
help='cpu/cuda/cuda:x',
)
ag.add_argument(
'-nw',
'--nworkers',
type=int,
default=1,
help='Number of cores to build graph, defaults to 1',
)
ag.add_argument(
'-o',
'--output',
type=str,
default='./inference_results',
help='A directory name to write outputs',
)
ag.add_argument(
'-b',
'--batch',
type=int,
default='4',
help='batch size, useful for GPU'
)
ag.add_argument(
'-s',
'--save_graph',
action='store_true',
help='Additionally, save preprocessed graph as sevenn_data'
)
ag.add_argument(
'-au',
'--allow_unlabeled',
action='store_true',
help='Allow energy or force unlabeled data'
)
ag.add_argument(
'-m',
'--modal',
type=str,
default=None,
help='modality for multi-modal inference',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to reader, or can be used to specify EFS key',
)
def run(args):
import torch
from sevenn.scripts.inference import inference
from sevenn.util import pretrained_name_to_path
out = args.output
if os.path.exists(out):
raise FileExistsError(f'Directory {out} already exists')
device = args.device
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
targets = []
for target in args.targets:
targets.extend(glob.glob(target))
if len(targets) == 0:
print('No targets (data to inference) are found')
sys.exit(0)
cp = args.checkpoint
if not os.path.isfile(cp):
cp = pretrained_name_to_path(cp) # raises value error
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if args.save_graph and args.allow_unlabeled:
raise ValueError('save_graph and allow_unlabeled are mutually exclusive')
inference(
cp,
targets,
out,
args.nworkers,
device,
args.batch,
args.save_graph,
args.allow_unlabeled,
args.modal,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
description = (
'evaluate sevenn_data/ase readable with a model (checkpoint).'
)
checkpoint_help = 'Checkpoint or pre-trained model name'
target_help = 'Target files to evaluate'
def add_parser(subparsers):
ag = subparsers.add_parser('inference', help=description, aliases=['inf'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', type=str, help=checkpoint_help)
ag.add_argument('targets', type=str, nargs='+', help=target_help)
ag.add_argument(
'-d',
'--device',
type=str,
default='auto',
help='cpu/cuda/cuda:x',
)
ag.add_argument(
'-nw',
'--nworkers',
type=int,
default=1,
help='Number of cores to build graph, defaults to 1',
)
ag.add_argument(
'-o',
'--output',
type=str,
default='./inference_results',
help='A directory name to write outputs',
)
ag.add_argument(
'-b',
'--batch',
type=int,
default='4',
help='batch size, useful for GPU'
)
ag.add_argument(
'-s',
'--save_graph',
action='store_true',
help='Additionally, save preprocessed graph as sevenn_data'
)
ag.add_argument(
'-au',
'--allow_unlabeled',
action='store_true',
help='Allow energy or force unlabeled data'
)
ag.add_argument(
'-m',
'--modal',
type=str,
default=None,
help='modality for multi-modal inference',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to reader, or can be used to specify EFS key',
)
def run(args):
import torch
from sevenn.scripts.inference import inference
from sevenn.util import pretrained_name_to_path
out = args.output
if os.path.exists(out):
raise FileExistsError(f'Directory {out} already exists')
device = args.device
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
targets = []
for target in args.targets:
targets.extend(glob.glob(target))
if len(targets) == 0:
print('No targets (data to inference) are found')
sys.exit(0)
cp = args.checkpoint
if not os.path.isfile(cp):
cp = pretrained_name_to_path(cp) # raises value error
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if args.save_graph and args.allow_unlabeled:
raise ValueError('save_graph and allow_unlabeled are mutually exclusive')
inference(
cp,
targets,
out,
args.nworkers,
device,
args.batch,
args.save_graph,
args.allow_unlabeled,
args.modal,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
import subprocess
from sevenn import __version__
# python wrapper of patch_lammps.sh script
# importlib.resources is correct way to do these things
# but it changes so frequently to use
pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn')
description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile'
def add_parser(subparsers):
ag = subparsers.add_parser('patch_lammps', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str)
ag.add_argument('--d3', help='Enable D3 support', action='store_true')
# cxx_standard is detected automatically
def run(args):
lammps_dir = os.path.abspath(args.lammps_dir)
print('Patching LAMMPS with the following settings:')
print(' - LAMMPS source directory:', lammps_dir)
cxx_standard = '17' # always 17
if args.d3:
d3_support = '1'
print(' - D3 support enabled')
else:
d3_support = '0'
print(' - D3 support disabled')
script = f'{pair_e3gnn_dir}/patch_lammps.sh'
cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}'
res = subprocess.run(cmd.split())
return res.returncode # is it meaningless?
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
if __name__ == '__main__':
main()
import argparse
import os
import subprocess
from sevenn import __version__
# python wrapper of patch_lammps.sh script
# importlib.resources is correct way to do these things
# but it changes so frequently to use
pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn')
description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile'
def add_parser(subparsers):
ag = subparsers.add_parser('patch_lammps', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str)
ag.add_argument('--d3', help='Enable D3 support', action='store_true')
# cxx_standard is detected automatically
def run(args):
lammps_dir = os.path.abspath(args.lammps_dir)
print('Patching LAMMPS with the following settings:')
print(' - LAMMPS source directory:', lammps_dir)
cxx_standard = '17' # always 17
if args.d3:
d3_support = '1'
print(' - D3 support enabled')
else:
d3_support = '0'
print(' - D3 support disabled')
script = f'{pair_e3gnn_dir}/patch_lammps.sh'
cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}'
res = subprocess.run(cmd.split())
return res.returncode # is it meaningless?
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
if __name__ == '__main__':
main()
import argparse
import os
from sevenn import __version__
description = (
'print the selected preset for training. '
+ 'ex) sevennet_preset fine_tune > my_input.yaml'
)
preset_help = 'Name of preset'
def add_parser(subparsers):
ag = subparsers.add_parser('preset', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument(
'preset', choices=[
'fine_tune',
'fine_tune_le',
'sevennet-0',
'sevennet-l3i5',
'base',
'multi_modal'
],
help=preset_help
)
def run(args):
preset = args.preset
prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets')
with open(f'{prefix}/{preset}.yaml', 'r') as f:
print(f.read())
# When executed as sevenn_preset (legacy way)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
from sevenn import __version__
description = (
'print the selected preset for training. '
+ 'ex) sevennet_preset fine_tune > my_input.yaml'
)
preset_help = 'Name of preset'
def add_parser(subparsers):
ag = subparsers.add_parser('preset', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument(
'preset', choices=[
'fine_tune',
'fine_tune_le',
'sevennet-0',
'sevennet-l3i5',
'base',
'multi_modal'
],
help=preset_help
)
def run(args):
preset = args.preset
prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets')
with open(f'{prefix}/{preset}.yaml', 'r') as f:
print(f.read())
# When executed as sevenn_preset (legacy way)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import copy
import warnings
from collections import OrderedDict
from typing import List, Literal, Union, overload
from e3nn.o3 import Irreps
import sevenn._const as _const
import sevenn._keys as KEY
import sevenn.util as util
from .nn.convolution import IrrepsConvolution
from .nn.edge_embedding import (
BesselBasis,
EdgeEmbedding,
PolynomialCutoff,
SphericalEncoding,
XPLORCutoff,
)
from .nn.force_output import ForceStressOutputFromEdge
from .nn.interaction_blocks import NequIP_interaction_block
from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear
from .nn.node_embedding import OnehotEmbedding
from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale
from .nn.self_connection import (
SelfConnectionIntro,
SelfConnectionLinearIntro,
SelfConnectionOutro,
)
from .nn.sequential import AtomGraphSequential
# warning from PyTorch, about e3nn type annotations
warnings.filterwarnings(
'ignore',
message=(
"The TorchScript type system doesn't " 'support instance-level annotations'
),
)
def _insert_after(module_name_after, key_module_pair, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name_after:
idx = i
break
if idx == -1:
return layers # do nothing if not found
layers.insert(idx + 1, key_module_pair)
return layers
def init_self_connection(config):
self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE]
num_conv = config[KEY.NUM_CONVOLUTION]
if isinstance(self_connection_type_list, str):
self_connection_type_list = [self_connection_type_list] * num_conv
io_pair_list = []
for sc_type in self_connection_type_list:
if sc_type == 'none':
io_pair = None
elif sc_type == 'nequip':
io_pair = SelfConnectionIntro, SelfConnectionOutro
elif sc_type == 'linear':
io_pair = SelfConnectionLinearIntro, SelfConnectionOutro
else:
raise ValueError(f'Unknown self_connection_type found: {sc_type}')
io_pair_list.append(io_pair)
return io_pair_list
def init_edge_embedding(config):
_cutoff_param = {'cutoff_length': config[KEY.CUTOFF]}
rbf, env, sph = None, None, None
rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS])
rbf_dct.update(_cutoff_param)
rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME)
if rbf_name == 'bessel':
rbf = BesselBasis(**rbf_dct)
envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION])
envelop_dct.update(_cutoff_param)
envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME)
if envelop_name == 'poly_cut':
env = PolynomialCutoff(**envelop_dct)
elif envelop_name == 'XPLOR':
env = XPLORCutoff(**envelop_dct)
lmax_edge = config[KEY.LMAX]
if config[KEY.LMAX_EDGE] > 0:
lmax_edge = config[KEY.LMAX_EDGE]
parity = -1 if config[KEY.IS_PARITY] else 1
_normalize_sph = config[KEY._NORMALIZE_SPH]
sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph)
return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph)
def init_feature_reduce(config, irreps_x):
# features per node to scalar per node
layers = OrderedDict()
if config[KEY.READOUT_AS_FCN] is False:
hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))])
layers.update(
{
'reduce_input_to_hidden': IrrepsLinear(
irreps_x,
hidden_irreps,
data_key_in=KEY.NODE_FEATURE,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
'reduce_hidden_to_energy': IrrepsLinear(
hidden_irreps,
Irreps([(1, (0, 1))]),
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
}
)
else:
act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]]
hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS]
layers.update(
{
'readout_FCN': FCN_e3nn(
dim_out=1,
hidden_neurons=hidden_neurons,
activation=act,
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
irreps_in=irreps_x,
)
}
)
return layers
def init_shift_scale(config):
# for mm, ex, shift: modal_idx -> shifts
shift_scale = []
train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE]
type_map = config[KEY.TYPE_MAP]
# in case of modal, shift or scale has more dims [][]
# correct typing (I really want static python)
for s in (config[KEY.SHIFT], config[KEY.SCALE]):
if hasattr(s, 'tolist'): # numpy or torch
s = s.tolist()
if isinstance(s, dict):
s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()}
if isinstance(s, list) and len(s) == 1:
s = s[0]
shift_scale.append(s)
shift, scale = shift_scale
rescale_module = None
if config.get(KEY.USE_MODALITY, False):
rescale_module = ModalWiseRescale.from_mappers( # type: ignore
shift,
scale,
config[KEY.USE_MODAL_WISE_SHIFT],
config[KEY.USE_MODAL_WISE_SCALE],
type_map=type_map,
modal_map=config[KEY.MODAL_MAP],
train_shift_scale=train_shift_scale,
)
elif all([isinstance(s, float) for s in shift_scale]):
rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale)
elif any([isinstance(s, list) for s in shift_scale]):
rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore
shift, scale, type_map=type_map, train_shift_scale=train_shift_scale
)
else:
raise ValueError('shift, scale should be list of float or float')
return rescale_module
def patch_modality(layers: OrderedDict, config):
"""
Postprocess 7net-model to multimodal model.
1. prepend modality one-hot embedding layer
2. patch modalities of IrrepsLinear layers
Modality aware shift scale is handled by init_shift_scale, not here
"""
cfg = config
if not cfg.get(KEY.USE_MODALITY, False):
return layers
_layers = list(layers.items())
_layers = _insert_after(
'onehot_idx_to_onehot',
(
'one_hot_modality',
OnehotEmbedding(
num_classes=config[KEY.NUM_MODALITIES],
data_key_x=KEY.MODAL_TYPE,
data_key_out=KEY.MODAL_ATTR,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
layers = OrderedDict(_layers)
num_modal = config[KEY.NUM_MODALITIES]
for k, module in layers.items():
if not isinstance(module, IrrepsLinear):
continue
if (
(cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x'))
or (
cfg[KEY.USE_MODAL_SELF_INTER_INTRO]
and k.endswith('self_interaction_1')
)
or (
cfg[KEY.USE_MODAL_SELF_INTER_OUTRO]
and k.endswith('self_interaction_2')
)
or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden')
):
module.set_num_modalities(num_modal)
return layers
def patch_cue(layers: OrderedDict, config):
import sevenn.nn.cue_helper as cue_helper
cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {}))
if not cue_cfg.pop('use', False):
return layers
if not cue_helper.is_cue_available():
warnings.warn(
(
'cuEquivariance is requested, but the package is not installed. '
+ 'Fallback to original code.'
)
)
return layers
if not cue_helper.is_cue_cuda_available_model(config):
return layers
group = 'O3' if config[KEY.IS_PARITY] else 'SO3'
cueq_module_params = dict(layout='mul_ir')
cueq_module_params.update(cue_cfg)
updates = {}
for k, module in layers.items():
if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)):
if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape
continue
module_patched = cue_helper.patch_linear(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, SelfConnectionIntro):
module_patched = cue_helper.patch_fully_connected(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, IrrepsConvolution):
module_patched = cue_helper.patch_convolution(
module, group, **cueq_module_params
)
updates[k] = module_patched
layers.update(updates)
return layers
def patch_modules(layers: OrderedDict, config):
layers = patch_modality(layers, config)
layers = patch_cue(layers, config)
return layers
def _to_parallel_model(layers: OrderedDict, config):
num_classes = layers['onehot_idx_to_onehot'].num_classes
one_hot_irreps = Irreps(f'{num_classes}x0e')
irreps_node_zero = layers['onehot_to_feature_x'].irreps_out
_layers = list(layers.items())
layers_list = []
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
def slice_until_this(module_name, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name:
idx = i
break
first_to = layers[: idx + 1]
remain = layers[idx + 1 :]
return first_to, remain
_layers = _insert_after(
'onehot_to_feature_x',
(
'one_hot_ghost',
OnehotEmbedding(
data_key_x=KEY.NODE_FEATURE_GHOST,
num_classes=num_classes,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
_layers = _insert_after(
'one_hot_ghost',
(
'ghost_onehot_to_feature_x',
IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
_layers = _insert_after(
'0_self_interaction_1',
(
'ghost_0_self_interaction_1',
IrrepsLinear(
irreps_node_zero,
irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
# assign modules (before first communications)
# initialize edge related to retain position gradients
for i in range(1, num_convolution_layer):
sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers)
layers_list.append(OrderedDict(sliced))
_layers.insert(0, ('edge_embedding', init_edge_embedding(config)))
layers_list.append(OrderedDict(_layers))
del layers_list[-1]['force_output'] # done in LAMMPS
return layers_list
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[False] = False
) -> AtomGraphSequential: # noqa
...
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[True]
) -> List[AtomGraphSequential]: # noqa
...
def build_E3_equivariant_model(
config: dict, parallel: bool = False
) -> Union[AtomGraphSequential, List[AtomGraphSequential]]:
"""
output shapes (w/o batch)
PRED_TOTAL_ENERGY: (),
ATOMIC_ENERGY: (natoms, 1), # intended
PRED_FORCE: (natoms, 3),
PRED_STRESS: (6,),
for data w/o cell volume, pred_stress has garbage values
"""
layers = OrderedDict()
cutoff = config[KEY.CUTOFF]
num_species = config[KEY.NUM_SPECIES]
feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY]
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
interaction_type = config[KEY.INTERACTION_TYPE]
use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR]
lmax_node = config[KEY.LMAX] # ignore second (lmax_edge)
# if config[KEY.LMAX_EDGE] > 0: # not yet used
# _ = config[KEY.LMAX_EDGE]
if config[KEY.LMAX_NODE] > 0:
lmax_node = config[KEY.LMAX_NODE]
act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]]
self_connection_pair_list = init_self_connection(config)
irreps_manual = None
if config[KEY.IRREPS_MANUAL] is not False:
irreps_manual = config[KEY.IRREPS_MANUAL]
try:
irreps_manual = [Irreps(irr) for irr in irreps_manual]
assert len(irreps_manual) == num_convolution_layer + 1
except Exception:
raise RuntimeError('invalid irreps_manual input given')
conv_denominator = config[KEY.CONV_DENOMINATOR]
if not isinstance(conv_denominator, list):
conv_denominator = [conv_denominator] * num_convolution_layer
train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR]
edge_embedding = init_edge_embedding(config)
irreps_filter = edge_embedding.spherical.irreps_out
radial_basis_num = edge_embedding.basis_function.num_basis
layers.update({'edge_embedding': edge_embedding})
one_hot_irreps = Irreps(f'{num_species}x0e')
irreps_x = (
Irreps(f'{feature_multiplicity}x0e')
if irreps_manual is None
else irreps_manual[0]
)
layers.update(
{
'onehot_idx_to_onehot': OnehotEmbedding(
num_classes=num_species,
data_key_x=KEY.NODE_FEATURE,
data_key_out=KEY.NODE_FEATURE,
data_key_save=KEY.ATOM_TYPE, # atomic numbers
data_key_additional=KEY.NODE_ATTR, # one-hot embeddings
),
'onehot_to_feature_x': IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_x,
data_key_in=KEY.NODE_FEATURE,
biases=use_bias_in_linear,
),
}
)
weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS]
weight_nn_layers = [radial_basis_num] + weight_nn_hidden
param_interaction_block = {
'irreps_filter': irreps_filter,
'weight_nn_layers': weight_nn_layers,
'train_conv_denominator': train_conv_denominator,
'act_radial': act_radial,
'bias_in_linear': use_bias_in_linear,
'num_species': num_species,
'parallel': parallel,
}
interaction_builder = None
if interaction_type in ['nequip']:
act_scalar = {}
act_gate = {}
for k, v in config[KEY.ACTIVATION_SCARLAR].items():
act_scalar[k] = _const.ACTIVATION_DICT[k][v]
for k, v in config[KEY.ACTIVATION_GATE].items():
act_gate[k] = _const.ACTIVATION_DICT[k][v]
param_interaction_block.update(
{
'act_scalar': act_scalar,
'act_gate': act_gate,
}
)
if interaction_type == 'nequip':
interaction_builder = NequIP_interaction_block
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
for t in range(num_convolution_layer):
param_interaction_block.update(
{
'irreps_x': irreps_x,
't': t,
'conv_denominator': conv_denominator[t],
'self_connection_pair': self_connection_pair_list[t],
}
)
if interaction_type == 'nequip':
parity_mode = 'full'
fix_multiplicity = False
if t == num_convolution_layer - 1:
lmax_node = 0
parity_mode = 'even'
# TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out
irreps_out = (
util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
lmax_node, # type: ignore
parity_mode,
fix_multiplicity=feature_multiplicity,
)
if irreps_manual is None
else irreps_manual[t + 1]
)
irreps_out_tp = util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
irreps_out.lmax, # type: ignore
parity_mode,
fix_multiplicity,
)
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
param_interaction_block.update(
{
'irreps_out_tp': irreps_out_tp,
'irreps_out': irreps_out,
}
)
layers.update(interaction_builder(**param_interaction_block))
irreps_x = irreps_out
layers.update(init_feature_reduce(config, irreps_x))
layers.update(
{
'rescale_atomic_energy': init_shift_scale(config),
'reduce_total_enegy': AtomReduce(
data_key_in=KEY.ATOMIC_ENERGY,
data_key_out=KEY.PRED_TOTAL_ENERGY,
),
}
)
gradient_module = ForceStressOutputFromEdge()
grad_key = gradient_module.get_grad_key()
layers.update({'force_output': gradient_module})
common_args = {
'cutoff': cutoff,
'type_map': config[KEY.TYPE_MAP],
'modal_map': config.get(KEY.MODAL_MAP, None),
'eval_type_map': False if parallel else True,
'eval_modal_map': False
if not config.get(KEY.USE_MODALITY, False) or parallel
else True,
'data_key_grad': grad_key,
}
if parallel:
layers_list = _to_parallel_model(layers, config)
return [
AtomGraphSequential(patch_modules(layers, config), **common_args)
for layers in layers_list
]
else:
return AtomGraphSequential(patch_modules(layers, config), **common_args)
import copy
import warnings
from collections import OrderedDict
from typing import List, Literal, Union, overload
from e3nn.o3 import Irreps
import sevenn._const as _const
import sevenn._keys as KEY
import sevenn.util as util
from .nn.convolution import IrrepsConvolution
from .nn.edge_embedding import (
BesselBasis,
EdgeEmbedding,
PolynomialCutoff,
SphericalEncoding,
XPLORCutoff,
)
from .nn.force_output import ForceStressOutputFromEdge
from .nn.interaction_blocks import NequIP_interaction_block
from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear
from .nn.node_embedding import OnehotEmbedding
from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale
from .nn.self_connection import (
SelfConnectionIntro,
SelfConnectionLinearIntro,
SelfConnectionOutro,
)
from .nn.sequential import AtomGraphSequential
# warning from PyTorch, about e3nn type annotations
warnings.filterwarnings(
'ignore',
message=(
"The TorchScript type system doesn't " 'support instance-level annotations'
),
)
def _insert_after(module_name_after, key_module_pair, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name_after:
idx = i
break
if idx == -1:
return layers # do nothing if not found
layers.insert(idx + 1, key_module_pair)
return layers
def init_self_connection(config):
self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE]
num_conv = config[KEY.NUM_CONVOLUTION]
if isinstance(self_connection_type_list, str):
self_connection_type_list = [self_connection_type_list] * num_conv
io_pair_list = []
for sc_type in self_connection_type_list:
if sc_type == 'none':
io_pair = None
elif sc_type == 'nequip':
io_pair = SelfConnectionIntro, SelfConnectionOutro
elif sc_type == 'linear':
io_pair = SelfConnectionLinearIntro, SelfConnectionOutro
else:
raise ValueError(f'Unknown self_connection_type found: {sc_type}')
io_pair_list.append(io_pair)
return io_pair_list
def init_edge_embedding(config):
_cutoff_param = {'cutoff_length': config[KEY.CUTOFF]}
rbf, env, sph = None, None, None
rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS])
rbf_dct.update(_cutoff_param)
rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME)
if rbf_name == 'bessel':
rbf = BesselBasis(**rbf_dct)
envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION])
envelop_dct.update(_cutoff_param)
envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME)
if envelop_name == 'poly_cut':
env = PolynomialCutoff(**envelop_dct)
elif envelop_name == 'XPLOR':
env = XPLORCutoff(**envelop_dct)
lmax_edge = config[KEY.LMAX]
if config[KEY.LMAX_EDGE] > 0:
lmax_edge = config[KEY.LMAX_EDGE]
parity = -1 if config[KEY.IS_PARITY] else 1
_normalize_sph = config[KEY._NORMALIZE_SPH]
sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph)
return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph)
def init_feature_reduce(config, irreps_x):
# features per node to scalar per node
layers = OrderedDict()
if config[KEY.READOUT_AS_FCN] is False:
hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))])
layers.update(
{
'reduce_input_to_hidden': IrrepsLinear(
irreps_x,
hidden_irreps,
data_key_in=KEY.NODE_FEATURE,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
'reduce_hidden_to_energy': IrrepsLinear(
hidden_irreps,
Irreps([(1, (0, 1))]),
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
}
)
else:
act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]]
hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS]
layers.update(
{
'readout_FCN': FCN_e3nn(
dim_out=1,
hidden_neurons=hidden_neurons,
activation=act,
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
irreps_in=irreps_x,
)
}
)
return layers
def init_shift_scale(config):
# for mm, ex, shift: modal_idx -> shifts
shift_scale = []
train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE]
type_map = config[KEY.TYPE_MAP]
# in case of modal, shift or scale has more dims [][]
# correct typing (I really want static python)
for s in (config[KEY.SHIFT], config[KEY.SCALE]):
if hasattr(s, 'tolist'): # numpy or torch
s = s.tolist()
if isinstance(s, dict):
s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()}
if isinstance(s, list) and len(s) == 1:
s = s[0]
shift_scale.append(s)
shift, scale = shift_scale
rescale_module = None
if config.get(KEY.USE_MODALITY, False):
rescale_module = ModalWiseRescale.from_mappers( # type: ignore
shift,
scale,
config[KEY.USE_MODAL_WISE_SHIFT],
config[KEY.USE_MODAL_WISE_SCALE],
type_map=type_map,
modal_map=config[KEY.MODAL_MAP],
train_shift_scale=train_shift_scale,
)
elif all([isinstance(s, float) for s in shift_scale]):
rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale)
elif any([isinstance(s, list) for s in shift_scale]):
rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore
shift, scale, type_map=type_map, train_shift_scale=train_shift_scale
)
else:
raise ValueError('shift, scale should be list of float or float')
return rescale_module
def patch_modality(layers: OrderedDict, config):
"""
Postprocess 7net-model to multimodal model.
1. prepend modality one-hot embedding layer
2. patch modalities of IrrepsLinear layers
Modality aware shift scale is handled by init_shift_scale, not here
"""
cfg = config
if not cfg.get(KEY.USE_MODALITY, False):
return layers
_layers = list(layers.items())
_layers = _insert_after(
'onehot_idx_to_onehot',
(
'one_hot_modality',
OnehotEmbedding(
num_classes=config[KEY.NUM_MODALITIES],
data_key_x=KEY.MODAL_TYPE,
data_key_out=KEY.MODAL_ATTR,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
layers = OrderedDict(_layers)
num_modal = config[KEY.NUM_MODALITIES]
for k, module in layers.items():
if not isinstance(module, IrrepsLinear):
continue
if (
(cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x'))
or (
cfg[KEY.USE_MODAL_SELF_INTER_INTRO]
and k.endswith('self_interaction_1')
)
or (
cfg[KEY.USE_MODAL_SELF_INTER_OUTRO]
and k.endswith('self_interaction_2')
)
or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden')
):
module.set_num_modalities(num_modal)
return layers
def patch_cue(layers: OrderedDict, config):
import sevenn.nn.cue_helper as cue_helper
cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {}))
if not cue_cfg.pop('use', False):
return layers
if not cue_helper.is_cue_available():
warnings.warn(
(
'cuEquivariance is requested, but the package is not installed. '
+ 'Fallback to original code.'
)
)
return layers
if not cue_helper.is_cue_cuda_available_model(config):
return layers
group = 'O3' if config[KEY.IS_PARITY] else 'SO3'
cueq_module_params = dict(layout='mul_ir')
cueq_module_params.update(cue_cfg)
updates = {}
for k, module in layers.items():
if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)):
if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape
continue
module_patched = cue_helper.patch_linear(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, SelfConnectionIntro):
module_patched = cue_helper.patch_fully_connected(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, IrrepsConvolution):
module_patched = cue_helper.patch_convolution(
module, group, **cueq_module_params
)
updates[k] = module_patched
layers.update(updates)
return layers
def patch_modules(layers: OrderedDict, config):
layers = patch_modality(layers, config)
layers = patch_cue(layers, config)
return layers
def _to_parallel_model(layers: OrderedDict, config):
num_classes = layers['onehot_idx_to_onehot'].num_classes
one_hot_irreps = Irreps(f'{num_classes}x0e')
irreps_node_zero = layers['onehot_to_feature_x'].irreps_out
_layers = list(layers.items())
layers_list = []
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
def slice_until_this(module_name, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name:
idx = i
break
first_to = layers[: idx + 1]
remain = layers[idx + 1 :]
return first_to, remain
_layers = _insert_after(
'onehot_to_feature_x',
(
'one_hot_ghost',
OnehotEmbedding(
data_key_x=KEY.NODE_FEATURE_GHOST,
num_classes=num_classes,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
_layers = _insert_after(
'one_hot_ghost',
(
'ghost_onehot_to_feature_x',
IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
_layers = _insert_after(
'0_self_interaction_1',
(
'ghost_0_self_interaction_1',
IrrepsLinear(
irreps_node_zero,
irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
# assign modules (before first communications)
# initialize edge related to retain position gradients
for i in range(1, num_convolution_layer):
sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers)
layers_list.append(OrderedDict(sliced))
_layers.insert(0, ('edge_embedding', init_edge_embedding(config)))
layers_list.append(OrderedDict(_layers))
del layers_list[-1]['force_output'] # done in LAMMPS
return layers_list
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[False] = False
) -> AtomGraphSequential: # noqa
...
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[True]
) -> List[AtomGraphSequential]: # noqa
...
def build_E3_equivariant_model(
config: dict, parallel: bool = False
) -> Union[AtomGraphSequential, List[AtomGraphSequential]]:
"""
output shapes (w/o batch)
PRED_TOTAL_ENERGY: (),
ATOMIC_ENERGY: (natoms, 1), # intended
PRED_FORCE: (natoms, 3),
PRED_STRESS: (6,),
for data w/o cell volume, pred_stress has garbage values
"""
layers = OrderedDict()
cutoff = config[KEY.CUTOFF]
num_species = config[KEY.NUM_SPECIES]
feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY]
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
interaction_type = config[KEY.INTERACTION_TYPE]
use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR]
lmax_node = config[KEY.LMAX] # ignore second (lmax_edge)
# if config[KEY.LMAX_EDGE] > 0: # not yet used
# _ = config[KEY.LMAX_EDGE]
if config[KEY.LMAX_NODE] > 0:
lmax_node = config[KEY.LMAX_NODE]
act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]]
self_connection_pair_list = init_self_connection(config)
irreps_manual = None
if config[KEY.IRREPS_MANUAL] is not False:
irreps_manual = config[KEY.IRREPS_MANUAL]
try:
irreps_manual = [Irreps(irr) for irr in irreps_manual]
assert len(irreps_manual) == num_convolution_layer + 1
except Exception:
raise RuntimeError('invalid irreps_manual input given')
conv_denominator = config[KEY.CONV_DENOMINATOR]
if not isinstance(conv_denominator, list):
conv_denominator = [conv_denominator] * num_convolution_layer
train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR]
edge_embedding = init_edge_embedding(config)
irreps_filter = edge_embedding.spherical.irreps_out
radial_basis_num = edge_embedding.basis_function.num_basis
layers.update({'edge_embedding': edge_embedding})
one_hot_irreps = Irreps(f'{num_species}x0e')
irreps_x = (
Irreps(f'{feature_multiplicity}x0e')
if irreps_manual is None
else irreps_manual[0]
)
layers.update(
{
'onehot_idx_to_onehot': OnehotEmbedding(
num_classes=num_species,
data_key_x=KEY.NODE_FEATURE,
data_key_out=KEY.NODE_FEATURE,
data_key_save=KEY.ATOM_TYPE, # atomic numbers
data_key_additional=KEY.NODE_ATTR, # one-hot embeddings
),
'onehot_to_feature_x': IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_x,
data_key_in=KEY.NODE_FEATURE,
biases=use_bias_in_linear,
),
}
)
weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS]
weight_nn_layers = [radial_basis_num] + weight_nn_hidden
param_interaction_block = {
'irreps_filter': irreps_filter,
'weight_nn_layers': weight_nn_layers,
'train_conv_denominator': train_conv_denominator,
'act_radial': act_radial,
'bias_in_linear': use_bias_in_linear,
'num_species': num_species,
'parallel': parallel,
}
interaction_builder = None
if interaction_type in ['nequip']:
act_scalar = {}
act_gate = {}
for k, v in config[KEY.ACTIVATION_SCARLAR].items():
act_scalar[k] = _const.ACTIVATION_DICT[k][v]
for k, v in config[KEY.ACTIVATION_GATE].items():
act_gate[k] = _const.ACTIVATION_DICT[k][v]
param_interaction_block.update(
{
'act_scalar': act_scalar,
'act_gate': act_gate,
}
)
if interaction_type == 'nequip':
interaction_builder = NequIP_interaction_block
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
for t in range(num_convolution_layer):
param_interaction_block.update(
{
'irreps_x': irreps_x,
't': t,
'conv_denominator': conv_denominator[t],
'self_connection_pair': self_connection_pair_list[t],
}
)
if interaction_type == 'nequip':
parity_mode = 'full'
fix_multiplicity = False
if t == num_convolution_layer - 1:
lmax_node = 0
parity_mode = 'even'
# TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out
irreps_out = (
util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
lmax_node, # type: ignore
parity_mode,
fix_multiplicity=feature_multiplicity,
)
if irreps_manual is None
else irreps_manual[t + 1]
)
irreps_out_tp = util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
irreps_out.lmax, # type: ignore
parity_mode,
fix_multiplicity,
)
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
param_interaction_block.update(
{
'irreps_out_tp': irreps_out_tp,
'irreps_out': irreps_out,
}
)
layers.update(interaction_builder(**param_interaction_block))
irreps_x = irreps_out
layers.update(init_feature_reduce(config, irreps_x))
layers.update(
{
'rescale_atomic_energy': init_shift_scale(config),
'reduce_total_enegy': AtomReduce(
data_key_in=KEY.ATOMIC_ENERGY,
data_key_out=KEY.PRED_TOTAL_ENERGY,
),
}
)
gradient_module = ForceStressOutputFromEdge()
grad_key = gradient_module.get_grad_key()
layers.update({'force_output': gradient_module})
common_args = {
'cutoff': cutoff,
'type_map': config[KEY.TYPE_MAP],
'modal_map': config.get(KEY.MODAL_MAP, None),
'eval_type_map': False if parallel else True,
'eval_modal_map': False
if not config.get(KEY.USE_MODALITY, False) or parallel
else True,
'data_key_grad': grad_key,
}
if parallel:
layers_list = _to_parallel_model(layers, config)
return [
AtomGraphSequential(patch_modules(layers, config), **common_args)
for layers in layers_list
]
else:
return AtomGraphSequential(patch_modules(layers, config), **common_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment