Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
63c295ac
Unverified
Commit
63c295ac
authored
Mar 11, 2021
by
Sylvain Gugger
Committed by
GitHub
Mar 11, 2021
Browse files
Ensure metric results are JSON-serializable (#10632)
parent
27d9e05c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
4 deletions
+26
-4
src/transformers/trainer.py
src/transformers/trainer.py
+4
-0
src/transformers/trainer_utils.py
src/transformers/trainer_utils.py
+22
-4
No files found.
src/transformers/trainer.py
View file @
63c295ac
...
@@ -101,6 +101,7 @@ from .trainer_utils import (
...
@@ -101,6 +101,7 @@ from .trainer_utils import (
TrainOutput
,
TrainOutput
,
default_compute_objective
,
default_compute_objective
,
default_hp_space
,
default_hp_space
,
denumpify_detensorize
,
get_last_checkpoint
,
get_last_checkpoint
,
set_seed
,
set_seed
,
speed_metrics
,
speed_metrics
,
...
@@ -1831,6 +1832,9 @@ class Trainer:
...
@@ -1831,6 +1832,9 @@ class Trainer:
else
:
else
:
metrics
=
{}
metrics
=
{}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics
=
denumpify_detensorize
(
metrics
)
if
eval_loss
is
not
None
:
if
eval_loss
is
not
None
:
metrics
[
f
"
{
metric_key_prefix
}
_loss"
]
=
eval_loss
.
mean
().
item
()
metrics
[
f
"
{
metric_key_prefix
}
_loss"
]
=
eval_loss
.
mean
().
item
()
...
...
src/transformers/trainer_utils.py
View file @
63c295ac
...
@@ -38,6 +38,13 @@ from .file_utils import (
...
@@ -38,6 +38,13 @@ from .file_utils import (
)
)
if
is_torch_available
():
import
torch
if
is_tf_available
():
import
tensorflow
as
tf
def
set_seed
(
seed
:
int
):
def
set_seed
(
seed
:
int
):
"""
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
...
@@ -49,14 +56,10 @@ def set_seed(seed: int):
...
@@ -49,14 +56,10 @@ def set_seed(seed: int):
random
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
if
is_torch_available
():
if
is_torch_available
():
import
torch
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
# ^^ safe to call this function even if cuda is not available
# ^^ safe to call this function even if cuda is not available
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
tf
.
random
.
set_seed
(
seed
)
tf
.
random
.
set_seed
(
seed
)
...
@@ -423,6 +426,21 @@ class TrainerMemoryTracker:
...
@@ -423,6 +426,21 @@ class TrainerMemoryTracker:
self
.
update_metrics
(
stage
,
metrics
)
self
.
update_metrics
(
stage
,
metrics
)
def
denumpify_detensorize
(
metrics
):
"""
Recursively calls `.item()` on the element of the dictionary passed
"""
if
isinstance
(
metrics
,
(
list
,
tuple
)):
return
type
(
metrics
)(
denumpify_detensorize
(
m
)
for
m
in
metrics
)
elif
isinstance
(
metrics
,
dict
):
return
type
(
metrics
)({
k
:
denumpify_detensorize
(
v
)
for
k
,
v
in
metrics
.
items
()})
elif
isinstance
(
metrics
,
np
.
generic
):
return
metrics
.
item
()
elif
is_torch_available
()
and
isinstance
(
metrics
,
torch
.
Tensor
)
and
metrics
.
numel
()
==
1
:
return
metrics
.
item
()
return
metrics
class
ShardedDDPOption
(
ExplicitEnum
):
class
ShardedDDPOption
(
ExplicitEnum
):
SIMPLE
=
"simple"
SIMPLE
=
"simple"
ZERO_DP_2
=
"zero_dp_2"
ZERO_DP_2
=
"zero_dp_2"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment