Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
786346f3
Commit
786346f3
authored
Dec 14, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Dec 14, 2020
Browse files
Internal change
PiperOrigin-RevId: 347480509
parent
a167bf93
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
11 deletions
+37
-11
official/core/base_trainer.py
official/core/base_trainer.py
+9
-2
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+21
-2
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+6
-7
official/nlp/tasks/question_answering_test.py
official/nlp/tasks/question_answering_test.py
+1
-0
No files found.
official/core/base_trainer.py
View file @
786346f3
...
...
@@ -301,6 +301,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def
step_fn
(
inputs
):
logs
=
self
.
task
.
validation_step
(
inputs
,
model
=
self
.
model
,
metrics
=
self
.
validation_metrics
)
if
self
.
task
.
loss
in
logs
:
self
.
_validation_loss
.
update_state
(
logs
[
self
.
task
.
loss
])
return
logs
...
...
@@ -311,8 +312,14 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def
eval_end
(
self
,
aggregated_logs
=
None
):
"""Processes evaluation results."""
logs
=
{}
for
metric
in
self
.
validation_metrics
+
[
self
.
validation_loss
]
:
for
metric
in
self
.
validation_metrics
:
logs
[
metric
.
name
]
=
metric
.
result
()
if
self
.
validation_loss
.
count
.
numpy
()
!=
0
:
logs
[
self
.
validation_loss
.
name
]
=
self
.
validation_loss
.
result
()
else
:
# `self.validation_loss` metric was not updated, because the validation
# loss was not returned from the task's `validation_step` method.
logging
.
info
(
"The task did not report validation loss."
)
if
aggregated_logs
:
metrics
=
self
.
task
.
reduce_aggregated_logs
(
aggregated_logs
)
logs
.
update
(
metrics
)
...
...
official/core/base_trainer_test.py
View file @
786346f3
...
...
@@ -54,8 +54,8 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
}
})))
def
create_test_trainer
(
self
,
config
,
model_dir
=
None
):
task
=
mock_task
.
MockTask
(
config
.
task
,
logging_dir
=
model_dir
)
def
create_test_trainer
(
self
,
config
,
model_dir
=
None
,
task
=
None
):
task
=
task
or
mock_task
.
MockTask
(
config
.
task
,
logging_dir
=
model_dir
)
ckpt_exporter
=
train_lib
.
maybe_create_best_ckpt_exporter
(
config
,
model_dir
)
trainer
=
trainer_lib
.
Trainer
(
config
,
...
...
@@ -79,6 +79,25 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
trainer
=
self
.
create_test_trainer
(
self
.
_config
)
logs
=
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
logs
[
'counter'
],
5.
*
distribution
.
num_replicas_in_sync
)
self
.
assertIn
(
'validation_loss'
,
logs
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_validate_without_loss
(
self
,
distribution
):
class
MockTaskWithoutValidationLoss
(
mock_task
.
MockTask
):
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
# Disable validation loss.
logs
=
super
().
validation_step
(
inputs
,
model
)
del
logs
[
self
.
loss
]
return
logs
with
distribution
.
scope
():
task
=
MockTaskWithoutValidationLoss
()
trainer
=
self
.
create_test_trainer
(
self
.
_config
,
task
=
task
)
logs
=
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
logs
[
'counter'
],
5.
*
distribution
.
num_replicas_in_sync
)
self
.
assertNotIn
(
'validation_loss'
,
logs
)
@
combinations
.
generate
(
combinations
.
combine
(
...
...
official/nlp/tasks/question_answering.py
View file @
786346f3
...
...
@@ -212,7 +212,10 @@ class QuestionAnsweringTask(base_task.Task):
input_context
)
def
build_metrics
(
self
,
training
=
None
):
del
training
if
not
training
:
# We cannot compute start/end_position_accuracy because start/end_position
# labels are not available in the validation dataset (b/173794928).
return
[]
# TODO(lehou): a list of metrics doesn't work the same as in compile/fit.
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
...
...
@@ -244,8 +247,9 @@ class QuestionAnsweringTask(base_task.Task):
unique_ids
=
features
.
pop
(
'unique_ids'
)
model_outputs
=
self
.
inference_step
(
features
,
model
)
start_logits
,
end_logits
=
model_outputs
# We cannot compute validation_loss here, because start/end_position
# labels are not available in the validation dataset (b/173794928).
logs
=
{
self
.
loss
:
0.0
,
# TODO(lehou): compute the real validation loss.
'unique_ids'
:
unique_ids
,
'start_logits'
:
start_logits
,
'end_logits'
:
end_logits
,
...
...
@@ -293,8 +297,6 @@ class QuestionAnsweringTask(base_task.Task):
if
self
.
task_config
.
validation_data
.
version_2_with_negative
:
eval_metrics
=
squad_evaluate_v2_0
.
evaluate
(
pred_dataset
,
all_predictions
,
scores_diff
)
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
eval_metrics
=
{
'exact_match'
:
eval_metrics
[
'final_exact'
],
'exact_match_threshold'
:
eval_metrics
[
'final_exact_thresh'
],
...
...
@@ -305,8 +307,6 @@ class QuestionAnsweringTask(base_task.Task):
}
else
:
eval_metrics
=
squad_evaluate_v1_1
.
evaluate
(
pred_dataset
,
all_predictions
)
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
eval_metrics
=
{
'exact_match'
:
eval_metrics
[
'exact_match'
],
'final_f1'
:
eval_metrics
[
'final_f1'
]
...
...
@@ -417,7 +417,6 @@ class XLNetQuestionAnsweringTask(QuestionAnsweringTask):
class_logits
=
model_outputs
[
'class_logits'
]
logs
=
{
self
.
loss
:
0.0
,
# TODO(lehou): compute the real validation loss.
'unique_ids'
:
unique_ids
,
'start_top_predictions'
:
start_top_predictions
,
'end_top_predictions'
:
end_top_predictions
,
...
...
official/nlp/tasks/question_answering_test.py
View file @
786346f3
...
...
@@ -250,6 +250,7 @@ class XLNetQuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
logs
=
task
.
aggregate_logs
(
step_outputs
=
logs
)
metrics
=
task
.
reduce_aggregated_logs
(
logs
)
self
.
assertIn
(
"final_f1"
,
metrics
)
self
.
assertNotIn
(
"loss"
,
metrics
)
def
test_task
(
self
):
config
=
question_answering
.
XLNetQuestionAnsweringConfig
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment