Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5d29769c
Commit
5d29769c
authored
Sep 02, 2020
by
mohammad
Browse files
addressed Jareds comments
parent
d6c4248b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
11 deletions
+22
-11
megatron/arguments.py
megatron/arguments.py
+5
-0
megatron/initialize.py
megatron/initialize.py
+1
-10
megatron/mpu/random.py
megatron/mpu/random.py
+16
-1
No files found.
megatron/arguments.py
View file @
5d29769c
...
@@ -112,6 +112,11 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -112,6 +112,11 @@ def parse_args(extra_args_provider=None, defaults={},
# Mixed precision checks.
# Mixed precision checks.
if
args
.
fp16_lm_cross_entropy
:
if
args
.
fp16_lm_cross_entropy
:
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
assert
args
.
checkpoint_activations
,
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
...
megatron/initialize.py
View file @
5d29769c
...
@@ -162,13 +162,4 @@ def _initialize_mem_buffs():
...
@@ -162,13 +162,4 @@ def _initialize_mem_buffs():
# Initialize memory for checkpointed activations.
# Initialize memory for checkpointed activations.
if
args
.
distribute_checkpointed_activations
:
if
args
.
distribute_checkpointed_activations
:
per_layer
=
args
.
batch_size
*
args
.
max_position_embeddings
*
\
mpu
.
init_checkpointed_activations_memory_buffer
()
args
.
hidden_size
//
args
.
model_parallel_size
assert
args
.
num_layers
%
args
.
checkpoint_num_layers
==
0
,
\
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers
=
args
.
num_layers
//
args
.
checkpoint_num_layers
numel
=
per_layer
*
num_checkpointer_layers
dtype
=
torch
.
half
if
not
args
.
fp16
:
dtype
=
torch
.
float
mpu
.
init_checkpointed_activations_memory_buffer
(
numel
,
dtype
)
megatron/mpu/random.py
View file @
5d29769c
...
@@ -24,6 +24,7 @@ from torch import _C
...
@@ -24,6 +24,7 @@ from torch import _C
from
torch.cuda
import
_lazy_call
,
device
as
device_ctx_manager
from
torch.cuda
import
_lazy_call
,
device
as
device_ctx_manager
from
torch.utils.checkpoint
import
detach_variable
from
torch.utils.checkpoint
import
detach_variable
from
megatron
import
get_args
from
megatron.memory
import
allocate_mem_buff
from
megatron.memory
import
allocate_mem_buff
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_data_parallel_rank
...
@@ -40,9 +41,23 @@ _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
...
@@ -40,9 +41,23 @@ _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
def
init_checkpointed_activations_memory_buffer
(
numel
,
dtype
):
def
init_checkpointed_activations_memory_buffer
():
"""Initializ the memory buffer for the checkpointed activations."""
"""Initializ the memory buffer for the checkpointed activations."""
args
=
get_args
()
per_layer
=
args
.
batch_size
*
args
.
max_position_embeddings
*
\
args
.
hidden_size
//
args
.
model_parallel_size
assert
args
.
num_layers
%
args
.
checkpoint_num_layers
==
0
,
\
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers
=
args
.
num_layers
//
args
.
checkpoint_num_layers
numel
=
per_layer
*
num_checkpointer_layers
dtype
=
torch
.
half
if
not
args
.
fp16
:
dtype
=
torch
.
float
global
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
global
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
None
,
\
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
allocate_mem_buff
(
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
allocate_mem_buff
(
'checkpointed activations'
,
numel
,
dtype
,
track_usage
=
False
)
'checkpointed activations'
,
numel
,
dtype
,
track_usage
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment