Commit 9aad9203 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'sequence_parallel' into 'main'

Sequence parallelism + attention checkpoint

See merge request ADLR/megatron-lm!413
parents 667c2bcb 3f91f09b
...@@ -103,14 +103,20 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -103,14 +103,20 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.model_parallel_size is None, '--model-parallel-size is no ' \ assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead' 'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size del args.model_parallel_size
if args.checkpoint_activations: if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform' args.recompute_granularity = 'full'
args.recompute_method = 'uniform'
if args.rank == 0: if args.rank == 0:
print('--checkpoint-activations is no longer valid, ' print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. ' 'use --recompute-granularity and --recompute-method instead. '
'Defaulting to activation-checkpoint-method=uniform.') 'Defaulting to recompute-granularity=full and recompute-method=uniform.')
del args.checkpoint_activations del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
del args.recompute_activations
# Set input defaults. # Set input defaults.
for key in defaults: for key in defaults:
# For default to be valid, it should not be provided in the # For default to be valid, it should not be provided in the
...@@ -278,19 +284,32 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -278,19 +284,32 @@ def parse_args(extra_args_provider=None, defaults={},
'pytorch v1.11 (nvidia pytorch container paired with v1.11). ' 'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True') 'Defaulting to no_persist_layer_norm=True')
# Activation checkpointing. # Activation recomputing.
if args.distribute_checkpointed_activations: if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \ assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \ 'recomputed activations only across tensor model ' \
'parallel groups' 'parallel groups'
assert args.activations_checkpoint_method is not None, \ assert args.recompute_granularity == 'full', \
'for distributed checkpoint activations to work you '\ 'distributed recompute activations is only '\
'need to use a activation-checkpoint method ' 'application to full recompute granularity'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \ assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed checkpoint activations are supported for pytorch ' \ 'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
_print_args(args) _print_args(args)
return args return args
...@@ -471,27 +490,40 @@ def _add_training_args(parser): ...@@ -471,27 +490,40 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase' ' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval' 'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.') 'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--checkpoint-activations', action='store_true', group.add_argument('--recompute-activations', action='store_true',
help='Checkpoint activation to allow for training ' help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.') 'with larger models, sequences, and batch sizes.')
group.add_argument('--distribute-checkpointed-activations', group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.')
group.add_argument('--distribute-saved-activations',
action='store_true', action='store_true',
help='If set, distribute checkpointed activations ' help='If set, distribute recomputed activations '
'across model parallel group.') 'across model parallel group.')
group.add_argument('--activations-checkpoint-method', type=str, default=None, group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'], choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of ' help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of ' 'Transformer layers and recompute the input activation of '
'each divided chunk, ' 'each divided chunk at specified granularity, '
'2) checkpoint the input activations of only a set number of ' '2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the ' 'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing' 'rest without any recomputing at specified granularity'
'default) do not apply activations checkpoint to any layers') 'default) do not apply activations recompute to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1, group.add_argument('--recompute-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each ' help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, ' 'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers ' '2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.') 'to recompute within each pipeline stage.')
# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--train-iters', type=int, default=None, group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all ' help='Total number of iterations to train over all '
'training runs. Note that either train-iters or ' 'training runs. Note that either train-iters or '
...@@ -540,6 +572,8 @@ def _add_training_args(parser): ...@@ -540,6 +572,8 @@ def _add_training_args(parser):
'This kernel supports only a set of hidden sizes. Please ' 'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden ' 'check persist_ln_hidden_sizes if your hidden '
'size is supported.') 'size is supported.')
group.add_argument('--sequence-parallel', action='store_true',
help='Enable sequence parallel optimization.')
group.add_argument('--no-gradient-accumulation-fusion', group.add_argument('--no-gradient-accumulation-fusion',
action='store_false', action='store_false',
help='Disable fusing gradient accumulation to weight ' help='Disable fusing gradient accumulation to weight '
......
...@@ -278,9 +278,13 @@ def _warmup_jit_function(): ...@@ -278,9 +278,13 @@ def _warmup_jit_function():
del bias, input, output del bias, input, output
# Warmup fused bias+dropout+add # Warmup fused bias+dropout+add
input = torch.rand((args.seq_length, args.micro_batch_size, args.hidden_size), if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
input = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype, device='cuda') dtype=dtype, device='cuda')
residual = torch.rand((args.seq_length, args.micro_batch_size, args.hidden_size), residual = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype, device='cuda') dtype=dtype, device='cuda')
bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual) bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual)
dropout_rate = 0.1 dropout_rate = 0.1
......
...@@ -78,7 +78,12 @@ class BertLMHead(MegatronModule): ...@@ -78,7 +78,12 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)
self.layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
self.gelu = openai_gelu self.gelu = openai_gelu
...@@ -110,14 +115,20 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -110,14 +115,20 @@ def post_language_model_processing(lm_output, pooled_output,
binary_logits = binary_head(pooled_output) binary_logits = binary_head(pooled_output)
if lm_labels is None: if lm_labels is None:
return lm_logits, binary_logits # [s b h] => [b s h]
return lm_logits.transpose(0,1).contiguous(), binary_logits
else: else:
# [b s] => [s b]
lm_labels = lm_labels.transpose(0,1).contiguous()
# lm_logits : [s, b, h] and lm_labels: [s, b]
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels) lm_labels)
# [s, b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss, binary_logits return lm_loss, binary_logits
......
...@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule): ...@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule):
pool_mask = (input_ids == self.pad_id).unsqueeze(2) pool_mask = (input_ids == self.pad_id).unsqueeze(2)
# Taking the representation of the [CLS] token of BERT # Taking the representation of the [CLS] token of BERT
pooled_output = lm_output[:, 0, :] pooled_output = lm_output[0, :, :]
# Converting to float16 dtype # Converting to float16 dtype
pooled_output = pooled_output.to(lm_output.dtype) pooled_output = pooled_output.to(lm_output.dtype)
......
...@@ -69,7 +69,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -69,7 +69,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class MixedFusedLayerNorm(torch.nn.Module): class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True): def __init__(self, normalized_shape, eps=1e-5,
no_persist_layer_norm=True,
sequence_parallel=False):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
...@@ -94,6 +96,11 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -94,6 +96,11 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.bias = Parameter(torch.Tensor(*normalized_shape)) self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters() self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm self.no_persist_layer_norm = no_persist_layer_norm
self.sequence_parallel = sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
def reset_parameters(self): def reset_parameters(self):
......
...@@ -32,20 +32,26 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -32,20 +32,26 @@ def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output, parallel_output,
fp16_lm_cross_entropy): fp16_lm_cross_entropy):
# Output. # Output. Format [s b h]
output = parallel_lm_logits( output = parallel_lm_logits(
lm_output, lm_output,
logit_weights, logit_weights,
parallel_output) parallel_output)
if labels is None: if labels is None:
return output # [s b h] => [b s h]
return output.transpose(0,1).contiguous()
else: else:
# [b s] => [s b]
labels = labels.transpose(0,1).contiguous()
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert output.dtype == torch.half assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels) loss = mpu.vocab_parallel_cross_entropy(output, labels)
else: else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
# [s b] => [b, s]
loss = loss.transpose(0,1).contiguous()
return loss return loss
......
...@@ -26,23 +26,29 @@ from megatron.model.transformer import ParallelTransformer ...@@ -26,23 +26,29 @@ from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal from megatron.model.utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None): bias=None):
"""LM logits using word embedding weights.""" """LM logits using word embedding weights."""
args = get_args() args = get_args()
# Parallel logits. # Parallel logits.
if args.async_tensor_model_parallel_allreduce: if args.async_tensor_model_parallel_allreduce or\
args.sequence_parallel:
input_parallel = input_ input_parallel = input_
async_grad_allreduce = mpu.get_tensor_model_parallel_world_size() > 1 model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.sequence_parallel
else: else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False async_grad_allreduce = False
# Matrix multiply. # Matrix multiply.
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncAllreduce.apply( logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, word_embeddings_weight, bias, input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion, args.gradient_accumulation_fusion,
async_grad_allreduce) async_grad_allreduce, args.sequence_parallel)
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
...@@ -98,12 +104,21 @@ class Pooler(MegatronModule): ...@@ -98,12 +104,21 @@ class Pooler(MegatronModule):
def __init__(self, hidden_size, init_method): def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__() super(Pooler, self).__init__()
args = get_args()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.sequence_parallel = args.sequence_parallel
def forward(self, hidden_states, sequence_index=0): def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h] # hidden_states: [s, b, h]
# sequence_index: index of the token to pool. # sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled) pooled = self.dense(pooled)
pooled = torch.tanh(pooled) pooled = torch.tanh(pooled)
return pooled return pooled
...@@ -164,6 +179,8 @@ class Embedding(MegatronModule): ...@@ -164,6 +179,8 @@ class Embedding(MegatronModule):
else: else:
self.tokentype_embeddings = None self.tokentype_embeddings = None
self.fp32_residual_connection = args.fp32_residual_connection
self.sequence_parallel = args.sequence_parallel
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
...@@ -205,8 +222,20 @@ class Embedding(MegatronModule): ...@@ -205,8 +222,20 @@ class Embedding(MegatronModule):
else: else:
assert self.tokentype_embeddings is None assert self.tokentype_embeddings is None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout. # Dropout.
embeddings = self.embedding_dropout(embeddings) if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
with mpu.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings return embeddings
......
...@@ -152,19 +152,24 @@ class T5Model(MegatronModule): ...@@ -152,19 +152,24 @@ class T5Model(MegatronModule):
if self.post_process and self.add_decoder: if self.post_process and self.add_decoder:
decoder_output, encoder_output = lm_output decoder_output, encoder_output = lm_output
# Output. # Output. [s, b, h]
lm_logits = self.lm_head(decoder_output, lm_logits = self.lm_head(decoder_output,
self.word_embeddings_weight()) self.word_embeddings_weight())
if lm_labels is None: if lm_labels is None:
return lm_logits # [s b h] => [b s h]
return lm_logits.transpose(0,1).contiguous()
else: else:
# [b s] => [s b]
lm_labels = lm_labels.transpose(0,1).contiguous()
if self.fp16_lm_cross_entropy: if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels) lm_labels)
# [s b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss return lm_loss
elif self.add_decoder and not self.add_encoder: elif self.add_decoder and not self.add_encoder:
decoder_output, encoder_output = lm_output decoder_output, encoder_output = lm_output
......
This diff is collapsed.
...@@ -21,7 +21,6 @@ import torch ...@@ -21,7 +21,6 @@ import torch
import apex import apex
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.model import LayerNorm
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import ( from megatron.model.utils import (
get_linear_layer, get_linear_layer,
......
...@@ -49,18 +49,21 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi ...@@ -49,18 +49,21 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncAllreduce from .layers import LinearWithGradAccumulationAndAsyncCommunication
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes, from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes, set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes) copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import checkpoint from .random import checkpoint
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
......
...@@ -30,20 +30,21 @@ from .initialize import get_tensor_model_parallel_world_size ...@@ -30,20 +30,21 @@ from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
from .utils import VocabUtility from .utils import VocabUtility
from megatron import get_args from megatron import get_args
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1, 'partition_dim': -1,
'partition_stride': 1} 'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param): def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or ( param.tensor_model_parallel) or (
...@@ -199,19 +200,39 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -199,19 +200,39 @@ class VocabParallelEmbedding(torch.nn.Module):
return output return output
class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function): class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
""" """
Linear layer execution with asynchronous all-reduce and gradient accumulation Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop. fusion in backprop.
""" """
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce): async_grad_allreduce, sequence_parallel):
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input, weight.t()) ctx.sequence_parallel = sequence_parallel
if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group())
total_input = all_gather_buffer
else:
total_input = input
output = torch.matmul(total_input, weight.t())
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output return output
...@@ -220,13 +241,39 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function): ...@@ -220,13 +241,39 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
handle = torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel:
handle.wait()
# Convert the tensor shapes to 2D for execution compatibility # Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
grad_output.shape[2]) grad_output.shape[2])
input = input.view(input.shape[0] * input.shape[1], input.shape[2]) total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
total_input.shape[2])
if ctx.async_grad_allreduce: if ctx.async_grad_allreduce:
# Asynchronous all-reduce # Asynchronous all-reduce
handle = torch.distributed.all_reduce( handle = torch.distributed.all_reduce(
...@@ -234,16 +281,38 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function): ...@@ -234,16 +281,38 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
# Delay the start of weight gradient computation shortly (3us) to have # Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated # all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce
dim_size = list(input.size())
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(),
async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
import fused_dense_cuda import fused_dense_cuda
fused_dense_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad) fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
grad_weight = None grad_weight = None
else: else:
grad_weight = grad_output.t().matmul(input) grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None
if ctx.async_grad_allreduce: if ctx.async_grad_allreduce:
handle.wait() handle.wait()
return grad_input, grad_weight, grad_bias, None, None
return grad_input, grad_weight, grad_bias, None, None, None
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
...@@ -323,23 +392,28 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -323,23 +392,28 @@ class ColumnParallelLinear(torch.nn.Module):
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
args.async_tensor_model_parallel_allreduce and args.async_tensor_model_parallel_allreduce and
world_size > 1) world_size > 1)
self.sequence_parallel = (
args.sequence_parallel and
world_size > 1)
assert not self.async_tensor_model_parallel_allreduce or \
not self.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce: if self.async_tensor_model_parallel_allreduce or \
self.sequence_parallel:
input_parallel = input_ input_parallel = input_
else: else:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncAllreduce.apply( output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, bias, self.gradient_accumulation_fusion, input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce) self.async_tensor_model_parallel_allreduce, self.sequence_parallel)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
assert not self.sequence_parallel
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
...@@ -420,26 +494,34 @@ class RowParallelLinear(torch.nn.Module): ...@@ -420,26 +494,34 @@ class RowParallelLinear(torch.nn.Module):
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(), self.output_size, device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
setattr(self.bias, 'sequence_parallel', args.sequence_parallel)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.sequence_parallel = args.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce. # Set up backprop all-reduce.
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncAllreduce.apply( output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, None, input_parallel, self.weight, None,
self.gradient_accumulation_fusion, None) self.gradient_accumulation_fusion, None, None)
# All-reduce across all the partitions. # All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel) if self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_
output_bias = None output_bias = None
......
...@@ -32,13 +32,13 @@ def _reduce(input_): ...@@ -32,13 +32,13 @@ def _reduce(input_):
return input_ return input_
def _split(input_): def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the """Split the tensor along its last dimension and keep the
corresponding slice.""" corresponding slice."""
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size == 1:
return input_ return input_
# Split along last dimension. # Split along last dimension.
...@@ -51,12 +51,34 @@ def _split(input_): ...@@ -51,12 +51,34 @@ def _split(input_):
return output return output
def _gather(input_): def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along first dimension.
dim_size = input_.size()[0]
assert dim_size % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank()
dim_offset = rank * local_dim_size
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
return output
def _gather_along_last_dim(input_):
"""Gather tensors and concatinate along the last dimension.""" """Gather tensors and concatinate along the last dimension."""
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size == 1:
return input_ return input_
# Size and dimension. # Size and dimension.
...@@ -73,6 +95,44 @@ def _gather(input_): ...@@ -73,6 +95,44 @@ def _gather(input_):
return output return output
def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._all_gather_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
class _CopyToModelParallelRegion(torch.autograd.Function): class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region.""" """Pass the input to the model parallel region."""
...@@ -110,15 +170,15 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): ...@@ -110,15 +170,15 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _split(input_) return _split_along_last_dim(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _split(input_) return _split_along_last_dim(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _gather(grad_output) return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function): class _GatherFromModelParallelRegion(torch.autograd.Function):
...@@ -126,15 +186,63 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): ...@@ -126,15 +186,63 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _gather(input_) return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split_along_last_dim(grad_output)
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" #TODO
@staticmethod
def symbolic(graph, input_):
return _gather_along_first_dim(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _gather(input_) return _gather_along_first_dim(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _split(grad_output) return _reduce_scatter_along_first_dim(grad_output)
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
# ----------------- # -----------------
...@@ -155,3 +263,16 @@ def scatter_to_tensor_model_parallel_region(input_): ...@@ -155,3 +263,16 @@ def scatter_to_tensor_model_parallel_region(input_):
def gather_from_tensor_model_parallel_region(input_): def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_) return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_):
return _GatherFromSequenceParallelRegion.apply(input_)
def reduce_scatter_to_sequence_parallel_region(input_):
return _ReduceScatterToSequenceParallelRegion.apply(input_)
...@@ -307,10 +307,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -307,10 +307,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset. tracked/set/reset.
""" """
@staticmethod @staticmethod
def forward(ctx, run_function, distribute_checkpointed_activations, *args): def forward(ctx, run_function, distribute_saved_activations, *args):
ctx.run_function = run_function ctx.run_function = run_function
ctx.distribute_checkpointed_activations \ ctx.distribute_saved_activations \
= distribute_checkpointed_activations = distribute_saved_activations
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
...@@ -322,7 +322,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -322,7 +322,7 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep # Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank. # the chunk corresponding to the current rank.
if distribute_checkpointed_activations: if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data( safely_set_viewless_tensor_data(
args[0], args[0],
...@@ -339,7 +339,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -339,7 +339,7 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), " raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible") "please use .backward() if possible")
inputs = ctx.saved_tensors inputs = ctx.saved_tensors
if ctx.distribute_checkpointed_activations: if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data( safely_set_viewless_tensor_data(
inputs[0], inputs[0],
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)) gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
...@@ -372,8 +372,8 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -372,8 +372,8 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + grads return (None, None) + grads
def checkpoint(function, distribute_checkpointed_activations, *args): def checkpoint(function, distribute_saved_activations, *args):
"""Checkpoint a model or part of the model. """Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint.""" This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, return CheckpointFunction.apply(function,
distribute_checkpointed_activations, *args) distribute_saved_activations, *args)
...@@ -17,7 +17,6 @@ from apex.optimizers import FusedAdam as Adam ...@@ -17,7 +17,6 @@ from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD from apex.optimizers import FusedSGD as SGD
from megatron import get_args from megatron import get_args
from megatron.model import LayerNorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
......
...@@ -264,7 +264,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -264,7 +264,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if param in self.optimizer.state: if param in self.optimizer.state:
self.optimizer.state[main_param] \ self.optimizer.state[main_param] \
= self.optimizer.state.pop(param) = self.optimizer.state.pop(param)
# fp32 params. # fp32 params.
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param) fp32_params_this_group.append(param)
...@@ -282,10 +281,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -282,10 +281,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
fp32_from_float16_params_this_group) fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
......
...@@ -61,7 +61,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -61,7 +61,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
override_scatter_gather_tensors_in_pipeline = False override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline: if args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0: if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
tensor_chunk_shape = tensor_chunk_shape // \ tensor_chunk_shape = tensor_chunk_shape // \
...@@ -93,7 +94,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -93,7 +94,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# Split tensor into smaller chunks if using scatter-gather optimization. # Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and \ if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline: args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel:
if tensor_send_next is not None: if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
...@@ -138,7 +140,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -138,7 +140,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# If using scatter-gather optimization, gather smaller chunks. # If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and \ if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline: args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel:
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
......
...@@ -279,8 +279,12 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -279,8 +279,12 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank() pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
args = get_args() args = get_args()
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
# Compute number of warmup and remaining microbatches. # Compute number of warmup and remaining microbatches.
num_model_chunks = len(model) num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks num_microbatches = get_num_microbatches() * num_model_chunks
...@@ -514,18 +518,25 @@ def get_tensor_shapes(rank, model_type): ...@@ -514,18 +518,25 @@ def get_tensor_shapes(rank, model_type):
# Otherwise, send one tensor (pre-transpose). # Otherwise, send one tensor (pre-transpose).
args = get_args() args = get_args()
tensor_shapes = [] tensor_shapes = []
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
if model_type == ModelType.encoder_and_decoder: if model_type == ModelType.encoder_and_decoder:
if args.sequence_parallel:
decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size()
else:
decoder_seq_length = args.decoder_seq_length
if mpu.is_pipeline_stage_before_split(rank): if mpu.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state. tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
if mpu.is_pipeline_stage_before_split(rank+1):
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
else: else:
tensor_shapes.append((args.decoder_seq_length, args.micro_batch_size, args.hidden_size)) tensor_shapes.append((decoder_seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size)) tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
else: else:
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size)) tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
return tensor_shapes return tensor_shapes
......
...@@ -23,6 +23,7 @@ import time ...@@ -23,6 +23,7 @@ import time
_TRAIN_START_TIME = time.time() _TRAIN_START_TIME = time.time()
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args from megatron import get_args
from megatron import get_signal_handler from megatron import get_signal_handler
...@@ -421,6 +422,25 @@ def train_step(forward_step_func, data_iterator, ...@@ -421,6 +422,25 @@ def train_step(forward_step_func, data_iterator,
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache() torch.cuda.empty_cache()
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if mpu.get_tensor_model_parallel_world_size() > 1 and \
args.sequence_parallel:
grads = []
for model_module in model:
unwrapped_model = unwrap_model(
model_module, (torchDDP, LocalDDP, Float16Module))
for param in unwrapped_model.parameters():
if getattr(param, 'sequence_parallel', False):
grad = param.main_grad if args.DDP_impl == 'local' else param.grad
grads.append(grad.data)
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment