Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8474e6e5
Commit
8474e6e5
authored
May 24, 2022
by
Vijay Korthikanti
Browse files
fix for sequence parallelism in bert pooling
parent
3f91f09b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
7 deletions
+15
-7
megatron/model/language_model.py
megatron/model/language_model.py
+3
-1
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+12
-6
No files found.
megatron/model/language_model.py
View file @
8474e6e5
...
@@ -116,7 +116,9 @@ class Pooler(MegatronModule):
...
@@ -116,7 +116,9 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
# same pooler is run on all tensor parallel nodes
if
self
.
sequence_parallel
:
if
self
.
sequence_parallel
:
hidden_states
=
mpu
.
gather_from_sequence_parallel_region
(
hidden_states
)
hidden_states
=
mpu
.
gather_from_sequence_parallel_region
(
hidden_states
,
to_model_parallel
=
False
)
pooled
=
hidden_states
[
sequence_index
,
:,
:]
pooled
=
hidden_states
[
sequence_index
,
:,
:]
pooled
=
self
.
dense
(
pooled
)
pooled
=
self
.
dense
(
pooled
)
...
...
megatron/mpu/mappings.py
View file @
8474e6e5
...
@@ -214,19 +214,25 @@ class _ScatterToSequenceParallelRegion(torch.autograd.Function):
...
@@ -214,19 +214,25 @@ class _ScatterToSequenceParallelRegion(torch.autograd.Function):
class
_GatherFromSequenceParallelRegion
(
torch
.
autograd
.
Function
):
class
_GatherFromSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from
model
parallel region and concatinate."""
#TODO
"""Gather the input from
sequence
parallel region and concatinate."""
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
):
def
symbolic
(
graph
,
input_
,
to_model_parallel
=
True
):
return
_gather_along_first_dim
(
input_
)
return
_gather_along_first_dim
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
,
to_model_parallel
=
True
):
ctx
.
to_model_parallel
=
to_model_parallel
return
_gather_along_first_dim
(
input_
)
return
_gather_along_first_dim
(
input_
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
return
_reduce_scatter_along_first_dim
(
grad_output
)
to_model_parallel
=
ctx
.
to_model_parallel
if
to_model_parallel
:
return
_reduce_scatter_along_first_dim
(
grad_output
),
None
else
:
return
_split_along_first_dim
(
grad_output
),
None
class
_ReduceScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
class
_ReduceScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
...
@@ -269,8 +275,8 @@ def scatter_to_sequence_parallel_region(input_):
...
@@ -269,8 +275,8 @@ def scatter_to_sequence_parallel_region(input_):
return
_ScatterToSequenceParallelRegion
.
apply
(
input_
)
return
_ScatterToSequenceParallelRegion
.
apply
(
input_
)
def
gather_from_sequence_parallel_region
(
input_
):
def
gather_from_sequence_parallel_region
(
input_
,
to_model_parallel
=
True
):
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
)
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
,
to_model_parallel
)
def
reduce_scatter_to_sequence_parallel_region
(
input_
):
def
reduce_scatter_to_sequence_parallel_region
(
input_
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment