Commit 417c7f6a authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Changes for NeMo/lightning compatibility


Signed-off-by: default avatarBoris Fomitchev <bfomitchev@nvidia.com>
parent 07ebf714
...@@ -39,12 +39,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -39,12 +39,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Make sure cuda is available. # Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.' assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# This is temporary WAR to make simple case like pytest calling with same args twice
# Need to implement clean factory init.
if mpu.model_parallel_is_initialized():
return
# Parse args, build tokenizer, and set adlr-autoresume, # Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers. # tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider, set_global_variables(extra_args_provider=extra_args_provider,
......
...@@ -88,13 +88,16 @@ def model_parallel_is_initialized(): ...@@ -88,13 +88,16 @@ def model_parallel_is_initialized():
return False return False
return True return True
def get_model_parallel_group(): def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to.""" """Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, \ assert _MODEL_PARALLEL_GROUP is not None, \
'model parallel group is not initialized' 'model parallel group is not initialized'
return _MODEL_PARALLEL_GROUP return _MODEL_PARALLEL_GROUP
def set_model_parallel_group(group):
"""Set model parallel group."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = group
def get_data_parallel_group(): def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to.""" """Get the data parallel group the caller rank belongs to."""
...@@ -102,6 +105,10 @@ def get_data_parallel_group(): ...@@ -102,6 +105,10 @@ def get_data_parallel_group():
'data parallel group is not initialized' 'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP return _DATA_PARALLEL_GROUP
def set_data_parallel_group(group):
"""Set data parallel group."""
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = group
def set_model_parallel_world_size(world_size): def set_model_parallel_world_size(world_size):
"""Set the model parallel size""" """Set the model parallel size"""
......
...@@ -127,19 +127,23 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -127,19 +127,23 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition, 0, init_method) self.num_embeddings_per_partition, 0, init_method)
def forward(self, input_): def forward(self, input_):
# Build the mask. if self.num_embeddings_per_partition < self.num_embeddings:
input_mask = (input_ < self.vocab_start_index) | \ # Build the mask.
(input_ >= self.vocab_end_index) input_mask = (input_ < self.vocab_start_index) | \
# Mask the input. (input_ >= self.vocab_end_index)
masked_input = input_.clone() - self.vocab_start_index # Mask the input.
masked_input[input_mask] = 0 masked_input = input_.clone() - self.vocab_start_index
# Get the embeddings. masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight, output_parallel = F.embedding(masked_input, self.weight,
self.padding_idx, self.max_norm, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.norm_type, self.scale_grad_by_freq,
self.sparse) self.sparse)
# Mask the output embedding. # Mask the output embedding.
output_parallel[input_mask, :] = 0.0 if self.num_embeddings_per_partition < self.num_embeddings:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel) output = reduce_from_model_parallel_region(output_parallel)
return output return output
......
...@@ -15,20 +15,19 @@ ...@@ -15,20 +15,19 @@
import torch import torch
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
def _reduce(input_): def _reduce(input_):
"""All-reduce the the input tensor across model parallel group.""" """All-reduce the the input tensor across model parallel group."""
group = get_model_parallel_group()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1: if get_model_parallel_world_size()==1:
return input_ return input_
# All-reduce. # All-reduce.
torch.distributed.all_reduce(input_, group=group) torch.distributed.all_reduce(input_, group=get_model_parallel_group())
return input_ return input_
...@@ -36,18 +35,17 @@ def _reduce(input_): ...@@ -36,18 +35,17 @@ def _reduce(input_):
def _split(input_): def _split(input_):
"""Split the tensor along its last dimension and keep the """Split the tensor along its last dimension and keep the
corresponding slice.""" corresponding slice."""
group = get_model_parallel_group()
world_size = get_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1: if world_size==1:
return input_ return input_
# Split along last dimension. # Split along last dimension.
world_size = torch.distributed.get_world_size(group=group)
input_list = split_tensor_along_last_dim(input_, world_size) input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default. # Note: torch.split does not create contiguous tensors by default.
rank = torch.distributed.get_rank(group=group) rank = get_model_parallel_rank()
output = input_list[rank].contiguous() output = input_list[rank].contiguous()
return output return output
...@@ -55,16 +53,15 @@ def _split(input_): ...@@ -55,16 +53,15 @@ def _split(input_):
def _gather(input_): def _gather(input_):
"""Gather tensors and concatinate along the last dimension.""" """Gather tensors and concatinate along the last dimension."""
group = get_model_parallel_group()
world_size = get_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1: if world_size==1:
return input_ return input_
# Size and dimension. # Size and dimension.
last_dim = input_.dim() - 1 last_dim = input_.dim() - 1
rank = torch.distributed.get_rank(group=group) rank = get_model_parallel_rank()
world_size = torch.distributed.get_world_size(group=group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_ tensor_list[rank] = input_
...@@ -79,6 +76,10 @@ def _gather(input_): ...@@ -79,6 +76,10 @@ def _gather(input_):
class _CopyToModelParallelRegion(torch.autograd.Function): class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region.""" """Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return input_ return input_
...@@ -91,6 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function): ...@@ -91,6 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function): class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region.""" """All-redcue the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _reduce(input_) return _reduce(input_)
...@@ -103,6 +108,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): ...@@ -103,6 +108,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function): class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank.""" """Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _split(input_) return _split(input_)
...@@ -115,6 +124,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): ...@@ -115,6 +124,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class _GatherFromModelParallelRegion(torch.autograd.Function): class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" """Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _gather(input_) return _gather(input_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment