Commit 86eb5bd8 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'validation_loss_reporting' into 'main'

fixed validation loss reporting in tensorboard

See merge request ADLR/megatron-lm!207
parents f5eac3d1 5b74f764
......@@ -924,6 +924,7 @@ def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model,
iteration, verbose=False):
"""Helper function to evaluate and dump results on screen."""
args = get_args()
writer = get_tensorboard_writer()
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
......@@ -932,11 +933,16 @@ def evaluate_and_print_results(prefix, forward_step_func,
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('{} value'.format(key),
if writer and is_last_rank():
writer.add_scalar('{} value-validation'.format(key),
total_loss_dict[key].item(),
iteration)
writer.add_scalar('{} ppl'.format(key), ppl, iteration)
writer.add_scalar('{} ppl-validation'.format(key), ppl, iteration)
writer.add_scalar('{} value-validation vs samples'.format(key),
total_loss_dict[key].item(),
args.consumed_train_samples)
writer.add_scalar('{} ppl-validation vs samples'.format(key), ppl,
args.consumed_train_samples)
length = len(string) + 1
print_rank_last('-' * length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment