Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e923ec52
Commit
e923ec52
authored
Aug 19, 2021
by
mshoeybi
Browse files
removed contiguous buffer for checkpointed activation
parent
6a0ef5b1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
59 deletions
+23
-59
megatron/initialize.py
megatron/initialize.py
+0
-11
megatron/model/transformer.py
megatron/model/transformer.py
+3
-2
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+0
-2
megatron/mpu/random.py
megatron/mpu/random.py
+20
-44
No files found.
megatron/initialize.py
View file @
e923ec52
...
...
@@ -77,9 +77,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init
()
# Initialize memory buffers.
_initialize_mem_buffs
()
# Autoresume.
_init_autoresume
()
...
...
@@ -224,11 +221,3 @@ def write_args_to_tensorboard():
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)),
global_step
=
args
.
iteration
)
def
_initialize_mem_buffs
():
"""Initialize manually allocated static memory."""
args
=
get_args
()
# Initialize memory for checkpointed activations.
if
args
.
distribute_checkpointed_activations
:
mpu
.
init_checkpointed_activations_memory_buffer
()
megatron/model/transformer.py
View file @
e923ec52
...
...
@@ -544,6 +544,8 @@ class ParallelTransformer(MegatronModule):
# Store activation checkpoiting flag.
self
.
checkpoint_activations
=
args
.
checkpoint_activations
self
.
checkpoint_num_layers
=
args
.
checkpoint_num_layers
self
.
distribute_checkpointed_activations
\
=
args
.
distribute_checkpointed_activations
# Number of layers.
assert
args
.
num_layers
%
mpu
.
get_pipeline_model_parallel_world_size
()
==
0
,
\
...
...
@@ -607,12 +609,11 @@ class ParallelTransformer(MegatronModule):
return
x_
return
custom_forward
# Make sure memory is freed.
mpu
.
reset_checkpointed_activations_memory_buffer
()
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
checkpoint_num_layers
...
...
megatron/mpu/__init__.py
View file @
e923ec52
...
...
@@ -56,9 +56,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
...
...
megatron/mpu/random.py
View file @
e923ec52
...
...
@@ -37,37 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
def
init_checkpointed_activations_memory_buffer
():
"""Initializ the memory buffer for the checkpointed activations."""
args
=
get_args
()
per_layer
=
args
.
micro_batch_size
*
args
.
max_position_embeddings
*
\
args
.
hidden_size
//
args
.
tensor_model_parallel_size
assert
args
.
num_layers
%
args
.
checkpoint_num_layers
==
0
,
\
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers
=
args
.
num_layers
//
args
.
checkpoint_num_layers
numel
=
per_layer
*
num_checkpointer_layers
dtype
=
torch
.
half
if
not
args
.
fp16
:
dtype
=
torch
.
float
global
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
None
,
\
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
allocate_mem_buff
(
'checkpointed activations'
,
numel
,
dtype
,
track_usage
=
False
)
def
reset_checkpointed_activations_memory_buffer
():
"""Reset the memory used for checkpointing."""
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
reset
()
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
"""Sets the random number generator state of the current GPU.
...
...
@@ -101,14 +70,21 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call
(
cb
)
def
split_tensor_into_1d_equal_chunks
(
tensor
):
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
partition_size
=
torch
.
numel
(
data
)
//
get_tensor_model_parallel_world_size
()
partition_size
=
torch
.
numel
(
tensor
)
//
\
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
return
data
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
...
...
@@ -250,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
"""
@
staticmethod
def
forward
(
ctx
,
run_function
,
*
args
):
def
forward
(
ctx
,
run_function
,
distribute_checkpointed_activations
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
distribute_checkpointed_activations
\
=
distribute_checkpointed_activations
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
...
@@ -263,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
N
on
e
:
if
distribute_checkpointed_activati
on
s
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
)
args
[
0
].
data
=
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
add
(
args
[
0
].
data
)
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
,
new_buffer
=
True
)
# Store everything.
ctx
.
save_for_backward
(
*
args
)
return
outputs
@
staticmethod
...
...
@@ -281,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function):
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
N
on
e
:
if
ctx
.
distribute_checkpointed_activati
on
s
:
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
...
...
@@ -310,7 +286,7 @@ class CheckpointFunction(torch.autograd.Function):
torch
.
autograd
.
backward
(
outputs
,
args
)
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
for
inp
in
detached_inputs
)
return
(
None
,)
+
grads
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
*
args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment