Unverified Commit 9de4afa8 authored by Rakesh Chada's avatar Rakesh Chada Committed by GitHub
Browse files

Make get_last_lr in trainer backward compatible (#4446)

* makes fetching last learning late in trainer backward compatible

* split comment to multiple lines

* fixes black styling issue

* uses version to create a more explicit logic
parent 42e8fbfc
...@@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union ...@@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
...@@ -440,7 +441,14 @@ class Trainer: ...@@ -440,7 +441,14 @@ class Trainer:
): ):
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
logs["learning_rate"] = scheduler.get_last_lr()[0] # maintaining backward compatibility.
# could use "scheduler.get_last_lr()[0]" instead for pytorch >= 1.4.0
logs["learning_rate"] = (
scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else scheduler.get_lr()[0]
)
logging_loss = tr_loss logging_loss = tr_loss
self._log(logs) self._log(logs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment