Commit 877f6a1b authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'sequence_parallel' into 'main'

fix for sequence parallelism in bert pooling

See merge request ADLR/megatron-lm!418
parents 32132c6e 8474e6e5
......@@ -116,7 +116,9 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
hidden_states = mpu.gather_from_sequence_parallel_region(
hidden_states,
to_model_parallel=False)
pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled)
......
......@@ -214,19 +214,25 @@ class _ScatterToSequenceParallelRegion(torch.autograd.Function):
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" #TODO
"""Gather the input from sequence parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
def symbolic(graph, input_, to_model_parallel=True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
def forward(ctx, input_, to_model_parallel=True):
ctx.to_model_parallel = to_model_parallel
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _reduce_scatter_along_first_dim(grad_output)
to_model_parallel = ctx.to_model_parallel
if to_model_parallel:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
......@@ -269,8 +275,8 @@ def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_):
return _GatherFromSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_, to_model_parallel=True):
return _GatherFromSequenceParallelRegion.apply(input_, to_model_parallel)
def reduce_scatter_to_sequence_parallel_region(input_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment