Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a1d04b79
"docs/en/vscode:/vscode.git/clone" did not exist on "30057a61132e64978eb83e660b31d1841e8a8fec"
Commit
a1d04b79
authored
Oct 04, 2019
by
Jared Casper
Browse files
Updating public repo with latest changes.
parent
93ab4bea
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
3 deletions
+53
-3
utils.py
utils.py
+53
-3
No files found.
utils.py
View file @
a1d04b79
...
@@ -35,7 +35,34 @@ def print_rank_0(message):
...
@@ -35,7 +35,34 @@ def print_rank_0(message):
print
(
message
,
flush
=
True
)
print
(
message
,
flush
=
True
)
def
print_args
(
args
):
def
enable_adlr_autoresume
(
args
):
print_rank_0
(
'enabling autoresume ...'
)
import
sys
sys
.
path
.
append
(
os
.
environ
.
get
(
'SUBMIT_SCRIPTS'
,
'.'
))
try
:
from
userlib.auto_resume
import
AutoResume
except
:
print_rank_0
(
'ADLR autoresume is not available, exiting ...'
)
exit
(
0
)
args
.
AutoResume
=
AutoResume
args
.
AutoResume
.
init
()
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
):
# Add barrier to ensure consistnecy.
torch
.
distributed
.
barrier
()
if
args
.
AutoResume
.
termination_requested
():
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
)
print_rank_0
(
">>> autoresume termination request found!"
)
if
torch
.
distributed
.
get_rank
()
==
0
:
args
.
AutoResume
.
request_resume
()
print_rank_0
(
">>> training terminated. Returning"
)
exit
(
0
)
def
print_args
(
args
,
writer
=
None
):
"""Print arguments."""
"""Print arguments."""
print
(
'arguments:'
,
flush
=
True
)
print
(
'arguments:'
,
flush
=
True
)
...
@@ -43,6 +70,8 @@ def print_args(args):
...
@@ -43,6 +70,8 @@ def print_args(args):
dots
=
'.'
*
(
29
-
len
(
arg
))
dots
=
'.'
*
(
29
-
len
(
arg
))
print
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)),
flush
=
True
)
print
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)),
flush
=
True
)
if
writer
:
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)))
def
print_params_min_max_norm
(
optimizer
,
iteration
):
def
print_params_min_max_norm
(
optimizer
,
iteration
):
"""Print min, max, and norm of all parameters."""
"""Print min, max, and norm of all parameters."""
...
@@ -119,6 +148,16 @@ class Timers:
...
@@ -119,6 +148,16 @@ class Timers:
self
.
timers
[
name
]
=
self
.
Timer
(
name
)
self
.
timers
[
name
]
=
self
.
Timer
(
name
)
return
self
.
timers
[
name
]
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
'_time'
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
"""Log a group of timers."""
assert
normalizer
>
0.0
assert
normalizer
>
0.0
...
@@ -144,13 +183,13 @@ def report_memory(name):
...
@@ -144,13 +183,13 @@ def report_memory(name):
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
print_rank_0
(
string
)
print_rank_0
(
string
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
):
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
,
mp_rank
=
None
):
if
release
:
if
release
:
d
=
'release'
d
=
'release'
else
:
else
:
d
=
'iter_{:07d}'
.
format
(
iteration
)
d
=
'iter_{:07d}'
.
format
(
iteration
)
return
os
.
path
.
join
(
checkpoints_path
,
d
,
return
os
.
path
.
join
(
checkpoints_path
,
d
,
'mp_rank_{:02d}'
.
format
(
mpu
.
get_model_parallel_rank
()),
'mp_rank_{:02d}'
.
format
(
mpu
.
get_model_parallel_rank
()
if
mp_rank
is
None
else
mp_rank
),
'model_optim_rng.pt'
)
'model_optim_rng.pt'
)
...
@@ -353,3 +392,14 @@ def move_weights(our, oai, dst2src=False):
...
@@ -353,3 +392,14 @@ def move_weights(our, oai, dst2src=False):
for
our_layer
,
oai_layer
in
zip
(
our
.
transformer
.
layers
,
oai
.
transformer
.
h
):
for
our_layer
,
oai_layer
in
zip
(
our
.
transformer
.
layers
,
oai
.
transformer
.
h
):
load_transformer_layer
(
our_layer
,
oai_layer
,
dst2src
)
load_transformer_layer
(
our_layer
,
oai_layer
,
dst2src
)
def
merge_parallel_state_dicts
(
state_dicts
):
temp_sd
=
{}
for
sd
in
state_dicts
:
for
k
,
v
in
sd
.
items
():
temp_sd
[
k
].
append
()
pass
def
merge_parallel_checkpoints
(
checkpoint_dir
,
model_parallel_size
):
pass
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment