"""Megatron initialization.""" import torch from datetime import timedelta from megatron.training import get_args from megatron.core import mpu def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): """Initialize torch.distributed and core model parallel.""" args = get_args() device_count = torch.cuda.device_count() if torch.distributed.is_initialized(): if args.rank == 0: print( "torch distributed is already initialized, " "skipping initialization ...", flush=True, ) args.rank = torch.distributed.get_rank() args.world_size = torch.distributed.get_world_size() else: if args.rank == 0: print("> initializing torch distributed ...", flush=True) # Manually set the device ids. if device_count > 0: torch.cuda.set_device(args.local_rank) device_id = torch.device(f'cuda:{args.local_rank}') else: device_id = None # Call the init process init_process_group_kwargs = { 'backend' : args.distributed_backend, 'world_size': args.world_size, 'rank': args.rank, 'init_method': args.dist_url, 'timeout': timedelta(minutes=args.distributed_timeout_minutes), } torch.distributed.init_process_group(**init_process_group_kwargs) # Set the tensor model-parallel, pipeline model-parallel, and # data-parallel communicators. if device_count > 0: if mpu.model_parallel_is_initialized(): print("model parallel is already initialized") else: mpu.initialize_model_parallel( args.tensor_model_parallel_size, args.pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size, args.pipeline_model_parallel_split_rank, context_parallel_size=args.context_parallel_size, hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes, expert_model_parallel_size=args.expert_model_parallel_size, num_distributed_optimizer_instances=args.num_distributed_optimizer_instances, expert_tensor_parallel_size=args.expert_tensor_parallel_size, distributed_timeout_minutes=args.distributed_timeout_minutes, nccl_communicator_config_path=args.nccl_communicator_config_path, order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-cp-ep-pp-dp', encoder_tensor_model_parallel_size=args.encoder_tensor_model_parallel_size, encoder_pipeline_model_parallel_size=args.encoder_pipeline_model_parallel_size, get_embedding_ranks=get_embedding_ranks, get_position_embedding_ranks=get_position_embedding_ranks, ) if args.rank == 0: print( f"> initialized tensor model parallel with size " f"{mpu.get_tensor_model_parallel_world_size()}" ) print( f"> initialized pipeline model parallel with size " f"{mpu.get_pipeline_model_parallel_world_size()}" )