Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b81cad66
Commit
b81cad66
authored
Dec 12, 2020
by
mohammad
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Fix TensorBoard writes
parent
5a304ede
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
54 additions
and
41 deletions
+54
-41
megatron/global_vars.py
megatron/global_vars.py
+4
-3
megatron/training.py
megatron/training.py
+42
-30
pretrain_bert.py
pretrain_bert.py
+2
-2
pretrain_gpt2.py
pretrain_gpt2.py
+2
-2
pretrain_ict.py
pretrain_ict.py
+2
-2
tasks/finetune_utils.py
tasks/finetune_utils.py
+2
-2
No files found.
megatron/global_vars.py
View file @
b81cad66
...
@@ -131,7 +131,7 @@ def _set_tensorboard_writer(args):
...
@@ -131,7 +131,7 @@ def _set_tensorboard_writer(args):
'tensorboard writer'
)
'tensorboard writer'
)
if
hasattr
(
args
,
'tensorboard_dir'
)
and
\
if
hasattr
(
args
,
'tensorboard_dir'
)
and
\
args
.
tensorboard_dir
and
args
.
rank
==
0
:
args
.
tensorboard_dir
and
args
.
rank
==
(
args
.
world_size
-
1
)
:
try
:
try
:
from
torch.utils.tensorboard
import
SummaryWriter
from
torch.utils.tensorboard
import
SummaryWriter
print
(
'> setting tensorboard ...'
)
print
(
'> setting tensorboard ...'
)
...
@@ -242,7 +242,7 @@ class Timers:
...
@@ -242,7 +242,7 @@ class Timers:
assert
normalizer
>
0.0
assert
normalizer
>
0.0
for
name
in
names
:
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
'
_
time'
,
value
,
iteration
)
writer
.
add_scalar
(
name
+
'
-
time'
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
"""Log a group of timers."""
...
@@ -253,7 +253,8 @@ class Timers:
...
@@ -253,7 +253,8 @@ class Timers:
reset
=
reset
)
*
1000.0
/
normalizer
reset
=
reset
)
*
1000.0
/
normalizer
string
+=
' | {}: {:.2f}'
.
format
(
name
,
elapsed_time
)
string
+=
' | {}: {:.2f}'
.
format
(
name
,
elapsed_time
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
):
print
(
string
,
flush
=
True
)
print
(
string
,
flush
=
True
)
else
:
else
:
print
(
string
,
flush
=
True
)
print
(
string
,
flush
=
True
)
megatron/training.py
View file @
b81cad66
...
@@ -31,6 +31,7 @@ from megatron import get_timers
...
@@ -31,6 +31,7 @@ from megatron import get_timers
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_current_global_batch_size
from
megatron
import
get_current_global_batch_size
from
megatron
import
get_num_microbatches
from
megatron
import
get_num_microbatches
from
megatron
import
is_last_rank
from
megatron
import
update_num_microbatches
from
megatron
import
update_num_microbatches
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
...
@@ -675,12 +676,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -675,12 +676,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
timers
=
get_timers
()
timers
=
get_timers
()
writer
=
get_tensorboard_writer
()
writer
=
get_tensorboard_writer
()
# Update losses.
# Advanced, skipped, and Nan iterations.
advanced_iters_key
=
'advanced iterations'
skipped_iters_key
=
'skipped iterations'
skipped_iters_key
=
'skipped iterations'
nan_iters_key
=
'nan iterations'
# Advanced iterations.
if
not
skipped_iter
:
total_loss_dict
[
advanced_iters_key
]
=
total_loss_dict
.
get
(
advanced_iters_key
,
0
)
+
1
else
:
if
advanced_iters_key
not
in
total_loss_dict
:
total_loss_dict
[
advanced_iters_key
]
=
0
# Skipped iterations.
total_loss_dict
[
skipped_iters_key
]
=
total_loss_dict
.
get
(
total_loss_dict
[
skipped_iters_key
]
=
total_loss_dict
.
get
(
skipped_iters_key
,
0
)
+
skipped_iter
skipped_iters_key
,
0
)
+
skipped_iter
got_nan_key
=
'got nan'
# Update losses and set nan iterations
got_nan
=
False
got_nan
=
False
for
key
in
loss_dict
:
for
key
in
loss_dict
:
if
not
skipped_iter
:
if
not
skipped_iter
:
...
@@ -692,9 +702,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -692,9 +702,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
value
==
-
float
(
'inf'
)
or
\
value
==
-
float
(
'inf'
)
or
\
value
!=
value
value
!=
value
got_nan
=
got_nan
or
is_nan
got_nan
=
got_nan
or
is_nan
total_loss_dict
[
nan_iters_key
]
=
total_loss_dict
.
get
(
total_loss_dict
[
got_nan_key
]
=
total_loss_dict
.
get
(
nan_iters_key
,
0
)
+
int
(
got_nan
)
got_nan_key
,
0
)
+
int
(
got_nan
)
# Logging.
# Logging.
timers_to_log
=
[]
timers_to_log
=
[]
...
@@ -715,51 +724,53 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -715,51 +724,53 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'backward-clip-grad'
)
add_to_logging
(
'backward-clip-grad'
)
add_to_logging
(
'optimizer'
)
add_to_logging
(
'optimizer'
)
add_to_logging
(
'batch
generator'
)
add_to_logging
(
'batch
-
generator'
)
# Calculate batch size.
batch_size
=
args
.
micro_batch_size
*
args
.
data_parallel_size
*
\
batch_size
=
args
.
micro_batch_size
*
args
.
data_parallel_size
*
\
get_num_microbatches
()
get_num_microbatches
()
total_iterations
=
total_loss_dict
[
advanced_iters_key
]
+
\
total_loss_dict
[
skipped_iters_key
]
# Tensorboard values.
# Tensorboard values.
if
writer
and
torch
.
distributed
.
ge
t_rank
()
==
0
:
if
writer
and
is_las
t_rank
():
writer
.
add_scalar
(
'learning
_
rate
-iterations
'
,
learning_rate
,
iteration
)
writer
.
add_scalar
(
'learning
-
rate'
,
learning_rate
,
iteration
)
writer
.
add_scalar
(
'learning
_
rate
-
samples'
,
learning_rate
,
writer
.
add_scalar
(
'learning
-
rate
vs
samples'
,
learning_rate
,
args
.
consumed_train_samples
)
args
.
consumed_train_samples
)
writer
.
add_scalar
(
'batch
_
size
-iterations
'
,
batch_size
,
iteration
)
writer
.
add_scalar
(
'batch
-
size'
,
batch_size
,
iteration
)
writer
.
add_scalar
(
'batch
_
size
-
samples'
,
batch_size
,
writer
.
add_scalar
(
'batch
-
size
vs
samples'
,
batch_size
,
args
.
consumed_train_samples
)
args
.
consumed_train_samples
)
for
key
in
loss_dict
:
for
key
in
loss_dict
:
writer
.
add_scalar
(
key
+
'-iterations'
,
loss_dict
[
key
],
iteration
)
writer
.
add_scalar
(
key
,
loss_dict
[
key
],
iteration
)
writer
.
add_scalar
(
key
+
'
-
samples'
,
loss_dict
[
key
],
writer
.
add_scalar
(
key
+
'
vs
samples'
,
loss_dict
[
key
],
args
.
consumed_train_samples
)
args
.
consumed_train_samples
)
if
args
.
fp16
:
if
args
.
fp16
:
writer
.
add_scalar
(
'loss
_
scale
-iterations
'
,
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss
-
scale'
,
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss
_
scale
-
samples'
,
loss_scale
,
writer
.
add_scalar
(
'loss
-
scale
vs
samples'
,
loss_scale
,
args
.
consumed_train_samples
)
args
.
consumed_train_samples
)
normalizer
=
iteration
%
args
.
log_interval
if
normalizer
==
0
:
normalizer
=
args
.
log_interval
timers
.
write
(
timers_to_log
,
writer
,
iteration
,
timers
.
write
(
timers_to_log
,
writer
,
iteration
,
normalizer
=
normalizer
)
normalizer
=
total_iterations
)
if
iteration
%
args
.
log_interval
==
0
:
if
iteration
%
args
.
log_interval
==
0
:
elapsed_time
=
timers
(
'interval time'
).
elapsed
()
elapsed_time
=
timers
(
'interval time'
).
elapsed
()
elapsed_time_per_iteration
=
elapsed_time
/
total_iterations
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
writer
.
add_scalar
(
'iteration
_
time'
,
writer
.
add_scalar
(
'iteration
-
time'
,
elapsed_time
/
args
.
log
_i
n
ter
val
,
iteration
)
elapsed_time
_per
_iter
ation
,
iteration
)
log_string
=
' iteration {:8d}/{:8d} |'
.
format
(
log_string
=
' iteration {:8d}/{:8d} |'
.
format
(
iteration
,
args
.
train_iters
)
iteration
,
args
.
train_iters
)
log_string
+=
' consumed samples: {:12d} |'
.
format
(
log_string
+=
' consumed samples: {:12d} |'
.
format
(
args
.
consumed_train_samples
)
args
.
consumed_train_samples
)
log_string
+=
' elapsed time per iteration (ms): {:.1f} |'
.
format
(
log_string
+=
' elapsed time per iteration (ms): {:.1f} |'
.
format
(
elapsed_time
*
1000.0
/
args
.
log_interval
)
elapsed_time
_per_iteration
*
1000.0
)
log_string
+=
' learning rate: {:.3E} |'
.
format
(
learning_rate
)
log_string
+=
' learning rate: {:.3E} |'
.
format
(
learning_rate
)
log_string
+=
' global batch size: {:6d} |'
.
format
(
batch_size
)
log_string
+=
' global batch size: {:5d} |'
.
format
(
batch_size
)
num_iterations
=
max
(
1
,
args
.
log_interval
-
total_loss_dict
[
skipped_iters_key
])
for
key
in
total_loss_dict
:
for
key
in
total_loss_dict
:
if
key
not
in
[
skipped_iters_key
,
got_nan_key
]:
if
key
not
in
[
advanced_iters_key
,
skipped_iters_key
,
avg
=
total_loss_dict
[
key
].
item
()
/
float
(
num_iterations
)
nan_iters_key
]:
avg
=
total_loss_dict
[
key
].
item
()
/
\
float
(
max
(
1
,
total_loss_dict
[
advanced_iters_key
]))
if
avg
>
0.0
:
if
avg
>
0.0
:
log_string
+=
' {}: {:.6E} |'
.
format
(
key
,
avg
)
log_string
+=
' {}: {:.6E} |'
.
format
(
key
,
avg
)
total_loss_dict
[
key
]
=
torch
.
cuda
.
FloatTensor
([
0.0
])
total_loss_dict
[
key
]
=
torch
.
cuda
.
FloatTensor
([
0.0
])
...
@@ -768,9 +779,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -768,9 +779,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string
+=
' number of skipped iterations: {:3d} |'
.
format
(
log_string
+=
' number of skipped iterations: {:3d} |'
.
format
(
total_loss_dict
[
skipped_iters_key
])
total_loss_dict
[
skipped_iters_key
])
log_string
+=
' number of nan iterations: {:3d} |'
.
format
(
log_string
+=
' number of nan iterations: {:3d} |'
.
format
(
total_loss_dict
[
got_nan_key
])
total_loss_dict
[
nan_iters_key
])
total_loss_dict
[
advanced_iters_key
]
=
0
total_loss_dict
[
skipped_iters_key
]
=
0
total_loss_dict
[
skipped_iters_key
]
=
0
total_loss_dict
[
got_nan
_key
]
=
0
total_loss_dict
[
nan_iters
_key
]
=
0
print_rank_last
(
log_string
)
print_rank_last
(
log_string
)
if
report_memory_flag
and
learning_rate
>
0.
:
if
report_memory_flag
and
learning_rate
>
0.
:
# Report memory after optimizer state has been initialized.
# Report memory after optimizer state has been initialized.
...
...
pretrain_bert.py
View file @
b81cad66
...
@@ -87,10 +87,10 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -87,10 +87,10 @@ def forward_step(data_iterator, model, input_tensor):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
'batch
generator'
).
start
()
timers
(
'batch
-
generator'
).
start
()
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
\
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
\
=
get_batch
(
data_iterator
)
=
get_batch
(
data_iterator
)
timers
(
'batch
generator'
).
stop
()
timers
(
'batch
-
generator'
).
stop
()
# Forward pass through the model.
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
...
...
pretrain_gpt2.py
View file @
b81cad66
...
@@ -87,10 +87,10 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -87,10 +87,10 @@ def forward_step(data_iterator, model, input_tensor):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
'batch
generator'
).
start
()
timers
(
'batch
-
generator'
).
start
()
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
data_iterator
)
timers
(
'batch
generator'
).
stop
()
timers
(
'batch
-
generator'
).
stop
()
# Forward pass through the model.
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
...
...
pretrain_ict.py
View file @
b81cad66
...
@@ -79,10 +79,10 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -79,10 +79,10 @@ def forward_step(data_iterator, model, input_tensor):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
'batch
generator'
).
start
()
timers
(
'batch
-
generator'
).
start
()
query_tokens
,
query_pad_mask
,
\
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_indices
=
get_ict_batch
(
data_iterator
)
block_tokens
,
block_pad_mask
,
block_indices
=
get_ict_batch
(
data_iterator
)
timers
(
'batch
generator'
).
stop
()
timers
(
'batch
-
generator'
).
stop
()
# Forward model.
# Forward model.
...
...
tasks/finetune_utils.py
View file @
b81cad66
...
@@ -50,13 +50,13 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
...
@@ -50,13 +50,13 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
'batch
generator'
).
start
()
timers
(
'batch
-
generator'
).
start
()
try
:
try
:
batch_
=
next
(
batch
)
batch_
=
next
(
batch
)
except
BaseException
:
except
BaseException
:
batch_
=
batch
batch_
=
batch
tokens
,
types
,
labels
,
attention_mask
=
process_batch
(
batch_
)
tokens
,
types
,
labels
,
attention_mask
=
process_batch
(
batch_
)
timers
(
'batch
generator'
).
stop
()
timers
(
'batch
-
generator'
).
stop
()
# Forward model.
# Forward model.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment