Commit 9993ea25 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'refactor_utils' into 'master'

moved few common elements between bert and gpt to utils

See merge request ADLR/megatron-lm!4
parents c882ac61 787c1a0b
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
from megatron import mpu from megatron import mpu
from megatron import model from megatron import model
...@@ -183,13 +184,67 @@ def report_memory(name): ...@@ -183,13 +184,67 @@ def report_memory(name):
torch.cuda.max_memory_cached()/ mega_bytes) torch.cuda.max_memory_cached()/ mega_bytes)
print_rank_0(string) print_rank_0(string)
def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None):
def initialize_distributed(args):
"""Initialize torch.distributed."""
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
def wrap_model_for_distributed_training(model, args):
"""Wrap model for distributed training."""
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
args.DDP_type = torchDDP
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
return model
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
return model
else:
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def get_checkpoint_name(checkpoints_path, iteration, release=False,
mp_rank=None):
if release: if release:
d = 'release' d = 'release'
else: else:
d = 'iter_{:07d}'.format(iteration) d = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, d, return os.path.join(checkpoints_path, d,
'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank() if mp_rank is None else mp_rank), 'mp_rank_{:02d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None \
else mp_rank),
'model_optim_rng.pt') 'model_optim_rng.pt')
......
...@@ -30,7 +30,6 @@ from megatron.learning_rates import AnnealingLR ...@@ -30,7 +30,6 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import BertModel from megatron.model import BertModel
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.model import gpt2_get_params_for_weight_decay_optimization from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron.model import DistributedDataParallel as LocalDDP
from megatron import mpu from megatron import mpu
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from megatron.utils import Timers from megatron.utils import Timers
...@@ -42,6 +41,10 @@ from megatron.utils import print_params_min_max_norm ...@@ -42,6 +41,10 @@ from megatron.utils import print_params_min_max_norm
from megatron.utils import print_rank_0 from megatron.utils import print_rank_0
from megatron.utils import enable_adlr_autoresume from megatron.utils import enable_adlr_autoresume
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import wrap_model_for_distributed_training
def get_model(args): def get_model(args):
"""Build the model.""" """Build the model."""
...@@ -72,18 +75,7 @@ def get_model(args): ...@@ -72,18 +75,7 @@ def get_model(args):
_module.float() _module.float()
# Wrap model for distributed training. # Wrap model for distributed training.
if args.DDP_impl == 'torch': model = wrap_model_for_distributed_training(model, args)
i = torch.cuda.current_device()
args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
else:
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
return model return model
...@@ -474,38 +466,6 @@ def evaluate_and_print_results(prefix, data_iterator, model, ...@@ -474,38 +466,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
return val_loss return val_loss
def initialize_distributed(args):
"""Initialize torch.distributed."""
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def get_train_val_test_data(args): def get_train_val_test_data(args):
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Load the data on rank zero and boradcast number of tokens to all GPUS."""
......
...@@ -29,7 +29,6 @@ from megatron.fp16 import FP16_Optimizer ...@@ -29,7 +29,6 @@ from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.model import gpt2_get_params_for_weight_decay_optimization from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron.model import DistributedDataParallel as LocalDDP
from megatron import mpu from megatron import mpu
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from megatron.utils import Timers from megatron.utils import Timers
...@@ -41,6 +40,9 @@ from megatron.utils import print_params_min_max_norm ...@@ -41,6 +40,9 @@ from megatron.utils import print_params_min_max_norm
from megatron.utils import print_rank_0 from megatron.utils import print_rank_0
from megatron.utils import enable_adlr_autoresume from megatron.utils import enable_adlr_autoresume
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import wrap_model_for_distributed_training
from gpt2_data_loader import make_gpt2_dataloaders from gpt2_data_loader import make_gpt2_dataloaders
...@@ -73,18 +75,7 @@ def get_model(args): ...@@ -73,18 +75,7 @@ def get_model(args):
model = FP16_Module(model) model = FP16_Module(model)
# Wrap model for distributed training. # Wrap model for distributed training.
if args.DDP_impl == 'torch': model = wrap_model_for_distributed_training(model, args)
i = torch.cuda.current_device()
args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
else:
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
return model return model
...@@ -500,38 +491,6 @@ def evaluate_and_print_results(prefix, data_iterator, model, ...@@ -500,38 +491,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
return lm_loss return lm_loss
def initialize_distributed(args):
"""Initialize torch.distributed."""
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def get_train_val_test_data(args): def get_train_val_test_data(args):
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Load the data on rank zero and boradcast number of tokens to all GPUS."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment