Commit 59cefd4f authored by thomwolf's avatar thomwolf
Browse files

fix #726 - get_lr in examples

parent ddc2cc61
...@@ -313,7 +313,8 @@ def main(): ...@@ -313,7 +313,8 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) if not args.fp16:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step) tb_writer.add_scalar('loss', loss.item(), global_step)
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
......
...@@ -319,7 +319,8 @@ def main(): ...@@ -319,7 +319,8 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
if args.local_rank in [-1, 0] and (args.log_every <= 0 or (step + 1) % args.log_every == 0): if args.local_rank in [-1, 0] and (args.log_every <= 0 or (step + 1) % args.log_every == 0):
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) if not args.fp16:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step) tb_writer.add_scalar('loss', loss.item(), global_step)
### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() ### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
......
...@@ -313,7 +313,8 @@ def main(): ...@@ -313,7 +313,8 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) if not args.fp16:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step) tb_writer.add_scalar('loss', loss.item(), global_step)
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment