Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9dc3c42a
Commit
9dc3c42a
authored
May 24, 2022
by
Vijay Korthikanti
Browse files
preallocating global buffer to avoid memory fragmentation
parent
8474e6e5
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
56 additions
and
25 deletions
+56
-25
megatron/__init__.py
megatron/__init__.py
+1
-0
megatron/global_vars.py
megatron/global_vars.py
+36
-2
megatron/model/language_model.py
megatron/model/language_model.py
+1
-1
megatron/model/transformer.py
megatron/model/transformer.py
+4
-7
megatron/mpu/layers.py
megatron/mpu/layers.py
+3
-8
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+11
-7
No files found.
megatron/__init__.py
View file @
9dc3c42a
...
...
@@ -23,6 +23,7 @@ from .global_vars import get_tokenizer
from
.global_vars
import
get_tensorboard_writer
from
.global_vars
import
get_adlr_autoresume
from
.global_vars
import
get_timers
from
.global_vars
import
get_global_memory_buffer
from
.initialize
import
initialize_megatron
def
print_rank_0
(
message
):
...
...
megatron/global_vars.py
View file @
9dc3c42a
...
...
@@ -18,7 +18,8 @@
import
os
import
sys
import
time
from
functools
import
reduce
import
operator
import
torch
from
megatron
import
dist_signal_handler
...
...
@@ -33,7 +34,7 @@ _GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME
=
None
_GLOBAL_TIMERS
=
None
_GLOBAL_SIGNAL_HANDLER
=
None
_GLOBAL_MEMORY_BUFFER
=
None
def
get_args
():
"""Return arguments."""
...
...
@@ -77,15 +78,23 @@ def get_timers():
_ensure_var_is_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
return
_GLOBAL_TIMERS
def
get_signal_handler
():
_ensure_var_is_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
return
_GLOBAL_SIGNAL_HANDLER
def
get_global_memory_buffer
():
_ensure_var_is_initialized
(
_GLOBAL_MEMORY_BUFFER
,
'global memory buffer'
)
return
_GLOBAL_MEMORY_BUFFER
def
_set_signal_handler
():
global
_GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
_GLOBAL_SIGNAL_HANDLER
=
dist_signal_handler
.
DistributedSignalHandler
().
__enter__
()
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
...
...
@@ -98,6 +107,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
_set_tensorboard_writer
(
args
)
_set_adlr_autoresume
(
args
)
_set_timers
()
_set_global_memory_buffer
()
if
args
.
exit_signal_handler
:
_set_signal_handler
()
...
...
@@ -182,6 +192,12 @@ def _set_timers():
_ensure_var_is_not_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
_GLOBAL_TIMERS
=
Timers
()
def
_set_global_memory_buffer
():
"""Initialize global buffer"""
global
_GLOBAL_MEMORY_BUFFER
_ensure_var_is_not_initialized
(
_GLOBAL_MEMORY_BUFFER
,
'global memory buffer'
)
_GLOBAL_MEMORY_BUFFER
=
GlobalMemoryBuffer
()
def
_ensure_var_is_initialized
(
var
,
name
):
"""Make sure the input variable is not None."""
...
...
@@ -273,3 +289,21 @@ class Timers:
print
(
string
,
flush
=
True
)
else
:
print
(
string
,
flush
=
True
)
class
GlobalMemoryBuffer
:
"Global buffer to avoid dynamic memory allocations"
def
__init__
(
self
):
self
.
buffer
=
{}
def
allocate_tensor
(
self
,
tensor_shape
,
dtype
):
required_len
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
if
self
.
buffer
.
get
(
dtype
,
None
)
is
None
or
self
.
buffer
[
dtype
].
numel
()
<
required_len
:
self
.
buffer
[
dtype
]
=
torch
.
empty
(
required_len
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
return
self
.
buffer
[
dtype
][
0
:
required_len
].
view
(
*
tensor_shape
)
megatron/model/language_model.py
View file @
9dc3c42a
...
...
@@ -118,7 +118,7 @@ class Pooler(MegatronModule):
if
self
.
sequence_parallel
:
hidden_states
=
mpu
.
gather_from_sequence_parallel_region
(
hidden_states
,
t
o_model
_parallel
=
False
)
t
ensor
_parallel
_output_grad
=
False
)
pooled
=
hidden_states
[
sequence_index
,
:,
:]
pooled
=
self
.
dense
(
pooled
)
...
...
megatron/model/transformer.py
View file @
9dc3c42a
...
...
@@ -19,7 +19,7 @@ from contextlib import nullcontext
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_timers
,
get_args
from
megatron
import
get_timers
,
get_args
,
get_global_memory_buffer
from
megatron
import
mpu
from
.module
import
MegatronModule
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
...
...
@@ -234,12 +234,9 @@ class CoreAttention(MegatronModule):
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
matmul_input_buffer
=
get_global_memory_buffer
().
allocate_tensor
(
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
]),
dtype
=
query_layer
.
dtype
)
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
...
...
megatron/mpu/layers.py
View file @
9dc3c42a
...
...
@@ -39,7 +39,7 @@ from .random import get_cuda_rng_tracker
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
VocabUtility
from
megatron
import
get_args
from
megatron
import
get_args
,
get_global_memory_buffer
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
'tensor_model_parallel'
:
False
,
'partition_dim'
:
-
1
,
...
...
@@ -221,9 +221,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
get_global_memory_buffer
().
allocate_tensor
(
dim_size
,
dtype
=
input
.
dtype
)
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
...
...
@@ -248,10 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
get_global_memory_buffer
().
allocate_tensor
(
dim_size
,
dtype
=
input
.
dtype
)
handle
=
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
...
...
megatron/mpu/mappings.py
View file @
9dc3c42a
...
...
@@ -217,19 +217,23 @@ class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatinate."""
@
staticmethod
def
symbolic
(
graph
,
input_
,
t
o_model
_parallel
=
True
):
def
symbolic
(
graph
,
input_
,
t
ensor
_parallel
_output_grad
=
True
):
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
t
o_model
_parallel
=
True
):
ctx
.
t
o_model
_parallel
=
to_model_parallel
def
forward
(
ctx
,
input_
,
t
ensor
_parallel
_output_grad
=
True
):
ctx
.
t
ensor
_parallel
_output_grad
=
tensor_parallel_output_grad
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
t
o_model
_parallel
=
ctx
.
t
o_model
_parallel
t
ensor
_parallel
_output_grad
=
ctx
.
t
ensor
_parallel
_output_grad
if
to_model_parallel
:
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if
tensor_parallel_output_grad
:
return
_reduce_scatter_along_first_dim
(
grad_output
),
None
else
:
return
_split_along_first_dim
(
grad_output
),
None
...
...
@@ -275,8 +279,8 @@ def scatter_to_sequence_parallel_region(input_):
return
_ScatterToSequenceParallelRegion
.
apply
(
input_
)
def
gather_from_sequence_parallel_region
(
input_
,
t
o_model
_parallel
=
True
):
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
,
t
o_model
_parallel
)
def
gather_from_sequence_parallel_region
(
input_
,
t
ensor
_parallel
_output_grad
=
True
):
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
,
t
ensor
_parallel
_output_grad
)
def
reduce_scatter_to_sequence_parallel_region
(
input_
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment