Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d6c4248b
Commit
d6c4248b
authored
Sep 02, 2020
by
mohammad
Browse files
added splitting checkpointed activations
parent
42d2be06
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
234 additions
and
1 deletion
+234
-1
megatron/arguments.py
megatron/arguments.py
+4
-0
megatron/initialize.py
megatron/initialize.py
+21
-0
megatron/memory.py
megatron/memory.py
+145
-0
megatron/model/transformer.py
megatron/model/transformer.py
+2
-0
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+2
-0
megatron/mpu/random.py
megatron/mpu/random.py
+60
-1
No files found.
megatron/arguments.py
View file @
d6c4248b
...
@@ -200,6 +200,10 @@ def _add_training_args(parser):
...
@@ -200,6 +200,10 @@ def _add_training_args(parser):
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--distribute-checkpointed-activations'
,
action
=
'store_true'
,
help
=
'If set, distribute checkpointed activations '
'across model parallel group.'
)
group
.
add_argument
(
'--checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
help
=
'chunk size (number of layers) for checkpointing.'
)
help
=
'chunk size (number of layers) for checkpointing.'
)
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
...
...
megatron/initialize.py
View file @
d6c4248b
...
@@ -72,6 +72,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -72,6 +72,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
else
:
else
:
# Megatron's MPU is the master. Complete initialization right away.
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init
()
finish_mpu_init
()
# Initialize memory buffers.
_initialize_mem_buffs
()
# Autoresume.
# Autoresume.
_init_autoresume
()
_init_autoresume
()
...
@@ -151,3 +154,21 @@ def _write_args_to_tensorboard():
...
@@ -151,3 +154,21 @@ def _write_args_to_tensorboard():
if
writer
:
if
writer
:
for
arg
in
vars
(
args
):
for
arg
in
vars
(
args
):
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)))
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)))
def
_initialize_mem_buffs
():
"""Initialize manually allocated static memory."""
args
=
get_args
()
# Initialize memory for checkpointed activations.
if
args
.
distribute_checkpointed_activations
:
per_layer
=
args
.
batch_size
*
args
.
max_position_embeddings
*
\
args
.
hidden_size
//
args
.
model_parallel_size
assert
args
.
num_layers
%
args
.
checkpoint_num_layers
==
0
,
\
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers
=
args
.
num_layers
//
args
.
checkpoint_num_layers
numel
=
per_layer
*
num_checkpointer_layers
dtype
=
torch
.
half
if
not
args
.
fp16
:
dtype
=
torch
.
float
mpu
.
init_checkpointed_activations_memory_buffer
(
numel
,
dtype
)
megatron/memory.py
0 → 100644
View file @
d6c4248b
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
# A dictionary of all the memory buffers allocated.
_MEM_BUFFS
=
dict
()
def
allocate_mem_buff
(
name
,
numel
,
dtype
,
track_usage
):
"""Allocate a memory buffer."""
assert
name
not
in
_MEM_BUFFS
,
\
'memory buffer {} already allocated.'
.
format
(
name
)
_MEM_BUFFS
[
name
]
=
MemoryBuffer
(
name
,
numel
,
dtype
,
track_usage
)
return
_MEM_BUFFS
[
name
]
def
get_mem_buff
(
name
):
"""Get the memory buffer."""
return
_MEM_BUFFS
[
name
]
class
MemoryBuffer
:
"""Contiguous memory buffer.
Allocate a contiguous memory of type `dtype` and size `numel`. It is
used to reduce memory fragmentation.
Usage: After the allocation, the `_start` index is set tot the first
index of the memory. A memory chunk starting from `_start` index
can be `allocated` for an input tensor, with the elements of the
tensor being coppied. The buffer can be reused by resetting the
`_start` index.
"""
def
__init__
(
self
,
name
,
numel
,
dtype
,
track_usage
):
if
torch
.
distributed
.
get_rank
()
==
0
:
element_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
print
(
'> building the {} memory buffer with {} num elements '
'and {} dtype ({:.1f} MB)...'
.
format
(
name
,
numel
,
dtype
,
numel
*
element_size
/
1024
/
1024
),
flush
=
True
)
self
.
name
=
name
self
.
numel
=
numel
self
.
dtype
=
dtype
self
.
data
=
torch
.
empty
(
self
.
numel
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# Index tracking the start of the free memory.
self
.
_start
=
0
# Values used for tracking usage.
self
.
track_usage
=
track_usage
if
self
.
track_usage
:
self
.
in_use_value
=
0.0
self
.
total_value
=
0.0
def
reset
(
self
):
"""Reset the buffer start index to the beginning of the buffer."""
self
.
_start
=
0
def
is_in_use
(
self
):
"""Whether the current buffer hold on to any memory."""
return
self
.
_start
>
0
def
numel_in_use
(
self
):
"""Return number of elements in use."""
return
self
.
_start
def
add
(
self
,
tensor
):
"""Allocate a chunk of memory from the buffer to tensor and copy
the values."""
assert
tensor
.
dtype
==
self
.
dtype
,
\
'Input tensor type {} different from buffer type {}'
.
format
(
tensor
.
dtype
,
self
.
dtype
)
# Number of elements of the input tensor.
tensor_numel
=
torch
.
numel
(
tensor
)
new_start
=
self
.
_start
+
tensor_numel
assert
new_start
<=
self
.
numel
,
\
'Not enough memory left in the buffer ({} > {})'
.
format
(
tensor_numel
,
self
.
numel
-
self
.
_start
)
# New tensor is a view into the memory.
new_tensor
=
self
.
data
[
self
.
_start
:
new_start
]
self
.
_start
=
new_start
new_tensor
=
new_tensor
.
view
(
tensor
.
shape
)
new_tensor
.
copy_
(
tensor
)
# Return a pointer to the new tensor.
return
new_tensor
def
get_data
(
self
):
"""Return the data currently in use."""
if
self
.
track_usage
:
self
.
in_use_value
+=
float
(
self
.
_start
)
self
.
total_value
+=
float
(
self
.
numel
)
return
self
.
data
[:
self
.
_start
]
def
print_average_usage
(
self
):
"""Print memory usage average over time. We would like this value
to be as high as possible."""
assert
self
.
track_usage
,
'You need to enable track usage.'
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' > usage of {} memory buffer: {:.2f} %'
.
format
(
self
.
name
,
self
.
in_use_value
*
100.0
/
self
.
total_value
),
flush
=
True
)
class
RingMemBuffer
:
"""A ring of memory buffers."""
def
__init__
(
self
,
name
,
num_buffers
,
numel
,
dtype
,
track_usage
):
self
.
num_buffers
=
num_buffers
self
.
buffers
=
[
allocate_mem_buff
(
name
+
' {}'
.
format
(
i
),
numel
,
dtype
,
track_usage
)
for
i
in
range
(
num_buffers
)]
self
.
_index
=
-
1
def
get_next_buffer
(
self
):
self
.
_index
+=
1
self
.
_index
=
self
.
_index
%
self
.
num_buffers
buff
=
self
.
buffers
[
self
.
_index
]
assert
not
buff
.
is_in_use
(),
'buffer is already in use.'
return
buff
megatron/model/transformer.py
View file @
d6c4248b
...
@@ -411,6 +411,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -411,6 +411,8 @@ class ParallelTransformer(MegatronModule):
return
x_
return
x_
return
custom_forward
return
custom_forward
# Make sure memory is freed.
mpu
.
reset_checkpointed_activations_memory_buffer
()
l
=
0
l
=
0
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
...
...
megatron/mpu/__init__.py
View file @
d6c4248b
...
@@ -45,7 +45,9 @@ from .mappings import scatter_to_model_parallel_region
...
@@ -45,7 +45,9 @@ from .mappings import scatter_to_model_parallel_region
from
.random
import
checkpoint
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.utils
import
divide
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
split_tensor_along_last_dim
megatron/mpu/random.py
View file @
d6c4248b
...
@@ -24,14 +24,35 @@ from torch import _C
...
@@ -24,14 +24,35 @@ from torch import _C
from
torch.cuda
import
_lazy_call
,
device
as
device_ctx_manager
from
torch.cuda
import
_lazy_call
,
device
as
device_ctx_manager
from
torch.utils.checkpoint
import
detach_variable
from
torch.utils.checkpoint
import
detach_variable
from
megatron.memory
import
allocate_mem_buff
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_model_parallel_rank
from
.initialize
import
get_model_parallel_rank
from
.initialize
import
get_model_parallel_world_size
# Default name for the model parallel rng tracker.
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
def
init_checkpointed_activations_memory_buffer
(
numel
,
dtype
):
"""Initializ the memory buffer for the checkpointed activations."""
global
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
allocate_mem_buff
(
'checkpointed activations'
,
numel
,
dtype
,
track_usage
=
False
)
def
reset_checkpointed_activations_memory_buffer
():
"""Reset the memory used for checkpointing."""
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
reset
()
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
"""Sets the random number generator state of the current GPU.
"""Sets the random number generator state of the current GPU.
...
@@ -65,6 +86,29 @@ def _set_cuda_rng_state(new_state, device=-1):
...
@@ -65,6 +86,29 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call
(
cb
)
_lazy_call
(
cb
)
def
split_tensor_into_1d_equal_chunks
(
tensor
):
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
partition_size
=
torch
.
numel
(
data
)
//
get_model_parallel_world_size
()
start_index
=
partition_size
*
get_model_parallel_rank
()
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
world_size
=
get_model_parallel_world_size
()
numel
=
torch
.
numel
(
tensor
)
numel_gathered
=
world_size
*
numel
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
chunks
=
[
gathered
[
i
*
numel
:(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
chunks
,
tensor
,
group
=
get_model_parallel_group
())
return
gathered
class
CudaRNGStatesTracker
:
class
CudaRNGStatesTracker
:
"""Tracker for the cuda RNG states.
"""Tracker for the cuda RNG states.
...
@@ -199,9 +243,21 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -199,9 +243,21 @@ class CheckpointFunction(torch.autograd.Function):
ctx
.
fwd_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
ctx
.
fwd_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
ctx
.
fwd_cuda_rng_state_tracker
=
get_cuda_rng_tracker
().
get_states
()
ctx
.
fwd_cuda_rng_state_tracker
=
get_cuda_rng_tracker
().
get_states
()
ctx
.
save_for_backward
(
*
args
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
outputs
=
run_function
(
*
args
)
outputs
=
run_function
(
*
args
)
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
)
args
[
0
].
data
=
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
add
(
args
[
0
].
data
)
# Store everything.
ctx
.
save_for_backward
(
*
args
)
return
outputs
return
outputs
@
staticmethod
@
staticmethod
...
@@ -210,6 +266,9 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -210,6 +266,9 @@ class CheckpointFunction(torch.autograd.Function):
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
inputs
=
ctx
.
saved_tensors
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
# Store the current states.
# Store the current states.
bwd_cpu_rng_state
=
torch
.
get_rng_state
()
bwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment