Commit 6658158b authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

get sequence parallelism to work with pipeline parallelism

parent c0f10643
......@@ -619,6 +619,8 @@ class ParallelTransformer(MegatronModule):
super(ParallelTransformer, self).__init__()
args = get_args()
self.layer_type = layer_type
self.model_type = args.model_type
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
......@@ -629,7 +631,8 @@ class ParallelTransformer(MegatronModule):
# Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
self.distribute_checkpointed_activations = \
args.distribute_checkpointed_activations and not args.model_parallel_memory_opt
self.model_parallel_memory_opt = args.model_parallel_memory_opt
......@@ -807,9 +810,9 @@ class ParallelTransformer(MegatronModule):
)
# Transpose encoder output.
if encoder_output is not None:
if encoder_output is not None and \
not self.model_parallel_memory_opt:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
encoder_output = mpu.scatter_to_sequence_parallel_region(encoder_output)
......@@ -835,10 +838,15 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h].
hidden_states = self.final_layernorm(hidden_states)
if self.model_parallel_memory_opt:
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
if self.layer_type==LayerType.encoder and \
self.model_type==ModelType.encoder_and_decoder and \
self.model_parallel_memory_opt:
output = hidden_states
else:
if self.model_parallel_memory_opt:
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
output = hidden_states.transpose(0, 1).contiguous()
output = hidden_states.transpose(0, 1).contiguous()
else:
output = hidden_states
......
......@@ -61,7 +61,7 @@ from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_to_sequence_parallel_region
from .mappings import gather_from_seqeuence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import checkpoint
......
......@@ -278,7 +278,7 @@ def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_seqeuence_parallel_region(input_):
def gather_from_sequence_parallel_region(input_):
return _GatherFromSequenceParallelRegion.apply(input_)
......
......@@ -61,7 +61,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline:
if args.scatter_gather_tensors_in_pipeline and \
not args.model_parallel_memory_opt:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
tensor_chunk_shape = tensor_chunk_shape // \
......@@ -93,7 +94,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
args.scatter_gather_tensors_in_pipeline and \
not args.model_parallel_memory_opt:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
......@@ -138,7 +140,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
args.scatter_gather_tensors_in_pipeline and \
not args.model_parallel_memory_opt:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
......
......@@ -514,6 +514,21 @@ def get_tensor_shapes(rank, model_type):
# Otherwise, send one tensor (pre-transpose).
args = get_args()
tensor_shapes = []
if args.model_parallel_memory_opt:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
if model_type == ModelType.encoder_and_decoder:
decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size()
if mpu.is_pipeline_stage_before_split(rank):
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((decoder_seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
return tensor_shapes
if model_type == ModelType.encoder_and_decoder:
if mpu.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state.
......
......@@ -421,16 +421,16 @@ def train_step(forward_step_func, data_iterator,
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if args.get_tensor_model_parallel_world_size > 1 and \
if mpu.get_tensor_model_parallel_world_size() > 1 and \
args.model_parallel_memory_opt:
grads = []
for model_module in model:
unwrapped_model = unwrap_model(
model_module, (torchDDP, LocalDDP, Float16Module))
for param in unwrapped_model.parameters():
if param.get_attr('sequence_parallel', False):
assert param.requires_grad and param.grad is not None
grads.append(param.grad.data)
if getattr(param, 'sequence_parallel', False):
grad = param.main_grad if args.DDP_impl == 'local' else param.grad
grads.append(grad.data)
coalesced = _flatten_dense_tensors(grads)
coalesced /= mpu.get_tensor_model_parallel_world_size()
torch.distributed.all_reduce(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment