Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5d4689c4
Commit
5d4689c4
authored
Feb 18, 2022
by
Vijay Korthikanti
Browse files
tensor model parallelism memory optmization
parent
0ed2f6ac
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
209 additions
and
28 deletions
+209
-28
megatron/arguments.py
megatron/arguments.py
+4
-0
megatron/model/language_model.py
megatron/model/language_model.py
+1
-1
megatron/model/transformer.py
megatron/model/transformer.py
+13
-2
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+9
-5
megatron/mpu/layers.py
megatron/mpu/layers.py
+19
-6
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+163
-14
No files found.
megatron/arguments.py
View file @
5d4689c4
...
...
@@ -530,6 +530,10 @@ def _add_training_args(parser):
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.'
)
group
.
add_argument
(
'--model-parallel-memory-opt'
,
action
=
'store_true'
,
help
=
'Enable model parallel memory optmization.'
)
return
parser
...
...
megatron/model/language_model.py
View file @
5d4689c4
...
...
@@ -40,7 +40,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if
parallel_output
:
return
logits_parallel
return
mpu
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
return
mpu
.
gather_
along_last_dim_
from_tensor_model_parallel_region
(
logits_parallel
)
def
get_language_model
(
num_tokentypes
,
add_pooler
,
...
...
megatron/model/transformer.py
View file @
5d4689c4
...
...
@@ -628,6 +628,8 @@ class ParallelTransformer(MegatronModule):
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
distribute_checkpointed_activations
=
args
.
distribute_checkpointed_activations
self
.
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
# Number of layers.
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
...
...
@@ -771,6 +773,10 @@ class ParallelTransformer(MegatronModule):
# Otherwise, leave it as is.
else
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
model_parallel_memory_opt
:
hidden_states
=
mpu
.
scatter_along_first_dim_to_tensor_model_parallel_region
(
hidden_states
)
else
:
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
...
...
@@ -820,9 +826,14 @@ class ParallelTransformer(MegatronModule):
# Final layer norm.
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
if
self
.
model_parallel_memory_opt
:
hidden_states
=
mpu
.
gather_along_first_dim_from_tensor_model_parallel_region
(
hidden_states
)
output
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
else
:
output
=
hidden_states
return
output
megatron/mpu/__init__.py
View file @
5d4689c4
...
...
@@ -55,11 +55,15 @@ from .layers import VocabParallelEmbedding
from
.layers
import
(
set_tensor_model_parallel_attributes
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
)
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
scatter_along_last_dim_to_tensor_model_parallel_region
from
.mappings
import
gather_along_last_dim_from_tensor_model_parallel_region
from
.mappings
import
scatter_along_first_dim_to_tensor_model_parallel_region
from
.mappings
import
gather_along_first_dim_from_tensor_model_parallel_region
from
.mappings
import
reduce_scatter_along_first_dim_to_tensor_model_parallel_region
from
.mappings
import
reduce_scatter_along_last_dim_to_tensor_model_parallel_region
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
...
...
megatron/mpu/layers.py
View file @
5d4689c4
...
...
@@ -29,9 +29,12 @@ from .initialize import get_tensor_model_parallel_rank
from
.initialize
import
get_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_group
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
gather_along_first_dim_from_tensor_model_parallel_region
from
.mappings
import
gather_along_last_dim_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
scatter_along_last_dim_to_tensor_model_parallel_region
from
.mappings
import
reduce_scatter_along_first_dim_to_tensor_model_parallel_region
from
.random
import
get_cuda_rng_tracker
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
...
...
@@ -307,6 +310,7 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
async_tensor_model_parallel_allreduce
=
(
not
args
.
no_async_tensor_model_parallel_allreduce
and
world_size
>
1
)
self
.
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
...
...
@@ -323,14 +327,18 @@ class ColumnParallelLinear(torch.nn.Module):
input_shape
[
0
],
input_shape
[
1
],
output_parallel
.
shape
[
1
])
else
:
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
if
self
.
model_parallel_memory_opt
:
input_parallel
=
gather_along_first_dim_from_tensor_model_parallel_region
(
input_
)
else
:
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
assert
not
self
.
model_parallel_memory_opt
output
=
gather_along_last_dim_from_tensor_model_parallel_region
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
...
...
@@ -416,6 +424,7 @@ class RowParallelLinear(torch.nn.Module):
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
def
forward
(
self
,
input_
):
...
...
@@ -423,11 +432,15 @@ class RowParallelLinear(torch.nn.Module):
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
assert
not
self
.
model_parallel_memory_opt
input_parallel
=
scatter_along_last_dim_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
# All-reduce across all the partitions.
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
self
.
model_parallel_memory_opt
:
output_
=
reduce_scatter_along_first_dim_to_tensor_model_parallel_region
(
output_parallel
)
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
...
...
megatron/mpu/mappings.py
View file @
5d4689c4
...
...
@@ -32,7 +32,8 @@ def _reduce(input_):
return
input_
def
_split
(
input_
):
def
_split_along_last_dim
(
input_
):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
...
...
@@ -50,8 +51,28 @@ def _split(input_):
return
output
def
_split_along_first_dim
(
input_
):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
# Split along first dimension.
dim_size
=
input_
.
size
()[
0
]
assert
dim_size
%
world_size
==
0
local_dim_size
=
dim_size
//
world_size
rank
=
get_tensor_model_parallel_rank
()
dim_offset
=
rank
*
(
local_dim_size
)
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
]
return
output
def
_gather
(
input_
):
def
_gather_along_last_dim
(
input_
):
"""Gather tensors and concatinate along the last dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -73,6 +94,54 @@ def _gather(input_):
return
output
def
_gather_along_first_dim
(
input_
):
"""Gather tensors and concatinate along the first dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
dim_size
=
list
(
input_
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
torch
.
distributed
.
_all_gather_base
(
output
,
input_
,
group
=
get_tensor_model_parallel_group
())
return
output
def
_reduce_scatter_along_first_dim
(
input_
):
"""Reduce-scatter the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
return
input_
dim_size
=
list
(
input_
.
size
())
assert
dim_size
[
0
]
%
world_size
==
0
dim_size
[
0
]
=
dim_size
[
0
]
//
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# reduce_scatter
torch
.
distributed
.
_reduce_scatter_base
(
output
,
input_
,
group
=
get_tensor_model_parallel_group
())
return
output
def
_reduce_scatter_along_last_dim
(
input_
):
output
=
_reduce
(
input_
)
output
=
_split_along_last_dim
(
output
)
return
output
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Pass the input to the model parallel region."""
...
...
@@ -105,36 +174,100 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
return
grad_output
class
_ScatterToModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_Scatter
AlongLastDim
ToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split
(
input_
)
return
_split
_along_last_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_split
(
input_
)
return
_split
_along_last_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather
(
grad_output
)
return
_gather
_along_last_dim
(
grad_output
)
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_Gather
AlongLastDim
FromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatinate."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_gather
(
input_
)
return
_gather
_along_last_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_gather
(
input_
)
return
_gather
_along_last_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
)
return
_reduce_scatter_along_last_dim
(
grad_output
)
class
_ScatterAlongFirstDimToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
class
_GatherAlongFirstDimFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatinate."""
#TODO
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_reduce_scatter_along_first_dim
(
grad_output
)
class
_ReduceScatterAlongLastDimToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce_scatter_along_last_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_reduce_scatter_along_last_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_last_dim
(
grad_output
)
class
_ReduceScatterAlongFirstDimToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
# -----------------
...
...
@@ -149,9 +282,25 @@ def reduce_from_tensor_model_parallel_region(input_):
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_tensor_model_parallel_region
(
input_
):
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
def
scatter_along_last_dim_to_tensor_model_parallel_region
(
input_
):
return
_ScatterAlongLastDimToModelParallelRegion
.
apply
(
input_
)
def
gather_along_last_dim_from_tensor_model_parallel_region
(
input_
):
return
_GatherAlongLastDimFromModelParallelRegion
.
apply
(
input_
)
def
scatter_along_first_dim_to_tensor_model_parallel_region
(
input_
):
return
_ScatterAlongFirstDimToModelParallelRegion
.
apply
(
input_
)
def
gather_along_first_dim_from_tensor_model_parallel_region
(
input_
):
return
_GatherAlongFirstDimFromModelParallelRegion
.
apply
(
input_
)
def
reduce_scatter_along_first_dim_to_tensor_model_parallel_region
(
input_
):
return
_ReduceScatterAlongFirstDimToModelParallelRegion
.
apply
(
input_
)
def
g
at
h
er_
from
_tensor_model_parallel_region
(
input_
):
return
_
G
at
h
er
From
ModelParallelRegion
.
apply
(
input_
)
def
reduce_sc
at
t
er_
along_last_dim_to
_tensor_model_parallel_region
(
input_
):
return
_
ReduceSc
at
t
er
AlongLastDimTo
ModelParallelRegion
.
apply
(
input_
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment