Unverified Commit 63c295ac authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Ensure metric results are JSON-serializable (#10632)

parent 27d9e05c
...@@ -101,6 +101,7 @@ from .trainer_utils import ( ...@@ -101,6 +101,7 @@ from .trainer_utils import (
TrainOutput, TrainOutput,
default_compute_objective, default_compute_objective,
default_hp_space, default_hp_space,
denumpify_detensorize,
get_last_checkpoint, get_last_checkpoint,
set_seed, set_seed,
speed_metrics, speed_metrics,
...@@ -1831,6 +1832,9 @@ class Trainer: ...@@ -1831,6 +1832,9 @@ class Trainer:
else: else:
metrics = {} metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)
if eval_loss is not None: if eval_loss is not None:
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
......
...@@ -38,6 +38,13 @@ from .file_utils import ( ...@@ -38,6 +38,13 @@ from .file_utils import (
) )
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
def set_seed(seed: int): def set_seed(seed: int):
""" """
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
...@@ -49,14 +56,10 @@ def set_seed(seed: int): ...@@ -49,14 +56,10 @@ def set_seed(seed: int):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
if is_torch_available(): if is_torch_available():
import torch
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available # ^^ safe to call this function even if cuda is not available
if is_tf_available(): if is_tf_available():
import tensorflow as tf
tf.random.set_seed(seed) tf.random.set_seed(seed)
...@@ -423,6 +426,21 @@ class TrainerMemoryTracker: ...@@ -423,6 +426,21 @@ class TrainerMemoryTracker:
self.update_metrics(stage, metrics) self.update_metrics(stage, metrics)
def denumpify_detensorize(metrics):
"""
Recursively calls `.item()` on the element of the dictionary passed
"""
if isinstance(metrics, (list, tuple)):
return type(metrics)(denumpify_detensorize(m) for m in metrics)
elif isinstance(metrics, dict):
return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()})
elif isinstance(metrics, np.generic):
return metrics.item()
elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1:
return metrics.item()
return metrics
class ShardedDDPOption(ExplicitEnum): class ShardedDDPOption(ExplicitEnum):
SIMPLE = "simple" SIMPLE = "simple"
ZERO_DP_2 = "zero_dp_2" ZERO_DP_2 = "zero_dp_2"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment