Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
ede0a58f
Commit
ede0a58f
authored
Aug 17, 2021
by
mshoeybi
Browse files
simplified the iteration read check across ranks
parent
a8f4edcb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
16 deletions
+9
-16
megatron/checkpointing.py
megatron/checkpointing.py
+9
-16
No files found.
megatron/checkpointing.py
View file @
ede0a58f
...
@@ -124,26 +124,19 @@ def read_metadata(tracker_filename):
...
@@ -124,26 +124,19 @@ def read_metadata(tracker_filename):
assert
iteration
>
0
or
release
,
'error parsing metadata file {}'
.
format
(
assert
iteration
>
0
or
release
,
'error parsing metadata file {}'
.
format
(
tracker_filename
)
tracker_filename
)
# Make sure all the ranks read the same meta data.
# Get the max iteration retrieved across the ranks.
iters_cuda
=
torch
.
cuda
.
LongTensor
(
iters_cuda
=
torch
.
cuda
.
LongTensor
([
iteration
])
torch
.
distributed
.
get_world_size
()).
fill_
(
0
)
torch
.
distributed
.
all_reduce
(
iters_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
iters_cuda
[
torch
.
distributed
.
get_rank
()]
=
iteration
max_iter
=
iters_cuda
[
0
].
item
()
torch
.
distributed
.
all_reduce
(
iters_cuda
)
# We should now have all the same iteration.
# We should now have all the same iteration.
# If not, print a warning and chose the maximum
# If not, print a warning and chose the maximum
# iteration across all ranks.
# iteration across all ranks.
max_iter
=
iters_cuda
.
max
().
item
()
if
iteration
!=
max_iter
:
min_iter
=
iters_cuda
.
min
().
item
()
print
(
'WARNING: on rank {} found iteration {} in the '
if
max_iter
==
min_iter
:
'metadata while max iteration across the ranks '
print_rank_0
(
'> meta data was loaded successfully ...'
)
'is {}, replacing it with max iteration.'
.
format
(
else
:
rank
,
iteration
,
max_iter
),
flush
=
True
)
for
rank
in
range
(
torch
.
distributed
.
get_world_size
()):
if
iters_cuda
[
rank
]
!=
max_iters
:
print_rank_0
(
'WARNING: on rank {} found iteration {} in the '
'meta data while max iteration across the ranks '
'is {}, replacing it with max iteration.'
.
format
(
rank
,
iters_cuda
[
rank
],
max_iter
))
return
max_iter
,
release
return
max_iter
,
release
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment