Commit 8474e6e5 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

fix for sequence parallelism in bert pooling

parent 3f91f09b
...@@ -116,7 +116,9 @@ class Pooler(MegatronModule): ...@@ -116,7 +116,9 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions # gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes # same pooler is run on all tensor parallel nodes
if self.sequence_parallel: if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states) hidden_states = mpu.gather_from_sequence_parallel_region(
hidden_states,
to_model_parallel=False)
pooled = hidden_states[sequence_index, :, :] pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled) pooled = self.dense(pooled)
......
...@@ -214,19 +214,25 @@ class _ScatterToSequenceParallelRegion(torch.autograd.Function): ...@@ -214,19 +214,25 @@ class _ScatterToSequenceParallelRegion(torch.autograd.Function):
class _GatherFromSequenceParallelRegion(torch.autograd.Function): class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" #TODO """Gather the input from sequence parallel region and concatinate."""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_, to_model_parallel=True):
return _gather_along_first_dim(input_) return _gather_along_first_dim(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_, to_model_parallel=True):
ctx.to_model_parallel = to_model_parallel
return _gather_along_first_dim(input_) return _gather_along_first_dim(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _reduce_scatter_along_first_dim(grad_output) to_model_parallel = ctx.to_model_parallel
if to_model_parallel:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
...@@ -269,8 +275,8 @@ def scatter_to_sequence_parallel_region(input_): ...@@ -269,8 +275,8 @@ def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_) return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_): def gather_from_sequence_parallel_region(input_, to_model_parallel=True):
return _GatherFromSequenceParallelRegion.apply(input_) return _GatherFromSequenceParallelRegion.apply(input_, to_model_parallel)
def reduce_scatter_to_sequence_parallel_region(input_): def reduce_scatter_to_sequence_parallel_region(input_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment