import argparse import itertools import os import re import yaml SBATCH_TEMPLATE = ''' srun --container-image nvcr.io/nvidia/pytorch:24.01-py3 \\ --container-mounts "{}:{},{}:/workspace/megatron-lm" \\ bash -c \" \n{} \" ''' def eval_name(**globals): name_template = globals['name'] to_eval = re.findall("{.*?}", name_template) to_eval = [x.strip('{}') for x in to_eval] str_to_format = re.sub("{.*?}", '{}', name_template) format_contents = [eval(x, globals) for x in to_eval] return str_to_format.format(*format_contents) def save_script(save_dir, format, sbatch_dataset_path, sbatch_mlm_path, **globals): script = globals['script'] globals['name'] = eval_name(**globals) globals['key'] = "basic/" + globals['name'].lower().replace('_', '-') globals['assets_dir'] = f"/assets/{globals['key']}" if format == 'sbatch' and globals['extra_args'] is not None: globals['extra_args'] = globals['extra_args'].replace('"', "'") # gather and evaluate all substitutions marked by braces in script in order of ocurrence to_eval = re.findall("{.*}", script) to_eval = [x.strip('{}') for x in to_eval] str_to_format = re.sub("{.*}", '{}', script) format_contents = [eval(x, globals) for x in to_eval] file_content = str_to_format.format(*format_contents) if not os.path.exists(save_dir): os.mkdir(save_dir) with open(os.path.join(save_dir, globals['name']+".sh"), 'w') as f: f.write("#!/bin/bash\n") if format == 'sbatch': dataset_mount = list(globals['artifacts'].keys())[0] if 'artifacts' in globals else "/path/to/mount/dataset" sbatch_content = SBATCH_TEMPLATE.format(sbatch_dataset_path, dataset_mount, sbatch_mlm_path, file_content) f.write(sbatch_content) else: f.write(file_content) def main(src_yaml, save_dir, format, sbatch_dataset_path, sbatch_mlm_path): # load yaml with open(src_yaml, 'r') as f: raw_content = yaml.safe_load(f) spec_template = raw_content['spec'] for prod in raw_content['products']: config = spec_template.copy() # expand cartesian products into list of all config overrides for replace in itertools.product(*prod.values()): # update config dict with overrides from products config.update({k: v for k, v in zip(prod.keys(), replace)}) save_script(save_dir, format, sbatch_dataset_path, sbatch_mlm_path, **config) if __name__ == "__main__": parser = argparse.ArgumentParser( prog='Functional tests script generator', description="""Generates bash or sbatch scripts from yamls in this directory to run functional tests locally""") parser.add_argument('src_yaml', help="Yaml file in this directory from which to generate test scripts") parser.add_argument('--save_dir', required=False, default='./scripts', help='Directory where scripts will be saved to. Defaults to ./scripts') parser.add_argument('--format', required=False, default='bash', choices=['bash', 'sbatch'], help="Script format") parser.add_argument('--sbatch-dataset-path', required=False, default='/path/to/dataset') parser.add_argument('--sbatch-megatronlm-path', required=False, default='/path/to/megatron-lm') args = parser.parse_args() main(args.src_yaml, args.save_dir, args.format, args.sbatch_dataset_path, args.sbatch_megatronlm_path)