"dockerfile/cuda12.4.dockerfile" did not exist on "115cd2e6ae9d7fc93bb03fb5bbdb27a731a628d2"
Unverified Commit ca86f720 authored by zcxzcx1's avatar zcxzcx1 Committed by GitHub
Browse files

Add files via upload

parent b75ed73c
import argparse
import os
import sys
import time
from sevenn import __version__
description = 'train a model given the input.yaml'
input_yaml_help = 'input.yaml for training'
mode_help = 'main training script to run. Default is train.'
working_dir_help = 'path to write output. Default is cwd.'
screen_help = 'print log to stdout'
distributed_help = 'set this flag if it is distributed training'
distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi'
# Metainfo will be saved to checkpoint
global_config = {
'version': __version__,
'when': time.ctime(),
'_model_type': 'E3_equivariant_model',
}
def run(args):
"""
main function of sevenn
"""
import random
import sys
import torch
import torch.distributed as dist
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.parse_input import read_config_yaml
from sevenn.scripts.train import train, train_v2
from sevenn.util import unique_filepath
input_yaml = args.input_yaml
mode = args.mode
working_dir = args.working_dir
log = args.log
screen = args.screen
distributed = args.distributed
distributed_backend = args.distributed_backend
use_cue = args.enable_cueq
if use_cue:
import sevenn.nn.cue_helper
if not sevenn.nn.cue_helper.is_cue_available():
raise ImportError('cuEquivariance not installed.')
if working_dir is None:
working_dir = os.getcwd()
elif not os.path.isdir(working_dir):
os.makedirs(working_dir, exist_ok=True)
world_size = 1
if distributed:
if distributed_backend == 'nccl':
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
elif distributed_backend == 'mpi':
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
else:
raise ValueError(f'Unknown distributed backend: {distributed_backend}')
dist.init_process_group(
backend=distributed_backend, world_size=world_size, rank=rank
)
else:
local_rank, rank, world_size = 0, 0, 1
log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}')
with Logger(filename=log_fname, screen=screen, rank=rank) as logger:
logger.greeting()
if distributed:
logger.writeline(
f'Distributed training enabled, total world size is {world_size}'
)
try:
model_config, train_config, data_config = read_config_yaml(
input_yaml, return_separately=True
)
except Exception as e:
logger.writeline('Failed to parsing input.yaml')
logger.error(e)
sys.exit(1)
train_config[KEY.IS_DDP] = distributed
train_config[KEY.DDP_BACKEND] = distributed_backend
train_config[KEY.LOCAL_RANK] = local_rank
train_config[KEY.RANK] = rank
train_config[KEY.WORLD_SIZE] = world_size
if distributed:
torch.cuda.set_device(torch.device('cuda', local_rank))
if use_cue:
if KEY.CUEQUIVARIANCE_CONFIG not in model_config:
model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True}
else:
model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True})
logger.print_config(model_config, data_config, train_config)
# don't have to distinguish configs inside program
global_config.update(model_config)
global_config.update(train_config)
global_config.update(data_config)
# Not implemented
if global_config[KEY.DTYPE] == 'double':
raise Exception('double precision is not implemented yet')
# torch.set_default_dtype(torch.double)
seed = global_config[KEY.RANDOM_SEED]
random.seed(seed)
torch.manual_seed(seed)
# run train
if mode == 'train_v1':
train(global_config, working_dir)
elif mode == 'train_v2':
train_v2(global_config, working_dir)
def cmd_parser_train(parser):
ag = parser
ag.add_argument('input_yaml', help=input_yaml_help, type=str)
ag.add_argument(
'-m',
'--mode',
choices=['train_v1', 'train_v2'],
default='train_v2',
help=mode_help,
type=str,
)
ag.add_argument(
'-cueq',
'--enable_cueq',
help='(Not stable!) use cuEquivariance for training',
action='store_true',
)
ag.add_argument(
'-w',
'--working_dir',
nargs='?',
const=os.getcwd(),
help=working_dir_help,
type=str,
)
ag.add_argument(
'-l',
'--log',
default='log.sevenn',
help='name of logfile, default is log.sevenn',
type=str,
)
ag.add_argument('-s', '--screen', help=screen_help, action='store_true')
ag.add_argument(
'-d', '--distributed', help=distributed_help, action='store_true'
)
ag.add_argument(
'--distributed_backend',
help=distributed_backend_help,
type=str,
default='nccl',
choices=['nccl', 'mpi'],
)
def add_parser(subparsers):
ag = subparsers.add_parser('train', help=description)
cmd_parser_train(ag)
def set_default_subparser(self, name, args=None, positional_args=0):
"""default subparser selection. Call after setup, just before parse_args()
name: is the name of the subparser to call by default
args: if set is the argument list handed to parse_args()
Hack copied from stack overflow
"""
subparser_found = False
for arg in sys.argv[1:]:
if arg in ['-h', '--help']: # global help if no subparser
break
else:
for x in self._subparsers._actions:
if not isinstance(x, argparse._SubParsersAction):
continue
for sp_name in x._name_parser_map.keys():
if sp_name in sys.argv[1:]:
subparser_found = True
if not subparser_found:
# insert default in last position before global positional
# arguments, this implies no global options are specified after
# first positional argument
if args is None:
sys.argv.insert(len(sys.argv) - positional_args, name)
else:
args.insert(len(args) - positional_args, name)
argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore
def main():
import sevenn.main.sevenn_cp as checkpoint_cmd
import sevenn.main.sevenn_get_model as get_model_cmd
import sevenn.main.sevenn_graph_build as graph_build_cmd
import sevenn.main.sevenn_inference as inference_cmd
import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd
import sevenn.main.sevenn_preset as preset_cmd
ag = argparse.ArgumentParser(f'SevenNet version={__version__}')
subparsers = ag.add_subparsers(dest='command', help='Sub-commands')
add_parser(subparsers) # add 'train'
checkpoint_cmd.add_parser(subparsers)
inference_cmd.add_parser(subparsers)
graph_build_cmd.add_parser(subparsers)
preset_cmd.add_parser(subparsers)
get_model_cmd.add_parser(subparsers)
patch_lammps_cmd.add_parser(subparsers)
ag.set_default_subparser('train') # type: ignore
args = ag.parse_args()
if args.command is None: # backward compatibility
args.command = 'train'
if args.command == 'train':
run(args)
elif args.command == 'preset':
preset_cmd.run(args)
if __name__ == '__main__':
main()
import argparse
import os.path as osp
from sevenn import __version__
description = (
'tool box for sevennet checkpoints'
)
def add_parser(subparsers):
ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str)
group = ag.add_mutually_exclusive_group(required=False)
group.add_argument(
'--get_yaml',
choices=['reproduce', 'continue', 'continue_modal'],
help='create input.yaml based on the given checkpoint',
type=str,
)
group.add_argument(
'--append_modal_yaml',
help='append modality with given yaml.',
type=str,
)
ag.add_argument(
'--original_modal_name',
help=(
'when the append_modal is used and checkpoint is not multi-modal, '
+ 'used to name previously trained modality. defaults to "origin"'
),
default='origin',
type=str,
)
def run(args):
import torch
import yaml
from sevenn.parse_input import read_config_yaml
from sevenn.util import load_checkpoint
checkpoint = load_checkpoint(args.checkpoint)
if args.get_yaml:
mode = args.get_yaml
cfg = checkpoint.yaml_dict(mode)
print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False))
elif args.append_modal_yaml:
dst_yaml = args.append_modal_yaml
if not osp.exists(dst_yaml):
raise FileNotFoundError(f'No yaml file {dst_yaml}')
dst_config = read_config_yaml(dst_yaml, return_separately=False)
model_state_dict = checkpoint.append_modal(
dst_config, args.original_modal_name
)
to_save = checkpoint.get_checkpoint_dict()
to_save.update({'config': dst_config, 'model_state_dict': model_state_dict})
torch.save(to_save, 'checkpoint_modal_appended.pth')
print('checkpoint_modal_appended.pth is successfully saved.')
print(f'update continue of {dst_yaml} as blow (recommend) to continue')
cont_dct = {
'continue': {
'checkpoint': 'checkpoint_modal_appended.pth',
'reset_epoch': True,
'reset_optimizer': True,
'reset_scheduler': True,
}
}
print(
yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False)
)
else:
print(checkpoint)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
from sevenn import __version__
description_get_model = (
'deploy LAMMPS model from the checkpoint'
)
checkpoint_help = (
'path to the checkpoint | SevenNet-0 | 7net-0 |'
' {SevenNet-0|7net-0}_{11July2024|22May2024}'
)
output_name_help = 'filename prefix'
get_parallel_help = 'deploy parallel model'
def add_parser(subparsers):
ag = subparsers.add_parser(
'get_model', help=description_get_model, aliases=['deploy']
)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', help=checkpoint_help, type=str)
ag.add_argument(
'-o', '--output_prefix', nargs='?', help=output_name_help, type=str
)
ag.add_argument(
'-p', '--get_parallel', help=get_parallel_help, action='store_true'
)
ag.add_argument(
'-m',
'--modal',
help='Modality of multi-modal model',
type=str,
)
def run(args):
import sevenn.util
from sevenn.scripts.deploy import deploy, deploy_parallel
checkpoint = args.checkpoint
output_prefix = args.output_prefix
get_parallel = args.get_parallel
get_serial = not get_parallel
modal = args.modal
if output_prefix is None:
output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial'
checkpoint_path = None
if os.path.isfile(checkpoint):
checkpoint_path = checkpoint
else:
checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint)
if get_serial:
deploy(checkpoint_path, output_prefix, modal)
else:
deploy_parallel(checkpoint_path, output_prefix, modal)
# legacy way
def main():
ag = argparse.ArgumentParser(description=description_get_model)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
from datetime import datetime
from sevenn import __version__
description = 'create `sevenn_data/dataset.pt` from ase readable'
source_help = 'source data to build graph, knows *'
cutoff_help = 'cutoff radius of edges in Angstrom'
filename_help = (
'Name of the dataset, default is graph.pt. '
+ 'The dataset will be written under "sevenn_data", '
+ 'for example, {out}/sevenn_data/graph.pt.'
)
legacy_help = 'build legacy .sevenn_data'
def add_parser(subparsers):
ag = subparsers.add_parser('graph_build', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('source', help=source_help, type=str)
ag.add_argument('cutoff', help=cutoff_help, type=float)
ag.add_argument(
'-n',
'--num_cores',
help='number of cores to build graph in parallel',
default=1,
type=int,
)
ag.add_argument(
'-o',
'--out',
help='Existing path to write outputs.',
type=str,
default='./',
)
ag.add_argument(
'-f',
'--filename',
help=filename_help,
type=str,
default='graph.pt',
)
ag.add_argument(
'--legacy',
help=legacy_help,
action='store_true',
)
ag.add_argument(
'-s',
'--screen',
help='print log to the screen',
action='store_true',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to ase.io.read, or can be used to specify EFS key',
)
def run(args):
import sevenn.scripts.graph_build as graph_build
from sevenn.logger import Logger
source = glob.glob(args.source)
cutoff = args.cutoff
num_cores = args.num_cores
filename = args.filename
out = args.out
legacy = args.legacy
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if len(source) == 0:
print('Source has zero len, nothing to read')
sys.exit(0)
if not os.path.isdir(out):
raise NotADirectoryError(f'No such directory: {out}')
to_be_written = os.path.join(out, 'sevenn_data', filename)
if os.path.isfile(to_be_written):
raise FileExistsError(f'File already exist: {to_be_written}')
metadata = {
'sevenn_version': __version__,
'when': datetime.now().strftime('%Y-%m-%d'),
'cutoff': cutoff,
}
with Logger(filename=None, screen=args.screen) as logger:
logger.writeline(description)
if not legacy:
graph_build.build_sevennet_graph_dataset(
source,
cutoff,
num_cores,
out,
filename,
metadata,
**fmt_kwargs,
)
else:
out = os.path.join(out, filename.split('.')[0])
graph_build.build_script( # build .sevenn_data
source,
cutoff,
num_cores,
out,
metadata,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import glob
import os
import sys
description = (
'evaluate sevenn_data/ase readable with a model (checkpoint).'
)
checkpoint_help = 'Checkpoint or pre-trained model name'
target_help = 'Target files to evaluate'
def add_parser(subparsers):
ag = subparsers.add_parser('inference', help=description, aliases=['inf'])
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('checkpoint', type=str, help=checkpoint_help)
ag.add_argument('targets', type=str, nargs='+', help=target_help)
ag.add_argument(
'-d',
'--device',
type=str,
default='auto',
help='cpu/cuda/cuda:x',
)
ag.add_argument(
'-nw',
'--nworkers',
type=int,
default=1,
help='Number of cores to build graph, defaults to 1',
)
ag.add_argument(
'-o',
'--output',
type=str,
default='./inference_results',
help='A directory name to write outputs',
)
ag.add_argument(
'-b',
'--batch',
type=int,
default='4',
help='batch size, useful for GPU'
)
ag.add_argument(
'-s',
'--save_graph',
action='store_true',
help='Additionally, save preprocessed graph as sevenn_data'
)
ag.add_argument(
'-au',
'--allow_unlabeled',
action='store_true',
help='Allow energy or force unlabeled data'
)
ag.add_argument(
'-m',
'--modal',
type=str,
default=None,
help='modality for multi-modal inference',
)
ag.add_argument(
'--kwargs',
nargs=argparse.REMAINDER,
help='will be passed to reader, or can be used to specify EFS key',
)
def run(args):
import torch
from sevenn.scripts.inference import inference
from sevenn.util import pretrained_name_to_path
out = args.output
if os.path.exists(out):
raise FileExistsError(f'Directory {out} already exists')
device = args.device
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
targets = []
for target in args.targets:
targets.extend(glob.glob(target))
if len(targets) == 0:
print('No targets (data to inference) are found')
sys.exit(0)
cp = args.checkpoint
if not os.path.isfile(cp):
cp = pretrained_name_to_path(cp) # raises value error
fmt_kwargs = {}
if args.kwargs:
for kwarg in args.kwargs:
k, v = kwarg.split('=')
fmt_kwargs[k] = v
if args.save_graph and args.allow_unlabeled:
raise ValueError('save_graph and allow_unlabeled are mutually exclusive')
inference(
cp,
targets,
out,
args.nworkers,
device,
args.batch,
args.save_graph,
args.allow_unlabeled,
args.modal,
**fmt_kwargs,
)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
import argparse
import os
import subprocess
from sevenn import __version__
# python wrapper of patch_lammps.sh script
# importlib.resources is correct way to do these things
# but it changes so frequently to use
pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn')
description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile'
def add_parser(subparsers):
ag = subparsers.add_parser('patch_lammps', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str)
ag.add_argument('--d3', help='Enable D3 support', action='store_true')
# cxx_standard is detected automatically
def run(args):
lammps_dir = os.path.abspath(args.lammps_dir)
print('Patching LAMMPS with the following settings:')
print(' - LAMMPS source directory:', lammps_dir)
cxx_standard = '17' # always 17
if args.d3:
d3_support = '1'
print(' - D3 support enabled')
else:
d3_support = '0'
print(' - D3 support disabled')
script = f'{pair_e3gnn_dir}/patch_lammps.sh'
cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}'
res = subprocess.run(cmd.split())
return res.returncode # is it meaningless?
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
if __name__ == '__main__':
main()
import argparse
import os
from sevenn import __version__
description = (
'print the selected preset for training. '
+ 'ex) sevennet_preset fine_tune > my_input.yaml'
)
preset_help = 'Name of preset'
def add_parser(subparsers):
ag = subparsers.add_parser('preset', help=description)
add_args(ag)
def add_args(parser):
ag = parser
ag.add_argument(
'preset', choices=[
'fine_tune',
'fine_tune_le',
'sevennet-0',
'sevennet-l3i5',
'base',
'multi_modal'
],
help=preset_help
)
def run(args):
preset = args.preset
prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets')
with open(f'{prefix}/{preset}.yaml', 'r') as f:
print(f.read())
# When executed as sevenn_preset (legacy way)
def main(args=None):
ag = argparse.ArgumentParser(description=description)
add_args(ag)
run(ag.parse_args())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment