Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a1d04b79
Commit
a1d04b79
authored
Oct 04, 2019
by
Jared Casper
Browse files
Updating public repo with latest changes.
parent
93ab4bea
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
3 deletions
+53
-3
utils.py
utils.py
+53
-3
No files found.
utils.py
View file @
a1d04b79
...
@@ -35,7 +35,34 @@ def print_rank_0(message):
...
@@ -35,7 +35,34 @@ def print_rank_0(message):
print
(
message
,
flush
=
True
)
print
(
message
,
flush
=
True
)
def
print_args
(
args
):
def
enable_adlr_autoresume
(
args
):
print_rank_0
(
'enabling autoresume ...'
)
import
sys
sys
.
path
.
append
(
os
.
environ
.
get
(
'SUBMIT_SCRIPTS'
,
'.'
))
try
:
from
userlib.auto_resume
import
AutoResume
except
:
print_rank_0
(
'ADLR autoresume is not available, exiting ...'
)
exit
(
0
)
args
.
AutoResume
=
AutoResume
args
.
AutoResume
.
init
()
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
):
# Add barrier to ensure consistnecy.
torch
.
distributed
.
barrier
()
if
args
.
AutoResume
.
termination_requested
():
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
)
print_rank_0
(
">>> autoresume termination request found!"
)
if
torch
.
distributed
.
get_rank
()
==
0
:
args
.
AutoResume
.
request_resume
()
print_rank_0
(
">>> training terminated. Returning"
)
exit
(
0
)
def
print_args
(
args
,
writer
=
None
):
"""Print arguments."""
"""Print arguments."""
print
(
'arguments:'
,
flush
=
True
)
print
(
'arguments:'
,
flush
=
True
)
...
@@ -43,6 +70,8 @@ def print_args(args):
...
@@ -43,6 +70,8 @@ def print_args(args):
dots
=
'.'
*
(
29
-
len
(
arg
))
dots
=
'.'
*
(
29
-
len
(
arg
))
print
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)),
flush
=
True
)
print
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)),
flush
=
True
)
if
writer
:
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)))
def
print_params_min_max_norm
(
optimizer
,
iteration
):
def
print_params_min_max_norm
(
optimizer
,
iteration
):
"""Print min, max, and norm of all parameters."""
"""Print min, max, and norm of all parameters."""
...
@@ -119,6 +148,16 @@ class Timers:
...
@@ -119,6 +148,16 @@ class Timers:
self
.
timers
[
name
]
=
self
.
Timer
(
name
)
self
.
timers
[
name
]
=
self
.
Timer
(
name
)
return
self
.
timers
[
name
]
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
'_time'
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
"""Log a group of timers."""
assert
normalizer
>
0.0
assert
normalizer
>
0.0
...
@@ -144,13 +183,13 @@ def report_memory(name):
...
@@ -144,13 +183,13 @@ def report_memory(name):
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
print_rank_0
(
string
)
print_rank_0
(
string
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
):
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
,
mp_rank
=
None
):
if
release
:
if
release
:
d
=
'release'
d
=
'release'
else
:
else
:
d
=
'iter_{:07d}'
.
format
(
iteration
)
d
=
'iter_{:07d}'
.
format
(
iteration
)
return
os
.
path
.
join
(
checkpoints_path
,
d
,
return
os
.
path
.
join
(
checkpoints_path
,
d
,
'mp_rank_{:02d}'
.
format
(
mpu
.
get_model_parallel_rank
()),
'mp_rank_{:02d}'
.
format
(
mpu
.
get_model_parallel_rank
()
if
mp_rank
is
None
else
mp_rank
),
'model_optim_rng.pt'
)
'model_optim_rng.pt'
)
...
@@ -353,3 +392,14 @@ def move_weights(our, oai, dst2src=False):
...
@@ -353,3 +392,14 @@ def move_weights(our, oai, dst2src=False):
for
our_layer
,
oai_layer
in
zip
(
our
.
transformer
.
layers
,
oai
.
transformer
.
h
):
for
our_layer
,
oai_layer
in
zip
(
our
.
transformer
.
layers
,
oai
.
transformer
.
h
):
load_transformer_layer
(
our_layer
,
oai_layer
,
dst2src
)
load_transformer_layer
(
our_layer
,
oai_layer
,
dst2src
)
def
merge_parallel_state_dicts
(
state_dicts
):
temp_sd
=
{}
for
sd
in
state_dicts
:
for
k
,
v
in
sd
.
items
():
temp_sd
[
k
].
append
()
pass
def
merge_parallel_checkpoints
(
checkpoint_dir
,
model_parallel_size
):
pass
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment