Commit 9993ea25 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'refactor_utils' into 'master'

moved few common elements between bert and gpt to utils

See merge request ADLR/megatron-lm!4
parents c882ac61 787c1a0b
......@@ -22,6 +22,7 @@ import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.fp16 import FP16_Optimizer
from megatron import mpu
from megatron import model
......@@ -183,13 +184,67 @@ def report_memory(name):
torch.cuda.max_memory_cached()/ mega_bytes)
print_rank_0(string)
def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None):
def initialize_distributed(args):
"""Initialize torch.distributed."""
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
def wrap_model_for_distributed_training(model, args):
"""Wrap model for distributed training."""
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
args.DDP_type = torchDDP
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
return model
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
return model
else:
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def get_checkpoint_name(checkpoints_path, iteration, release=False,
mp_rank=None):
if release:
d = 'release'
else:
d = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, d,
'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank() if mp_rank is None else mp_rank),
'mp_rank_{:02d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None \
else mp_rank),
'model_optim_rng.pt')
......
......@@ -30,7 +30,6 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import BertModel
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron.model import DistributedDataParallel as LocalDDP
from megatron import mpu
from apex.optimizers import FusedAdam as Adam
from megatron.utils import Timers
......@@ -42,6 +41,10 @@ from megatron.utils import print_params_min_max_norm
from megatron.utils import print_rank_0
from megatron.utils import enable_adlr_autoresume
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import wrap_model_for_distributed_training
def get_model(args):
"""Build the model."""
......@@ -72,18 +75,7 @@ def get_model(args):
_module.float()
# Wrap model for distributed training.
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
else:
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
model = wrap_model_for_distributed_training(model, args)
return model
......@@ -474,38 +466,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
return val_loss
def initialize_distributed(args):
"""Initialize torch.distributed."""
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def get_train_val_test_data(args):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
......
......@@ -29,7 +29,6 @@ from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR
from megatron.model import GPT2Model
from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron.model import DistributedDataParallel as LocalDDP
from megatron import mpu
from apex.optimizers import FusedAdam as Adam
from megatron.utils import Timers
......@@ -41,6 +40,9 @@ from megatron.utils import print_params_min_max_norm
from megatron.utils import print_rank_0
from megatron.utils import enable_adlr_autoresume
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import wrap_model_for_distributed_training
from gpt2_data_loader import make_gpt2_dataloaders
......@@ -73,18 +75,7 @@ def get_model(args):
model = FP16_Module(model)
# Wrap model for distributed training.
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
else:
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
model = wrap_model_for_distributed_training(model, args)
return model
......@@ -500,38 +491,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
return lm_loss
def initialize_distributed(args):
"""Initialize torch.distributed."""
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def get_train_val_test_data(args):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment