Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5050203f
Commit
5050203f
authored
Mar 28, 2020
by
Mohammad
Browse files
working on utils
parent
94e2ca57
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
141 additions
and
221 deletions
+141
-221
megatron/arguments.py
megatron/arguments.py
+5
-1
megatron/global_vars.py
megatron/global_vars.py
+83
-1
megatron/training.py
megatron/training.py
+13
-36
megatron/utils.py
megatron/utils.py
+38
-182
pretrain_bert.py
pretrain_bert.py
+2
-1
No files found.
megatron/arguments.py
View file @
5050203f
...
...
@@ -45,6 +45,11 @@ def parse_args(extra_args_provider=None, defaults={}):
# Set input defaults.
for
key
in
defaults
:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
assert
getattr
(
args
,
key
)
is
None
,
\
'defaults can only be overwritten for args with None values.'
setattr
(
args
,
key
,
defaults
[
key
])
# Distributed args.
...
...
@@ -60,7 +65,6 @@ def parse_args(extra_args_provider=None, defaults={}):
if
args
.
loss_scale
is
None
:
args
.
dynamic_loss_scale
=
True
# Checks.
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
assert
args
.
max_position_embeddings
>=
args
.
seq_length
...
...
megatron/global_vars.py
View file @
5050203f
...
...
@@ -17,10 +17,12 @@
import
os
import
sys
import
time
import
torch
from
megatron.data.tokenizer
import
build_tokenizer
from
.arguments
import
parse_args
from
.utils
import
Timers
_GLOBAL_ARGS
=
None
_GLOBAL_TOKENIZER
=
None
...
...
@@ -137,3 +139,83 @@ def _ensure_var_is_initialized(var, name):
def
_ensure_var_is_not_initialized
(
var
,
name
):
"""Make sure the input variable is not None."""
assert
var
is
None
,
'{} is already initialized.'
.
format
(
name
)
class
Timers
:
"""Group of timers."""
class
Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
started_
=
self
.
started_
# If the timing in progress, end it first.
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
started_
:
self
.
start
()
return
elapsed_
def
__init__
(
self
):
self
.
timers
=
{}
def
__call__
(
self
,
name
):
if
name
not
in
self
.
timers
:
self
.
timers
[
name
]
=
self
.
Timer
(
name
)
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
'_time'
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
assert
normalizer
>
0.0
string
=
'time (ms)'
for
name
in
names
:
elapsed_time
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
*
1000.0
/
normalizer
string
+=
' | {}: {:.2f}'
.
format
(
name
,
elapsed_time
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
string
,
flush
=
True
)
else
:
print
(
string
,
flush
=
True
)
megatron/training.py
View file @
5050203f
...
...
@@ -22,7 +22,12 @@ import torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
apex.optimizers
import
FusedAdam
as
Adam
from
megatron.arguments
import
parse_args
from
megatron.global_vars
import
get_args
from
megatron.global_vars
import
get_timers
from
megatron.global_vars
import
get_tensorboard_writer
from
megatron.global_vars
import
get_adlr_autoresume
from
megatron.initialize
import
initialize_megatron
from
megatron
import
mpu
from
megatron.fp16
import
FP16_Module
from
megatron.fp16
import
FP16_Optimizer
...
...
@@ -30,20 +35,15 @@ from megatron.learning_rates import AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
enable_adlr_autoresume
from
megatron.utils
import
get_tensorboard_writer
from
megatron.utils
import
initialize_distributed
from
megatron.utils
import
load_checkpoint
from
megatron.utils
import
print_args
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
report_memory
from
megatron.utils
import
save_checkpoint
from
megatron.utils
import
set_random_seed
from
megatron.utils
import
Timers
def
run
(
top_level_message
,
train_val_test_data_provider
,
model_provider
,
forward_step_func
,
extra_args_provider
=
None
):
model_provider
,
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
"""Main training program.
This function will run the followings in the order provided:
...
...
@@ -72,8 +72,11 @@ def run(top_level_message, train_val_test_data_provider,
"""
# Initalize and get arguments, timers, and Tensorboard writer.
args
=
parse_args
(
extra_args_provider
=
extra_args_provider
)
timers
,
writer
=
initialize_megatron
(
top_level_message
,
args
)
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
args
=
get_args
()
timers
=
get_timers
()
writer
=
get_tensorboard_writer
()
# Data stuff.
train_data
,
val_data
,
test_data
=
train_val_test_data_provider
(
args
)
...
...
@@ -116,32 +119,6 @@ def run(top_level_message, train_val_test_data_provider,
args
,
None
,
0
,
timers
,
True
)
def
initialize_megatron
(
message
,
args
):
""""Initialize distributed, random seed, and autoresume."""
# Timer.
timers
=
Timers
()
# Tensorboard writer.
writer
=
get_tensorboard_writer
(
args
)
# Pytorch distributed.
initialize_distributed
(
args
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
message
,
flush
=
True
)
print_args
(
args
,
writer
)
# Autoresume.
torch
.
distributed
.
barrier
()
if
args
.
adlr_autoresume
:
enable_adlr_autoresume
(
args
)
# Random seeds for reproducability.
set_random_seed
(
args
.
seed
)
return
timers
,
writer
def
get_model
(
model_provider_func
,
args
):
"""Build the model."""
...
...
megatron/utils.py
View file @
5050203f
...
...
@@ -22,7 +22,8 @@ import numpy as np
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
apex.optimizers
import
FusedAdam
as
Adam
#from megatron.global_vars import get_args
#from megatron.global_vars import get_adlr_autoresume
from
megatron
import
mpu
from
megatron.fp16
import
FP16_Module
...
...
@@ -31,6 +32,42 @@ from megatron.model import DistributedDataParallel as LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
def
print_rank_0
(
message
):
"""If distributed is initialized print only on rank 0."""
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
def
reduce_losses
(
losses
):
"""Reduce a tensor of losses across all GPUs."""
reduced_losses
=
torch
.
cat
(
[
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
torch
.
distributed
.
all_reduce
(
reduced_losses
)
reduced_losses
=
reduced_losses
/
torch
.
distributed
.
get_world_size
()
return
reduced_losses
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
):
# Add barrier to ensure consistnecy.
torch
.
distributed
.
barrier
()
if
args
.
AutoResume
.
termination_requested
():
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
)
print_rank_0
(
">>> autoresume termination request found!"
)
if
torch
.
distributed
.
get_rank
()
==
0
:
args
.
AutoResume
.
request_resume
()
print_rank_0
(
">>> training terminated. Returning"
)
exit
(
0
)
def
get_ltor_masks_and_position_ids
(
data
,
eod_token
,
reset_position_ids
,
...
...
@@ -88,78 +125,6 @@ def get_ltor_masks_and_position_ids(data,
return
attention_mask
,
loss_mask
,
position_ids
def
reduce_losses
(
losses
):
reduced_losses
=
torch
.
cat
(
[
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
torch
.
distributed
.
all_reduce
(
reduced_losses
)
reduced_losses
=
reduced_losses
/
torch
.
distributed
.
get_world_size
()
return
reduced_losses
def
get_tensorboard_writer
(
args
):
writer
=
None
if
hasattr
(
args
,
'tensorboard_dir'
)
and
\
args
.
tensorboard_dir
and
args
.
rank
==
0
:
try
:
from
torch.utils.tensorboard
import
SummaryWriter
writer
=
SummaryWriter
(
log_dir
=
args
.
tensorboard_dir
)
except
ModuleNotFoundError
:
print_rank_0
(
'WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.'
)
writer
=
None
return
writer
def
print_rank_0
(
message
):
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
def
enable_adlr_autoresume
(
args
):
print_rank_0
(
'enabling autoresume ...'
)
import
sys
sys
.
path
.
append
(
os
.
environ
.
get
(
'SUBMIT_SCRIPTS'
,
'.'
))
try
:
from
userlib.auto_resume
import
AutoResume
except
:
print_rank_0
(
'ADLR autoresume is not available, exiting ...'
)
exit
()
args
.
AutoResume
=
AutoResume
args
.
AutoResume
.
init
()
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
):
# Add barrier to ensure consistnecy.
torch
.
distributed
.
barrier
()
if
args
.
AutoResume
.
termination_requested
():
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
)
print_rank_0
(
">>> autoresume termination request found!"
)
if
torch
.
distributed
.
get_rank
()
==
0
:
args
.
AutoResume
.
request_resume
()
print_rank_0
(
">>> training terminated. Returning"
)
exit
(
0
)
def
print_args
(
args
,
writer
=
None
):
"""Print arguments."""
print_rank_0
(
'arguments:'
)
str_list
=
[]
for
arg
in
vars
(
args
):
dots
=
'.'
*
(
29
-
len
(
arg
))
str_list
.
append
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)))
if
writer
:
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)))
for
arg
in
sorted
(
str_list
,
key
=
lambda
a
:
a
.
lower
()):
print_rank_0
(
arg
)
def
print_params_min_max_norm
(
optimizer
,
iteration
):
"""Print min, max, and norm of all parameters."""
...
...
@@ -181,82 +146,6 @@ def print_params_min_max_norm(optimizer, iteration):
print
(
string
,
flush
=
True
)
class
Timers
:
"""Group of timers."""
class
Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
started_
=
self
.
started_
# If the timing in progress, end it first.
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
started_
:
self
.
start
()
return
elapsed_
def
__init__
(
self
):
self
.
timers
=
{}
def
__call__
(
self
,
name
):
if
name
not
in
self
.
timers
:
self
.
timers
[
name
]
=
self
.
Timer
(
name
)
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
'_time'
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
assert
normalizer
>
0.0
string
=
'time (ms)'
for
name
in
names
:
elapsed_time
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
*
1000.0
/
normalizer
string
+=
' | {}: {:.2f}'
.
format
(
name
,
elapsed_time
)
print_rank_0
(
string
)
def
report_memory
(
name
):
"""Simple GPU memory report."""
...
...
@@ -285,39 +174,6 @@ def vocab_size_with_padding(num_tokens, args):
return
after
def
initialize_distributed
(
args
):
"""Initialize torch.distributed."""
# Manually set the device ids.
device
=
args
.
rank
%
torch
.
cuda
.
device_count
()
if
args
.
local_rank
is
not
None
:
device
=
args
.
local_rank
torch
.
cuda
.
set_device
(
device
)
# Call the init process
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
# Set the model-parallel / data-parallel communicators.
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
def
set_random_seed
(
seed
):
"""Set random seed for reproducability."""
if
seed
is
not
None
and
seed
>
0
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
,
mp_rank
=
None
):
if
release
:
...
...
pretrain_bert.py
View file @
5050203f
...
...
@@ -203,4 +203,5 @@ if __name__ == "__main__":
exit()
'''
run
(
'Pretrain BERT model'
,
get_train_val_test_data
,
model_provider
,
forward_step
)
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment