Commit c0f10643 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

layernorm grad sync + name chnages

parent 5d4689c4
...@@ -67,7 +67,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -67,7 +67,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class MixedFusedLayerNorm(torch.nn.Module): class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True): def __init__(self, normalized_shape, eps=1e-5,
no_persist_layer_norm=True,
sequence_parallel=False):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
...@@ -92,6 +94,11 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -92,6 +94,11 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.bias = Parameter(torch.Tensor(*normalized_shape)) self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters() self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm self.no_persist_layer_norm = no_persist_layer_norm
self.sequence_parallel = sequence_parallel
# set sequence parallelism flag on weight and bias parameters
self.weight.sequence_parallel = self.sequence_parallel
self.bias.sequence_parallel = self.sequence_parallel
def reset_parameters(self): def reset_parameters(self):
......
...@@ -29,8 +29,13 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal ...@@ -29,8 +29,13 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None): bias=None):
"""LM logits using word embedding weights.""" """LM logits using word embedding weights."""
args = get_args()
# Parallel logits. # Parallel logits.
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) if not args.model_parallel_memory_opt:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
else:
input_parallel = input_
# Matrix multiply. # Matrix multiply.
if bias is None: if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight) logits_parallel = F.linear(input_parallel, word_embeddings_weight)
...@@ -40,7 +45,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -40,7 +45,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
return mpu.gather_along_last_dim_from_tensor_model_parallel_region(logits_parallel) return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
......
...@@ -447,7 +447,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -447,7 +447,8 @@ class ParallelTransformerLayer(MegatronModule):
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm) no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
# Self attention. # Self attention.
self.self_attention = ParallelAttention( self.self_attention = ParallelAttention(
...@@ -464,7 +465,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -464,7 +465,8 @@ class ParallelTransformerLayer(MegatronModule):
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm) no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention( self.inter_attention = ParallelAttention(
...@@ -476,7 +478,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -476,7 +478,8 @@ class ParallelTransformerLayer(MegatronModule):
self.post_inter_attention_layernorm = LayerNorm( self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm) no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
# MLP # MLP
self.mlp = ParallelMLP(init_method, self.mlp = ParallelMLP(init_method,
...@@ -697,7 +700,8 @@ class ParallelTransformer(MegatronModule): ...@@ -697,7 +700,8 @@ class ParallelTransformer(MegatronModule):
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm) no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
return self.layers[layer_number] return self.layers[layer_number]
...@@ -775,7 +779,7 @@ class ParallelTransformer(MegatronModule): ...@@ -775,7 +779,7 @@ class ParallelTransformer(MegatronModule):
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt: if self.model_parallel_memory_opt:
hidden_states = mpu.scatter_along_first_dim_to_tensor_model_parallel_region(hidden_states) hidden_states = mpu.scatter_to_sequence_parallel_region(hidden_states)
else: else:
# See set_input_tensor() # See set_input_tensor()
...@@ -806,6 +810,9 @@ class ParallelTransformer(MegatronModule): ...@@ -806,6 +810,9 @@ class ParallelTransformer(MegatronModule):
if encoder_output is not None: if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous() encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
encoder_output = mpu.scatter_to_sequence_parallel_region(encoder_output)
# Forward pass. # Forward pass.
if self.activations_checkpoint_method is not None: if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
...@@ -829,7 +836,7 @@ class ParallelTransformer(MegatronModule): ...@@ -829,7 +836,7 @@ class ParallelTransformer(MegatronModule):
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
if self.model_parallel_memory_opt: if self.model_parallel_memory_opt:
hidden_states = mpu.gather_along_first_dim_from_tensor_model_parallel_region(hidden_states) hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
output = hidden_states.transpose(0, 1).contiguous() output = hidden_states.transpose(0, 1).contiguous()
else: else:
......
...@@ -21,7 +21,6 @@ import torch ...@@ -21,7 +21,6 @@ import torch
import apex import apex
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.model import LayerNorm
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import ( from megatron.model.utils import (
get_linear_layer, get_linear_layer,
......
...@@ -58,12 +58,11 @@ from .layers import (set_tensor_model_parallel_attributes, ...@@ -58,12 +58,11 @@ from .layers import (set_tensor_model_parallel_attributes,
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_along_last_dim_to_tensor_model_parallel_region from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import gather_along_last_dim_from_tensor_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_along_first_dim_to_tensor_model_parallel_region from .mappings import scatter_to_sequence_parallel_region
from .mappings import gather_along_first_dim_from_tensor_model_parallel_region from .mappings import gather_from_seqeuence_parallel_region
from .mappings import reduce_scatter_along_first_dim_to_tensor_model_parallel_region from .mappings import reduce_scatter_to_sequence_parallel_region
from .mappings import reduce_scatter_along_last_dim_to_tensor_model_parallel_region
from .random import checkpoint from .random import checkpoint
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
......
...@@ -29,11 +29,11 @@ from .initialize import get_tensor_model_parallel_rank ...@@ -29,11 +29,11 @@ from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_along_first_dim_from_tensor_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import gather_along_last_dim_from_tensor_model_parallel_region from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_along_last_dim_to_tensor_model_parallel_region from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import reduce_scatter_along_first_dim_to_tensor_model_parallel_region from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .utils import divide from .utils import divide
...@@ -328,7 +328,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -328,7 +328,7 @@ class ColumnParallelLinear(torch.nn.Module):
else: else:
# Set up backprop all-reduce. # Set up backprop all-reduce.
if self.model_parallel_memory_opt: if self.model_parallel_memory_opt:
input_parallel = gather_along_first_dim_from_tensor_model_parallel_region(input_) input_parallel = gather_from_sequence_parallel_region(input_)
else: else:
input_parallel = copy_to_tensor_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
...@@ -338,7 +338,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -338,7 +338,7 @@ class ColumnParallelLinear(torch.nn.Module):
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
assert not self.model_parallel_memory_opt assert not self.model_parallel_memory_opt
output = gather_along_last_dim_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
...@@ -433,12 +433,12 @@ class RowParallelLinear(torch.nn.Module): ...@@ -433,12 +433,12 @@ class RowParallelLinear(torch.nn.Module):
input_parallel = input_ input_parallel = input_
else: else:
assert not self.model_parallel_memory_opt assert not self.model_parallel_memory_opt
input_parallel = scatter_along_last_dim_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight) output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions. # All-reduce across all the partitions.
if self.model_parallel_memory_opt: if self.model_parallel_memory_opt:
output_ = reduce_scatter_along_first_dim_to_tensor_model_parallel_region(output_parallel) output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else: else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add: if not self.skip_bias_add:
......
...@@ -32,7 +32,6 @@ def _reduce(input_): ...@@ -32,7 +32,6 @@ def _reduce(input_):
return input_ return input_
def _split_along_last_dim(input_): def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the """Split the tensor along its last dimension and keep the
corresponding slice.""" corresponding slice."""
...@@ -51,6 +50,7 @@ def _split_along_last_dim(input_): ...@@ -51,6 +50,7 @@ def _split_along_last_dim(input_):
return output return output
def _split_along_first_dim(input_): def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the """Split the tensor along its first dimension and keep the
corresponding slice.""" corresponding slice."""
...@@ -174,7 +174,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): ...@@ -174,7 +174,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
return grad_output return grad_output
class _ScatterAlongLastDimToModelParallelRegion(torch.autograd.Function): class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank.""" """Split the input and keep only the corresponding chuck to the rank."""
@staticmethod @staticmethod
...@@ -190,7 +190,7 @@ class _ScatterAlongLastDimToModelParallelRegion(torch.autograd.Function): ...@@ -190,7 +190,7 @@ class _ScatterAlongLastDimToModelParallelRegion(torch.autograd.Function):
return _gather_along_last_dim(grad_output) return _gather_along_last_dim(grad_output)
class _GatherAlongLastDimFromModelParallelRegion(torch.autograd.Function): class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" """Gather the input from model parallel region and concatinate."""
@staticmethod @staticmethod
...@@ -203,10 +203,10 @@ class _GatherAlongLastDimFromModelParallelRegion(torch.autograd.Function): ...@@ -203,10 +203,10 @@ class _GatherAlongLastDimFromModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _reduce_scatter_along_last_dim(grad_output) return _split_along_last_dim(grad_output)
class _ScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function): class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank.""" """Split the input and keep only the corresponding chuck to the rank."""
@staticmethod @staticmethod
...@@ -222,7 +222,7 @@ class _ScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function): ...@@ -222,7 +222,7 @@ class _ScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function):
return _gather_along_first_dim(grad_output) return _gather_along_first_dim(grad_output)
class _GatherAlongFirstDimFromModelParallelRegion(torch.autograd.Function): class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" #TODO """Gather the input from model parallel region and concatinate.""" #TODO
@staticmethod @staticmethod
...@@ -238,23 +238,7 @@ class _GatherAlongFirstDimFromModelParallelRegion(torch.autograd.Function): ...@@ -238,23 +238,7 @@ class _GatherAlongFirstDimFromModelParallelRegion(torch.autograd.Function):
return _reduce_scatter_along_first_dim(grad_output) return _reduce_scatter_along_first_dim(grad_output)
class _ReduceScatterAlongLastDimToModelParallelRegion(torch.autograd.Function): class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_last_dim(grad_output)
class _ReduceScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region.""" """Reduce scatter the input from the model parallel region."""
@staticmethod @staticmethod
...@@ -282,25 +266,22 @@ def reduce_from_tensor_model_parallel_region(input_): ...@@ -282,25 +266,22 @@ def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_) return _ReduceFromModelParallelRegion.apply(input_)
def scatter_along_last_dim_to_tensor_model_parallel_region(input_): def scatter_to_tensor_model_parallel_region(input_):
return _ScatterAlongLastDimToModelParallelRegion.apply(input_) return _ScatterToModelParallelRegion.apply(input_)
def gather_along_last_dim_from_tensor_model_parallel_region(input_):
return _GatherAlongLastDimFromModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
def scatter_along_first_dim_to_tensor_model_parallel_region(input_):
return _ScatterAlongFirstDimToModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_along_first_dim_from_tensor_model_parallel_region(input_):
return _GatherAlongFirstDimFromModelParallelRegion.apply(input_)
def gather_from_seqeuence_parallel_region(input_):
return _GatherFromSequenceParallelRegion.apply(input_)
def reduce_scatter_along_first_dim_to_tensor_model_parallel_region(input_):
return _ReduceScatterAlongFirstDimToModelParallelRegion.apply(input_)
def reduce_scatter_to_sequence_parallel_region(input_):
return _ReduceScatterToSequenceParallelRegion.apply(input_)
def reduce_scatter_along_last_dim_to_tensor_model_parallel_region(input_):
return _ReduceScatterAlongLastDimToModelParallelRegion.apply(input_)
...@@ -17,7 +17,6 @@ from apex.optimizers import FusedAdam as Adam ...@@ -17,7 +17,6 @@ from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD from apex.optimizers import FusedSGD as SGD
from megatron import get_args from megatron import get_args
from megatron.model import LayerNorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
......
...@@ -23,6 +23,7 @@ import time ...@@ -23,6 +23,7 @@ import time
_TRAIN_START_TIME = time.time() _TRAIN_START_TIME = time.time()
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args from megatron import get_args
from megatron import get_signal_handler from megatron import get_signal_handler
...@@ -418,6 +419,26 @@ def train_step(forward_step_func, data_iterator, ...@@ -418,6 +419,26 @@ def train_step(forward_step_func, data_iterator,
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache() torch.cuda.empty_cache()
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if args.get_tensor_model_parallel_world_size > 1 and \
args.model_parallel_memory_opt:
grads = []
for model_module in model:
unwrapped_model = unwrap_model(
model_module, (torchDDP, LocalDDP, Float16Module))
for param in unwrapped_model.parameters():
if param.get_attr('sequence_parallel', False):
assert param.requires_grad and param.grad is not None
grads.append(param.grad.data)
coalesced = _flatten_dense_tensors(grads)
coalesced /= mpu.get_tensor_model_parallel_world_size()
torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment