Commit 2409a22f authored by fanding2000's avatar fanding2000
Browse files

Format fix. More options in readme

parent ce29afea
import os import os
import time import time
import traceback import traceback
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from ase.data import atomic_numbers from ase.data import atomic_numbers
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn import __version__ from sevenn import __version__
CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()} CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()}
class Singleton(type): class Singleton(type):
_instances = {} _instances = {}
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
if cls not in cls._instances: if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls] return cls._instances[cls]
class Logger(metaclass=Singleton): class Logger(metaclass=Singleton):
SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output
def __init__( def __init__(
self, filename: Optional[str] = None, screen: bool = False, rank: int = 0 self, filename: Optional[str] = None, screen: bool = False, rank: int = 0
): ):
self.rank = rank self.rank = rank
self._filename = filename self._filename = filename
if rank == 0: if rank == 0:
# if filename is not None: # if filename is not None:
# self.logfile = open(filename, 'a', buffering=1) # self.logfile = open(filename, 'a', buffering=1)
self.logfile = None self.logfile = None
self.files = {} self.files = {}
self.screen = screen self.screen = screen
else: else:
self.logfile = None self.logfile = None
self.screen = False self.screen = False
self.timer_dct = {} self.timer_dct = {}
self.active = True self.active = True
def __enter__(self): def __enter__(self):
if self.rank != 0: if self.rank != 0:
return self return self
if self.logfile is None and self._filename is not None: if self.logfile is None and self._filename is not None:
try: try:
self.logfile = open( self.logfile = open(
self._filename, 'a', buffering=1, encoding='utf-8' self._filename, 'a', buffering=1, encoding='utf-8'
) )
except IOError as e: except IOError as e:
print(f'Failed to re-open log file {self._filename}: {e}') print(f'Failed to re-open log file {self._filename}: {e}')
self.logfile = None self.logfile = None
self.files = {} self.files = {}
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
if self.rank != 0: if self.rank != 0:
return self return self
try: try:
if self.logfile is not None: if self.logfile is not None:
self.logfile.close() self.logfile.close()
self.logfile = None self.logfile = None
for f in self.files.values(): for f in self.files.values():
f.close() f.close()
except IOError as e: except IOError as e:
print(f'Failed to close log files: {e}') print(f'Failed to close log files: {e}')
finally: finally:
self.logfile = None self.logfile = None
self.files = {} self.files = {}
def switch_file(self, new_filename: str): def switch_file(self, new_filename: str):
if self.rank != 0: if self.rank != 0:
return self return self
if self.logfile is not None: if self.logfile is not None:
raise ValueError('Current logfile is not yet closed') raise ValueError('Current logfile is not yet closed')
self._filename = new_filename self._filename = new_filename
return self return self
def write(self, content: str): def write(self, content: str):
if self.rank != 0: if self.rank != 0:
return return
# no newline! # no newline!
if self.logfile is not None and self.active: if self.logfile is not None and self.active:
self.logfile.write(content) self.logfile.write(content)
if self.screen and self.active: if self.screen and self.active:
print(content, end='') print(content, end='')
def writeline(self, content: str): def writeline(self, content: str):
content = content + '\n' content = content + '\n'
self.write(content) self.write(content)
def init_csv(self, filename: str, header: list): def init_csv(self, filename: str, header: list):
""" """
Deprecated Deprecated
""" """
if self.rank == 0: if self.rank == 0:
self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8') self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8')
self.files[filename].write(','.join(header) + '\n') self.files[filename].write(','.join(header) + '\n')
else: else:
pass pass
def append_csv(self, filename: str, content: list, decimal: int = 6): def append_csv(self, filename: str, content: list, decimal: int = 6):
""" """
Deprecated Deprecated
""" """
if self.rank == 0: if self.rank == 0:
if filename not in self.files: if filename not in self.files:
self.files[filename] = open(filename, 'a', buffering=1) self.files[filename] = open(filename, 'a', buffering=1)
str_content = [] str_content = []
for c in content: for c in content:
if isinstance(c, float): if isinstance(c, float):
str_content.append(f'{c:.{decimal}f}') str_content.append(f'{c:.{decimal}f}')
else: else:
str_content.append(str(c)) str_content.append(str(c))
self.files[filename].write(','.join(str_content) + '\n') self.files[filename].write(','.join(str_content) + '\n')
else: else:
pass pass
def natoms_write(self, natoms: Dict[str, Dict]): def natoms_write(self, natoms: Dict[str, Dict]):
content = '' content = ''
total_natom = {} total_natom = {}
for label, natom in natoms.items(): for label, natom in natoms.items():
content += self.format_k_v(label, natom) content += self.format_k_v(label, natom)
for specie, num in natom.items(): for specie, num in natom.items():
try: try:
total_natom[specie] += num total_natom[specie] += num
except KeyError: except KeyError:
total_natom[specie] = num total_natom[specie] = num
content += self.format_k_v('Total, label wise', total_natom) content += self.format_k_v('Total, label wise', total_natom)
content += self.format_k_v('Total', sum(total_natom.values())) content += self.format_k_v('Total', sum(total_natom.values()))
self.write(content) self.write(content)
def statistic_write(self, statistic: Dict[str, Dict]): def statistic_write(self, statistic: Dict[str, Dict]):
content = '' content = ''
for label, dct in statistic.items(): for label, dct in statistic.items():
if label.startswith('_'): if label.startswith('_'):
continue continue
if not isinstance(dct, dict): if not isinstance(dct, dict):
continue continue
dct_new = {} dct_new = {}
for k, v in dct.items(): for k, v in dct.items():
if k.startswith('_'): if k.startswith('_'):
continue continue
if isinstance(v, int): if isinstance(v, int):
dct_new[k] = v dct_new[k] = v
else: else:
dct_new[k] = f'{v:.3f}' dct_new[k] = f'{v:.3f}'
content += self.format_k_v(label, dct_new) content += self.format_k_v(label, dct_new)
self.write(content) self.write(content)
# TODO : refactoring!!!, this is not loss, rmse # TODO : refactoring!!!, this is not loss, rmse
def epoch_write_specie_wise_loss(self, train_loss, valid_loss): def epoch_write_specie_wise_loss(self, train_loss, valid_loss):
lb_pad = 21 lb_pad = 21
fs = 6 fs = 6
pad = 21 - fs pad = 21 - fs
ln = '-' * fs ln = '-' * fs
total_atom_type = train_loss.keys() total_atom_type = train_loss.keys()
content = '' content = ''
for at in total_atom_type: for at in total_atom_type:
t_F = train_loss[at] t_F = train_loss[at]
v_F = valid_loss[at] v_F = valid_loss[at]
at_sym = CHEM_SYMBOLS[at] at_sym = CHEM_SYMBOLS[at]
content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format( content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format(
label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs
) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format( ) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format(
t_F=t_F, v_F=v_F, pad=pad, fs=fs t_F=t_F, v_F=v_F, pad=pad, fs=fs
) )
content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format( content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format(
t_S=ln, v_S=ln, pad=pad, fs=fs t_S=ln, v_S=ln, pad=pad, fs=fs
) )
content += '\n' content += '\n'
self.write(content) self.write(content)
def write_full_table( def write_full_table(
self, self,
dict_list: List[Dict], dict_list: List[Dict],
row_labels: List[str], row_labels: List[str],
decimal_places: int = 6, decimal_places: int = 6,
pad: int = 2, pad: int = 2,
): ):
""" """
Assume data_list is list of dict with same keys Assume data_list is list of dict with same keys
""" """
assert len(dict_list) == len(row_labels) assert len(dict_list) == len(row_labels)
label_len = max(map(len, row_labels)) label_len = max(map(len, row_labels))
# Extract the column names and create a 2D array of values # Extract the column names and create a 2D array of values
col_names = list(dict_list[0].keys()) col_names = list(dict_list[0].keys())
values = [list(d.values()) for d in dict_list] values = [list(d.values()) for d in dict_list]
# Format the numbers with the given decimal places # Format the numbers with the given decimal places
formatted_values = [ formatted_values = [
[f'{value:.{decimal_places}f}' for value in row] for row in values [f'{value:.{decimal_places}f}' for value in row] for row in values
] ]
# Calculate padding lengths for each column (with extra padding) # Calculate padding lengths for each column (with extra padding)
max_col_lengths = [ max_col_lengths = [
max(len(str(value)) for value in col) + pad max(len(str(value)) for value in col) + pad
for col in zip(col_names, *formatted_values) for col in zip(col_names, *formatted_values)
] ]
# Create header row and separator # Create header row and separator
header = ' ' * (label_len + pad) + ' '.join( header = ' ' * (label_len + pad) + ' '.join(
col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths) col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths)
) )
separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * ( separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * (
label_len + pad label_len + pad
) )
# Print header and separator # Print header and separator
self.writeline(header) self.writeline(header)
self.writeline(separator) self.writeline(separator)
# Print the data rows with row labels # Print the data rows with row labels
for row_label, row in zip(row_labels, formatted_values): for row_label, row in zip(row_labels, formatted_values):
data_row = ' '.join( data_row = ' '.join(
value.rjust(pad) for value, pad in zip(row, max_col_lengths) value.rjust(pad) for value, pad in zip(row, max_col_lengths)
) )
self.writeline(f'{row_label.ljust(label_len)}{data_row}') self.writeline(f'{row_label.ljust(label_len)}{data_row}')
def format_k_v(self, key: Any, val: Any, write: bool = False): def format_k_v(self, key: Any, val: Any, write: bool = False):
""" """
key and val should be str convertible key and val should be str convertible
""" """
MAX_KEY_SIZE = 20 MAX_KEY_SIZE = 20
SEPARATOR = ', ' SEPARATOR = ', '
EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3) EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3)
NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5 NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5
key = str(key) key = str(key)
val = str(val) val = str(val)
content = f'{key:<{MAX_KEY_SIZE}}: {val}' content = f'{key:<{MAX_KEY_SIZE}}: {val}'
if len(content) > NEW_LINE_LEN: if len(content) > NEW_LINE_LEN:
content = f'{key:<{MAX_KEY_SIZE}}: ' content = f'{key:<{MAX_KEY_SIZE}}: '
# septate val by separator # septate val by separator
val_list = val.split(SEPARATOR) val_list = val.split(SEPARATOR)
current_len = len(content) current_len = len(content)
for val_compo in val_list: for val_compo in val_list:
current_len += len(val_compo) current_len += len(val_compo)
if current_len > NEW_LINE_LEN: if current_len > NEW_LINE_LEN:
newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}' newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}'
content += f'\\\n{newline_content}' content += f'\\\n{newline_content}'
current_len = len(newline_content) current_len = len(newline_content)
else: else:
content += f'{val_compo}{SEPARATOR}' content += f'{val_compo}{SEPARATOR}'
if content.endswith(f'{SEPARATOR}'): if content.endswith(f'{SEPARATOR}'):
content = content[: -len(SEPARATOR)] content = content[: -len(SEPARATOR)]
content += '\n' content += '\n'
if write is False: if write is False:
return content return content
else: else:
self.write(content) self.write(content)
return '' return ''
def greeting(self): def greeting(self):
LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii' LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii'
with open(LOGO_ASCII_FILE, 'r') as logo_f: with open(LOGO_ASCII_FILE, 'r') as logo_f:
logo_ascii = logo_f.read() logo_ascii = logo_f.read()
content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n' content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n'
content += f'version {__version__}, {time.ctime()}\n' content += f'version {__version__}, {time.ctime()}\n'
self.write(content) self.write(content)
self.write(logo_ascii) self.write(logo_ascii)
def bar(self): def bar(self):
content = '-' * Logger.SCREEN_WIDTH + '\n' content = '-' * Logger.SCREEN_WIDTH + '\n'
self.write(content) self.write(content)
def print_config( def print_config(
self, self,
model_config: Dict[str, Any], model_config: Dict[str, Any],
data_config: Dict[str, Any], data_config: Dict[str, Any],
train_config: Dict[str, Any], train_config: Dict[str, Any],
): ):
""" """
print some important information from config print some important information from config
""" """
content = 'successfully read yaml config!\n\n' + 'from model configuration\n' content = 'successfully read yaml config!\n\n' + 'from model configuration\n'
for k, v in model_config.items(): for k, v in model_config.items():
content += self.format_k_v(k, str(v)) content += self.format_k_v(k, str(v))
content += '\nfrom train configuration\n' content += '\nfrom train configuration\n'
for k, v in train_config.items(): for k, v in train_config.items():
content += self.format_k_v(k, str(v)) content += self.format_k_v(k, str(v))
content += '\nfrom data configuration\n' content += '\nfrom data configuration\n'
for k, v in data_config.items(): for k, v in data_config.items():
content += self.format_k_v(k, str(v)) content += self.format_k_v(k, str(v))
self.write(content) self.write(content)
# TODO: This is not good make own exception # TODO: This is not good make own exception
def error(self, e: Exception): def error(self, e: Exception):
content = '' content = ''
if type(e) is ValueError: if type(e) is ValueError:
content += 'Error occurred!\n' content += 'Error occurred!\n'
content += str(e) + '\n' content += str(e) + '\n'
else: else:
content += 'Unknown error occurred!\n' content += 'Unknown error occurred!\n'
content += traceback.format_exc() content += traceback.format_exc()
self.write(content) self.write(content)
def timer_start(self, name: str): def timer_start(self, name: str):
self.timer_dct[name] = datetime.now() self.timer_dct[name] = datetime.now()
def timer_end(self, name: str, message: str, remove: bool = True): def timer_end(self, name: str, message: str, remove: bool = True):
""" """
print f"{message}: {elapsed}" print f"{message}: {elapsed}"
""" """
elapsed = str(datetime.now() - self.timer_dct[name]) elapsed = str(datetime.now() - self.timer_dct[name])
# elapsed = elapsed.strftime('%H-%M-%S') # elapsed = elapsed.strftime('%H-%M-%S')
if remove: if remove:
del self.timer_dct[name] del self.timer_dct[name]
self.write(f'{message}: {elapsed[:-4]}\n') self.write(f'{message}: {elapsed[:-4]}\n')
# TODO: print it without config # TODO: print it without config
# TODO: refactoring, readout part name :( # TODO: refactoring, readout part name :(
def print_model_info(self, model, config): def print_model_info(self, model, config):
from functools import partial from functools import partial
kv_write = partial(self.format_k_v, write=True) kv_write = partial(self.format_k_v, write=True)
self.writeline('Irreps of features') self.writeline('Irreps of features')
kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out')) kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out'))
for i in range(config[KEY.NUM_CONVOLUTION]): for i in range(config[KEY.NUM_CONVOLUTION]):
kv_write( kv_write(
f'{i}th node', f'{i}th node',
model.get_irreps_in(f'{i}_self_interaction_1'), model.get_irreps_in(f'{i}_self_interaction_1'),
) )
i = config[KEY.NUM_CONVOLUTION] - 1 i = config[KEY.NUM_CONVOLUTION] - 1
kv_write( kv_write(
'readout irreps', 'readout irreps',
model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'), model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'),
) )
num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad) num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.writeline(f'# learnable parameters: {num_weights}\n') self.writeline(f'# learnable parameters: {num_weights}\n')
import argparse import argparse
import os import os
import sys import sys
import time import time
from sevenn import __version__ from sevenn import __version__
description = 'train a model given the input.yaml' description = 'train a model given the input.yaml'
input_yaml_help = 'input.yaml for training' input_yaml_help = 'input.yaml for training'
mode_help = 'main training script to run. Default is train.' mode_help = 'main training script to run. Default is train.'
working_dir_help = 'path to write output. Default is cwd.' working_dir_help = 'path to write output. Default is cwd.'
screen_help = 'print log to stdout' screen_help = 'print log to stdout'
distributed_help = 'set this flag if it is distributed training' distributed_help = 'set this flag if it is distributed training'
distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi' distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi'
# Metainfo will be saved to checkpoint # Metainfo will be saved to checkpoint
global_config = { global_config = {
'version': __version__, 'version': __version__,
'when': time.ctime(), 'when': time.ctime(),
'_model_type': 'E3_equivariant_model', '_model_type': 'E3_equivariant_model',
} }
def run(args): def run(args):
""" """
main function of sevenn main function of sevenn
""" """
import random import random
import sys import sys
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn.logger import Logger from sevenn.logger import Logger
from sevenn.parse_input import read_config_yaml from sevenn.parse_input import read_config_yaml
from sevenn.scripts.train import train, train_v2 from sevenn.scripts.train import train, train_v2
from sevenn.util import unique_filepath from sevenn.util import unique_filepath
input_yaml = args.input_yaml input_yaml = args.input_yaml
mode = args.mode mode = args.mode
working_dir = args.working_dir working_dir = args.working_dir
log = args.log log = args.log
screen = args.screen screen = args.screen
distributed = args.distributed distributed = args.distributed
distributed_backend = args.distributed_backend distributed_backend = args.distributed_backend
use_cue = args.enable_cueq use_cue = args.enable_cueq
if use_cue: if use_cue:
import sevenn.nn.cue_helper import sevenn.nn.cue_helper
if not sevenn.nn.cue_helper.is_cue_available(): if not sevenn.nn.cue_helper.is_cue_available():
raise ImportError('cuEquivariance not installed.') raise ImportError('cuEquivariance not installed.')
if working_dir is None: if working_dir is None:
working_dir = os.getcwd() working_dir = os.getcwd()
elif not os.path.isdir(working_dir): elif not os.path.isdir(working_dir):
os.makedirs(working_dir, exist_ok=True) os.makedirs(working_dir, exist_ok=True)
world_size = 1 world_size = 1
if distributed: if distributed:
if distributed_backend == 'nccl': if distributed_backend == 'nccl':
local_rank = int(os.environ['LOCAL_RANK']) local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK']) rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE']) world_size = int(os.environ['WORLD_SIZE'])
elif distributed_backend == 'mpi': elif distributed_backend == 'mpi':
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
rank = int(os.environ['OMPI_COMM_WORLD_RANK']) rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
else: else:
raise ValueError(f'Unknown distributed backend: {distributed_backend}') raise ValueError(f'Unknown distributed backend: {distributed_backend}')
dist.init_process_group( dist.init_process_group(
backend=distributed_backend, world_size=world_size, rank=rank backend=distributed_backend, world_size=world_size, rank=rank
) )
else: else:
local_rank, rank, world_size = 0, 0, 1 local_rank, rank, world_size = 0, 0, 1
log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}') log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}')
with Logger(filename=log_fname, screen=screen, rank=rank) as logger: with Logger(filename=log_fname, screen=screen, rank=rank) as logger:
logger.greeting() logger.greeting()
if distributed: if distributed:
logger.writeline( logger.writeline(
f'Distributed training enabled, total world size is {world_size}' f'Distributed training enabled, total world size is {world_size}'
) )
try: try:
model_config, train_config, data_config = read_config_yaml( model_config, train_config, data_config = read_config_yaml(
input_yaml, return_separately=True input_yaml, return_separately=True
) )
except Exception as e: except Exception as e:
logger.writeline('Failed to parsing input.yaml') logger.writeline('Failed to parsing input.yaml')
logger.error(e) logger.error(e)
sys.exit(1) sys.exit(1)
train_config[KEY.IS_DDP] = distributed train_config[KEY.IS_DDP] = distributed
train_config[KEY.DDP_BACKEND] = distributed_backend train_config[KEY.DDP_BACKEND] = distributed_backend
train_config[KEY.LOCAL_RANK] = local_rank train_config[KEY.LOCAL_RANK] = local_rank
train_config[KEY.RANK] = rank train_config[KEY.RANK] = rank
train_config[KEY.WORLD_SIZE] = world_size train_config[KEY.WORLD_SIZE] = world_size
if distributed: if distributed:
torch.cuda.set_device(torch.device('cuda', local_rank)) torch.cuda.set_device(torch.device('cuda', local_rank))
if use_cue: if use_cue:
if KEY.CUEQUIVARIANCE_CONFIG not in model_config: if KEY.CUEQUIVARIANCE_CONFIG not in model_config:
model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True} model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True}
else: else:
model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True}) model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True})
logger.print_config(model_config, data_config, train_config) logger.print_config(model_config, data_config, train_config)
# don't have to distinguish configs inside program # don't have to distinguish configs inside program
global_config.update(model_config) global_config.update(model_config)
global_config.update(train_config) global_config.update(train_config)
global_config.update(data_config) global_config.update(data_config)
# Not implemented # Not implemented
if global_config[KEY.DTYPE] == 'double': if global_config[KEY.DTYPE] == 'double':
raise Exception('double precision is not implemented yet') raise Exception('double precision is not implemented yet')
# torch.set_default_dtype(torch.double) # torch.set_default_dtype(torch.double)
seed = global_config[KEY.RANDOM_SEED] seed = global_config[KEY.RANDOM_SEED]
random.seed(seed) random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
# run train # run train
if mode == 'train_v1': if mode == 'train_v1':
train(global_config, working_dir) train(global_config, working_dir)
elif mode == 'train_v2': elif mode == 'train_v2':
train_v2(global_config, working_dir) train_v2(global_config, working_dir)
def cmd_parser_train(parser): def cmd_parser_train(parser):
ag = parser ag = parser
ag.add_argument('input_yaml', help=input_yaml_help, type=str) ag.add_argument('input_yaml', help=input_yaml_help, type=str)
ag.add_argument( ag.add_argument(
'-m', '-m',
'--mode', '--mode',
choices=['train_v1', 'train_v2'], choices=['train_v1', 'train_v2'],
default='train_v2', default='train_v2',
help=mode_help, help=mode_help,
type=str, type=str,
) )
ag.add_argument( ag.add_argument(
'-cueq', '-cueq',
'--enable_cueq', '--enable_cueq',
help='(Not stable!) use cuEquivariance for training', help='(Not stable!) use cuEquivariance for training',
action='store_true', action='store_true',
) )
ag.add_argument( ag.add_argument(
'-w', '-w',
'--working_dir', '--working_dir',
nargs='?', nargs='?',
const=os.getcwd(), const=os.getcwd(),
help=working_dir_help, help=working_dir_help,
type=str, type=str,
) )
ag.add_argument( ag.add_argument(
'-l', '-l',
'--log', '--log',
default='log.sevenn', default='log.sevenn',
help='name of logfile, default is log.sevenn', help='name of logfile, default is log.sevenn',
type=str, type=str,
) )
ag.add_argument('-s', '--screen', help=screen_help, action='store_true') ag.add_argument('-s', '--screen', help=screen_help, action='store_true')
ag.add_argument( ag.add_argument(
'-d', '--distributed', help=distributed_help, action='store_true' '-d', '--distributed', help=distributed_help, action='store_true'
) )
ag.add_argument( ag.add_argument(
'--distributed_backend', '--distributed_backend',
help=distributed_backend_help, help=distributed_backend_help,
type=str, type=str,
default='nccl', default='nccl',
choices=['nccl', 'mpi'], choices=['nccl', 'mpi'],
) )
def add_parser(subparsers): def add_parser(subparsers):
ag = subparsers.add_parser('train', help=description) ag = subparsers.add_parser('train', help=description)
cmd_parser_train(ag) cmd_parser_train(ag)
def set_default_subparser(self, name, args=None, positional_args=0): def set_default_subparser(self, name, args=None, positional_args=0):
"""default subparser selection. Call after setup, just before parse_args() """default subparser selection. Call after setup, just before parse_args()
name: is the name of the subparser to call by default name: is the name of the subparser to call by default
args: if set is the argument list handed to parse_args() args: if set is the argument list handed to parse_args()
Hack copied from stack overflow Hack copied from stack overflow
""" """
subparser_found = False subparser_found = False
for arg in sys.argv[1:]: for arg in sys.argv[1:]:
if arg in ['-h', '--help']: # global help if no subparser if arg in ['-h', '--help']: # global help if no subparser
break break
else: else:
for x in self._subparsers._actions: for x in self._subparsers._actions:
if not isinstance(x, argparse._SubParsersAction): if not isinstance(x, argparse._SubParsersAction):
continue continue
for sp_name in x._name_parser_map.keys(): for sp_name in x._name_parser_map.keys():
if sp_name in sys.argv[1:]: if sp_name in sys.argv[1:]:
subparser_found = True subparser_found = True
if not subparser_found: if not subparser_found:
# insert default in last position before global positional # insert default in last position before global positional
# arguments, this implies no global options are specified after # arguments, this implies no global options are specified after
# first positional argument # first positional argument
if args is None: if args is None:
sys.argv.insert(len(sys.argv) - positional_args, name) sys.argv.insert(len(sys.argv) - positional_args, name)
else: else:
args.insert(len(args) - positional_args, name) args.insert(len(args) - positional_args, name)
argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore
def main(): def main():
import sevenn.main.sevenn_cp as checkpoint_cmd import sevenn.main.sevenn_cp as checkpoint_cmd
import sevenn.main.sevenn_get_model as get_model_cmd import sevenn.main.sevenn_get_model as get_model_cmd
import sevenn.main.sevenn_graph_build as graph_build_cmd import sevenn.main.sevenn_graph_build as graph_build_cmd
import sevenn.main.sevenn_inference as inference_cmd import sevenn.main.sevenn_inference as inference_cmd
import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd
import sevenn.main.sevenn_preset as preset_cmd import sevenn.main.sevenn_preset as preset_cmd
ag = argparse.ArgumentParser(f'SevenNet version={__version__}') ag = argparse.ArgumentParser(f'SevenNet version={__version__}')
subparsers = ag.add_subparsers(dest='command', help='Sub-commands') subparsers = ag.add_subparsers(dest='command', help='Sub-commands')
add_parser(subparsers) # add 'train' add_parser(subparsers) # add 'train'
checkpoint_cmd.add_parser(subparsers) checkpoint_cmd.add_parser(subparsers)
inference_cmd.add_parser(subparsers) inference_cmd.add_parser(subparsers)
graph_build_cmd.add_parser(subparsers) graph_build_cmd.add_parser(subparsers)
preset_cmd.add_parser(subparsers) preset_cmd.add_parser(subparsers)
get_model_cmd.add_parser(subparsers) get_model_cmd.add_parser(subparsers)
patch_lammps_cmd.add_parser(subparsers) patch_lammps_cmd.add_parser(subparsers)
ag.set_default_subparser('train') # type: ignore ag.set_default_subparser('train') # type: ignore
args = ag.parse_args() args = ag.parse_args()
if args.command is None: # backward compatibility if args.command is None: # backward compatibility
args.command = 'train' args.command = 'train'
if args.command == 'train': if args.command == 'train':
run(args) run(args)
elif args.command == 'preset': elif args.command == 'preset':
preset_cmd.run(args) preset_cmd.run(args)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
import argparse import argparse
import os.path as osp import os.path as osp
from sevenn import __version__ from sevenn import __version__
description = ( description = (
'tool box for sevennet checkpoints' 'tool box for sevennet checkpoints'
) )
def add_parser(subparsers): def add_parser(subparsers):
ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp']) ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp'])
add_args(ag) add_args(ag)
def add_args(parser): def add_args(parser):
ag = parser ag = parser
ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str) ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str)
group = ag.add_mutually_exclusive_group(required=False) group = ag.add_mutually_exclusive_group(required=False)
group.add_argument( group.add_argument(
'--get_yaml', '--get_yaml',
choices=['reproduce', 'continue', 'continue_modal'], choices=['reproduce', 'continue', 'continue_modal'],
help='create input.yaml based on the given checkpoint', help='create input.yaml based on the given checkpoint',
type=str, type=str,
) )
group.add_argument( group.add_argument(
'--append_modal_yaml', '--append_modal_yaml',
help='append modality with given yaml.', help='append modality with given yaml.',
type=str, type=str,
) )
ag.add_argument( ag.add_argument(
'--original_modal_name', '--original_modal_name',
help=( help=(
'when the append_modal is used and checkpoint is not multi-modal, ' 'when the append_modal is used and checkpoint is not multi-modal, '
+ 'used to name previously trained modality. defaults to "origin"' + 'used to name previously trained modality. defaults to "origin"'
), ),
default='origin', default='origin',
type=str, type=str,
) )
def run(args): def run(args):
import torch import torch
import yaml import yaml
from sevenn.parse_input import read_config_yaml from sevenn.parse_input import read_config_yaml
from sevenn.util import load_checkpoint from sevenn.util import load_checkpoint
checkpoint = load_checkpoint(args.checkpoint) checkpoint = load_checkpoint(args.checkpoint)
if args.get_yaml: if args.get_yaml:
mode = args.get_yaml mode = args.get_yaml
cfg = checkpoint.yaml_dict(mode) cfg = checkpoint.yaml_dict(mode)
print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False)) print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False))
elif args.append_modal_yaml: elif args.append_modal_yaml:
dst_yaml = args.append_modal_yaml dst_yaml = args.append_modal_yaml
if not osp.exists(dst_yaml): if not osp.exists(dst_yaml):
raise FileNotFoundError(f'No yaml file {dst_yaml}') raise FileNotFoundError(f'No yaml file {dst_yaml}')
dst_config = read_config_yaml(dst_yaml, return_separately=False) dst_config = read_config_yaml(dst_yaml, return_separately=False)
model_state_dict = checkpoint.append_modal( model_state_dict = checkpoint.append_modal(
dst_config, args.original_modal_name dst_config, args.original_modal_name
) )
to_save = checkpoint.get_checkpoint_dict() to_save = checkpoint.get_checkpoint_dict()
to_save.update({'config': dst_config, 'model_state_dict': model_state_dict}) to_save.update({'config': dst_config, 'model_state_dict': model_state_dict})
torch.save(to_save, 'checkpoint_modal_appended.pth') torch.save(to_save, 'checkpoint_modal_appended.pth')
print('checkpoint_modal_appended.pth is successfully saved.') print('checkpoint_modal_appended.pth is successfully saved.')
print(f'update continue of {dst_yaml} as blow (recommend) to continue') print(f'update continue of {dst_yaml} as blow (recommend) to continue')
cont_dct = { cont_dct = {
'continue': { 'continue': {
'checkpoint': 'checkpoint_modal_appended.pth', 'checkpoint': 'checkpoint_modal_appended.pth',
'reset_epoch': True, 'reset_epoch': True,
'reset_optimizer': True, 'reset_optimizer': True,
'reset_scheduler': True, 'reset_scheduler': True,
} }
} }
print( print(
yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False) yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False)
) )
else: else:
print(checkpoint) print(checkpoint)
def main(args=None): def main(args=None):
ag = argparse.ArgumentParser(description=description) ag = argparse.ArgumentParser(description=description)
add_args(ag) add_args(ag)
run(ag.parse_args()) run(ag.parse_args())
import argparse import argparse
import os import os
from sevenn import __version__ from sevenn import __version__
description_get_model = ( description_get_model = (
'deploy LAMMPS model from the checkpoint' 'deploy LAMMPS model from the checkpoint'
) )
checkpoint_help = ( checkpoint_help = (
'path to the checkpoint | SevenNet-0 | 7net-0 |' 'path to the checkpoint | SevenNet-0 | 7net-0 |'
' {SevenNet-0|7net-0}_{11July2024|22May2024}' ' {SevenNet-0|7net-0}_{11July2024|22May2024}'
) )
output_name_help = 'filename prefix' output_name_help = 'filename prefix'
get_parallel_help = 'deploy parallel model' get_parallel_help = 'deploy parallel model'
def add_parser(subparsers): def add_parser(subparsers):
ag = subparsers.add_parser( ag = subparsers.add_parser(
'get_model', help=description_get_model, aliases=['deploy'] 'get_model', help=description_get_model, aliases=['deploy']
) )
add_args(ag) add_args(ag)
def add_args(parser): def add_args(parser):
ag = parser ag = parser
ag.add_argument('checkpoint', help=checkpoint_help, type=str) ag.add_argument('checkpoint', help=checkpoint_help, type=str)
ag.add_argument( ag.add_argument(
'-o', '--output_prefix', nargs='?', help=output_name_help, type=str '-o', '--output_prefix', nargs='?', help=output_name_help, type=str
) )
ag.add_argument( ag.add_argument(
'-p', '--get_parallel', help=get_parallel_help, action='store_true' '-p', '--get_parallel', help=get_parallel_help, action='store_true'
) )
ag.add_argument( ag.add_argument(
'-m', '-m',
'--modal', '--modal',
help='Modality of multi-modal model', help='Modality of multi-modal model',
type=str, type=str,
) )
def run(args): def run(args):
import sevenn.util import sevenn.util
from sevenn.scripts.deploy import deploy, deploy_parallel from sevenn.scripts.deploy import deploy, deploy_parallel
checkpoint = args.checkpoint checkpoint = args.checkpoint
output_prefix = args.output_prefix output_prefix = args.output_prefix
get_parallel = args.get_parallel get_parallel = args.get_parallel
get_serial = not get_parallel get_serial = not get_parallel
modal = args.modal modal = args.modal
if output_prefix is None: if output_prefix is None:
output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial' output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial'
checkpoint_path = None checkpoint_path = None
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
checkpoint_path = checkpoint checkpoint_path = checkpoint
else: else:
checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint) checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint)
if get_serial: if get_serial:
deploy(checkpoint_path, output_prefix, modal) deploy(checkpoint_path, output_prefix, modal)
else: else:
deploy_parallel(checkpoint_path, output_prefix, modal) deploy_parallel(checkpoint_path, output_prefix, modal)
# legacy way # legacy way
def main(): def main():
ag = argparse.ArgumentParser(description=description_get_model) ag = argparse.ArgumentParser(description=description_get_model)
add_args(ag) add_args(ag)
run(ag.parse_args()) run(ag.parse_args())
import argparse import argparse
import glob import glob
import os import os
import sys import sys
from datetime import datetime from datetime import datetime
from sevenn import __version__ from sevenn import __version__
description = 'create `sevenn_data/dataset.pt` from ase readable' description = 'create `sevenn_data/dataset.pt` from ase readable'
source_help = 'source data to build graph, knows *' source_help = 'source data to build graph, knows *'
cutoff_help = 'cutoff radius of edges in Angstrom' cutoff_help = 'cutoff radius of edges in Angstrom'
filename_help = ( filename_help = (
'Name of the dataset, default is graph.pt. ' 'Name of the dataset, default is graph.pt. '
+ 'The dataset will be written under "sevenn_data", ' + 'The dataset will be written under "sevenn_data", '
+ 'for example, {out}/sevenn_data/graph.pt.' + 'for example, {out}/sevenn_data/graph.pt.'
) )
legacy_help = 'build legacy .sevenn_data' legacy_help = 'build legacy .sevenn_data'
def add_parser(subparsers): def add_parser(subparsers):
ag = subparsers.add_parser('graph_build', help=description) ag = subparsers.add_parser('graph_build', help=description)
add_args(ag) add_args(ag)
def add_args(parser): def add_args(parser):
ag = parser ag = parser
ag.add_argument('source', help=source_help, type=str) ag.add_argument('source', help=source_help, type=str)
ag.add_argument('cutoff', help=cutoff_help, type=float) ag.add_argument('cutoff', help=cutoff_help, type=float)
ag.add_argument( ag.add_argument(
'-n', '-n',
'--num_cores', '--num_cores',
help='number of cores to build graph in parallel', help='number of cores to build graph in parallel',
default=1, default=1,
type=int, type=int,
) )
ag.add_argument( ag.add_argument(
'-o', '-o',
'--out', '--out',
help='Existing path to write outputs.', help='Existing path to write outputs.',
type=str, type=str,
default='./', default='./',
) )
ag.add_argument( ag.add_argument(
'-f', '-f',
'--filename', '--filename',
help=filename_help, help=filename_help,
type=str, type=str,
default='graph.pt', default='graph.pt',
) )
ag.add_argument( ag.add_argument(
'--legacy', '--legacy',
help=legacy_help, help=legacy_help,
action='store_true', action='store_true',
) )
ag.add_argument( ag.add_argument(
'-s', '-s',
'--screen', '--screen',
help='print log to the screen', help='print log to the screen',
action='store_true', action='store_true',
) )
ag.add_argument( ag.add_argument(
'--kwargs', '--kwargs',
nargs=argparse.REMAINDER, nargs=argparse.REMAINDER,
help='will be passed to ase.io.read, or can be used to specify EFS key', help='will be passed to ase.io.read, or can be used to specify EFS key',
) )
def run(args): def run(args):
import sevenn.scripts.graph_build as graph_build import sevenn.scripts.graph_build as graph_build
from sevenn.logger import Logger from sevenn.logger import Logger
source = glob.glob(args.source) source = glob.glob(args.source)
cutoff = args.cutoff cutoff = args.cutoff
num_cores = args.num_cores num_cores = args.num_cores
filename = args.filename filename = args.filename
out = args.out out = args.out
legacy = args.legacy legacy = args.legacy
fmt_kwargs = {} fmt_kwargs = {}
if args.kwargs: if args.kwargs:
for kwarg in args.kwargs: for kwarg in args.kwargs:
k, v = kwarg.split('=') k, v = kwarg.split('=')
fmt_kwargs[k] = v fmt_kwargs[k] = v
if len(source) == 0: if len(source) == 0:
print('Source has zero len, nothing to read') print('Source has zero len, nothing to read')
sys.exit(0) sys.exit(0)
if not os.path.isdir(out): if not os.path.isdir(out):
raise NotADirectoryError(f'No such directory: {out}') raise NotADirectoryError(f'No such directory: {out}')
to_be_written = os.path.join(out, 'sevenn_data', filename) to_be_written = os.path.join(out, 'sevenn_data', filename)
if os.path.isfile(to_be_written): if os.path.isfile(to_be_written):
raise FileExistsError(f'File already exist: {to_be_written}') raise FileExistsError(f'File already exist: {to_be_written}')
metadata = { metadata = {
'sevenn_version': __version__, 'sevenn_version': __version__,
'when': datetime.now().strftime('%Y-%m-%d'), 'when': datetime.now().strftime('%Y-%m-%d'),
'cutoff': cutoff, 'cutoff': cutoff,
} }
with Logger(filename=None, screen=args.screen) as logger: with Logger(filename=None, screen=args.screen) as logger:
logger.writeline(description) logger.writeline(description)
if not legacy: if not legacy:
graph_build.build_sevennet_graph_dataset( graph_build.build_sevennet_graph_dataset(
source, source,
cutoff, cutoff,
num_cores, num_cores,
out, out,
filename, filename,
metadata, metadata,
**fmt_kwargs, **fmt_kwargs,
) )
else: else:
out = os.path.join(out, filename.split('.')[0]) out = os.path.join(out, filename.split('.')[0])
graph_build.build_script( # build .sevenn_data graph_build.build_script( # build .sevenn_data
source, source,
cutoff, cutoff,
num_cores, num_cores,
out, out,
metadata, metadata,
**fmt_kwargs, **fmt_kwargs,
) )
def main(args=None): def main(args=None):
ag = argparse.ArgumentParser(description=description) ag = argparse.ArgumentParser(description=description)
add_args(ag) add_args(ag)
run(ag.parse_args()) run(ag.parse_args())
import argparse import argparse
import glob import glob
import os import os
import sys import sys
description = ( description = (
'evaluate sevenn_data/ase readable with a model (checkpoint).' 'evaluate sevenn_data/ase readable with a model (checkpoint).'
) )
checkpoint_help = 'Checkpoint or pre-trained model name' checkpoint_help = 'Checkpoint or pre-trained model name'
target_help = 'Target files to evaluate' target_help = 'Target files to evaluate'
def add_parser(subparsers): def add_parser(subparsers):
ag = subparsers.add_parser('inference', help=description, aliases=['inf']) ag = subparsers.add_parser('inference', help=description, aliases=['inf'])
add_args(ag) add_args(ag)
def add_args(parser): def add_args(parser):
ag = parser ag = parser
ag.add_argument('checkpoint', type=str, help=checkpoint_help) ag.add_argument('checkpoint', type=str, help=checkpoint_help)
ag.add_argument('targets', type=str, nargs='+', help=target_help) ag.add_argument('targets', type=str, nargs='+', help=target_help)
ag.add_argument( ag.add_argument(
'-d', '-d',
'--device', '--device',
type=str, type=str,
default='auto', default='auto',
help='cpu/cuda/cuda:x', help='cpu/cuda/cuda:x',
) )
ag.add_argument( ag.add_argument(
'-nw', '-nw',
'--nworkers', '--nworkers',
type=int, type=int,
default=1, default=1,
help='Number of cores to build graph, defaults to 1', help='Number of cores to build graph, defaults to 1',
) )
ag.add_argument( ag.add_argument(
'-o', '-o',
'--output', '--output',
type=str, type=str,
default='./inference_results', default='./inference_results',
help='A directory name to write outputs', help='A directory name to write outputs',
) )
ag.add_argument( ag.add_argument(
'-b', '-b',
'--batch', '--batch',
type=int, type=int,
default='4', default='4',
help='batch size, useful for GPU' help='batch size, useful for GPU'
) )
ag.add_argument( ag.add_argument(
'-s', '-s',
'--save_graph', '--save_graph',
action='store_true', action='store_true',
help='Additionally, save preprocessed graph as sevenn_data' help='Additionally, save preprocessed graph as sevenn_data'
) )
ag.add_argument( ag.add_argument(
'-au', '-au',
'--allow_unlabeled', '--allow_unlabeled',
action='store_true', action='store_true',
help='Allow energy or force unlabeled data' help='Allow energy or force unlabeled data'
) )
ag.add_argument( ag.add_argument(
'-m', '-m',
'--modal', '--modal',
type=str, type=str,
default=None, default=None,
help='modality for multi-modal inference', help='modality for multi-modal inference',
) )
ag.add_argument( ag.add_argument(
'--kwargs', '--kwargs',
nargs=argparse.REMAINDER, nargs=argparse.REMAINDER,
help='will be passed to reader, or can be used to specify EFS key', help='will be passed to reader, or can be used to specify EFS key',
) )
def run(args): def run(args):
import torch import torch
from sevenn.scripts.inference import inference from sevenn.scripts.inference import inference
from sevenn.util import pretrained_name_to_path from sevenn.util import pretrained_name_to_path
out = args.output out = args.output
if os.path.exists(out): if os.path.exists(out):
raise FileExistsError(f'Directory {out} already exists') raise FileExistsError(f'Directory {out} already exists')
device = args.device device = args.device
if device == 'auto': if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
targets = [] targets = []
for target in args.targets: for target in args.targets:
targets.extend(glob.glob(target)) targets.extend(glob.glob(target))
if len(targets) == 0: if len(targets) == 0:
print('No targets (data to inference) are found') print('No targets (data to inference) are found')
sys.exit(0) sys.exit(0)
cp = args.checkpoint cp = args.checkpoint
if not os.path.isfile(cp): if not os.path.isfile(cp):
cp = pretrained_name_to_path(cp) # raises value error cp = pretrained_name_to_path(cp) # raises value error
fmt_kwargs = {} fmt_kwargs = {}
if args.kwargs: if args.kwargs:
for kwarg in args.kwargs: for kwarg in args.kwargs:
k, v = kwarg.split('=') k, v = kwarg.split('=')
fmt_kwargs[k] = v fmt_kwargs[k] = v
if args.save_graph and args.allow_unlabeled: if args.save_graph and args.allow_unlabeled:
raise ValueError('save_graph and allow_unlabeled are mutually exclusive') raise ValueError('save_graph and allow_unlabeled are mutually exclusive')
inference( inference(
cp, cp,
targets, targets,
out, out,
args.nworkers, args.nworkers,
device, device,
args.batch, args.batch,
args.save_graph, args.save_graph,
args.allow_unlabeled, args.allow_unlabeled,
args.modal, args.modal,
**fmt_kwargs, **fmt_kwargs,
) )
def main(args=None): def main(args=None):
ag = argparse.ArgumentParser(description=description) ag = argparse.ArgumentParser(description=description)
add_args(ag) add_args(ag)
run(ag.parse_args()) run(ag.parse_args())
import argparse import argparse
import os import os
import subprocess import subprocess
from sevenn import __version__ from sevenn import __version__
# python wrapper of patch_lammps.sh script # python wrapper of patch_lammps.sh script
# importlib.resources is correct way to do these things # importlib.resources is correct way to do these things
# but it changes so frequently to use # but it changes so frequently to use
pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn') pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn')
description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile' description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile'
def add_parser(subparsers): def add_parser(subparsers):
ag = subparsers.add_parser('patch_lammps', help=description) ag = subparsers.add_parser('patch_lammps', help=description)
add_args(ag) add_args(ag)
def add_args(parser): def add_args(parser):
ag = parser ag = parser
ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str) ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str)
ag.add_argument('--d3', help='Enable D3 support', action='store_true') ag.add_argument('--d3', help='Enable D3 support', action='store_true')
# cxx_standard is detected automatically # cxx_standard is detected automatically
def run(args): def run(args):
lammps_dir = os.path.abspath(args.lammps_dir) lammps_dir = os.path.abspath(args.lammps_dir)
print('Patching LAMMPS with the following settings:') print('Patching LAMMPS with the following settings:')
print(' - LAMMPS source directory:', lammps_dir) print(' - LAMMPS source directory:', lammps_dir)
cxx_standard = '17' # always 17 cxx_standard = '17' # always 17
if args.d3: if args.d3:
d3_support = '1' d3_support = '1'
print(' - D3 support enabled') print(' - D3 support enabled')
else: else:
d3_support = '0' d3_support = '0'
print(' - D3 support disabled') print(' - D3 support disabled')
script = f'{pair_e3gnn_dir}/patch_lammps.sh' script = f'{pair_e3gnn_dir}/patch_lammps.sh'
cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}' cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}'
res = subprocess.run(cmd.split()) res = subprocess.run(cmd.split())
return res.returncode # is it meaningless? return res.returncode # is it meaningless?
def main(args=None): def main(args=None):
ag = argparse.ArgumentParser(description=description) ag = argparse.ArgumentParser(description=description)
add_args(ag) add_args(ag)
run(ag.parse_args()) run(ag.parse_args())
if __name__ == '__main__': if __name__ == '__main__':
main() main()
import argparse import argparse
import os import os
from sevenn import __version__ from sevenn import __version__
description = ( description = (
'print the selected preset for training. ' 'print the selected preset for training. '
+ 'ex) sevennet_preset fine_tune > my_input.yaml' + 'ex) sevennet_preset fine_tune > my_input.yaml'
) )
preset_help = 'Name of preset' preset_help = 'Name of preset'
def add_parser(subparsers): def add_parser(subparsers):
ag = subparsers.add_parser('preset', help=description) ag = subparsers.add_parser('preset', help=description)
add_args(ag) add_args(ag)
def add_args(parser): def add_args(parser):
ag = parser ag = parser
ag.add_argument( ag.add_argument(
'preset', choices=[ 'preset', choices=[
'fine_tune', 'fine_tune',
'fine_tune_le', 'fine_tune_le',
'sevennet-0', 'sevennet-0',
'sevennet-l3i5', 'sevennet-l3i5',
'base', 'base',
'multi_modal' 'multi_modal'
], ],
help=preset_help help=preset_help
) )
def run(args): def run(args):
preset = args.preset preset = args.preset
prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets') prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets')
with open(f'{prefix}/{preset}.yaml', 'r') as f: with open(f'{prefix}/{preset}.yaml', 'r') as f:
print(f.read()) print(f.read())
# When executed as sevenn_preset (legacy way) # When executed as sevenn_preset (legacy way)
def main(args=None): def main(args=None):
ag = argparse.ArgumentParser(description=description) ag = argparse.ArgumentParser(description=description)
add_args(ag) add_args(ag)
run(ag.parse_args()) run(ag.parse_args())
import copy import copy
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import List, Literal, Union, overload from typing import List, Literal, Union, overload
from e3nn.o3 import Irreps from e3nn.o3 import Irreps
import sevenn._const as _const import sevenn._const as _const
import sevenn._keys as KEY import sevenn._keys as KEY
import sevenn.util as util import sevenn.util as util
from .nn.convolution import IrrepsConvolution from .nn.convolution import IrrepsConvolution
from .nn.edge_embedding import ( from .nn.edge_embedding import (
BesselBasis, BesselBasis,
EdgeEmbedding, EdgeEmbedding,
PolynomialCutoff, PolynomialCutoff,
SphericalEncoding, SphericalEncoding,
XPLORCutoff, XPLORCutoff,
) )
from .nn.force_output import ForceStressOutputFromEdge from .nn.force_output import ForceStressOutputFromEdge
from .nn.interaction_blocks import NequIP_interaction_block from .nn.interaction_blocks import NequIP_interaction_block
from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear
from .nn.node_embedding import OnehotEmbedding from .nn.node_embedding import OnehotEmbedding
from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale
from .nn.self_connection import ( from .nn.self_connection import (
SelfConnectionIntro, SelfConnectionIntro,
SelfConnectionLinearIntro, SelfConnectionLinearIntro,
SelfConnectionOutro, SelfConnectionOutro,
) )
from .nn.sequential import AtomGraphSequential from .nn.sequential import AtomGraphSequential
# warning from PyTorch, about e3nn type annotations # warning from PyTorch, about e3nn type annotations
warnings.filterwarnings( warnings.filterwarnings(
'ignore', 'ignore',
message=( message=(
"The TorchScript type system doesn't " 'support instance-level annotations' "The TorchScript type system doesn't " 'support instance-level annotations'
), ),
) )
def _insert_after(module_name_after, key_module_pair, layers): def _insert_after(module_name_after, key_module_pair, layers):
idx = -1 idx = -1
for i, (key, _) in enumerate(layers): for i, (key, _) in enumerate(layers):
if key == module_name_after: if key == module_name_after:
idx = i idx = i
break break
if idx == -1: if idx == -1:
return layers # do nothing if not found return layers # do nothing if not found
layers.insert(idx + 1, key_module_pair) layers.insert(idx + 1, key_module_pair)
return layers return layers
def init_self_connection(config): def init_self_connection(config):
self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE] self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE]
num_conv = config[KEY.NUM_CONVOLUTION] num_conv = config[KEY.NUM_CONVOLUTION]
if isinstance(self_connection_type_list, str): if isinstance(self_connection_type_list, str):
self_connection_type_list = [self_connection_type_list] * num_conv self_connection_type_list = [self_connection_type_list] * num_conv
io_pair_list = [] io_pair_list = []
for sc_type in self_connection_type_list: for sc_type in self_connection_type_list:
if sc_type == 'none': if sc_type == 'none':
io_pair = None io_pair = None
elif sc_type == 'nequip': elif sc_type == 'nequip':
io_pair = SelfConnectionIntro, SelfConnectionOutro io_pair = SelfConnectionIntro, SelfConnectionOutro
elif sc_type == 'linear': elif sc_type == 'linear':
io_pair = SelfConnectionLinearIntro, SelfConnectionOutro io_pair = SelfConnectionLinearIntro, SelfConnectionOutro
else: else:
raise ValueError(f'Unknown self_connection_type found: {sc_type}') raise ValueError(f'Unknown self_connection_type found: {sc_type}')
io_pair_list.append(io_pair) io_pair_list.append(io_pair)
return io_pair_list return io_pair_list
def init_edge_embedding(config): def init_edge_embedding(config):
_cutoff_param = {'cutoff_length': config[KEY.CUTOFF]} _cutoff_param = {'cutoff_length': config[KEY.CUTOFF]}
rbf, env, sph = None, None, None rbf, env, sph = None, None, None
rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS]) rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS])
rbf_dct.update(_cutoff_param) rbf_dct.update(_cutoff_param)
rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME) rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME)
if rbf_name == 'bessel': if rbf_name == 'bessel':
rbf = BesselBasis(**rbf_dct) rbf = BesselBasis(**rbf_dct)
envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION]) envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION])
envelop_dct.update(_cutoff_param) envelop_dct.update(_cutoff_param)
envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME) envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME)
if envelop_name == 'poly_cut': if envelop_name == 'poly_cut':
env = PolynomialCutoff(**envelop_dct) env = PolynomialCutoff(**envelop_dct)
elif envelop_name == 'XPLOR': elif envelop_name == 'XPLOR':
env = XPLORCutoff(**envelop_dct) env = XPLORCutoff(**envelop_dct)
lmax_edge = config[KEY.LMAX] lmax_edge = config[KEY.LMAX]
if config[KEY.LMAX_EDGE] > 0: if config[KEY.LMAX_EDGE] > 0:
lmax_edge = config[KEY.LMAX_EDGE] lmax_edge = config[KEY.LMAX_EDGE]
parity = -1 if config[KEY.IS_PARITY] else 1 parity = -1 if config[KEY.IS_PARITY] else 1
_normalize_sph = config[KEY._NORMALIZE_SPH] _normalize_sph = config[KEY._NORMALIZE_SPH]
sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph) sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph)
return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph) return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph)
def init_feature_reduce(config, irreps_x): def init_feature_reduce(config, irreps_x):
# features per node to scalar per node # features per node to scalar per node
layers = OrderedDict() layers = OrderedDict()
if config[KEY.READOUT_AS_FCN] is False: if config[KEY.READOUT_AS_FCN] is False:
hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))]) hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))])
layers.update( layers.update(
{ {
'reduce_input_to_hidden': IrrepsLinear( 'reduce_input_to_hidden': IrrepsLinear(
irreps_x, irreps_x,
hidden_irreps, hidden_irreps,
data_key_in=KEY.NODE_FEATURE, data_key_in=KEY.NODE_FEATURE,
biases=config[KEY.USE_BIAS_IN_LINEAR], biases=config[KEY.USE_BIAS_IN_LINEAR],
), ),
'reduce_hidden_to_energy': IrrepsLinear( 'reduce_hidden_to_energy': IrrepsLinear(
hidden_irreps, hidden_irreps,
Irreps([(1, (0, 1))]), Irreps([(1, (0, 1))]),
data_key_in=KEY.NODE_FEATURE, data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY, data_key_out=KEY.SCALED_ATOMIC_ENERGY,
biases=config[KEY.USE_BIAS_IN_LINEAR], biases=config[KEY.USE_BIAS_IN_LINEAR],
), ),
} }
) )
else: else:
act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]] act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]]
hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS] hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS]
layers.update( layers.update(
{ {
'readout_FCN': FCN_e3nn( 'readout_FCN': FCN_e3nn(
dim_out=1, dim_out=1,
hidden_neurons=hidden_neurons, hidden_neurons=hidden_neurons,
activation=act, activation=act,
data_key_in=KEY.NODE_FEATURE, data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.SCALED_ATOMIC_ENERGY, data_key_out=KEY.SCALED_ATOMIC_ENERGY,
irreps_in=irreps_x, irreps_in=irreps_x,
) )
} }
) )
return layers return layers
def init_shift_scale(config): def init_shift_scale(config):
# for mm, ex, shift: modal_idx -> shifts # for mm, ex, shift: modal_idx -> shifts
shift_scale = [] shift_scale = []
train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE] train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE]
type_map = config[KEY.TYPE_MAP] type_map = config[KEY.TYPE_MAP]
# in case of modal, shift or scale has more dims [][] # in case of modal, shift or scale has more dims [][]
# correct typing (I really want static python) # correct typing (I really want static python)
for s in (config[KEY.SHIFT], config[KEY.SCALE]): for s in (config[KEY.SHIFT], config[KEY.SCALE]):
if hasattr(s, 'tolist'): # numpy or torch if hasattr(s, 'tolist'): # numpy or torch
s = s.tolist() s = s.tolist()
if isinstance(s, dict): if isinstance(s, dict):
s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()} s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()}
if isinstance(s, list) and len(s) == 1: if isinstance(s, list) and len(s) == 1:
s = s[0] s = s[0]
shift_scale.append(s) shift_scale.append(s)
shift, scale = shift_scale shift, scale = shift_scale
rescale_module = None rescale_module = None
if config.get(KEY.USE_MODALITY, False): if config.get(KEY.USE_MODALITY, False):
rescale_module = ModalWiseRescale.from_mappers( # type: ignore rescale_module = ModalWiseRescale.from_mappers( # type: ignore
shift, shift,
scale, scale,
config[KEY.USE_MODAL_WISE_SHIFT], config[KEY.USE_MODAL_WISE_SHIFT],
config[KEY.USE_MODAL_WISE_SCALE], config[KEY.USE_MODAL_WISE_SCALE],
type_map=type_map, type_map=type_map,
modal_map=config[KEY.MODAL_MAP], modal_map=config[KEY.MODAL_MAP],
train_shift_scale=train_shift_scale, train_shift_scale=train_shift_scale,
) )
elif all([isinstance(s, float) for s in shift_scale]): elif all([isinstance(s, float) for s in shift_scale]):
rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale) rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale)
elif any([isinstance(s, list) for s in shift_scale]): elif any([isinstance(s, list) for s in shift_scale]):
rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore
shift, scale, type_map=type_map, train_shift_scale=train_shift_scale shift, scale, type_map=type_map, train_shift_scale=train_shift_scale
) )
else: else:
raise ValueError('shift, scale should be list of float or float') raise ValueError('shift, scale should be list of float or float')
return rescale_module return rescale_module
def patch_modality(layers: OrderedDict, config): def patch_modality(layers: OrderedDict, config):
""" """
Postprocess 7net-model to multimodal model. Postprocess 7net-model to multimodal model.
1. prepend modality one-hot embedding layer 1. prepend modality one-hot embedding layer
2. patch modalities of IrrepsLinear layers 2. patch modalities of IrrepsLinear layers
Modality aware shift scale is handled by init_shift_scale, not here Modality aware shift scale is handled by init_shift_scale, not here
""" """
cfg = config cfg = config
if not cfg.get(KEY.USE_MODALITY, False): if not cfg.get(KEY.USE_MODALITY, False):
return layers return layers
_layers = list(layers.items()) _layers = list(layers.items())
_layers = _insert_after( _layers = _insert_after(
'onehot_idx_to_onehot', 'onehot_idx_to_onehot',
( (
'one_hot_modality', 'one_hot_modality',
OnehotEmbedding( OnehotEmbedding(
num_classes=config[KEY.NUM_MODALITIES], num_classes=config[KEY.NUM_MODALITIES],
data_key_x=KEY.MODAL_TYPE, data_key_x=KEY.MODAL_TYPE,
data_key_out=KEY.MODAL_ATTR, data_key_out=KEY.MODAL_ATTR,
data_key_save=None, data_key_save=None,
data_key_additional=None, data_key_additional=None,
), ),
), ),
_layers, _layers,
) )
layers = OrderedDict(_layers) layers = OrderedDict(_layers)
num_modal = config[KEY.NUM_MODALITIES] num_modal = config[KEY.NUM_MODALITIES]
for k, module in layers.items(): for k, module in layers.items():
if not isinstance(module, IrrepsLinear): if not isinstance(module, IrrepsLinear):
continue continue
if ( if (
(cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x')) (cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x'))
or ( or (
cfg[KEY.USE_MODAL_SELF_INTER_INTRO] cfg[KEY.USE_MODAL_SELF_INTER_INTRO]
and k.endswith('self_interaction_1') and k.endswith('self_interaction_1')
) )
or ( or (
cfg[KEY.USE_MODAL_SELF_INTER_OUTRO] cfg[KEY.USE_MODAL_SELF_INTER_OUTRO]
and k.endswith('self_interaction_2') and k.endswith('self_interaction_2')
) )
or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden') or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden')
): ):
module.set_num_modalities(num_modal) module.set_num_modalities(num_modal)
return layers return layers
def patch_cue(layers: OrderedDict, config): def patch_cue(layers: OrderedDict, config):
import sevenn.nn.cue_helper as cue_helper import sevenn.nn.cue_helper as cue_helper
cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {})) cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {}))
if not cue_cfg.pop('use', False): if not cue_cfg.pop('use', False):
return layers return layers
if not cue_helper.is_cue_available(): if not cue_helper.is_cue_available():
warnings.warn( warnings.warn(
( (
'cuEquivariance is requested, but the package is not installed. ' 'cuEquivariance is requested, but the package is not installed. '
+ 'Fallback to original code.' + 'Fallback to original code.'
) )
) )
return layers return layers
if not cue_helper.is_cue_cuda_available_model(config): if not cue_helper.is_cue_cuda_available_model(config):
return layers return layers
group = 'O3' if config[KEY.IS_PARITY] else 'SO3' group = 'O3' if config[KEY.IS_PARITY] else 'SO3'
cueq_module_params = dict(layout='mul_ir') cueq_module_params = dict(layout='mul_ir')
cueq_module_params.update(cue_cfg) cueq_module_params.update(cue_cfg)
updates = {} updates = {}
for k, module in layers.items(): for k, module in layers.items():
if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)): if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)):
if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape
continue continue
module_patched = cue_helper.patch_linear( module_patched = cue_helper.patch_linear(
module, group, **cueq_module_params module, group, **cueq_module_params
) )
updates[k] = module_patched updates[k] = module_patched
elif isinstance(module, SelfConnectionIntro): elif isinstance(module, SelfConnectionIntro):
module_patched = cue_helper.patch_fully_connected( module_patched = cue_helper.patch_fully_connected(
module, group, **cueq_module_params module, group, **cueq_module_params
) )
updates[k] = module_patched updates[k] = module_patched
elif isinstance(module, IrrepsConvolution): elif isinstance(module, IrrepsConvolution):
module_patched = cue_helper.patch_convolution( module_patched = cue_helper.patch_convolution(
module, group, **cueq_module_params module, group, **cueq_module_params
) )
updates[k] = module_patched updates[k] = module_patched
layers.update(updates) layers.update(updates)
return layers return layers
def patch_modules(layers: OrderedDict, config): def patch_modules(layers: OrderedDict, config):
layers = patch_modality(layers, config) layers = patch_modality(layers, config)
layers = patch_cue(layers, config) layers = patch_cue(layers, config)
return layers return layers
def _to_parallel_model(layers: OrderedDict, config): def _to_parallel_model(layers: OrderedDict, config):
num_classes = layers['onehot_idx_to_onehot'].num_classes num_classes = layers['onehot_idx_to_onehot'].num_classes
one_hot_irreps = Irreps(f'{num_classes}x0e') one_hot_irreps = Irreps(f'{num_classes}x0e')
irreps_node_zero = layers['onehot_to_feature_x'].irreps_out irreps_node_zero = layers['onehot_to_feature_x'].irreps_out
_layers = list(layers.items()) _layers = list(layers.items())
layers_list = [] layers_list = []
num_convolution_layer = config[KEY.NUM_CONVOLUTION] num_convolution_layer = config[KEY.NUM_CONVOLUTION]
def slice_until_this(module_name, layers): def slice_until_this(module_name, layers):
idx = -1 idx = -1
for i, (key, _) in enumerate(layers): for i, (key, _) in enumerate(layers):
if key == module_name: if key == module_name:
idx = i idx = i
break break
first_to = layers[: idx + 1] first_to = layers[: idx + 1]
remain = layers[idx + 1 :] remain = layers[idx + 1 :]
return first_to, remain return first_to, remain
_layers = _insert_after( _layers = _insert_after(
'onehot_to_feature_x', 'onehot_to_feature_x',
( (
'one_hot_ghost', 'one_hot_ghost',
OnehotEmbedding( OnehotEmbedding(
data_key_x=KEY.NODE_FEATURE_GHOST, data_key_x=KEY.NODE_FEATURE_GHOST,
num_classes=num_classes, num_classes=num_classes,
data_key_save=None, data_key_save=None,
data_key_additional=None, data_key_additional=None,
), ),
), ),
_layers, _layers,
) )
_layers = _insert_after( _layers = _insert_after(
'one_hot_ghost', 'one_hot_ghost',
( (
'ghost_onehot_to_feature_x', 'ghost_onehot_to_feature_x',
IrrepsLinear( IrrepsLinear(
irreps_in=one_hot_irreps, irreps_in=one_hot_irreps,
irreps_out=irreps_node_zero, irreps_out=irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST, data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR], biases=config[KEY.USE_BIAS_IN_LINEAR],
), ),
), ),
_layers, _layers,
) )
_layers = _insert_after( _layers = _insert_after(
'0_self_interaction_1', '0_self_interaction_1',
( (
'ghost_0_self_interaction_1', 'ghost_0_self_interaction_1',
IrrepsLinear( IrrepsLinear(
irreps_node_zero, irreps_node_zero,
irreps_node_zero, irreps_node_zero,
data_key_in=KEY.NODE_FEATURE_GHOST, data_key_in=KEY.NODE_FEATURE_GHOST,
biases=config[KEY.USE_BIAS_IN_LINEAR], biases=config[KEY.USE_BIAS_IN_LINEAR],
), ),
), ),
_layers, _layers,
) )
# assign modules (before first communications) # assign modules (before first communications)
# initialize edge related to retain position gradients # initialize edge related to retain position gradients
for i in range(1, num_convolution_layer): for i in range(1, num_convolution_layer):
sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers) sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers)
layers_list.append(OrderedDict(sliced)) layers_list.append(OrderedDict(sliced))
_layers.insert(0, ('edge_embedding', init_edge_embedding(config))) _layers.insert(0, ('edge_embedding', init_edge_embedding(config)))
layers_list.append(OrderedDict(_layers)) layers_list.append(OrderedDict(_layers))
del layers_list[-1]['force_output'] # done in LAMMPS del layers_list[-1]['force_output'] # done in LAMMPS
return layers_list return layers_list
@overload @overload
def build_E3_equivariant_model( def build_E3_equivariant_model(
config: dict, parallel: Literal[False] = False config: dict, parallel: Literal[False] = False
) -> AtomGraphSequential: # noqa ) -> AtomGraphSequential: # noqa
... ...
@overload @overload
def build_E3_equivariant_model( def build_E3_equivariant_model(
config: dict, parallel: Literal[True] config: dict, parallel: Literal[True]
) -> List[AtomGraphSequential]: # noqa ) -> List[AtomGraphSequential]: # noqa
... ...
def build_E3_equivariant_model( def build_E3_equivariant_model(
config: dict, parallel: bool = False config: dict, parallel: bool = False
) -> Union[AtomGraphSequential, List[AtomGraphSequential]]: ) -> Union[AtomGraphSequential, List[AtomGraphSequential]]:
""" """
output shapes (w/o batch) output shapes (w/o batch)
PRED_TOTAL_ENERGY: (), PRED_TOTAL_ENERGY: (),
ATOMIC_ENERGY: (natoms, 1), # intended ATOMIC_ENERGY: (natoms, 1), # intended
PRED_FORCE: (natoms, 3), PRED_FORCE: (natoms, 3),
PRED_STRESS: (6,), PRED_STRESS: (6,),
for data w/o cell volume, pred_stress has garbage values for data w/o cell volume, pred_stress has garbage values
""" """
layers = OrderedDict() layers = OrderedDict()
cutoff = config[KEY.CUTOFF] cutoff = config[KEY.CUTOFF]
num_species = config[KEY.NUM_SPECIES] num_species = config[KEY.NUM_SPECIES]
feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY] feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY]
num_convolution_layer = config[KEY.NUM_CONVOLUTION] num_convolution_layer = config[KEY.NUM_CONVOLUTION]
interaction_type = config[KEY.INTERACTION_TYPE] interaction_type = config[KEY.INTERACTION_TYPE]
use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR] use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR]
lmax_node = config[KEY.LMAX] # ignore second (lmax_edge) lmax_node = config[KEY.LMAX] # ignore second (lmax_edge)
# if config[KEY.LMAX_EDGE] > 0: # not yet used # if config[KEY.LMAX_EDGE] > 0: # not yet used
# _ = config[KEY.LMAX_EDGE] # _ = config[KEY.LMAX_EDGE]
if config[KEY.LMAX_NODE] > 0: if config[KEY.LMAX_NODE] > 0:
lmax_node = config[KEY.LMAX_NODE] lmax_node = config[KEY.LMAX_NODE]
act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]] act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]]
self_connection_pair_list = init_self_connection(config) self_connection_pair_list = init_self_connection(config)
irreps_manual = None irreps_manual = None
if config[KEY.IRREPS_MANUAL] is not False: if config[KEY.IRREPS_MANUAL] is not False:
irreps_manual = config[KEY.IRREPS_MANUAL] irreps_manual = config[KEY.IRREPS_MANUAL]
try: try:
irreps_manual = [Irreps(irr) for irr in irreps_manual] irreps_manual = [Irreps(irr) for irr in irreps_manual]
assert len(irreps_manual) == num_convolution_layer + 1 assert len(irreps_manual) == num_convolution_layer + 1
except Exception: except Exception:
raise RuntimeError('invalid irreps_manual input given') raise RuntimeError('invalid irreps_manual input given')
conv_denominator = config[KEY.CONV_DENOMINATOR] conv_denominator = config[KEY.CONV_DENOMINATOR]
if not isinstance(conv_denominator, list): if not isinstance(conv_denominator, list):
conv_denominator = [conv_denominator] * num_convolution_layer conv_denominator = [conv_denominator] * num_convolution_layer
train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR] train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR]
edge_embedding = init_edge_embedding(config) edge_embedding = init_edge_embedding(config)
irreps_filter = edge_embedding.spherical.irreps_out irreps_filter = edge_embedding.spherical.irreps_out
radial_basis_num = edge_embedding.basis_function.num_basis radial_basis_num = edge_embedding.basis_function.num_basis
layers.update({'edge_embedding': edge_embedding}) layers.update({'edge_embedding': edge_embedding})
one_hot_irreps = Irreps(f'{num_species}x0e') one_hot_irreps = Irreps(f'{num_species}x0e')
irreps_x = ( irreps_x = (
Irreps(f'{feature_multiplicity}x0e') Irreps(f'{feature_multiplicity}x0e')
if irreps_manual is None if irreps_manual is None
else irreps_manual[0] else irreps_manual[0]
) )
layers.update( layers.update(
{ {
'onehot_idx_to_onehot': OnehotEmbedding( 'onehot_idx_to_onehot': OnehotEmbedding(
num_classes=num_species, num_classes=num_species,
data_key_x=KEY.NODE_FEATURE, data_key_x=KEY.NODE_FEATURE,
data_key_out=KEY.NODE_FEATURE, data_key_out=KEY.NODE_FEATURE,
data_key_save=KEY.ATOM_TYPE, # atomic numbers data_key_save=KEY.ATOM_TYPE, # atomic numbers
data_key_additional=KEY.NODE_ATTR, # one-hot embeddings data_key_additional=KEY.NODE_ATTR, # one-hot embeddings
), ),
'onehot_to_feature_x': IrrepsLinear( 'onehot_to_feature_x': IrrepsLinear(
irreps_in=one_hot_irreps, irreps_in=one_hot_irreps,
irreps_out=irreps_x, irreps_out=irreps_x,
data_key_in=KEY.NODE_FEATURE, data_key_in=KEY.NODE_FEATURE,
biases=use_bias_in_linear, biases=use_bias_in_linear,
), ),
} }
) )
weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS] weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS]
weight_nn_layers = [radial_basis_num] + weight_nn_hidden weight_nn_layers = [radial_basis_num] + weight_nn_hidden
param_interaction_block = { param_interaction_block = {
'irreps_filter': irreps_filter, 'irreps_filter': irreps_filter,
'weight_nn_layers': weight_nn_layers, 'weight_nn_layers': weight_nn_layers,
'train_conv_denominator': train_conv_denominator, 'train_conv_denominator': train_conv_denominator,
'act_radial': act_radial, 'act_radial': act_radial,
'bias_in_linear': use_bias_in_linear, 'bias_in_linear': use_bias_in_linear,
'num_species': num_species, 'num_species': num_species,
'parallel': parallel, 'parallel': parallel,
} }
interaction_builder = None interaction_builder = None
if interaction_type in ['nequip']: if interaction_type in ['nequip']:
act_scalar = {} act_scalar = {}
act_gate = {} act_gate = {}
for k, v in config[KEY.ACTIVATION_SCARLAR].items(): for k, v in config[KEY.ACTIVATION_SCARLAR].items():
act_scalar[k] = _const.ACTIVATION_DICT[k][v] act_scalar[k] = _const.ACTIVATION_DICT[k][v]
for k, v in config[KEY.ACTIVATION_GATE].items(): for k, v in config[KEY.ACTIVATION_GATE].items():
act_gate[k] = _const.ACTIVATION_DICT[k][v] act_gate[k] = _const.ACTIVATION_DICT[k][v]
param_interaction_block.update( param_interaction_block.update(
{ {
'act_scalar': act_scalar, 'act_scalar': act_scalar,
'act_gate': act_gate, 'act_gate': act_gate,
} }
) )
if interaction_type == 'nequip': if interaction_type == 'nequip':
interaction_builder = NequIP_interaction_block interaction_builder = NequIP_interaction_block
else: else:
raise ValueError(f'Unknown interaction type: {interaction_type}') raise ValueError(f'Unknown interaction type: {interaction_type}')
for t in range(num_convolution_layer): for t in range(num_convolution_layer):
param_interaction_block.update( param_interaction_block.update(
{ {
'irreps_x': irreps_x, 'irreps_x': irreps_x,
't': t, 't': t,
'conv_denominator': conv_denominator[t], 'conv_denominator': conv_denominator[t],
'self_connection_pair': self_connection_pair_list[t], 'self_connection_pair': self_connection_pair_list[t],
} }
) )
if interaction_type == 'nequip': if interaction_type == 'nequip':
parity_mode = 'full' parity_mode = 'full'
fix_multiplicity = False fix_multiplicity = False
if t == num_convolution_layer - 1: if t == num_convolution_layer - 1:
lmax_node = 0 lmax_node = 0
parity_mode = 'even' parity_mode = 'even'
# TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out # TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out
irreps_out = ( irreps_out = (
util.infer_irreps_out( util.infer_irreps_out(
irreps_x, # type: ignore irreps_x, # type: ignore
irreps_filter, irreps_filter,
lmax_node, # type: ignore lmax_node, # type: ignore
parity_mode, parity_mode,
fix_multiplicity=feature_multiplicity, fix_multiplicity=feature_multiplicity,
) )
if irreps_manual is None if irreps_manual is None
else irreps_manual[t + 1] else irreps_manual[t + 1]
) )
irreps_out_tp = util.infer_irreps_out( irreps_out_tp = util.infer_irreps_out(
irreps_x, # type: ignore irreps_x, # type: ignore
irreps_filter, irreps_filter,
irreps_out.lmax, # type: ignore irreps_out.lmax, # type: ignore
parity_mode, parity_mode,
fix_multiplicity, fix_multiplicity,
) )
else: else:
raise ValueError(f'Unknown interaction type: {interaction_type}') raise ValueError(f'Unknown interaction type: {interaction_type}')
param_interaction_block.update( param_interaction_block.update(
{ {
'irreps_out_tp': irreps_out_tp, 'irreps_out_tp': irreps_out_tp,
'irreps_out': irreps_out, 'irreps_out': irreps_out,
} }
) )
layers.update(interaction_builder(**param_interaction_block)) layers.update(interaction_builder(**param_interaction_block))
irreps_x = irreps_out irreps_x = irreps_out
layers.update(init_feature_reduce(config, irreps_x)) layers.update(init_feature_reduce(config, irreps_x))
layers.update( layers.update(
{ {
'rescale_atomic_energy': init_shift_scale(config), 'rescale_atomic_energy': init_shift_scale(config),
'reduce_total_enegy': AtomReduce( 'reduce_total_enegy': AtomReduce(
data_key_in=KEY.ATOMIC_ENERGY, data_key_in=KEY.ATOMIC_ENERGY,
data_key_out=KEY.PRED_TOTAL_ENERGY, data_key_out=KEY.PRED_TOTAL_ENERGY,
), ),
} }
) )
gradient_module = ForceStressOutputFromEdge() gradient_module = ForceStressOutputFromEdge()
grad_key = gradient_module.get_grad_key() grad_key = gradient_module.get_grad_key()
layers.update({'force_output': gradient_module}) layers.update({'force_output': gradient_module})
common_args = { common_args = {
'cutoff': cutoff, 'cutoff': cutoff,
'type_map': config[KEY.TYPE_MAP], 'type_map': config[KEY.TYPE_MAP],
'modal_map': config.get(KEY.MODAL_MAP, None), 'modal_map': config.get(KEY.MODAL_MAP, None),
'eval_type_map': False if parallel else True, 'eval_type_map': False if parallel else True,
'eval_modal_map': False 'eval_modal_map': False
if not config.get(KEY.USE_MODALITY, False) or parallel if not config.get(KEY.USE_MODALITY, False) or parallel
else True, else True,
'data_key_grad': grad_key, 'data_key_grad': grad_key,
} }
if parallel: if parallel:
layers_list = _to_parallel_model(layers, config) layers_list = _to_parallel_model(layers, config)
return [ return [
AtomGraphSequential(patch_modules(layers, config), **common_args) AtomGraphSequential(patch_modules(layers, config), **common_args)
for layers in layers_list for layers in layers_list
] ]
else: else:
return AtomGraphSequential(patch_modules(layers, config), **common_args) return AtomGraphSequential(patch_modules(layers, config), **common_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment