Unverified Commit 70fc197b authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

remove print-freq option and compute validation loss at each epoch. (#997)

parent 076052f1
......@@ -83,13 +83,6 @@ def parse_args():
parser.add_argument(
"--start-epoch", default=0, type=int, metavar="N", help="manual epoch number"
)
parser.add_argument(
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency in epochs",
)
parser.add_argument(
"--reduce-lr-valid",
action="store_true",
......@@ -615,37 +608,35 @@ def main(rank, args):
not args.reduce_lr_valid,
)
if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:
loss = evaluate(
model,
criterion,
loader_validation,
decoder,
language_model,
devices[0],
epoch,
not_main_rank,
)
is_best = loss < best_loss
best_loss = min(loss, best_loss)
save_checkpoint(
{
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
},
is_best,
args.checkpoint,
not_main_rank,
)
loss = evaluate(
model,
criterion,
loader_validation,
decoder,
language_model,
devices[0],
epoch,
not_main_rank,
)
if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(loss)
is_best = loss < best_loss
best_loss = min(loss, best_loss)
save_checkpoint(
{
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
},
is_best,
args.checkpoint,
not_main_rank,
)
logging.info("End time: %s", datetime.now())
if args.distributed:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment