Unverified Commit 8eb9e29d authored by Dave Berenbaum's avatar Dave Berenbaum Committed by GitHub
Browse files

dvclive callback: warn instead of fail when logging non-scalars (#27608)

* dvclive callback: warn instead of fail when logging non-scalars

* tests: log lr as scalar
parent 38e2633f
......@@ -1680,10 +1680,19 @@ class DVCLiveCallback(TrainerCallback):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
from dvclive.plots import Metric
from dvclive.utils import standardize_metric_name
for key, value in logs.items():
self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
if Metric.could_log(value):
self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
else:
logger.warning(
"Trainer is attempting to log a value of "
f'"{value}" of type {type(value)} for key "{key}" as a scalar. '
"This invocation of DVCLive's Live.log_metric() "
"is incorrect so we dropped this attribute."
)
self.live.next_step()
def on_save(self, args, state, control, **kwargs):
......
......@@ -672,7 +672,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
def log(self, logs):
# the LR is computed after metrics and does not exist for the first epoch
if hasattr(self.lr_scheduler, "_last_lr"):
logs["learning_rate"] = self.lr_scheduler._last_lr
logs["learning_rate"] = self.lr_scheduler._last_lr[0]
super().log(logs)
train_dataset = RegressionDataset(length=64)
......@@ -702,14 +702,14 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
if loss > best_loss:
bad_epochs += 1
if bad_epochs > patience:
self.assertLess(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
self.assertLess(logs[i + 1]["learning_rate"], log["learning_rate"])
just_decreased = True
bad_epochs = 0
else:
best_loss = loss
bad_epochs = 0
if not just_decreased:
self.assertEqual(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
self.assertEqual(logs[i + 1]["learning_rate"], log["learning_rate"])
def test_adafactor_lr_none(self):
# test the special case where lr=None, since Trainer can't not have lr_scheduler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment