Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
dbe6c728
Commit
dbe6c728
authored
Nov 23, 2021
by
Szymon Migacz
Committed by
Jared Casper
Nov 23, 2021
Browse files
Add support for signal-based dynamic checkpointing
parent
d6380fda
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
115 additions
and
1 deletion
+115
-1
megatron/__init__.py
megatron/__init__.py
+1
-0
megatron/arguments.py
megatron/arguments.py
+6
-0
megatron/dist_signal_handler.py
megatron/dist_signal_handler.py
+81
-0
megatron/global_vars.py
megatron/global_vars.py
+13
-0
megatron/initialize.py
megatron/initialize.py
+1
-1
megatron/training.py
megatron/training.py
+13
-0
No files found.
megatron/__init__.py
View file @
dbe6c728
...
@@ -17,6 +17,7 @@ import torch
...
@@ -17,6 +17,7 @@ import torch
from
.global_vars
import
get_args
from
.global_vars
import
get_args
from
.global_vars
import
get_current_global_batch_size
from
.global_vars
import
get_current_global_batch_size
from
.global_vars
import
get_num_microbatches
from
.global_vars
import
get_num_microbatches
from
.global_vars
import
get_signal_handler
from
.global_vars
import
update_num_microbatches
from
.global_vars
import
update_num_microbatches
from
.global_vars
import
get_tokenizer
from
.global_vars
import
get_tokenizer
from
.global_vars
import
get_tensorboard_writer
from
.global_vars
import
get_tensorboard_writer
...
...
megatron/arguments.py
View file @
dbe6c728
...
@@ -377,6 +377,9 @@ def _add_logging_args(parser):
...
@@ -377,6 +377,9 @@ def _add_logging_args(parser):
group
.
add_argument
(
'--log-memory-to-tensorboard'
,
group
.
add_argument
(
'--log-memory-to-tensorboard'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Enable memory logging to tensorboard.'
)
help
=
'Enable memory logging to tensorboard.'
)
group
.
add_argument
(
'--log-world-size-to-tensorboard'
,
action
=
'store_true'
,
help
=
'Enable world size logging to tensorboard.'
)
return
parser
return
parser
...
@@ -472,6 +475,9 @@ def _add_training_args(parser):
...
@@ -472,6 +475,9 @@ def _add_training_args(parser):
'by this value.'
)
'by this value.'
)
group
.
add_argument
(
'--exit-duration-in-mins'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--exit-duration-in-mins'
,
type
=
int
,
default
=
None
,
help
=
'Exit the program after this many minutes.'
)
help
=
'Exit the program after this many minutes.'
)
group
.
add_argument
(
'--exit-signal-handler'
,
action
=
'store_true'
,
help
=
'Dynamically save the checkpoint and shutdown the '
'training if SIGTERM is received'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--no-masked-softmax-fusion'
,
group
.
add_argument
(
'--no-masked-softmax-fusion'
,
...
...
megatron/dist_signal_handler.py
0 → 100644
View file @
dbe6c728
import
signal
import
torch
def
get_world_size
():
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
world_size
=
torch
.
distributed
.
get_world_size
()
else
:
world_size
=
1
return
world_size
def
get_device
(
local_rank
=
None
):
backend
=
torch
.
distributed
.
get_backend
()
if
backend
==
'nccl'
:
if
local_rank
is
None
:
device
=
torch
.
device
(
'cuda'
)
else
:
device
=
torch
.
device
(
f
'cuda:
{
local_rank
}
'
)
elif
backend
==
'gloo'
:
device
=
torch
.
device
(
'cpu'
)
else
:
raise
RuntimeError
return
device
def
all_gather_item
(
item
,
dtype
,
group
=
None
,
async_op
=
False
,
local_rank
=
None
):
if
not
torch
.
distributed
.
is_available
()
or
\
not
torch
.
distributed
.
is_initialized
():
return
[
item
]
device
=
get_device
(
local_rank
)
if
group
is
not
None
:
group_size
=
group
.
size
()
else
:
group_size
=
get_world_size
()
tensor
=
torch
.
tensor
([
item
],
device
=
device
,
dtype
=
dtype
)
output_tensors
=
[
torch
.
zeros
(
1
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
group_size
)
]
torch
.
distributed
.
all_gather
(
output_tensors
,
tensor
,
group
,
async_op
)
output
=
[
elem
.
item
()
for
elem
in
output_tensors
]
return
output
class
DistributedSignalHandler
:
def
__init__
(
self
,
sig
=
signal
.
SIGTERM
):
self
.
sig
=
sig
def
signals_received
(
self
):
all_received
=
all_gather_item
(
self
.
_signal_received
,
dtype
=
torch
.
int32
)
return
all_received
def
__enter__
(
self
):
self
.
_signal_received
=
False
self
.
released
=
False
self
.
original_handler
=
signal
.
getsignal
(
self
.
sig
)
def
handler
(
signum
,
frame
):
self
.
_signal_received
=
True
signal
.
signal
(
self
.
sig
,
handler
)
return
self
def
__exit__
(
self
,
type
,
value
,
tb
):
self
.
release
()
def
release
(
self
):
if
self
.
released
:
return
False
signal
.
signal
(
self
.
sig
,
self
.
original_handler
)
self
.
released
=
True
return
True
megatron/global_vars.py
View file @
dbe6c728
...
@@ -21,6 +21,7 @@ import time
...
@@ -21,6 +21,7 @@ import time
import
torch
import
torch
from
megatron
import
dist_signal_handler
from
megatron.tokenizer
import
build_tokenizer
from
megatron.tokenizer
import
build_tokenizer
from
.arguments
import
parse_args
from
.arguments
import
parse_args
from
.microbatches
import
build_num_microbatches_calculator
from
.microbatches
import
build_num_microbatches_calculator
...
@@ -31,6 +32,7 @@ _GLOBAL_TOKENIZER = None
...
@@ -31,6 +32,7 @@ _GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER
=
None
_GLOBAL_TENSORBOARD_WRITER
=
None
_GLOBAL_ADLR_AUTORESUME
=
None
_GLOBAL_ADLR_AUTORESUME
=
None
_GLOBAL_TIMERS
=
None
_GLOBAL_TIMERS
=
None
_GLOBAL_SIGNAL_HANDLER
=
None
def
get_args
():
def
get_args
():
...
@@ -75,6 +77,14 @@ def get_timers():
...
@@ -75,6 +77,14 @@ def get_timers():
_ensure_var_is_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
_ensure_var_is_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
return
_GLOBAL_TIMERS
return
_GLOBAL_TIMERS
def
get_signal_handler
():
_ensure_var_is_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
return
_GLOBAL_SIGNAL_HANDLER
def
_set_signal_handler
():
global
_GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
_GLOBAL_SIGNAL_HANDLER
=
dist_signal_handler
.
DistributedSignalHandler
().
__enter__
()
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{},
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
):
ignore_unknown_args
=
False
):
...
@@ -89,6 +99,9 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
...
@@ -89,6 +99,9 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
_set_adlr_autoresume
(
args
)
_set_adlr_autoresume
(
args
)
_set_timers
()
_set_timers
()
if
args
.
exit_signal_handler
:
_set_signal_handler
()
def
_parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
def
_parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
ignore_unknown_args
=
False
):
...
...
megatron/initialize.py
View file @
dbe6c728
...
@@ -180,7 +180,7 @@ def _initialize_distributed():
...
@@ -180,7 +180,7 @@ def _initialize_distributed():
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
timeout
=
timedelta
(
days
=
7
))
timeout
=
timedelta
(
minutes
=
10
))
# Set the tensor model-parallel, pipeline model-parallel, and
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
# data-parallel communicators.
...
...
megatron/training.py
View file @
dbe6c728
...
@@ -26,6 +26,7 @@ import torch
...
@@ -26,6 +26,7 @@ import torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_signal_handler
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_current_global_batch_size
from
megatron
import
get_current_global_batch_size
...
@@ -544,6 +545,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -544,6 +545,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer
.
add_scalar
(
'loss-scale'
,
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss-scale'
,
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss-scale vs samples'
,
loss_scale
,
writer
.
add_scalar
(
'loss-scale vs samples'
,
loss_scale
,
args
.
consumed_train_samples
)
args
.
consumed_train_samples
)
if
args
.
log_world_size_to_tensorboard
:
writer
.
add_scalar
(
'world-size'
,
args
.
world_size
,
iteration
)
writer
.
add_scalar
(
'world-size vs samples'
,
args
.
world_size
,
args
.
consumed_train_samples
)
if
grad_norm
is
not
None
:
if
grad_norm
is
not
None
:
writer
.
add_scalar
(
'grad-norm'
,
grad_norm
,
iteration
)
writer
.
add_scalar
(
'grad-norm'
,
grad_norm
,
iteration
)
writer
.
add_scalar
(
'grad-norm vs samples'
,
grad_norm
,
writer
.
add_scalar
(
'grad-norm vs samples'
,
grad_norm
,
...
@@ -698,6 +703,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -698,6 +703,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Checkpointing
# Checkpointing
saved_checkpoint
=
False
saved_checkpoint
=
False
if
args
.
exit_signal_handler
:
signal_handler
=
get_signal_handler
()
if
any
(
signal_handler
.
signals_received
()):
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
print_datetime
(
'exiting program after receiving SIGTERM.'
)
sys
.
exit
()
if
args
.
save
and
args
.
save_interval
and
\
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
iteration
%
args
.
save_interval
==
0
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment