import argparse import datetime import json import os import warnings import omegaconf import torch import torch.distributed from omegaconf import OmegaConf from sat import mpu from sat.arguments import (add_data_args, add_evaluation_args, add_training_args, set_random_seed) from sat.helpers import print_rank0 def add_model_config_args(parser): """Model arguments""" group = parser.add_argument_group('model', 'model configuration') group.add_argument('--base', type=str, nargs='*', help='config for input and saving') group.add_argument('--second', type=str, nargs='*', help='config for input and saving') group.add_argument( '--model-parallel-size', type=int, default=1, help='size of the model parallel. only use if you are an expert.') group.add_argument('--force-pretrain', action='store_true') group.add_argument('--device', type=int, default=-1) group.add_argument('--debug', action='store_true') group.add_argument('--log-image', type=bool, default=True) group.add_argument('--inf-ckpt', type=str, default=None) group.add_argument('--inf-ckpt2', type=str, default=None) group.add_argument('--skip-second', action='store_true') group.add_argument('--first-stage-re', type=int, default=270, help='resolution of first stage') return parser def add_sampling_config_args(parser): """Sampling configurations""" group = parser.add_argument_group('sampling', 'Sampling Configurations') group.add_argument('--output-dir', type=str, default='samples') group.add_argument('--input-dir', type=str, default=None) group.add_argument('--input-type', type=str, default='cli') group.add_argument('--input-file', type=str, default='./ht100.txt') group.add_argument('--final-size', type=int, default=2048) group.add_argument('--sdedit', action='store_true') group.add_argument('--grid-num-rows', type=int, default=1) group.add_argument('--force-inference', action='store_true') group.add_argument('--lcm_steps', type=int, default=None) group.add_argument('--sampling-num-frames', type=int, default=32) group.add_argument('--sampling-num-steps', type=int, default=30) group.add_argument('--sampling-fps', type=int, default=8) group.add_argument('--only-save-latents', type=bool, default=False) group.add_argument('--only-log-video-latents', type=bool, default=False) group.add_argument('--latent-channels', type=int, default=32) group.add_argument('--image2video', action='store_true') return parser def add_extra_config_args(parser): group = parser.add_argument_group('joint', 'joint training Configurations') group.add_argument('--img-iter', type=int, default=0) group.add_argument('--video-iter', type=int, default=0) return parser def get_args(args_list=None, parser=None): """Parse all the args.""" if parser is None: parser = argparse.ArgumentParser(description='sat') else: assert isinstance(parser, argparse.ArgumentParser) parser = add_model_config_args(parser) parser = add_sampling_config_args(parser) parser = add_training_args(parser) parser = add_evaluation_args(parser) parser = add_data_args(parser) parser = add_extra_config_args(parser) import deepspeed parser = deepspeed.add_config_arguments(parser) args = parser.parse_args(args_list) args = process_config_to_args(args) if not args.train_data: print_rank0('No training data specified', level='WARNING') assert (args.train_iters is None) or ( args.epochs is None), 'only one of train_iters and epochs should be set.' if args.train_iters is None and args.epochs is None: args.train_iters = 10000 # default 10k iters print_rank0( 'No train_iters (recommended) or epochs specified, use default 10k iters.', level='WARNING') args.cuda = torch.cuda.is_available() args.rank = int(os.getenv('RANK', '0')) args.world_size = int(os.getenv('WORLD_SIZE', '1')) if args.local_rank is None: args.local_rank = int(os.getenv('LOCAL_RANK', '0')) # torchrun if args.device == -1: if torch.cuda.device_count() == 0: args.device = 'cpu' elif args.local_rank is not None: args.device = args.local_rank else: args.device = args.rank % torch.cuda.device_count() if args.local_rank != args.device and args.mode != 'inference': raise ValueError( 'LOCAL_RANK (default 0) and args.device inconsistent. ' 'This can only happens in inference mode. ' 'Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. ') if args.rank == 0: print_rank0(f'using world size: {args.world_size}') if args.train_data_weights is not None: assert len(args.train_data_weights) == len(args.train_data) if args.mode != 'inference': # training with deepspeed args.deepspeed = True if args.deepspeed_config is None: # not specified deepspeed_config_path = os.path.join( os.path.dirname(__file__), 'training', f'deepspeed_zero{args.zero_stage}.json') with open(deepspeed_config_path) as file: args.deepspeed_config = json.load(file) override_deepspeed_config = True else: override_deepspeed_config = False assert not (args.fp16 and args.bf16), 'cannot specify both fp16 and bf16.' if args.zero_stage > 0 and not args.fp16 and not args.bf16: print_rank0('Automatically set fp16=True to use ZeRO.') args.fp16 = True args.bf16 = False if args.deepspeed: if args.checkpoint_activations: args.deepspeed_activation_checkpointing = True else: args.deepspeed_activation_checkpointing = False if args.deepspeed_config is not None: deepspeed_config = args.deepspeed_config if override_deepspeed_config: # not specify deepspeed_config, use args if args.fp16: deepspeed_config['fp16']['enabled'] = True elif args.bf16: deepspeed_config['bf16']['enabled'] = True deepspeed_config['fp16']['enabled'] = False else: deepspeed_config['fp16']['enabled'] = False deepspeed_config[ 'train_micro_batch_size_per_gpu'] = args.batch_size deepspeed_config[ 'gradient_accumulation_steps'] = args.gradient_accumulation_steps optimizer_params_config = deepspeed_config['optimizer']['params'] optimizer_params_config['lr'] = args.lr optimizer_params_config['weight_decay'] = args.weight_decay else: # override args with values in deepspeed_config if args.rank == 0: print_rank0( 'Will override arguments with manually specified deepspeed_config!' ) if 'fp16' in deepspeed_config and deepspeed_config['fp16'][ 'enabled']: args.fp16 = True else: args.fp16 = False if 'bf16' in deepspeed_config and deepspeed_config['bf16'][ 'enabled']: args.bf16 = True else: args.bf16 = False if 'train_micro_batch_size_per_gpu' in deepspeed_config: args.batch_size = deepspeed_config[ 'train_micro_batch_size_per_gpu'] if 'gradient_accumulation_steps' in deepspeed_config: args.gradient_accumulation_steps = deepspeed_config[ 'gradient_accumulation_steps'] else: args.gradient_accumulation_steps = None if 'optimizer' in deepspeed_config: optimizer_params_config = deepspeed_config['optimizer'].get( 'params', {}) args.lr = optimizer_params_config.get('lr', args.lr) args.weight_decay = optimizer_params_config.get( 'weight_decay', args.weight_decay) args.deepspeed_config = deepspeed_config # initialize distributed and random seed because it always seems to be necessary. initialize_distributed(args) args.seed = args.seed + mpu.get_data_parallel_rank() set_random_seed(args.seed) return args def initialize_distributed(args): """Initialize torch.distributed.""" if torch.distributed.is_initialized(): if mpu.model_parallel_is_initialized(): if args.model_parallel_size != mpu.get_model_parallel_world_size(): raise ValueError( 'model_parallel_size is inconsistent with prior configuration.' 'We currently do not support changing model_parallel_size.' ) return False else: if args.model_parallel_size > 1: warnings.warn( 'model_parallel_size > 1 but torch.distributed is not initialized via SAT.' 'Please carefully make sure the correctness on your own.') mpu.initialize_model_parallel(args.model_parallel_size) return True # the automatic assignment of devices has been moved to arguments.py if args.device == 'cpu': pass else: torch.cuda.set_device(args.device) # Call the init process init_method = 'tcp://' args.master_ip = os.getenv('MASTER_ADDR', 'localhost') if args.world_size == 1: from sat.helpers import get_free_port default_master_port = str(get_free_port()) else: default_master_port = '6000' args.master_port = os.getenv('MASTER_PORT', default_master_port) init_method = None #init_method += args.master_ip + ":" + args.master_port torch.distributed.init_process_group( backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method, timeout=datetime.timedelta(seconds=1200)) # Set the model-parallel / data-parallel communicators. mpu.initialize_model_parallel(args.model_parallel_size) # Set vae context parallel group equal to model parallel group from sgm.util import (initialize_context_parallel, set_context_parallel_group) if args.model_parallel_size <= 2: set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group()) else: initialize_context_parallel(2) # mpu.initialize_model_parallel(1) # Optional DeepSpeed Activation Checkpointing Features if args.deepspeed: import deepspeed deepspeed.init_distributed(dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method) # # It seems that it has no negative influence to configure it even without using checkpointing. # deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers) else: # in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout. try: import deepspeed from deepspeed.runtime.activation_checkpointing.checkpointing import ( _CUDA_RNG_STATE_TRACKER, _MODEL_PARALLEL_RNG_TRACKER_NAME) _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1 except Exception as e: from sat.helpers import print_rank0 print_rank0(str(e), level='DEBUG') return True def process_config_to_args(args): """Fetch args from only --base""" configs = [OmegaConf.load(cfg) for cfg in args.base] config = OmegaConf.merge(*configs) args_config = config.pop('args', OmegaConf.create()) for key in args_config: if isinstance(args_config[key], omegaconf.DictConfig) or isinstance( args_config[key], omegaconf.ListConfig): arg = OmegaConf.to_object(args_config[key]) else: arg = args_config[key] if hasattr(args, key): setattr(args, key, arg) if 'model' in config: model_config = config.pop('model', OmegaConf.create()) args.model_config = model_config if 'deepspeed' in config: deepspeed_config = config.pop('deepspeed', OmegaConf.create()) args.deepspeed_config = OmegaConf.to_object(deepspeed_config) if 'data' in config: data_config = config.pop('data', OmegaConf.create()) args.data_config = data_config if 'img_data' in config: img_data_config = config.pop('img_data', OmegaConf.create()) args.img_data_config = img_data_config if 'video_data' in config: video_data_config = config.pop('video_data', OmegaConf.create()) args.video_data_config = video_data_config if 'trainable_params' in config: trainable_params = config.pop('trainable_params', OmegaConf.create()) args.trainable_params_config = trainable_params if 'share_cache_args' in config: share_cache_args = config.pop('share_cache_args', OmegaConf.create()) args.share_cache_config = share_cache_args if 'custom_args' in config: custom_args = config.pop('custom_args', OmegaConf.create()) for k, v in custom_args.items(): setattr(args, k, v) return args