Commit 0a74547c authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'parallel-nemo' into 'main'

Changes for NeMo/lightning compatibility

See merge request ADLR/megatron-lm!104
parents bdc52c19 4e8ace9b
...@@ -30,7 +30,7 @@ from .global_vars import get_tokenizer ...@@ -30,7 +30,7 @@ from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume from .global_vars import get_adlr_autoresume
from .global_vars import get_timers from .global_vars import get_timers
from .initialize import initialize_megatron
def print_rank_0(message): def print_rank_0(message):
"""If distributed is initialized print only on rank 0.""" """If distributed is initialized print only on rank 0."""
...@@ -38,4 +38,4 @@ def print_rank_0(message): ...@@ -38,4 +38,4 @@ def print_rank_0(message):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print(message, flush=True) print(message, flush=True)
else: else:
print(message, flush=True) print(message, flush=True)
\ No newline at end of file
...@@ -322,7 +322,10 @@ def _add_distributed_args(parser): ...@@ -322,7 +322,10 @@ def _add_distributed_args(parser):
'to use.') 'to use.')
group.add_argument('--local_rank', type=int, default=None, group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher.') help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False,
help='If set to True, initialize_megatron() skips DDP initialization'
' and returns function to complete it instead'
'This is for external DDP manager.' )
return parser return parser
......
...@@ -26,7 +26,7 @@ from megatron import get_args ...@@ -26,7 +26,7 @@ from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.mpu import set_model_parallel_rank, set_model_parallel_world_size
def initialize_megatron(extra_args_provider=None, args_defaults={}, def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False): ignore_unknown_args=False, allow_no_cuda=False):
...@@ -34,38 +34,52 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -34,38 +34,52 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
set autoresume and random seeds. set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only `allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know data processing. In general this arg should not be set unless you know
what you are doing.""" what you are doing.
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
if not allow_no_cuda: if not allow_no_cuda:
# Make sure cuda is available. # Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.' assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# This is temporary WAR to make simple case like pytest calling with same args twice
# Need to implement clean factory init.
if mpu.model_parallel_is_initialized():
return
# Parse args, build tokenizer, and set adlr-autoresume, # Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers. # tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider, set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults, args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args) ignore_unknown_args=ignore_unknown_args)
# Pytorch distributed. # torch.distributed initialization
_initialize_distributed() def finish_mpu_init():
args = get_args()
# Autoresume. # Pytorch distributed.
_init_autoresume() _initialize_distributed()
# Random seeds for reproducibility.
if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed)
# Random seeds for reproducibility.
args = get_args() args = get_args()
if args.rank == 0: if args.lazy_mpu_init:
print('> setting random seeds to {} ...'.format(args.seed)) # delayed initialization of DDP-related stuff
_set_random_seed(args.seed) # We only set basic DDP globals
set_model_parallel_world_size(args.model_parallel_size)
# Write arguments to tensorboard. # and return function for external DDP manager to call when it has DDP initialized
_write_args_to_tensorboard() set_model_parallel_rank(args.rank)
return finish_mpu_init
else:
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init()
# Autoresume.
_init_autoresume()
# Write arguments to tensorboard.
_write_args_to_tensorboard()
# No continuation function
return None
def _initialize_distributed(): def _initialize_distributed():
"""Initialize torch.distributed and mpu.""" """Initialize torch.distributed and mpu."""
...@@ -79,11 +93,6 @@ def _initialize_distributed(): ...@@ -79,11 +93,6 @@ def _initialize_distributed():
'skipping initialization ...', flush=True) 'skipping initialization ...', flush=True)
args.rank = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
if device_count > 0:
device = torch.cuda.current_device()
local_rank = args.rank % device_count
assert local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
else: else:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from .distributed import * from .distributed import *
from .bert_model import BertModel from .bert_model import BertModel
from megatron.model.realm_model import ICTBertModel from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model
...@@ -24,7 +24,7 @@ from megatron.module import MegatronModule ...@@ -24,7 +24,7 @@ from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None): bias=None):
...@@ -44,7 +44,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -44,7 +44,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method, scaled_init_method): init_method=None, scaled_init_method=None):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args() args = get_args()
...@@ -55,6 +55,12 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -55,6 +55,12 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
elif args.onnx_safe: elif args.onnx_safe:
gelu = erf_gelu gelu = erf_gelu
if init_method is None:
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
# Language model. # Language model.
language_model = TransformerLanguageModel( language_model = TransformerLanguageModel(
attention_mask_func=attention_mask_func, attention_mask_func=attention_mask_func,
......
...@@ -27,9 +27,9 @@ from .initialize import get_data_parallel_group ...@@ -27,9 +27,9 @@ from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size from .initialize import get_data_parallel_world_size
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank, set_model_parallel_rank
from .initialize import get_model_parallel_src_rank from .initialize import get_model_parallel_src_rank
from .initialize import get_model_parallel_world_size from .initialize import get_model_parallel_world_size, set_model_parallel_world_size
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
......
...@@ -15,20 +15,19 @@ ...@@ -15,20 +15,19 @@
import torch import torch
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
def _reduce(input_): def _reduce(input_):
"""All-reduce the the input tensor across model parallel group.""" """All-reduce the the input tensor across model parallel group."""
group = get_model_parallel_group()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1: if get_model_parallel_world_size()==1:
return input_ return input_
# All-reduce. # All-reduce.
torch.distributed.all_reduce(input_, group=group) torch.distributed.all_reduce(input_, group=get_model_parallel_group())
return input_ return input_
...@@ -36,18 +35,17 @@ def _reduce(input_): ...@@ -36,18 +35,17 @@ def _reduce(input_):
def _split(input_): def _split(input_):
"""Split the tensor along its last dimension and keep the """Split the tensor along its last dimension and keep the
corresponding slice.""" corresponding slice."""
group = get_model_parallel_group()
world_size = get_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1: if world_size==1:
return input_ return input_
# Split along last dimension. # Split along last dimension.
world_size = torch.distributed.get_world_size(group=group)
input_list = split_tensor_along_last_dim(input_, world_size) input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default. # Note: torch.split does not create contiguous tensors by default.
rank = torch.distributed.get_rank(group=group) rank = get_model_parallel_rank()
output = input_list[rank].contiguous() output = input_list[rank].contiguous()
return output return output
...@@ -55,16 +53,15 @@ def _split(input_): ...@@ -55,16 +53,15 @@ def _split(input_):
def _gather(input_): def _gather(input_):
"""Gather tensors and concatinate along the last dimension.""" """Gather tensors and concatinate along the last dimension."""
group = get_model_parallel_group()
world_size = get_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1: if world_size==1:
return input_ return input_
# Size and dimension. # Size and dimension.
last_dim = input_.dim() - 1 last_dim = input_.dim() - 1
rank = torch.distributed.get_rank(group=group) rank = get_model_parallel_rank()
world_size = torch.distributed.get_world_size(group=group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_ tensor_list[rank] = input_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment