Unverified Commit ca86f720 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent b75ed73c
import argparse
import os
import sys
import time
from sevenn import __version__
description = 'train a model given the input.yaml'
input_yaml_help = 'input.yaml for training'
mode_help = 'main training script to run. Default is train.'
working_dir_help = 'path to write output. Default is cwd.'
screen_help = 'print log to stdout'
distributed_help = 'set this flag if it is distributed training'
distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi'
# Metainfo will be saved to checkpoint
global_config = {
'version': __version__,
'when': time.ctime(),
'_model_type': 'E3_equivariant_model',
}
def run(args):
"""
main function of sevenn
"""
import random
import sys
import torch
import torch.distributed as dist
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.parse_input import read_config_yaml
from sevenn.scripts.train import train, train_v2
from sevenn.util import unique_filepath
input_yaml = args.input_yaml
mode = args.mode
working_dir = args.working_dir
log = args.log
screen = args.screen
distributed = args.distributed
distributed_backend = args.distributed_backend
use_cue = args.enable_cueq
if use_cue:
import sevenn.nn.cue_helper
if not sevenn.nn.cue_helper.is_cue_available():
raise ImportError('cuEquivariance not installed.')
if working_dir is None:
working_dir = os.getcwd()
elif not os.path.isdir(working_dir):
os.makedirs(working_dir, exist_ok=True)
world_size = 1
if distributed:
if distributed_backend == 'nccl':
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
elif distributed_backend == 'mpi':
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
else:
raise ValueError(f'Unknown distributed backend: {distributed_backend}')
dist.init_process_group(
backend=distributed_backend, world_size=world_size, rank=rank
)
else:
local_rank, rank, world_size = 0, 0, 1
log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}')
with Logger(filename=log_fname, screen=screen, rank=rank) as logger:
logger.greeting()
if distributed:
logger.writeline(
f'Distributed training enabled, total world size is {world_size}'
)
try:
model_config, train_config, data_config = read_config_yaml(
input_yaml, return_separately=True
)
except Exception as e:
logger.writeline('Failed to parsing input.yaml')
logger.error(e)
sys.exit(1)
train_config[KEY.IS_DDP] = distributed
train_config[KEY.DDP_BACKEND] = distributed_backend
train_config[KEY.LOCAL_RANK] = local_rank
train_config[KEY.RANK] = rank
train_config[KEY.WORLD_SIZE] = world_size
if distributed:
torch.cuda.set_device(torch.device('cuda', local_rank))
if use_cue:
if KEY.CUEQUIVARIANCE_CONFIG not in model_config:
model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True}
else:
model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True})
logger.print_config(model_config, data_config, train_config)
# don't have to distinguish configs inside program
global_config.update(model_config)
global_config.update(train_config)
global_config.update(data_config)
# Not implemented
if global_config[KEY.DTYPE] == 'double':
raise Exception('double precision is not implemented yet')
# torch.set_default_dtype(torch.double)
seed = global_config[KEY.RANDOM_SEED]
random.seed(seed)
torch.manual_seed(seed)
# run train
if mode == 'train_v1':
train(global_config, working_dir)
elif mode == 'train_v2':
train_v2(global_config, working_dir)
def cmd_parser_train(parser):
ag = parser
ag.add_argument('input_yaml', help=input_yaml_help, type=str)
ag.add_argument(
'-m',
'--mode',
choices=['train_v1', 'train_v2'],
default='train_v2',
help=mode_help,
type=str,
)
ag.add_argument(
'-cueq',
'--enable_cueq',
help='(Not stable!) use cuEquivariance for training',
action='store_true',
)
ag.add_argument(
'-w',
'--working_dir',
nargs='?',
const=os.getcwd(),
help=working_dir_help,
type=str,
)
ag.add_argument(
'-l',
'--log',
default='log.sevenn',
help='name of logfile, default is log.sevenn',
type=str,
)
ag.add_argument('-s', '--screen', help=screen_help, action='store_true')
ag.add_argument(
'-d', '--distributed', help=distributed_help, action='store_true'
)
ag.add_argument(
'--distributed_backend',
help=distributed_backend_help,
type=str,
default='nccl',
choices=['nccl', 'mpi'],
)
def add_parser(subparsers):
ag = subparsers.add_parser('train', help=description)
cmd_parser_train(ag)
def set_default_subparser(self, name, args=None, positional_args=0):
"""default subparser selection. Call after setup, just before parse_args()
name: is the name of the subparser to call by default
args: if set is the argument list handed to parse_args()
Hack copied from stack overflow
"""
subparser_found = False
for arg in sys.argv[1:]:
if arg in ['-h', '--help']: # global help if no subparser
break
else:
for x in self._subparsers._actions:
if not isinstance(x, argparse._SubParsersAction):
continue
for sp_name in x._name_parser_map.keys():
if sp_name in sys.argv[1:]:
subparser_found = True
if not subparser_found:
# insert default in last position before global positional
# arguments, this implies no global options are specified after
# first positional argument
if args is None:
sys.argv.insert(len(sys.argv) - positional_args, name)
else:
args.insert(len(args) - positional_args, name)
argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore
def main():
import sevenn.main.sevenn_cp as checkpoint_cmd
import sevenn.main.sevenn_get_model as get_model_cmd
import sevenn.main.sevenn_graph_build as graph_build_cmd
import sevenn.main.sevenn_inference as inference_cmd
import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd
import sevenn.main.sevenn_preset as preset_cmd
ag = argparse.ArgumentParser(f'SevenNet version={__version__}')
subparsers = ag.add_subparsers(dest='command', help='Sub-commands')
add_parser(subparsers) # add 'train'
checkpoint_cmd.add_parser(subparsers)
inference_cmd.add_parser(subparsers)
graph_build_cmd.add_parser(subparsers)
preset_cmd.add_parser(subparsers)
get_model_cmd.add_parser(subparsers)
patch_lammps_cmd.add_parser(subparsers)
ag.set_default_subparser('train') # type: ignore
args = ag.parse_args()
if args.command is None: # backward compatibility
args.command = 'train'
if args.command == 'train':
run(args)
elif args.command == 'preset':
preset_cmd.run(args)
if __name__ == '__main__':
main()
import argparse
import os.path as osp
from sevenn import __version__
description = (
'tool box for sevennet checkpoints'
)
def add_parser(subparsers):
ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str)
group = ag.add_mutually_exclusive_group(required=False)
group.add_argument(
'--get_yaml',
choices=['reproduce', 'continue', 'continue_modal'],
help='create input.yaml based on the given checkpoint',
type=str,
)
group.add_argument(
'--append_modal_yaml',
help='append modality with given yaml.',
type=str,
)
ag.add_argument(
'--original_modal_name',
help=(
'when the append_modal is used and checkpoint is not multi-modal, '
+ 'used to name previously trained modality. defaults to "origin"'
),
default='origin',
type=str,
)
def run(args):
import torch
import yaml
from sevenn.parse_input import read_config_yaml
from sevenn.util import load_checkpoint
checkpoint = load_checkpoint(args.checkpoint)
if args.get_yaml:
mode = args.get_yaml
cfg = checkpoint.yaml_dict(mode)
print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False))
elif args.append_modal_yaml:
dst_yaml = args.append_modal_yaml
if not osp.exists(dst_yaml):
raise FileNotFoundError(f'No yaml file {dst_yaml}')
dst_config = read_config_yaml(dst_yaml, return_separately=False)
model_state_dict = checkpoint.append_modal(
dst_config, args.original_modal_name
)
to_save = checkpoint.get_checkpoint_dict()
to_save.update({'config': dst_config, 'model_state_dict': model_state_dict})
torch.save(to_save, 'checkpoint_modal_appended.pth')
print('checkpoint_modal_appended.pth is successfully saved.')
print(f'update continue of {dst_yaml} as blow (recommend) to continue')
cont_dct = {
'continue': {
'checkpoint': 'checkpoint_modal_appended.pth',
'reset_epoch': True,
'reset_optimizer': True,
'reset_scheduler': True,
}
}
print(
yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False)
)
else:
print(checkpoint)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
from sevenn import __version__
description_get_model = (
'deploy LAMMPS model from the checkpoint'
)
checkpoint_help = (
'path to the checkpoint | SevenNet-0 | 7net-0 |'
' {SevenNet-0|7net-0}_{11July2024|22May2024}'
)
output_name_help = 'filename prefix'
get_parallel_help = 'deploy parallel model'
def add_parser(subparsers):
ag = subparsers.add_parser(
'get_model', help=description_get_model, aliases=['deploy']
)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help=checkpoint_help, type=str)
ag.add_argument(
'-o', '--output_prefix', nargs='?', help=output_name_help, type=str
)
ag.add_argument(
'-p', '--get_parallel', help=get_parallel_help, action='store_true'
)
ag.add_argument(
'-m',
'--modal',
help='Modality of multi-modal model',
type=str,
)
def run(args):
import sevenn.util
from sevenn.scripts.deploy import deploy, deploy_parallel
checkpoint = args.checkpoint
output_prefix = args.output_prefix
get_parallel = args.get_parallel
get_serial = not get_parallel
modal = args.modal
if output_prefix is None:
output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial'
checkpoint_path = None
if os.path.isfile(checkpoint):
checkpoint_path = checkpoint
else:
checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint)
if get_serial:
deploy(checkpoint_path, output_prefix, modal)
else:
deploy_parallel(checkpoint_path, output_prefix, modal)
# legacy way
def main():
ag = argparse.ArgumentParser(description=description_get_model)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
from datetime import datetime
from sevenn import __version__
description = 'create `sevenn_data/dataset.pt` from ase readable'
source_help = 'source data to build graph, knows *'
cutoff_help = 'cutoff radius of edges in Angstrom'
filename_help = (
'Name of the dataset, default is graph.pt. '
+ 'The dataset will be written under "sevenn_data", '
+ 'for example, {out}/sevenn_data/graph.pt.'
)
legacy_help = 'build legacy .sevenn_data'
def add_parser(subparsers):
ag = subparsers.add_parser('graph_build', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('source', help=source_help, type=str)
ag.add_argument('cutoff', help=cutoff_help, type=float)
ag.add_argument(
'-n',
'--num_cores',
help='number of cores to build graph in parallel',
default=1,
type=int,
)
ag.add_argument(
'-o',
'--out',
help='Existing path to write outputs.',
type=str,
default='./',
)
ag.add_argument(
'-f',
'--filename',
help=filename_help,
type=str,
default='graph.pt',
)
ag.add_argument(
'--legacy',
help=legacy_help,
action='store_true',
)
ag.add_argument(
'-s',
'--screen',
help='print log to the screen',
action='store_true',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to ase.io.read, or can be used to specify EFS key',
)
def run(args):
import sevenn.scripts.graph_build as graph_build
from sevenn.logger import Logger
source = glob.glob(args.source)
cutoff = args.cutoff
num_cores = args.num_cores
filename = args.filename
out = args.out
legacy = args.legacy
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if len(source) == 0:
print('Source has zero len, nothing to read')
sys.exit(0)
if not os.path.isdir(out):
raise NotADirectoryError(f'No such directory: {out}')
to_be_written = os.path.join(out, 'sevenn_data', filename)
if os.path.isfile(to_be_written):
raise FileExistsError(f'File already exist: {to_be_written}')
metadata = {
'sevenn_version': __version__,
'when': datetime.now().strftime('%Y-%m-%d'),
'cutoff': cutoff,
}
with Logger(filename=None, screen=args.screen) as logger:
logger.writeline(description)
if not legacy:
graph_build.build_sevennet_graph_dataset(
source,
cutoff,
num_cores,
out,
filename,
metadata,
**fmt_kwargs,
)
else:
out = os.path.join(out, filename.split('.')[0])
graph_build.build_script( # build .sevenn_data
source,
cutoff,
num_cores,
out,
metadata,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
description = (
'evaluate sevenn_data/ase readable with a model (checkpoint).'
)
checkpoint_help = 'Checkpoint or pre-trained model name'
target_help = 'Target files to evaluate'
def add_parser(subparsers):
ag = subparsers.add_parser('inference', help=description, aliases=['inf'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', type=str, help=checkpoint_help)
ag.add_argument('targets', type=str, nargs='+', help=target_help)
ag.add_argument(
'-d',
'--device',
type=str,
default='auto',
help='cpu/cuda/cuda:x',
)
ag.add_argument(
'-nw',
'--nworkers',
type=int,
default=1,
help='Number of cores to build graph, defaults to 1',
)
ag.add_argument(
'-o',
'--output',
type=str,
default='./inference_results',
help='A directory name to write outputs',
)
ag.add_argument(
'-b',
'--batch',
type=int,
default='4',
help='batch size, useful for GPU'
)
ag.add_argument(
'-s',
'--save_graph',
action='store_true',
help='Additionally, save preprocessed graph as sevenn_data'
)
ag.add_argument(
'-au',
'--allow_unlabeled',
action='store_true',
help='Allow energy or force unlabeled data'
)
ag.add_argument(
'-m',
'--modal',
type=str,
default=None,
help='modality for multi-modal inference',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to reader, or can be used to specify EFS key',
)
def run(args):
import torch
from sevenn.scripts.inference import inference
from sevenn.util import pretrained_name_to_path
out = args.output
if os.path.exists(out):
raise FileExistsError(f'Directory {out} already exists')
device = args.device
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
targets = []
for target in args.targets:
targets.extend(glob.glob(target))
if len(targets) == 0:
print('No targets (data to inference) are found')
sys.exit(0)
cp = args.checkpoint
if not os.path.isfile(cp):
cp = pretrained_name_to_path(cp) # raises value error
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if args.save_graph and args.allow_unlabeled:
raise ValueError('save_graph and allow_unlabeled are mutually exclusive')
inference(
cp,
targets,
out,
args.nworkers,
device,
args.batch,
args.save_graph,
args.allow_unlabeled,
args.modal,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
import subprocess
from sevenn import __version__
# python wrapper of patch_lammps.sh script
# importlib.resources is correct way to do these things
# but it changes so frequently to use
pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn')
description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile'
def add_parser(subparsers):
ag = subparsers.add_parser('patch_lammps', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str)
ag.add_argument('--d3', help='Enable D3 support', action='store_true')
# cxx_standard is detected automatically
def run(args):
lammps_dir = os.path.abspath(args.lammps_dir)
print('Patching LAMMPS with the following settings:')
print(' - LAMMPS source directory:', lammps_dir)
cxx_standard = '17' # always 17
if args.d3:
d3_support = '1'
print(' - D3 support enabled')
else:
d3_support = '0'
print(' - D3 support disabled')
script = f'{pair_e3gnn_dir}/patch_lammps.sh'
cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}'
res = subprocess.run(cmd.split())
return res.returncode # is it meaningless?
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
if __name__ == '__main__':
main()
import argparse
import os
from sevenn import __version__
description = (
'print the selected preset for training. '
+ 'ex) sevennet_preset fine_tune > my_input.yaml'
)
preset_help = 'Name of preset'
def add_parser(subparsers):
ag = subparsers.add_parser('preset', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument(
'preset', choices=[
'fine_tune',
'fine_tune_le',
'sevennet-0',
'sevennet-l3i5',
'base',
'multi_modal'
],
help=preset_help
)
def run(args):
preset = args.preset
prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets')
with open(f'{prefix}/{preset}.yaml', 'r') as f:
print(f.read())
# When executed as sevenn_preset (legacy way)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment