Commit c0f10643 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

layernorm grad sync + name chnages

parent 5d4689c4
......@@ -67,7 +67,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True):
def __init__(self, normalized_shape, eps=1e-5,
no_persist_layer_norm=True,
sequence_parallel=False):
super(MixedFusedLayerNorm, self).__init__()
global fused_mix_prec_layer_norm_cuda
......@@ -92,6 +94,11 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm
self.sequence_parallel = sequence_parallel
# set sequence parallelism flag on weight and bias parameters
self.weight.sequence_parallel = self.sequence_parallel
self.bias.sequence_parallel = self.sequence_parallel
def reset_parameters(self):
......
......@@ -29,8 +29,13 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None):
"""LM logits using word embedding weights."""
args = get_args()
# Parallel logits.
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
if not args.model_parallel_memory_opt:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
else:
input_parallel = input_
# Matrix multiply.
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
......@@ -40,7 +45,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if parallel_output:
return logits_parallel
return mpu.gather_along_last_dim_from_tensor_model_parallel_region(logits_parallel)
return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler,
......
......@@ -447,7 +447,8 @@ class ParallelTransformerLayer(MegatronModule):
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
# Self attention.
self.self_attention = ParallelAttention(
......@@ -464,7 +465,8 @@ class ParallelTransformerLayer(MegatronModule):
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
......@@ -476,7 +478,8 @@ class ParallelTransformerLayer(MegatronModule):
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
# MLP
self.mlp = ParallelMLP(init_method,
......@@ -697,7 +700,8 @@ class ParallelTransformer(MegatronModule):
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
def _get_layer(self, layer_number):
return self.layers[layer_number]
......@@ -775,7 +779,7 @@ class ParallelTransformer(MegatronModule):
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
hidden_states = mpu.scatter_along_first_dim_to_tensor_model_parallel_region(hidden_states)
hidden_states = mpu.scatter_to_sequence_parallel_region(hidden_states)
else:
# See set_input_tensor()
......@@ -806,6 +810,9 @@ class ParallelTransformer(MegatronModule):
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
encoder_output = mpu.scatter_to_sequence_parallel_region(encoder_output)
# Forward pass.
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states,
......@@ -829,7 +836,7 @@ class ParallelTransformer(MegatronModule):
hidden_states = self.final_layernorm(hidden_states)
if self.model_parallel_memory_opt:
hidden_states = mpu.gather_along_first_dim_from_tensor_model_parallel_region(hidden_states)
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
output = hidden_states.transpose(0, 1).contiguous()
else:
......
......@@ -21,7 +21,6 @@ import torch
import apex
import torch.nn.functional as F
from megatron import get_args
from megatron.model import LayerNorm
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
get_linear_layer,
......
......@@ -58,12 +58,11 @@ from .layers import (set_tensor_model_parallel_attributes,
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_along_last_dim_to_tensor_model_parallel_region
from .mappings import gather_along_last_dim_from_tensor_model_parallel_region
from .mappings import scatter_along_first_dim_to_tensor_model_parallel_region
from .mappings import gather_along_first_dim_from_tensor_model_parallel_region
from .mappings import reduce_scatter_along_first_dim_to_tensor_model_parallel_region
from .mappings import reduce_scatter_along_last_dim_to_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_to_sequence_parallel_region
from .mappings import gather_from_seqeuence_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
......
......@@ -29,11 +29,11 @@ from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_along_first_dim_from_tensor_model_parallel_region
from .mappings import gather_along_last_dim_from_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_along_last_dim_to_tensor_model_parallel_region
from .mappings import reduce_scatter_along_first_dim_to_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
......@@ -328,7 +328,7 @@ class ColumnParallelLinear(torch.nn.Module):
else:
# Set up backprop all-reduce.
if self.model_parallel_memory_opt:
input_parallel = gather_along_first_dim_from_tensor_model_parallel_region(input_)
input_parallel = gather_from_sequence_parallel_region(input_)
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
......@@ -338,7 +338,7 @@ class ColumnParallelLinear(torch.nn.Module):
if self.gather_output:
# All-gather across the partitions.
assert not self.model_parallel_memory_opt
output = gather_along_last_dim_from_tensor_model_parallel_region(output_parallel)
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
......@@ -433,12 +433,12 @@ class RowParallelLinear(torch.nn.Module):
input_parallel = input_
else:
assert not self.model_parallel_memory_opt
input_parallel = scatter_along_last_dim_to_tensor_model_parallel_region(input_)
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
if self.model_parallel_memory_opt:
output_ = reduce_scatter_along_first_dim_to_tensor_model_parallel_region(output_parallel)
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
......
......@@ -32,7 +32,6 @@ def _reduce(input_):
return input_
def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
......@@ -51,6 +50,7 @@ def _split_along_last_dim(input_):
return output
def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
......@@ -174,7 +174,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
return grad_output
class _ScatterAlongLastDimToModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
......@@ -190,7 +190,7 @@ class _ScatterAlongLastDimToModelParallelRegion(torch.autograd.Function):
return _gather_along_last_dim(grad_output)
class _GatherAlongLastDimFromModelParallelRegion(torch.autograd.Function):
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
......@@ -203,10 +203,10 @@ class _GatherAlongLastDimFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
return _reduce_scatter_along_last_dim(grad_output)
return _split_along_last_dim(grad_output)
class _ScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function):
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
......@@ -222,7 +222,7 @@ class _ScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function):
return _gather_along_first_dim(grad_output)
class _GatherAlongFirstDimFromModelParallelRegion(torch.autograd.Function):
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" #TODO
@staticmethod
......@@ -238,23 +238,7 @@ class _GatherAlongFirstDimFromModelParallelRegion(torch.autograd.Function):
return _reduce_scatter_along_first_dim(grad_output)
class _ReduceScatterAlongLastDimToModelParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_last_dim(grad_output)
class _ReduceScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function):
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
......@@ -282,25 +266,22 @@ def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_along_last_dim_to_tensor_model_parallel_region(input_):
return _ScatterAlongLastDimToModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_along_last_dim_from_tensor_model_parallel_region(input_):
return _GatherAlongLastDimFromModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
def scatter_along_first_dim_to_tensor_model_parallel_region(input_):
return _ScatterAlongFirstDimToModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_along_first_dim_from_tensor_model_parallel_region(input_):
return _GatherAlongFirstDimFromModelParallelRegion.apply(input_)
def gather_from_seqeuence_parallel_region(input_):
return _GatherFromSequenceParallelRegion.apply(input_)
def reduce_scatter_along_first_dim_to_tensor_model_parallel_region(input_):
return _ReduceScatterAlongFirstDimToModelParallelRegion.apply(input_)
def reduce_scatter_to_sequence_parallel_region(input_):
return _ReduceScatterToSequenceParallelRegion.apply(input_)
def reduce_scatter_along_last_dim_to_tensor_model_parallel_region(input_):
return _ReduceScatterAlongLastDimToModelParallelRegion.apply(input_)
......@@ -17,7 +17,6 @@ from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from megatron import get_args
from megatron.model import LayerNorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
......
......@@ -23,6 +23,7 @@ import time
_TRAIN_START_TIME = time.time()
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args
from megatron import get_signal_handler
......@@ -418,6 +419,26 @@ def train_step(forward_step_func, data_iterator,
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if args.get_tensor_model_parallel_world_size > 1 and \
args.model_parallel_memory_opt:
grads = []
for model_module in model:
unwrapped_model = unwrap_model(
model_module, (torchDDP, LocalDDP, Float16Module))
for param in unwrapped_model.parameters():
if param.get_attr('sequence_parallel', False):
assert param.requires_grad and param.grad is not None
grads.append(param.grad.data)
coalesced = _flatten_dense_tensors(grads)
coalesced /= mpu.get_tensor_model_parallel_world_size()
torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment