Unverified Commit 273418ee authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Merge pull request #1 from hjhk258/main

Fix
parents 7fb73825 9d7b4f63
import os
import time
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional
from ase.data import atomic_numbers
import sevenn._keys as KEY
from sevenn import __version__
CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()}
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Logger(metaclass=Singleton):
SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output
def __init__(
self, filename: Optional[str] = None, screen: bool = False, rank: int = 0
):
self.rank = rank
self._filename = filename
if rank == 0:
# if filename is not None:
# self.logfile = open(filename, 'a', buffering=1)
self.logfile = None
self.files = {}
self.screen = screen
else:
self.logfile = None
self.screen = False
self.timer_dct = {}
self.active = True
def __enter__(self):
if self.rank != 0:
return self
if self.logfile is None and self._filename is not None:
try:
self.logfile = open(
self._filename, 'a', buffering=1, encoding='utf-8'
)
except IOError as e:
print(f'Failed to re-open log file {self._filename}: {e}')
self.logfile = None
self.files = {}
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.rank != 0:
return self
try:
if self.logfile is not None:
self.logfile.close()
self.logfile = None
for f in self.files.values():
f.close()
except IOError as e:
print(f'Failed to close log files: {e}')
finally:
self.logfile = None
self.files = {}
def switch_file(self, new_filename: str):
if self.rank != 0:
return self
if self.logfile is not None:
raise ValueError('Current logfile is not yet closed')
self._filename = new_filename
return self
def write(self, content: str):
if self.rank != 0:
return
# no newline!
if self.logfile is not None and self.active:
self.logfile.write(content)
if self.screen and self.active:
print(content, end='')
def writeline(self, content: str):
content = content + '\n'
self.write(content)
def init_csv(self, filename: str, header: list):
"""
Deprecated
"""
if self.rank == 0:
self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8')
self.files[filename].write(','.join(header) + '\n')
else:
pass
def append_csv(self, filename: str, content: list, decimal: int = 6):
"""
Deprecated
"""
if self.rank == 0:
if filename not in self.files:
self.files[filename] = open(filename, 'a', buffering=1)
str_content = []
for c in content:
if isinstance(c, float):
str_content.append(f'{c:.{decimal}f}')
else:
str_content.append(str(c))
self.files[filename].write(','.join(str_content) + '\n')
else:
pass
def natoms_write(self, natoms: Dict[str, Dict]):
content = ''
total_natom = {}
for label, natom in natoms.items():
content += self.format_k_v(label, natom)
for specie, num in natom.items():
try:
total_natom[specie] += num
except KeyError:
total_natom[specie] = num
content += self.format_k_v('Total, label wise', total_natom)
content += self.format_k_v('Total', sum(total_natom.values()))
self.write(content)
def statistic_write(self, statistic: Dict[str, Dict]):
content = ''
for label, dct in statistic.items():
if label.startswith('_'):
continue
if not isinstance(dct, dict):
continue
dct_new = {}
for k, v in dct.items():
if k.startswith('_'):
continue
if isinstance(v, int):
dct_new[k] = v
else:
dct_new[k] = f'{v:.3f}'
content += self.format_k_v(label, dct_new)
self.write(content)
# TODO : refactoring!!!, this is not loss, rmse
def epoch_write_specie_wise_loss(self, train_loss, valid_loss):
lb_pad = 21
fs = 6
pad = 21 - fs
ln = '-' * fs
total_atom_type = train_loss.keys()
content = ''
for at in total_atom_type:
t_F = train_loss[at]
v_F = valid_loss[at]
at_sym = CHEM_SYMBOLS[at]
content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format(
label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs
) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format(
t_F=t_F, v_F=v_F, pad=pad, fs=fs
)
content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format(
t_S=ln, v_S=ln, pad=pad, fs=fs
)
content += '\n'
self.write(content)
def write_full_table(
self,
dict_list: List[Dict],
row_labels: List[str],
decimal_places: int = 6,
pad: int = 2,
):
"""
Assume data_list is list of dict with same keys
"""
assert len(dict_list) == len(row_labels)
label_len = max(map(len, row_labels))
# Extract the column names and create a 2D array of values
col_names = list(dict_list[0].keys())
values = [list(d.values()) for d in dict_list]
# Format the numbers with the given decimal places
formatted_values = [
[f'{value:.{decimal_places}f}' for value in row] for row in values
]
# Calculate padding lengths for each column (with extra padding)
max_col_lengths = [
max(len(str(value)) for value in col) + pad
for col in zip(col_names, *formatted_values)
]
# Create header row and separator
header = ' ' * (label_len + pad) + ' '.join(
col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths)
)
separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * (
label_len + pad
)
# Print header and separator
self.writeline(header)
self.writeline(separator)
# Print the data rows with row labels
for row_label, row in zip(row_labels, formatted_values):
data_row = ' '.join(
value.rjust(pad) for value, pad in zip(row, max_col_lengths)
)
self.writeline(f'{row_label.ljust(label_len)}{data_row}')
def format_k_v(self, key: Any, val: Any, write: bool = False):
"""
key and val should be str convertible
"""
MAX_KEY_SIZE = 20
SEPARATOR = ', '
EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3)
NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5
key = str(key)
val = str(val)
content = f'{key:<{MAX_KEY_SIZE}}: {val}'
if len(content) > NEW_LINE_LEN:
content = f'{key:<{MAX_KEY_SIZE}}: '
# septate val by separator
val_list = val.split(SEPARATOR)
current_len = len(content)
for val_compo in val_list:
current_len += len(val_compo)
if current_len > NEW_LINE_LEN:
newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}'
content += f'\\\n{newline_content}'
current_len = len(newline_content)
else:
content += f'{val_compo}{SEPARATOR}'
if content.endswith(f'{SEPARATOR}'):
content = content[: -len(SEPARATOR)]
content += '\n'
if write is False:
return content
else:
self.write(content)
return ''
def greeting(self):
LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii'
with open(LOGO_ASCII_FILE, 'r') as logo_f:
logo_ascii = logo_f.read()
content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n'
content += f'version {__version__}, {time.ctime()}\n'
self.write(content)
self.write(logo_ascii)
def bar(self):
content = '-' * Logger.SCREEN_WIDTH + '\n'
self.write(content)
def print_config(
self,
model_config: Dict[str, Any],
data_config: Dict[str, Any],
train_config: Dict[str, Any],
):
"""
print some important information from config
"""
content = 'successfully read yaml config!\n\n' + 'from model configuration\n'
for k, v in model_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom train configuration\n'
for k, v in train_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom data configuration\n'
for k, v in data_config.items():
content += self.format_k_v(k, str(v))
self.write(content)
# TODO: This is not good make own exception
def error(self, e: Exception):
content = ''
if type(e) is ValueError:
content += 'Error occurred!\n'
content += str(e) + '\n'
else:
content += 'Unknown error occurred!\n'
content += traceback.format_exc()
self.write(content)
def timer_start(self, name: str):
self.timer_dct[name] = datetime.now()
def timer_end(self, name: str, message: str, remove: bool = True):
"""
print f"{message}: {elapsed}"
"""
elapsed = str(datetime.now() - self.timer_dct[name])
# elapsed = elapsed.strftime('%H-%M-%S')
if remove:
del self.timer_dct[name]
self.write(f'{message}: {elapsed[:-4]}\n')
# TODO: print it without config
# TODO: refactoring, readout part name :(
def print_model_info(self, model, config):
from functools import partial
kv_write = partial(self.format_k_v, write=True)
self.writeline('Irreps of features')
kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out'))
for i in range(config[KEY.NUM_CONVOLUTION]):
kv_write(
f'{i}th node',
model.get_irreps_in(f'{i}_self_interaction_1'),
)
i = config[KEY.NUM_CONVOLUTION] - 1
kv_write(
'readout irreps',
model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'),
)
num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.writeline(f'# learnable parameters: {num_weights}\n')
import os
import time
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional
from ase.data import atomic_numbers
import sevenn._keys as KEY
from sevenn import __version__
CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()}
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Logger(metaclass=Singleton):
SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output
def __init__(
self, filename: Optional[str] = None, screen: bool = False, rank: int = 0
):
self.rank = rank
self._filename = filename
if rank == 0:
# if filename is not None:
# self.logfile = open(filename, 'a', buffering=1)
self.logfile = None
self.files = {}
self.screen = screen
else:
self.logfile = None
self.screen = False
self.timer_dct = {}
self.active = True
def __enter__(self):
if self.rank != 0:
return self
if self.logfile is None and self._filename is not None:
try:
self.logfile = open(
self._filename, 'a', buffering=1, encoding='utf-8'
)
except IOError as e:
print(f'Failed to re-open log file {self._filename}: {e}')
self.logfile = None
self.files = {}
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.rank != 0:
return self
try:
if self.logfile is not None:
self.logfile.close()
self.logfile = None
for f in self.files.values():
f.close()
except IOError as e:
print(f'Failed to close log files: {e}')
finally:
self.logfile = None
self.files = {}
def switch_file(self, new_filename: str):
if self.rank != 0:
return self
if self.logfile is not None:
raise ValueError('Current logfile is not yet closed')
self._filename = new_filename
return self
def write(self, content: str):
if self.rank != 0:
return
# no newline!
if self.logfile is not None and self.active:
self.logfile.write(content)
if self.screen and self.active:
print(content, end='')
def writeline(self, content: str):
content = content + '\n'
self.write(content)
def init_csv(self, filename: str, header: list):
"""
Deprecated
"""
if self.rank == 0:
self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8')
self.files[filename].write(','.join(header) + '\n')
else:
pass
def append_csv(self, filename: str, content: list, decimal: int = 6):
"""
Deprecated
"""
if self.rank == 0:
if filename not in self.files:
self.files[filename] = open(filename, 'a', buffering=1)
str_content = []
for c in content:
if isinstance(c, float):
str_content.append(f'{c:.{decimal}f}')
else:
str_content.append(str(c))
self.files[filename].write(','.join(str_content) + '\n')
else:
pass
def natoms_write(self, natoms: Dict[str, Dict]):
content = ''
total_natom = {}
for label, natom in natoms.items():
content += self.format_k_v(label, natom)
for specie, num in natom.items():
try:
total_natom[specie] += num
except KeyError:
total_natom[specie] = num
content += self.format_k_v('Total, label wise', total_natom)
content += self.format_k_v('Total', sum(total_natom.values()))
self.write(content)
def statistic_write(self, statistic: Dict[str, Dict]):
content = ''
for label, dct in statistic.items():
if label.startswith('_'):
continue
if not isinstance(dct, dict):
continue
dct_new = {}
for k, v in dct.items():
if k.startswith('_'):
continue
if isinstance(v, int):
dct_new[k] = v
else:
dct_new[k] = f'{v:.3f}'
content += self.format_k_v(label, dct_new)
self.write(content)
# TODO : refactoring!!!, this is not loss, rmse
def epoch_write_specie_wise_loss(self, train_loss, valid_loss):
lb_pad = 21
fs = 6
pad = 21 - fs
ln = '-' * fs
total_atom_type = train_loss.keys()
content = ''
for at in total_atom_type:
t_F = train_loss[at]
v_F = valid_loss[at]
at_sym = CHEM_SYMBOLS[at]
content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format(
label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs
) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format(
t_F=t_F, v_F=v_F, pad=pad, fs=fs
)
content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format(
t_S=ln, v_S=ln, pad=pad, fs=fs
)
content += '\n'
self.write(content)
def write_full_table(
self,
dict_list: List[Dict],
row_labels: List[str],
decimal_places: int = 6,
pad: int = 2,
):
"""
Assume data_list is list of dict with same keys
"""
assert len(dict_list) == len(row_labels)
label_len = max(map(len, row_labels))
# Extract the column names and create a 2D array of values
col_names = list(dict_list[0].keys())
values = [list(d.values()) for d in dict_list]
# Format the numbers with the given decimal places
formatted_values = [
[f'{value:.{decimal_places}f}' for value in row] for row in values
]
# Calculate padding lengths for each column (with extra padding)
max_col_lengths = [
max(len(str(value)) for value in col) + pad
for col in zip(col_names, *formatted_values)
]
# Create header row and separator
header = ' ' * (label_len + pad) + ' '.join(
col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths)
)
separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * (
label_len + pad
)
# Print header and separator
self.writeline(header)
self.writeline(separator)
# Print the data rows with row labels
for row_label, row in zip(row_labels, formatted_values):
data_row = ' '.join(
value.rjust(pad) for value, pad in zip(row, max_col_lengths)
)
self.writeline(f'{row_label.ljust(label_len)}{data_row}')
def format_k_v(self, key: Any, val: Any, write: bool = False):
"""
key and val should be str convertible
"""
MAX_KEY_SIZE = 20
SEPARATOR = ', '
EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3)
NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5
key = str(key)
val = str(val)
content = f'{key:<{MAX_KEY_SIZE}}: {val}'
if len(content) > NEW_LINE_LEN:
content = f'{key:<{MAX_KEY_SIZE}}: '
# septate val by separator
val_list = val.split(SEPARATOR)
current_len = len(content)
for val_compo in val_list:
current_len += len(val_compo)
if current_len > NEW_LINE_LEN:
newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}'
content += f'\\\n{newline_content}'
current_len = len(newline_content)
else:
content += f'{val_compo}{SEPARATOR}'
if content.endswith(f'{SEPARATOR}'):
content = content[: -len(SEPARATOR)]
content += '\n'
if write is False:
return content
else:
self.write(content)
return ''
def greeting(self):
LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii'
with open(LOGO_ASCII_FILE, 'r') as logo_f:
logo_ascii = logo_f.read()
content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n'
content += f'version {__version__}, {time.ctime()}\n'
self.write(content)
self.write(logo_ascii)
def bar(self):
content = '-' * Logger.SCREEN_WIDTH + '\n'
self.write(content)
def print_config(
self,
model_config: Dict[str, Any],
data_config: Dict[str, Any],
train_config: Dict[str, Any],
):
"""
print some important information from config
"""
content = 'successfully read yaml config!\n\n' + 'from model configuration\n'
for k, v in model_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom train configuration\n'
for k, v in train_config.items():
content += self.format_k_v(k, str(v))
content += '\nfrom data configuration\n'
for k, v in data_config.items():
content += self.format_k_v(k, str(v))
self.write(content)
# TODO: This is not good make own exception
def error(self, e: Exception):
content = ''
if type(e) is ValueError:
content += 'Error occurred!\n'
content += str(e) + '\n'
else:
content += 'Unknown error occurred!\n'
content += traceback.format_exc()
self.write(content)
def timer_start(self, name: str):
self.timer_dct[name] = datetime.now()
def timer_end(self, name: str, message: str, remove: bool = True):
"""
print f"{message}: {elapsed}"
"""
elapsed = str(datetime.now() - self.timer_dct[name])
# elapsed = elapsed.strftime('%H-%M-%S')
if remove:
del self.timer_dct[name]
self.write(f'{message}: {elapsed[:-4]}\n')
# TODO: print it without config
# TODO: refactoring, readout part name :(
def print_model_info(self, model, config):
from functools import partial
kv_write = partial(self.format_k_v, write=True)
self.writeline('Irreps of features')
kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out'))
for i in range(config[KEY.NUM_CONVOLUTION]):
kv_write(
f'{i}th node',
model.get_irreps_in(f'{i}_self_interaction_1'),
)
i = config[KEY.NUM_CONVOLUTION] - 1
kv_write(
'readout irreps',
model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'),
)
num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.writeline(f'# learnable parameters: {num_weights}\n')
import argparse
import os
import sys
import time
from sevenn import __version__
description = 'train a model given the input.yaml'
input_yaml_help = 'input.yaml for training'
mode_help = 'main training script to run. Default is train.'
working_dir_help = 'path to write output. Default is cwd.'
screen_help = 'print log to stdout'
distributed_help = 'set this flag if it is distributed training'
distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi'
# Metainfo will be saved to checkpoint
global_config = {
'version': __version__,
'when': time.ctime(),
'_model_type': 'E3_equivariant_model',
}
def run(args):
"""
main function of sevenn
"""
import random
import sys
import torch
import torch.distributed as dist
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.parse_input import read_config_yaml
from sevenn.scripts.train import train, train_v2
from sevenn.util import unique_filepath
input_yaml = args.input_yaml
mode = args.mode
working_dir = args.working_dir
log = args.log
screen = args.screen
distributed = args.distributed
distributed_backend = args.distributed_backend
use_cue = args.enable_cueq
if use_cue:
import sevenn.nn.cue_helper
if not sevenn.nn.cue_helper.is_cue_available():
raise ImportError('cuEquivariance not installed.')
if working_dir is None:
working_dir = os.getcwd()
elif not os.path.isdir(working_dir):
os.makedirs(working_dir, exist_ok=True)
world_size = 1
if distributed:
if distributed_backend == 'nccl':
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
elif distributed_backend == 'mpi':
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
else:
raise ValueError(f'Unknown distributed backend: {distributed_backend}')
dist.init_process_group(
backend=distributed_backend, world_size=world_size, rank=rank
)
else:
local_rank, rank, world_size = 0, 0, 1
log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}')
with Logger(filename=log_fname, screen=screen, rank=rank) as logger:
logger.greeting()
if distributed:
logger.writeline(
f'Distributed training enabled, total world size is {world_size}'
)
try:
model_config, train_config, data_config = read_config_yaml(
input_yaml, return_separately=True
)
except Exception as e:
logger.writeline('Failed to parsing input.yaml')
logger.error(e)
sys.exit(1)
train_config[KEY.IS_DDP] = distributed
train_config[KEY.DDP_BACKEND] = distributed_backend
train_config[KEY.LOCAL_RANK] = local_rank
train_config[KEY.RANK] = rank
train_config[KEY.WORLD_SIZE] = world_size
if distributed:
torch.cuda.set_device(torch.device('cuda', local_rank))
if use_cue:
if KEY.CUEQUIVARIANCE_CONFIG not in model_config:
model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True}
else:
model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True})
logger.print_config(model_config, data_config, train_config)
# don't have to distinguish configs inside program
global_config.update(model_config)
global_config.update(train_config)
global_config.update(data_config)
# Not implemented
if global_config[KEY.DTYPE] == 'double':
raise Exception('double precision is not implemented yet')
# torch.set_default_dtype(torch.double)
seed = global_config[KEY.RANDOM_SEED]
random.seed(seed)
torch.manual_seed(seed)
# run train
if mode == 'train_v1':
train(global_config, working_dir)
elif mode == 'train_v2':
train_v2(global_config, working_dir)
def cmd_parser_train(parser):
ag = parser
ag.add_argument('input_yaml', help=input_yaml_help, type=str)
ag.add_argument(
'-m',
'--mode',
choices=['train_v1', 'train_v2'],
default='train_v2',
help=mode_help,
type=str,
)
ag.add_argument(
'-cueq',
'--enable_cueq',
help='(Not stable!) use cuEquivariance for training',
action='store_true',
)
ag.add_argument(
'-w',
'--working_dir',
nargs='?',
const=os.getcwd(),
help=working_dir_help,
type=str,
)
ag.add_argument(
'-l',
'--log',
default='log.sevenn',
help='name of logfile, default is log.sevenn',
type=str,
)
ag.add_argument('-s', '--screen', help=screen_help, action='store_true')
ag.add_argument(
'-d', '--distributed', help=distributed_help, action='store_true'
)
ag.add_argument(
'--distributed_backend',
help=distributed_backend_help,
type=str,
default='nccl',
choices=['nccl', 'mpi'],
)
def add_parser(subparsers):
ag = subparsers.add_parser('train', help=description)
cmd_parser_train(ag)
def set_default_subparser(self, name, args=None, positional_args=0):
"""default subparser selection. Call after setup, just before parse_args()
name: is the name of the subparser to call by default
args: if set is the argument list handed to parse_args()
Hack copied from stack overflow
"""
subparser_found = False
for arg in sys.argv[1:]:
if arg in ['-h', '--help']: # global help if no subparser
break
else:
for x in self._subparsers._actions:
if not isinstance(x, argparse._SubParsersAction):
continue
for sp_name in x._name_parser_map.keys():
if sp_name in sys.argv[1:]:
subparser_found = True
if not subparser_found:
# insert default in last position before global positional
# arguments, this implies no global options are specified after
# first positional argument
if args is None:
sys.argv.insert(len(sys.argv) - positional_args, name)
else:
args.insert(len(args) - positional_args, name)
argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore
def main():
import sevenn.main.sevenn_cp as checkpoint_cmd
import sevenn.main.sevenn_get_model as get_model_cmd
import sevenn.main.sevenn_graph_build as graph_build_cmd
import sevenn.main.sevenn_inference as inference_cmd
import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd
import sevenn.main.sevenn_preset as preset_cmd
ag = argparse.ArgumentParser(f'SevenNet version={__version__}')
subparsers = ag.add_subparsers(dest='command', help='Sub-commands')
add_parser(subparsers) # add 'train'
checkpoint_cmd.add_parser(subparsers)
inference_cmd.add_parser(subparsers)
graph_build_cmd.add_parser(subparsers)
preset_cmd.add_parser(subparsers)
get_model_cmd.add_parser(subparsers)
patch_lammps_cmd.add_parser(subparsers)
ag.set_default_subparser('train') # type: ignore
args = ag.parse_args()
if args.command is None: # backward compatibility
args.command = 'train'
if args.command == 'train':
run(args)
elif args.command == 'preset':
preset_cmd.run(args)
if __name__ == '__main__':
main()
import argparse
import os
import sys
import time
from sevenn import __version__
description = 'train a model given the input.yaml'
input_yaml_help = 'input.yaml for training'
mode_help = 'main training script to run. Default is train.'
working_dir_help = 'path to write output. Default is cwd.'
screen_help = 'print log to stdout'
distributed_help = 'set this flag if it is distributed training'
distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi'
# Metainfo will be saved to checkpoint
global_config = {
'version': __version__,
'when': time.ctime(),
'_model_type': 'E3_equivariant_model',
}
def run(args):
"""
main function of sevenn
"""
import random
import sys
import torch
import torch.distributed as dist
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.parse_input import read_config_yaml
from sevenn.scripts.train import train, train_v2
from sevenn.util import unique_filepath
input_yaml = args.input_yaml
mode = args.mode
working_dir = args.working_dir
log = args.log
screen = args.screen
distributed = args.distributed
distributed_backend = args.distributed_backend
use_cue = args.enable_cueq
if use_cue:
import sevenn.nn.cue_helper
if not sevenn.nn.cue_helper.is_cue_available():
raise ImportError('cuEquivariance not installed.')
if working_dir is None:
working_dir = os.getcwd()
elif not os.path.isdir(working_dir):
os.makedirs(working_dir, exist_ok=True)
world_size = 1
if distributed:
if distributed_backend == 'nccl':
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
elif distributed_backend == 'mpi':
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
else:
raise ValueError(f'Unknown distributed backend: {distributed_backend}')
dist.init_process_group(
backend=distributed_backend, world_size=world_size, rank=rank
)
else:
local_rank, rank, world_size = 0, 0, 1
log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}')
with Logger(filename=log_fname, screen=screen, rank=rank) as logger:
logger.greeting()
if distributed:
logger.writeline(
f'Distributed training enabled, total world size is {world_size}'
)
try:
model_config, train_config, data_config = read_config_yaml(
input_yaml, return_separately=True
)
except Exception as e:
logger.writeline('Failed to parsing input.yaml')
logger.error(e)
sys.exit(1)
train_config[KEY.IS_DDP] = distributed
train_config[KEY.DDP_BACKEND] = distributed_backend
train_config[KEY.LOCAL_RANK] = local_rank
train_config[KEY.RANK] = rank
train_config[KEY.WORLD_SIZE] = world_size
if distributed:
torch.cuda.set_device(torch.device('cuda', local_rank))
if use_cue:
if KEY.CUEQUIVARIANCE_CONFIG not in model_config:
model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True}
else:
model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True})
logger.print_config(model_config, data_config, train_config)
# don't have to distinguish configs inside program
global_config.update(model_config)
global_config.update(train_config)
global_config.update(data_config)
# Not implemented
if global_config[KEY.DTYPE] == 'double':
raise Exception('double precision is not implemented yet')
# torch.set_default_dtype(torch.double)
seed = global_config[KEY.RANDOM_SEED]
random.seed(seed)
torch.manual_seed(seed)
# run train
if mode == 'train_v1':
train(global_config, working_dir)
elif mode == 'train_v2':
train_v2(global_config, working_dir)
def cmd_parser_train(parser):
ag = parser
ag.add_argument('input_yaml', help=input_yaml_help, type=str)
ag.add_argument(
'-m',
'--mode',
choices=['train_v1', 'train_v2'],
default='train_v2',
help=mode_help,
type=str,
)
ag.add_argument(
'-cueq',
'--enable_cueq',
help='(Not stable!) use cuEquivariance for training',
action='store_true',
)
ag.add_argument(
'-w',
'--working_dir',
nargs='?',
const=os.getcwd(),
help=working_dir_help,
type=str,
)
ag.add_argument(
'-l',
'--log',
default='log.sevenn',
help='name of logfile, default is log.sevenn',
type=str,
)
ag.add_argument('-s', '--screen', help=screen_help, action='store_true')
ag.add_argument(
'-d', '--distributed', help=distributed_help, action='store_true'
)
ag.add_argument(
'--distributed_backend',
help=distributed_backend_help,
type=str,
default='nccl',
choices=['nccl', 'mpi'],
)
def add_parser(subparsers):
ag = subparsers.add_parser('train', help=description)
cmd_parser_train(ag)
def set_default_subparser(self, name, args=None, positional_args=0):
"""default subparser selection. Call after setup, just before parse_args()
name: is the name of the subparser to call by default
args: if set is the argument list handed to parse_args()
Hack copied from stack overflow
"""
subparser_found = False
for arg in sys.argv[1:]:
if arg in ['-h', '--help']: # global help if no subparser
break
else:
for x in self._subparsers._actions:
if not isinstance(x, argparse._SubParsersAction):
continue
for sp_name in x._name_parser_map.keys():
if sp_name in sys.argv[1:]:
subparser_found = True
if not subparser_found:
# insert default in last position before global positional
# arguments, this implies no global options are specified after
# first positional argument
if args is None:
sys.argv.insert(len(sys.argv) - positional_args, name)
else:
args.insert(len(args) - positional_args, name)
argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore
def main():
import sevenn.main.sevenn_cp as checkpoint_cmd
import sevenn.main.sevenn_get_model as get_model_cmd
import sevenn.main.sevenn_graph_build as graph_build_cmd
import sevenn.main.sevenn_inference as inference_cmd
import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd
import sevenn.main.sevenn_preset as preset_cmd
ag = argparse.ArgumentParser(f'SevenNet version={__version__}')
subparsers = ag.add_subparsers(dest='command', help='Sub-commands')
add_parser(subparsers) # add 'train'
checkpoint_cmd.add_parser(subparsers)
inference_cmd.add_parser(subparsers)
graph_build_cmd.add_parser(subparsers)
preset_cmd.add_parser(subparsers)
get_model_cmd.add_parser(subparsers)
patch_lammps_cmd.add_parser(subparsers)
ag.set_default_subparser('train') # type: ignore
args = ag.parse_args()
if args.command is None: # backward compatibility
args.command = 'train'
if args.command == 'train':
run(args)
elif args.command == 'preset':
preset_cmd.run(args)
if __name__ == '__main__':
main()
import argparse
import os.path as osp
from sevenn import __version__
description = (
'tool box for sevennet checkpoints'
)
def add_parser(subparsers):
ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str)
group = ag.add_mutually_exclusive_group(required=False)
group.add_argument(
'--get_yaml',
choices=['reproduce', 'continue', 'continue_modal'],
help='create input.yaml based on the given checkpoint',
type=str,
)
group.add_argument(
'--append_modal_yaml',
help='append modality with given yaml.',
type=str,
)
ag.add_argument(
'--original_modal_name',
help=(
'when the append_modal is used and checkpoint is not multi-modal, '
+ 'used to name previously trained modality. defaults to "origin"'
),
default='origin',
type=str,
)
def run(args):
import torch
import yaml
from sevenn.parse_input import read_config_yaml
from sevenn.util import load_checkpoint
checkpoint = load_checkpoint(args.checkpoint)
if args.get_yaml:
mode = args.get_yaml
cfg = checkpoint.yaml_dict(mode)
print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False))
elif args.append_modal_yaml:
dst_yaml = args.append_modal_yaml
if not osp.exists(dst_yaml):
raise FileNotFoundError(f'No yaml file {dst_yaml}')
dst_config = read_config_yaml(dst_yaml, return_separately=False)
model_state_dict = checkpoint.append_modal(
dst_config, args.original_modal_name
)
to_save = checkpoint.get_checkpoint_dict()
to_save.update({'config': dst_config, 'model_state_dict': model_state_dict})
torch.save(to_save, 'checkpoint_modal_appended.pth')
print('checkpoint_modal_appended.pth is successfully saved.')
print(f'update continue of {dst_yaml} as blow (recommend) to continue')
cont_dct = {
'continue': {
'checkpoint': 'checkpoint_modal_appended.pth',
'reset_epoch': True,
'reset_optimizer': True,
'reset_scheduler': True,
}
}
print(
yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False)
)
else:
print(checkpoint)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os.path as osp
from sevenn import __version__
description = (
'tool box for sevennet checkpoints'
)
def add_parser(subparsers):
ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str)
group = ag.add_mutually_exclusive_group(required=False)
group.add_argument(
'--get_yaml',
choices=['reproduce', 'continue', 'continue_modal'],
help='create input.yaml based on the given checkpoint',
type=str,
)
group.add_argument(
'--append_modal_yaml',
help='append modality with given yaml.',
type=str,
)
ag.add_argument(
'--original_modal_name',
help=(
'when the append_modal is used and checkpoint is not multi-modal, '
+ 'used to name previously trained modality. defaults to "origin"'
),
default='origin',
type=str,
)
def run(args):
import torch
import yaml
from sevenn.parse_input import read_config_yaml
from sevenn.util import load_checkpoint
checkpoint = load_checkpoint(args.checkpoint)
if args.get_yaml:
mode = args.get_yaml
cfg = checkpoint.yaml_dict(mode)
print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False))
elif args.append_modal_yaml:
dst_yaml = args.append_modal_yaml
if not osp.exists(dst_yaml):
raise FileNotFoundError(f'No yaml file {dst_yaml}')
dst_config = read_config_yaml(dst_yaml, return_separately=False)
model_state_dict = checkpoint.append_modal(
dst_config, args.original_modal_name
)
to_save = checkpoint.get_checkpoint_dict()
to_save.update({'config': dst_config, 'model_state_dict': model_state_dict})
torch.save(to_save, 'checkpoint_modal_appended.pth')
print('checkpoint_modal_appended.pth is successfully saved.')
print(f'update continue of {dst_yaml} as blow (recommend) to continue')
cont_dct = {
'continue': {
'checkpoint': 'checkpoint_modal_appended.pth',
'reset_epoch': True,
'reset_optimizer': True,
'reset_scheduler': True,
}
}
print(
yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False)
)
else:
print(checkpoint)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
from sevenn import __version__
description_get_model = (
'deploy LAMMPS model from the checkpoint'
)
checkpoint_help = (
'path to the checkpoint | SevenNet-0 | 7net-0 |'
' {SevenNet-0|7net-0}_{11July2024|22May2024}'
)
output_name_help = 'filename prefix'
get_parallel_help = 'deploy parallel model'
def add_parser(subparsers):
ag = subparsers.add_parser(
'get_model', help=description_get_model, aliases=['deploy']
)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help=checkpoint_help, type=str)
ag.add_argument(
'-o', '--output_prefix', nargs='?', help=output_name_help, type=str
)
ag.add_argument(
'-p', '--get_parallel', help=get_parallel_help, action='store_true'
)
ag.add_argument(
'-m',
'--modal',
help='Modality of multi-modal model',
type=str,
)
def run(args):
import sevenn.util
from sevenn.scripts.deploy import deploy, deploy_parallel
checkpoint = args.checkpoint
output_prefix = args.output_prefix
get_parallel = args.get_parallel
get_serial = not get_parallel
modal = args.modal
if output_prefix is None:
output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial'
checkpoint_path = None
if os.path.isfile(checkpoint):
checkpoint_path = checkpoint
else:
checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint)
if get_serial:
deploy(checkpoint_path, output_prefix, modal)
else:
deploy_parallel(checkpoint_path, output_prefix, modal)
# legacy way
def main():
ag = argparse.ArgumentParser(description=description_get_model)
add_args(ag)
run(ag.parse_args())
import argparse
import os
from sevenn import __version__
description_get_model = (
'deploy LAMMPS model from the checkpoint'
)
checkpoint_help = (
'path to the checkpoint | SevenNet-0 | 7net-0 |'
' {SevenNet-0|7net-0}_{11July2024|22May2024}'
)
output_name_help = 'filename prefix'
get_parallel_help = 'deploy parallel model'
def add_parser(subparsers):
ag = subparsers.add_parser(
'get_model', help=description_get_model, aliases=['deploy']
)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help=checkpoint_help, type=str)
ag.add_argument(
'-o', '--output_prefix', nargs='?', help=output_name_help, type=str
)
ag.add_argument(
'-p', '--get_parallel', help=get_parallel_help, action='store_true'
)
ag.add_argument(
'-m',
'--modal',
help='Modality of multi-modal model',
type=str,
)
def run(args):
import sevenn.util
from sevenn.scripts.deploy import deploy, deploy_parallel
checkpoint = args.checkpoint
output_prefix = args.output_prefix
get_parallel = args.get_parallel
get_serial = not get_parallel
modal = args.modal
if output_prefix is None:
output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial'
checkpoint_path = None
if os.path.isfile(checkpoint):
checkpoint_path = checkpoint
else:
checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint)
if get_serial:
deploy(checkpoint_path, output_prefix, modal)
else:
deploy_parallel(checkpoint_path, output_prefix, modal)
# legacy way
def main():
ag = argparse.ArgumentParser(description=description_get_model)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
from datetime import datetime
from sevenn import __version__
description = 'create `sevenn_data/dataset.pt` from ase readable'
source_help = 'source data to build graph, knows *'
cutoff_help = 'cutoff radius of edges in Angstrom'
filename_help = (
'Name of the dataset, default is graph.pt. '
+ 'The dataset will be written under "sevenn_data", '
+ 'for example, {out}/sevenn_data/graph.pt.'
)
legacy_help = 'build legacy .sevenn_data'
def add_parser(subparsers):
ag = subparsers.add_parser('graph_build', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('source', help=source_help, type=str)
ag.add_argument('cutoff', help=cutoff_help, type=float)
ag.add_argument(
'-n',
'--num_cores',
help='number of cores to build graph in parallel',
default=1,
type=int,
)
ag.add_argument(
'-o',
'--out',
help='Existing path to write outputs.',
type=str,
default='./',
)
ag.add_argument(
'-f',
'--filename',
help=filename_help,
type=str,
default='graph.pt',
)
ag.add_argument(
'--legacy',
help=legacy_help,
action='store_true',
)
ag.add_argument(
'-s',
'--screen',
help='print log to the screen',
action='store_true',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to ase.io.read, or can be used to specify EFS key',
)
def run(args):
import sevenn.scripts.graph_build as graph_build
from sevenn.logger import Logger
source = glob.glob(args.source)
cutoff = args.cutoff
num_cores = args.num_cores
filename = args.filename
out = args.out
legacy = args.legacy
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if len(source) == 0:
print('Source has zero len, nothing to read')
sys.exit(0)
if not os.path.isdir(out):
raise NotADirectoryError(f'No such directory: {out}')
to_be_written = os.path.join(out, 'sevenn_data', filename)
if os.path.isfile(to_be_written):
raise FileExistsError(f'File already exist: {to_be_written}')
metadata = {
'sevenn_version': __version__,
'when': datetime.now().strftime('%Y-%m-%d'),
'cutoff': cutoff,
}
with Logger(filename=None, screen=args.screen) as logger:
logger.writeline(description)
if not legacy:
graph_build.build_sevennet_graph_dataset(
source,
cutoff,
num_cores,
out,
filename,
metadata,
**fmt_kwargs,
)
else:
out = os.path.join(out, filename.split('.')[0])
graph_build.build_script( # build .sevenn_data
source,
cutoff,
num_cores,
out,
metadata,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
from datetime import datetime
from sevenn import __version__
description = 'create `sevenn_data/dataset.pt` from ase readable'
source_help = 'source data to build graph, knows *'
cutoff_help = 'cutoff radius of edges in Angstrom'
filename_help = (
'Name of the dataset, default is graph.pt. '
+ 'The dataset will be written under "sevenn_data", '
+ 'for example, {out}/sevenn_data/graph.pt.'
)
legacy_help = 'build legacy .sevenn_data'
def add_parser(subparsers):
ag = subparsers.add_parser('graph_build', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('source', help=source_help, type=str)
ag.add_argument('cutoff', help=cutoff_help, type=float)
ag.add_argument(
'-n',
'--num_cores',
help='number of cores to build graph in parallel',
default=1,
type=int,
)
ag.add_argument(
'-o',
'--out',
help='Existing path to write outputs.',
type=str,
default='./',
)
ag.add_argument(
'-f',
'--filename',
help=filename_help,
type=str,
default='graph.pt',
)
ag.add_argument(
'--legacy',
help=legacy_help,
action='store_true',
)
ag.add_argument(
'-s',
'--screen',
help='print log to the screen',
action='store_true',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to ase.io.read, or can be used to specify EFS key',
)
def run(args):
import sevenn.scripts.graph_build as graph_build
from sevenn.logger import Logger
source = glob.glob(args.source)
cutoff = args.cutoff
num_cores = args.num_cores
filename = args.filename
out = args.out
legacy = args.legacy
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if len(source) == 0:
print('Source has zero len, nothing to read')
sys.exit(0)
if not os.path.isdir(out):
raise NotADirectoryError(f'No such directory: {out}')
to_be_written = os.path.join(out, 'sevenn_data', filename)
if os.path.isfile(to_be_written):
raise FileExistsError(f'File already exist: {to_be_written}')
metadata = {
'sevenn_version': __version__,
'when': datetime.now().strftime('%Y-%m-%d'),
'cutoff': cutoff,
}
with Logger(filename=None, screen=args.screen) as logger:
logger.writeline(description)
if not legacy:
graph_build.build_sevennet_graph_dataset(
source,
cutoff,
num_cores,
out,
filename,
metadata,
**fmt_kwargs,
)
else:
out = os.path.join(out, filename.split('.')[0])
graph_build.build_script( # build .sevenn_data
source,
cutoff,
num_cores,
out,
metadata,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
description = (
'evaluate sevenn_data/ase readable with a model (checkpoint).'
)
checkpoint_help = 'Checkpoint or pre-trained model name'
target_help = 'Target files to evaluate'
def add_parser(subparsers):
ag = subparsers.add_parser('inference', help=description, aliases=['inf'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', type=str, help=checkpoint_help)
ag.add_argument('targets', type=str, nargs='+', help=target_help)
ag.add_argument(
'-d',
'--device',
type=str,
default='auto',
help='cpu/cuda/cuda:x',
)
ag.add_argument(
'-nw',
'--nworkers',
type=int,
default=1,
help='Number of cores to build graph, defaults to 1',
)
ag.add_argument(
'-o',
'--output',
type=str,
default='./inference_results',
help='A directory name to write outputs',
)
ag.add_argument(
'-b',
'--batch',
type=int,
default='4',
help='batch size, useful for GPU'
)
ag.add_argument(
'-s',
'--save_graph',
action='store_true',
help='Additionally, save preprocessed graph as sevenn_data'
)
ag.add_argument(
'-au',
'--allow_unlabeled',
action='store_true',
help='Allow energy or force unlabeled data'
)
ag.add_argument(
'-m',
'--modal',
type=str,
default=None,
help='modality for multi-modal inference',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to reader, or can be used to specify EFS key',
)
def run(args):
import torch
from sevenn.scripts.inference import inference
from sevenn.util import pretrained_name_to_path
out = args.output
if os.path.exists(out):
raise FileExistsError(f'Directory {out} already exists')
device = args.device
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
targets = []
for target in args.targets:
targets.extend(glob.glob(target))
if len(targets) == 0:
print('No targets (data to inference) are found')
sys.exit(0)
cp = args.checkpoint
if not os.path.isfile(cp):
cp = pretrained_name_to_path(cp) # raises value error
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if args.save_graph and args.allow_unlabeled:
raise ValueError('save_graph and allow_unlabeled are mutually exclusive')
inference(
cp,
targets,
out,
args.nworkers,
device,
args.batch,
args.save_graph,
args.allow_unlabeled,
args.modal,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
description = (
'evaluate sevenn_data/ase readable with a model (checkpoint).'
)
checkpoint_help = 'Checkpoint or pre-trained model name'
target_help = 'Target files to evaluate'
def add_parser(subparsers):
ag = subparsers.add_parser('inference', help=description, aliases=['inf'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', type=str, help=checkpoint_help)
ag.add_argument('targets', type=str, nargs='+', help=target_help)
ag.add_argument(
'-d',
'--device',
type=str,
default='auto',
help='cpu/cuda/cuda:x',
)
ag.add_argument(
'-nw',
'--nworkers',
type=int,
default=1,
help='Number of cores to build graph, defaults to 1',
)
ag.add_argument(
'-o',
'--output',
type=str,
default='./inference_results',
help='A directory name to write outputs',
)
ag.add_argument(
'-b',
'--batch',
type=int,
default='4',
help='batch size, useful for GPU'
)
ag.add_argument(
'-s',
'--save_graph',
action='store_true',
help='Additionally, save preprocessed graph as sevenn_data'
)
ag.add_argument(
'-au',
'--allow_unlabeled',
action='store_true',
help='Allow energy or force unlabeled data'
)
ag.add_argument(
'-m',
'--modal',
type=str,
default=None,
help='modality for multi-modal inference',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to reader, or can be used to specify EFS key',
)
def run(args):
import torch
from sevenn.scripts.inference import inference
from sevenn.util import pretrained_name_to_path
out = args.output
if os.path.exists(out):
raise FileExistsError(f'Directory {out} already exists')
device = args.device
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
targets = []
for target in args.targets:
targets.extend(glob.glob(target))
if len(targets) == 0:
print('No targets (data to inference) are found')
sys.exit(0)
cp = args.checkpoint
if not os.path.isfile(cp):
cp = pretrained_name_to_path(cp) # raises value error
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if args.save_graph and args.allow_unlabeled:
raise ValueError('save_graph and allow_unlabeled are mutually exclusive')
inference(
cp,
targets,
out,
args.nworkers,
device,
args.batch,
args.save_graph,
args.allow_unlabeled,
args.modal,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
import subprocess
from sevenn import __version__
# python wrapper of patch_lammps.sh script
# importlib.resources is correct way to do these things
# but it changes so frequently to use
pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn')
description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile'
def add_parser(subparsers):
ag = subparsers.add_parser('patch_lammps', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str)
ag.add_argument('--d3', help='Enable D3 support', action='store_true')
# cxx_standard is detected automatically
def run(args):
lammps_dir = os.path.abspath(args.lammps_dir)
print('Patching LAMMPS with the following settings:')
print(' - LAMMPS source directory:', lammps_dir)
cxx_standard = '17' # always 17
if args.d3:
d3_support = '1'
print(' - D3 support enabled')
else:
d3_support = '0'
print(' - D3 support disabled')
script = f'{pair_e3gnn_dir}/patch_lammps.sh'
cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}'
res = subprocess.run(cmd.split())
return res.returncode # is it meaningless?
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
if __name__ == '__main__':
main()
import argparse
import os
import subprocess
from sevenn import __version__
# python wrapper of patch_lammps.sh script
# importlib.resources is correct way to do these things
# but it changes so frequently to use
pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn')
description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile'
def add_parser(subparsers):
ag = subparsers.add_parser('patch_lammps', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str)
ag.add_argument('--d3', help='Enable D3 support', action='store_true')
# cxx_standard is detected automatically
def run(args):
lammps_dir = os.path.abspath(args.lammps_dir)
print('Patching LAMMPS with the following settings:')
print(' - LAMMPS source directory:', lammps_dir)
cxx_standard = '17' # always 17
if args.d3:
d3_support = '1'
print(' - D3 support enabled')
else:
d3_support = '0'
print(' - D3 support disabled')
script = f'{pair_e3gnn_dir}/patch_lammps.sh'
cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}'
res = subprocess.run(cmd.split())
return res.returncode # is it meaningless?
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
if __name__ == '__main__':
main()
import argparse
import os
from sevenn import __version__
description = (
'print the selected preset for training. '
+ 'ex) sevennet_preset fine_tune > my_input.yaml'
)
preset_help = 'Name of preset'
def add_parser(subparsers):
ag = subparsers.add_parser('preset', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument(
'preset', choices=[
'fine_tune',
'fine_tune_le',
'sevennet-0',
'sevennet-l3i5',
'base',
'multi_modal'
],
help=preset_help
)
def run(args):
preset = args.preset
prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets')
with open(f'{prefix}/{preset}.yaml', 'r') as f:
print(f.read())
# When executed as sevenn_preset (legacy way)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
from sevenn import __version__
description = (
'print the selected preset for training. '
+ 'ex) sevennet_preset fine_tune > my_input.yaml'
)
preset_help = 'Name of preset'
def add_parser(subparsers):
ag = subparsers.add_parser('preset', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument(
'preset', choices=[
'fine_tune',
'fine_tune_le',
'sevennet-0',
'sevennet-l3i5',
'base',
'multi_modal'
],
help=preset_help
)
def run(args):
preset = args.preset
prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets')
with open(f'{prefix}/{preset}.yaml', 'r') as f:
print(f.read())
# When executed as sevenn_preset (legacy way)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import copy
import warnings
from collections import OrderedDict
from typing import List, Literal, Union, overload
from e3nn.o3 import Irreps
import sevenn._const as _const
import sevenn._keys as KEY
import sevenn.util as util
from .nn.convolution import IrrepsConvolution
from .nn.edge_embedding import (
BesselBasis,
EdgeEmbedding,
PolynomialCutoff,
SphericalEncoding,
XPLORCutoff,
)
from .nn.force_output import ForceStressOutputFromEdge
from .nn.interaction_blocks import NequIP_interaction_block
from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear
from .nn.node_embedding import OnehotEmbedding
from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale
from .nn.self_connection import (
SelfConnectionIntro,
SelfConnectionLinearIntro,
SelfConnectionOutro,
)
from .nn.sequential import AtomGraphSequential
# warning from PyTorch, about e3nn type annotations
warnings.filterwarnings(
'ignore',
message=(
"The TorchScript type system doesn't " 'support instance-level annotations'
),
)
def _insert_after(module_name_after, key_module_pair, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name_after:
idx = i
break
if idx == -1:
return layers # do nothing if not found
layers.insert(idx + 1, key_module_pair)
return layers
def init_self_connection(config):
self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE]
num_conv = config[KEY.NUM_CONVOLUTION]
if isinstance(self_connection_type_list, str):
self_connection_type_list = [self_connection_type_list] * num_conv
io_pair_list = []
for sc_type in self_connection_type_list:
if sc_type == 'none':
io_pair = None
elif sc_type == 'nequip':
io_pair = SelfConnectionIntro, SelfConnectionOutro
elif sc_type == 'linear':
io_pair = SelfConnectionLinearIntro, SelfConnectionOutro
else:
raise ValueError(f'Unknown self_connection_type found: {sc_type}')
io_pair_list.append(io_pair)
return io_pair_list
def init_edge_embedding(config):
_cutoff_param = {'cutoff_length': config[KEY.CUTOFF]}
rbf, env, sph = None, None, None
rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS])
rbf_dct.update(_cutoff_param)
rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME)
if rbf_name == 'bessel':
rbf = BesselBasis(**rbf_dct)
envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION])
envelop_dct.update(_cutoff_param)
envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME)
if envelop_name == 'poly_cut':
env = PolynomialCutoff(**envelop_dct)
elif envelop_name == 'XPLOR':
env = XPLORCutoff(**envelop_dct)
lmax_edge = config[KEY.LMAX]
if config[KEY.LMAX_EDGE] > 0:
lmax_edge = config[KEY.LMAX_EDGE]
parity = -1 if config[KEY.IS_PARITY] else 1
_normalize_sph = config[KEY._NORMALIZE_SPH]
sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph)
return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph)
def init_feature_reduce(config, irreps_x):
# features per node to scalar per node
layers = OrderedDict()
if config[KEY.READOUT_AS_FCN] is False:
hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))])
layers.update(
{
'reduce_input_to_hidden': IrrepsLinear(
irreps_x,
hidden_irreps,
data_key_in=KEY.NODE_FEATURE,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
'reduce_hidden_to_energy': IrrepsLinear(
hidden_irreps,
Irreps([(1, (0, 1))]),
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
}
)
else:
act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]]
hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS]
layers.update(
{
'readout_FCN': FCN_e3nn(
dim_out=1,
hidden_neurons=hidden_neurons,
activation=act,
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
irreps_in=irreps_x,
)
}
)
return layers
def init_shift_scale(config):
# for mm, ex, shift: modal_idx -> shifts
shift_scale = []
train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE]
type_map = config[KEY.TYPE_MAP]
# in case of modal, shift or scale has more dims [][]
# correct typing (I really want static python)
for s in (config[KEY.SHIFT], config[KEY.SCALE]):
if hasattr(s, 'tolist'): # numpy or torch
s = s.tolist()
if isinstance(s, dict):
s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()}
if isinstance(s, list) and len(s) == 1:
s = s[0]
shift_scale.append(s)
shift, scale = shift_scale
rescale_module = None
if config.get(KEY.USE_MODALITY, False):
rescale_module = ModalWiseRescale.from_mappers( # type: ignore
shift,
scale,
config[KEY.USE_MODAL_WISE_SHIFT],
config[KEY.USE_MODAL_WISE_SCALE],
type_map=type_map,
modal_map=config[KEY.MODAL_MAP],
train_shift_scale=train_shift_scale,
)
elif all([isinstance(s, float) for s in shift_scale]):
rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale)
elif any([isinstance(s, list) for s in shift_scale]):
rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore
shift, scale, type_map=type_map, train_shift_scale=train_shift_scale
)
else:
raise ValueError('shift, scale should be list of float or float')
return rescale_module
def patch_modality(layers: OrderedDict, config):
"""
Postprocess 7net-model to multimodal model.
1. prepend modality one-hot embedding layer
2. patch modalities of IrrepsLinear layers
Modality aware shift scale is handled by init_shift_scale, not here
"""
cfg = config
if not cfg.get(KEY.USE_MODALITY, False):
return layers
_layers = list(layers.items())
_layers = _insert_after(
'onehot_idx_to_onehot',
(
'one_hot_modality',
OnehotEmbedding(
num_classes=config[KEY.NUM_MODALITIES],
data_key_x=KEY.MODAL_TYPE,
data_key_out=KEY.MODAL_ATTR,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
layers = OrderedDict(_layers)
num_modal = config[KEY.NUM_MODALITIES]
for k, module in layers.items():
if not isinstance(module, IrrepsLinear):
continue
if (
(cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x'))
or (
cfg[KEY.USE_MODAL_SELF_INTER_INTRO]
and k.endswith('self_interaction_1')
)
or (
cfg[KEY.USE_MODAL_SELF_INTER_OUTRO]
and k.endswith('self_interaction_2')
)
or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden')
):
module.set_num_modalities(num_modal)
return layers
def patch_cue(layers: OrderedDict, config):
import sevenn.nn.cue_helper as cue_helper
cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {}))
if not cue_cfg.pop('use', False):
return layers
if not cue_helper.is_cue_available():
warnings.warn(
(
'cuEquivariance is requested, but the package is not installed. '
+ 'Fallback to original code.'
)
)
return layers
if not cue_helper.is_cue_cuda_available_model(config):
return layers
group = 'O3' if config[KEY.IS_PARITY] else 'SO3'
cueq_module_params = dict(layout='mul_ir')
cueq_module_params.update(cue_cfg)
updates = {}
for k, module in layers.items():
if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)):
if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape
continue
module_patched = cue_helper.patch_linear(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, SelfConnectionIntro):
module_patched = cue_helper.patch_fully_connected(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, IrrepsConvolution):
module_patched = cue_helper.patch_convolution(
module, group, **cueq_module_params
)
updates[k] = module_patched
layers.update(updates)
return layers
def patch_modules(layers: OrderedDict, config):
layers = patch_modality(layers, config)
layers = patch_cue(layers, config)
return layers
def _to_parallel_model(layers: OrderedDict, config):
num_classes = layers['onehot_idx_to_onehot'].num_classes
one_hot_irreps = Irreps(f'{num_classes}x0e')
irreps_node_zero = layers['onehot_to_feature_x'].irreps_out
_layers = list(layers.items())
layers_list = []
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
def slice_until_this(module_name, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name:
idx = i
break
first_to = layers[: idx + 1]
remain = layers[idx + 1 :]
return first_to, remain
_layers = _insert_after(
'onehot_to_feature_x',
(
'one_hot_ghost',
OnehotEmbedding(
data_key_x=KEY.NODE_FEATURE_GHOST,
num_classes=num_classes,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
_layers = _insert_after(
'one_hot_ghost',
(
'ghost_onehot_to_feature_x',
IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
_layers = _insert_after(
'0_self_interaction_1',
(
'ghost_0_self_interaction_1',
IrrepsLinear(
irreps_node_zero,
irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
# assign modules (before first communications)
# initialize edge related to retain position gradients
for i in range(1, num_convolution_layer):
sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers)
layers_list.append(OrderedDict(sliced))
_layers.insert(0, ('edge_embedding', init_edge_embedding(config)))
layers_list.append(OrderedDict(_layers))
del layers_list[-1]['force_output'] # done in LAMMPS
return layers_list
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[False] = False
) -> AtomGraphSequential: # noqa
...
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[True]
) -> List[AtomGraphSequential]: # noqa
...
def build_E3_equivariant_model(
config: dict, parallel: bool = False
) -> Union[AtomGraphSequential, List[AtomGraphSequential]]:
"""
output shapes (w/o batch)
PRED_TOTAL_ENERGY: (),
ATOMIC_ENERGY: (natoms, 1), # intended
PRED_FORCE: (natoms, 3),
PRED_STRESS: (6,),
for data w/o cell volume, pred_stress has garbage values
"""
layers = OrderedDict()
cutoff = config[KEY.CUTOFF]
num_species = config[KEY.NUM_SPECIES]
feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY]
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
interaction_type = config[KEY.INTERACTION_TYPE]
use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR]
lmax_node = config[KEY.LMAX] # ignore second (lmax_edge)
# if config[KEY.LMAX_EDGE] > 0: # not yet used
# _ = config[KEY.LMAX_EDGE]
if config[KEY.LMAX_NODE] > 0:
lmax_node = config[KEY.LMAX_NODE]
act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]]
self_connection_pair_list = init_self_connection(config)
irreps_manual = None
if config[KEY.IRREPS_MANUAL] is not False:
irreps_manual = config[KEY.IRREPS_MANUAL]
try:
irreps_manual = [Irreps(irr) for irr in irreps_manual]
assert len(irreps_manual) == num_convolution_layer + 1
except Exception:
raise RuntimeError('invalid irreps_manual input given')
conv_denominator = config[KEY.CONV_DENOMINATOR]
if not isinstance(conv_denominator, list):
conv_denominator = [conv_denominator] * num_convolution_layer
train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR]
edge_embedding = init_edge_embedding(config)
irreps_filter = edge_embedding.spherical.irreps_out
radial_basis_num = edge_embedding.basis_function.num_basis
layers.update({'edge_embedding': edge_embedding})
one_hot_irreps = Irreps(f'{num_species}x0e')
irreps_x = (
Irreps(f'{feature_multiplicity}x0e')
if irreps_manual is None
else irreps_manual[0]
)
layers.update(
{
'onehot_idx_to_onehot': OnehotEmbedding(
num_classes=num_species,
data_key_x=KEY.NODE_FEATURE,
data_key_out=KEY.NODE_FEATURE,
data_key_save=KEY.ATOM_TYPE, # atomic numbers
data_key_additional=KEY.NODE_ATTR, # one-hot embeddings
),
'onehot_to_feature_x': IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_x,
data_key_in=KEY.NODE_FEATURE,
biases=use_bias_in_linear,
),
}
)
weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS]
weight_nn_layers = [radial_basis_num] + weight_nn_hidden
param_interaction_block = {
'irreps_filter': irreps_filter,
'weight_nn_layers': weight_nn_layers,
'train_conv_denominator': train_conv_denominator,
'act_radial': act_radial,
'bias_in_linear': use_bias_in_linear,
'num_species': num_species,
'parallel': parallel,
}
interaction_builder = None
if interaction_type in ['nequip']:
act_scalar = {}
act_gate = {}
for k, v in config[KEY.ACTIVATION_SCARLAR].items():
act_scalar[k] = _const.ACTIVATION_DICT[k][v]
for k, v in config[KEY.ACTIVATION_GATE].items():
act_gate[k] = _const.ACTIVATION_DICT[k][v]
param_interaction_block.update(
{
'act_scalar': act_scalar,
'act_gate': act_gate,
}
)
if interaction_type == 'nequip':
interaction_builder = NequIP_interaction_block
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
for t in range(num_convolution_layer):
param_interaction_block.update(
{
'irreps_x': irreps_x,
't': t,
'conv_denominator': conv_denominator[t],
'self_connection_pair': self_connection_pair_list[t],
}
)
if interaction_type == 'nequip':
parity_mode = 'full'
fix_multiplicity = False
if t == num_convolution_layer - 1:
lmax_node = 0
parity_mode = 'even'
# TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out
irreps_out = (
util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
lmax_node, # type: ignore
parity_mode,
fix_multiplicity=feature_multiplicity,
)
if irreps_manual is None
else irreps_manual[t + 1]
)
irreps_out_tp = util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
irreps_out.lmax, # type: ignore
parity_mode,
fix_multiplicity,
)
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
param_interaction_block.update(
{
'irreps_out_tp': irreps_out_tp,
'irreps_out': irreps_out,
}
)
layers.update(interaction_builder(**param_interaction_block))
irreps_x = irreps_out
layers.update(init_feature_reduce(config, irreps_x))
layers.update(
{
'rescale_atomic_energy': init_shift_scale(config),
'reduce_total_enegy': AtomReduce(
data_key_in=KEY.ATOMIC_ENERGY,
data_key_out=KEY.PRED_TOTAL_ENERGY,
),
}
)
gradient_module = ForceStressOutputFromEdge()
grad_key = gradient_module.get_grad_key()
layers.update({'force_output': gradient_module})
common_args = {
'cutoff': cutoff,
'type_map': config[KEY.TYPE_MAP],
'modal_map': config.get(KEY.MODAL_MAP, None),
'eval_type_map': False if parallel else True,
'eval_modal_map': False
if not config.get(KEY.USE_MODALITY, False) or parallel
else True,
'data_key_grad': grad_key,
}
if parallel:
layers_list = _to_parallel_model(layers, config)
return [
AtomGraphSequential(patch_modules(layers, config), **common_args)
for layers in layers_list
]
else:
return AtomGraphSequential(patch_modules(layers, config), **common_args)
import copy
import warnings
from collections import OrderedDict
from typing import List, Literal, Union, overload
from e3nn.o3 import Irreps
import sevenn._const as _const
import sevenn._keys as KEY
import sevenn.util as util
from .nn.convolution import IrrepsConvolution
from .nn.edge_embedding import (
BesselBasis,
EdgeEmbedding,
PolynomialCutoff,
SphericalEncoding,
XPLORCutoff,
)
from .nn.force_output import ForceStressOutputFromEdge
from .nn.interaction_blocks import NequIP_interaction_block
from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear
from .nn.node_embedding import OnehotEmbedding
from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale
from .nn.self_connection import (
SelfConnectionIntro,
SelfConnectionLinearIntro,
SelfConnectionOutro,
)
from .nn.sequential import AtomGraphSequential
# warning from PyTorch, about e3nn type annotations
warnings.filterwarnings(
'ignore',
message=(
"The TorchScript type system doesn't " 'support instance-level annotations'
),
)
def _insert_after(module_name_after, key_module_pair, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name_after:
idx = i
break
if idx == -1:
return layers # do nothing if not found
layers.insert(idx + 1, key_module_pair)
return layers
def init_self_connection(config):
self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE]
num_conv = config[KEY.NUM_CONVOLUTION]
if isinstance(self_connection_type_list, str):
self_connection_type_list = [self_connection_type_list] * num_conv
io_pair_list = []
for sc_type in self_connection_type_list:
if sc_type == 'none':
io_pair = None
elif sc_type == 'nequip':
io_pair = SelfConnectionIntro, SelfConnectionOutro
elif sc_type == 'linear':
io_pair = SelfConnectionLinearIntro, SelfConnectionOutro
else:
raise ValueError(f'Unknown self_connection_type found: {sc_type}')
io_pair_list.append(io_pair)
return io_pair_list
def init_edge_embedding(config):
_cutoff_param = {'cutoff_length': config[KEY.CUTOFF]}
rbf, env, sph = None, None, None
rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS])
rbf_dct.update(_cutoff_param)
rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME)
if rbf_name == 'bessel':
rbf = BesselBasis(**rbf_dct)
envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION])
envelop_dct.update(_cutoff_param)
envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME)
if envelop_name == 'poly_cut':
env = PolynomialCutoff(**envelop_dct)
elif envelop_name == 'XPLOR':
env = XPLORCutoff(**envelop_dct)
lmax_edge = config[KEY.LMAX]
if config[KEY.LMAX_EDGE] > 0:
lmax_edge = config[KEY.LMAX_EDGE]
parity = -1 if config[KEY.IS_PARITY] else 1
_normalize_sph = config[KEY._NORMALIZE_SPH]
sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph)
return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph)
def init_feature_reduce(config, irreps_x):
# features per node to scalar per node
layers = OrderedDict()
if config[KEY.READOUT_AS_FCN] is False:
hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))])
layers.update(
{
'reduce_input_to_hidden': IrrepsLinear(
irreps_x,
hidden_irreps,
data_key_in=KEY.NODE_FEATURE,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
'reduce_hidden_to_energy': IrrepsLinear(
hidden_irreps,
Irreps([(1, (0, 1))]),
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
}
)
else:
act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]]
hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS]
layers.update(
{
'readout_FCN': FCN_e3nn(
dim_out=1,
hidden_neurons=hidden_neurons,
activation=act,
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY,
irreps_in=irreps_x,
)
}
)
return layers
def init_shift_scale(config):
# for mm, ex, shift: modal_idx -> shifts
shift_scale = []
train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE]
type_map = config[KEY.TYPE_MAP]
# in case of modal, shift or scale has more dims [][]
# correct typing (I really want static python)
for s in (config[KEY.SHIFT], config[KEY.SCALE]):
if hasattr(s, 'tolist'): # numpy or torch
s = s.tolist()
if isinstance(s, dict):
s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()}
if isinstance(s, list) and len(s) == 1:
s = s[0]
shift_scale.append(s)
shift, scale = shift_scale
rescale_module = None
if config.get(KEY.USE_MODALITY, False):
rescale_module = ModalWiseRescale.from_mappers( # type: ignore
shift,
scale,
config[KEY.USE_MODAL_WISE_SHIFT],
config[KEY.USE_MODAL_WISE_SCALE],
type_map=type_map,
modal_map=config[KEY.MODAL_MAP],
train_shift_scale=train_shift_scale,
)
elif all([isinstance(s, float) for s in shift_scale]):
rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale)
elif any([isinstance(s, list) for s in shift_scale]):
rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore
shift, scale, type_map=type_map, train_shift_scale=train_shift_scale
)
else:
raise ValueError('shift, scale should be list of float or float')
return rescale_module
def patch_modality(layers: OrderedDict, config):
"""
Postprocess 7net-model to multimodal model.
1. prepend modality one-hot embedding layer
2. patch modalities of IrrepsLinear layers
Modality aware shift scale is handled by init_shift_scale, not here
"""
cfg = config
if not cfg.get(KEY.USE_MODALITY, False):
return layers
_layers = list(layers.items())
_layers = _insert_after(
'onehot_idx_to_onehot',
(
'one_hot_modality',
OnehotEmbedding(
num_classes=config[KEY.NUM_MODALITIES],
data_key_x=KEY.MODAL_TYPE,
data_key_out=KEY.MODAL_ATTR,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
layers = OrderedDict(_layers)
num_modal = config[KEY.NUM_MODALITIES]
for k, module in layers.items():
if not isinstance(module, IrrepsLinear):
continue
if (
(cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x'))
or (
cfg[KEY.USE_MODAL_SELF_INTER_INTRO]
and k.endswith('self_interaction_1')
)
or (
cfg[KEY.USE_MODAL_SELF_INTER_OUTRO]
and k.endswith('self_interaction_2')
)
or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden')
):
module.set_num_modalities(num_modal)
return layers
def patch_cue(layers: OrderedDict, config):
import sevenn.nn.cue_helper as cue_helper
cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {}))
if not cue_cfg.pop('use', False):
return layers
if not cue_helper.is_cue_available():
warnings.warn(
(
'cuEquivariance is requested, but the package is not installed. '
+ 'Fallback to original code.'
)
)
return layers
if not cue_helper.is_cue_cuda_available_model(config):
return layers
group = 'O3' if config[KEY.IS_PARITY] else 'SO3'
cueq_module_params = dict(layout='mul_ir')
cueq_module_params.update(cue_cfg)
updates = {}
for k, module in layers.items():
if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)):
if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape
continue
module_patched = cue_helper.patch_linear(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, SelfConnectionIntro):
module_patched = cue_helper.patch_fully_connected(
module, group, **cueq_module_params
)
updates[k] = module_patched
elif isinstance(module, IrrepsConvolution):
module_patched = cue_helper.patch_convolution(
module, group, **cueq_module_params
)
updates[k] = module_patched
layers.update(updates)
return layers
def patch_modules(layers: OrderedDict, config):
layers = patch_modality(layers, config)
layers = patch_cue(layers, config)
return layers
def _to_parallel_model(layers: OrderedDict, config):
num_classes = layers['onehot_idx_to_onehot'].num_classes
one_hot_irreps = Irreps(f'{num_classes}x0e')
irreps_node_zero = layers['onehot_to_feature_x'].irreps_out
_layers = list(layers.items())
layers_list = []
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
def slice_until_this(module_name, layers):
idx = -1
for i, (key, _) in enumerate(layers):
if key == module_name:
idx = i
break
first_to = layers[: idx + 1]
remain = layers[idx + 1 :]
return first_to, remain
_layers = _insert_after(
'onehot_to_feature_x',
(
'one_hot_ghost',
OnehotEmbedding(
data_key_x=KEY.NODE_FEATURE_GHOST,
num_classes=num_classes,
data_key_save=None,
data_key_additional=None,
),
),
_layers,
)
_layers = _insert_after(
'one_hot_ghost',
(
'ghost_onehot_to_feature_x',
IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
_layers = _insert_after(
'0_self_interaction_1',
(
'ghost_0_self_interaction_1',
IrrepsLinear(
irreps_node_zero,
irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR],
),
),
_layers,
)
# assign modules (before first communications)
# initialize edge related to retain position gradients
for i in range(1, num_convolution_layer):
sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers)
layers_list.append(OrderedDict(sliced))
_layers.insert(0, ('edge_embedding', init_edge_embedding(config)))
layers_list.append(OrderedDict(_layers))
del layers_list[-1]['force_output'] # done in LAMMPS
return layers_list
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[False] = False
) -> AtomGraphSequential: # noqa
...
@overload
def build_E3_equivariant_model(
config: dict, parallel: Literal[True]
) -> List[AtomGraphSequential]: # noqa
...
def build_E3_equivariant_model(
config: dict, parallel: bool = False
) -> Union[AtomGraphSequential, List[AtomGraphSequential]]:
"""
output shapes (w/o batch)
PRED_TOTAL_ENERGY: (),
ATOMIC_ENERGY: (natoms, 1), # intended
PRED_FORCE: (natoms, 3),
PRED_STRESS: (6,),
for data w/o cell volume, pred_stress has garbage values
"""
layers = OrderedDict()
cutoff = config[KEY.CUTOFF]
num_species = config[KEY.NUM_SPECIES]
feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY]
num_convolution_layer = config[KEY.NUM_CONVOLUTION]
interaction_type = config[KEY.INTERACTION_TYPE]
use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR]
lmax_node = config[KEY.LMAX] # ignore second (lmax_edge)
# if config[KEY.LMAX_EDGE] > 0: # not yet used
# _ = config[KEY.LMAX_EDGE]
if config[KEY.LMAX_NODE] > 0:
lmax_node = config[KEY.LMAX_NODE]
act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]]
self_connection_pair_list = init_self_connection(config)
irreps_manual = None
if config[KEY.IRREPS_MANUAL] is not False:
irreps_manual = config[KEY.IRREPS_MANUAL]
try:
irreps_manual = [Irreps(irr) for irr in irreps_manual]
assert len(irreps_manual) == num_convolution_layer + 1
except Exception:
raise RuntimeError('invalid irreps_manual input given')
conv_denominator = config[KEY.CONV_DENOMINATOR]
if not isinstance(conv_denominator, list):
conv_denominator = [conv_denominator] * num_convolution_layer
train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR]
edge_embedding = init_edge_embedding(config)
irreps_filter = edge_embedding.spherical.irreps_out
radial_basis_num = edge_embedding.basis_function.num_basis
layers.update({'edge_embedding': edge_embedding})
one_hot_irreps = Irreps(f'{num_species}x0e')
irreps_x = (
Irreps(f'{feature_multiplicity}x0e')
if irreps_manual is None
else irreps_manual[0]
)
layers.update(
{
'onehot_idx_to_onehot': OnehotEmbedding(
num_classes=num_species,
data_key_x=KEY.NODE_FEATURE,
data_key_out=KEY.NODE_FEATURE,
data_key_save=KEY.ATOM_TYPE, # atomic numbers
data_key_additional=KEY.NODE_ATTR, # one-hot embeddings
),
'onehot_to_feature_x': IrrepsLinear(
irreps_in=one_hot_irreps,
irreps_out=irreps_x,
data_key_in=KEY.NODE_FEATURE,
biases=use_bias_in_linear,
),
}
)
weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS]
weight_nn_layers = [radial_basis_num] + weight_nn_hidden
param_interaction_block = {
'irreps_filter': irreps_filter,
'weight_nn_layers': weight_nn_layers,
'train_conv_denominator': train_conv_denominator,
'act_radial': act_radial,
'bias_in_linear': use_bias_in_linear,
'num_species': num_species,
'parallel': parallel,
}
interaction_builder = None
if interaction_type in ['nequip']:
act_scalar = {}
act_gate = {}
for k, v in config[KEY.ACTIVATION_SCARLAR].items():
act_scalar[k] = _const.ACTIVATION_DICT[k][v]
for k, v in config[KEY.ACTIVATION_GATE].items():
act_gate[k] = _const.ACTIVATION_DICT[k][v]
param_interaction_block.update(
{
'act_scalar': act_scalar,
'act_gate': act_gate,
}
)
if interaction_type == 'nequip':
interaction_builder = NequIP_interaction_block
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
for t in range(num_convolution_layer):
param_interaction_block.update(
{
'irreps_x': irreps_x,
't': t,
'conv_denominator': conv_denominator[t],
'self_connection_pair': self_connection_pair_list[t],
}
)
if interaction_type == 'nequip':
parity_mode = 'full'
fix_multiplicity = False
if t == num_convolution_layer - 1:
lmax_node = 0
parity_mode = 'even'
# TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out
irreps_out = (
util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
lmax_node, # type: ignore
parity_mode,
fix_multiplicity=feature_multiplicity,
)
if irreps_manual is None
else irreps_manual[t + 1]
)
irreps_out_tp = util.infer_irreps_out(
irreps_x, # type: ignore
irreps_filter,
irreps_out.lmax, # type: ignore
parity_mode,
fix_multiplicity,
)
else:
raise ValueError(f'Unknown interaction type: {interaction_type}')
param_interaction_block.update(
{
'irreps_out_tp': irreps_out_tp,
'irreps_out': irreps_out,
}
)
layers.update(interaction_builder(**param_interaction_block))
irreps_x = irreps_out
layers.update(init_feature_reduce(config, irreps_x))
layers.update(
{
'rescale_atomic_energy': init_shift_scale(config),
'reduce_total_enegy': AtomReduce(
data_key_in=KEY.ATOMIC_ENERGY,
data_key_out=KEY.PRED_TOTAL_ENERGY,
),
}
)
gradient_module = ForceStressOutputFromEdge()
grad_key = gradient_module.get_grad_key()
layers.update({'force_output': gradient_module})
common_args = {
'cutoff': cutoff,
'type_map': config[KEY.TYPE_MAP],
'modal_map': config.get(KEY.MODAL_MAP, None),
'eval_type_map': False if parallel else True,
'eval_modal_map': False
if not config.get(KEY.USE_MODALITY, False) or parallel
else True,
'data_key_grad': grad_key,
}
if parallel:
layers_list = _to_parallel_model(layers, config)
return [
AtomGraphSequential(patch_modules(layers, config), **common_args)
for layers in layers_list
]
else:
return AtomGraphSequential(patch_modules(layers, config), **common_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment