"vscode:/vscode.git/clone" did not exist on "987d34b0cf8d6cd8725258332fcfc8c54529b1ab"
Commit 417c7f6a authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Changes for NeMo/lightning compatibility


Signed-off-by: default avatarBoris Fomitchev <bfomitchev@nvidia.com>
parent 07ebf714
......@@ -39,12 +39,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# This is temporary WAR to make simple case like pytest calling with same args twice
# Need to implement clean factory init.
if mpu.model_parallel_is_initialized():
return
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
......
......@@ -88,13 +88,16 @@ def model_parallel_is_initialized():
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, \
'model parallel group is not initialized'
return _MODEL_PARALLEL_GROUP
def set_model_parallel_group(group):
"""Set model parallel group."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = group
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
......@@ -102,6 +105,10 @@ def get_data_parallel_group():
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def set_data_parallel_group(group):
"""Set data parallel group."""
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = group
def set_model_parallel_world_size(world_size):
"""Set the model parallel size"""
......
......@@ -127,19 +127,23 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition, 0, init_method)
def forward(self, input_):
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
# Get the embeddings.
if self.num_embeddings_per_partition < self.num_embeddings:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.0
if self.num_embeddings_per_partition < self.num_embeddings:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
return output
......
......@@ -15,20 +15,19 @@
import torch
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank
from .utils import split_tensor_along_last_dim
def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
group = get_model_parallel_group()
# Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1:
if get_model_parallel_world_size()==1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=group)
torch.distributed.all_reduce(input_, group=get_model_parallel_group())
return input_
......@@ -36,18 +35,17 @@ def _reduce(input_):
def _split(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
group = get_model_parallel_group()
world_size = get_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1:
if world_size==1:
return input_
# Split along last dimension.
world_size = torch.distributed.get_world_size(group=group)
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = torch.distributed.get_rank(group=group)
rank = get_model_parallel_rank()
output = input_list[rank].contiguous()
return output
......@@ -55,16 +53,15 @@ def _split(input_):
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
group = get_model_parallel_group()
world_size = get_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1:
if world_size==1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
rank = get_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
......@@ -79,6 +76,10 @@ def _gather(input_):
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
......@@ -91,6 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
......@@ -103,6 +108,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_):
return _split(input_)
......@@ -115,6 +124,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_):
return _gather(input_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment