Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
40fdb50c
Commit
40fdb50c
authored
Jun 12, 2020
by
Maxim Neumann
Committed by
A. Unique TensorFlower
Jun 12, 2020
Browse files
Allow to provide custom metrics to run_classifier.
PiperOrigin-RevId: 316134518
parent
2dfd1e63
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
8 deletions
+17
-8
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+17
-8
No files found.
official/nlp/bert/run_classifier.py
View file @
40fdb50c
...
@@ -125,7 +125,8 @@ def run_bert_classifier(strategy,
...
@@ -125,7 +125,8 @@ def run_bert_classifier(strategy,
train_input_fn
,
train_input_fn
,
eval_input_fn
,
eval_input_fn
,
training_callbacks
=
True
,
training_callbacks
=
True
,
custom_callbacks
=
None
):
custom_callbacks
=
None
,
custom_metrics
=
None
):
"""Run BERT classifier training using low-level API."""
"""Run BERT classifier training using low-level API."""
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
num_classes
=
input_meta_data
.
get
(
'num_labels'
,
1
)
num_classes
=
input_meta_data
.
get
(
'num_labels'
,
1
)
...
@@ -159,7 +160,9 @@ def run_bert_classifier(strategy,
...
@@ -159,7 +160,9 @@ def run_bert_classifier(strategy,
# Defines evaluation metrics function, which will create metrics in the
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
# correct device and strategy scope.
if
is_regression
:
if
custom_metrics
:
metric_fn
=
custom_metrics
elif
is_regression
:
metric_fn
=
functools
.
partial
(
metric_fn
=
functools
.
partial
(
tf
.
keras
.
metrics
.
MeanSquaredError
,
tf
.
keras
.
metrics
.
MeanSquaredError
,
'mean_squared_error'
,
'mean_squared_error'
,
...
@@ -216,10 +219,12 @@ def run_keras_compile_fit(model_dir,
...
@@ -216,10 +219,12 @@ def run_keras_compile_fit(model_dir,
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
sub_model
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
sub_model
)
checkpoint
.
restore
(
init_checkpoint
).
assert_existing_objects_matched
()
checkpoint
.
restore
(
init_checkpoint
).
assert_existing_objects_matched
()
if
not
isinstance
(
metric_fn
,
(
list
,
tuple
)):
metric_fn
=
[
metric_fn
]
bert_model
.
compile
(
bert_model
.
compile
(
optimizer
=
optimizer
,
optimizer
=
optimizer
,
loss
=
loss_fn
,
loss
=
loss_fn
,
metrics
=
[
metric_fn
()
],
metrics
=
[
fn
()
for
fn
in
metric_fn
],
experimental_steps_per_execution
=
steps_per_loop
)
experimental_steps_per_execution
=
steps_per_loop
)
summary_dir
=
os
.
path
.
join
(
model_dir
,
'summaries'
)
summary_dir
=
os
.
path
.
join
(
model_dir
,
'summaries'
)
...
@@ -350,7 +355,8 @@ def run_bert(strategy,
...
@@ -350,7 +355,8 @@ def run_bert(strategy,
train_input_fn
=
None
,
train_input_fn
=
None
,
eval_input_fn
=
None
,
eval_input_fn
=
None
,
init_checkpoint
=
None
,
init_checkpoint
=
None
,
custom_callbacks
=
None
):
custom_callbacks
=
None
,
custom_metrics
=
None
):
"""Run BERT training."""
"""Run BERT training."""
# Enables XLA in Session Config. Should not be set for TPU.
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
...
@@ -391,7 +397,8 @@ def run_bert(strategy,
...
@@ -391,7 +397,8 @@ def run_bert(strategy,
init_checkpoint
or
FLAGS
.
init_checkpoint
,
init_checkpoint
or
FLAGS
.
init_checkpoint
,
train_input_fn
,
train_input_fn
,
eval_input_fn
,
eval_input_fn
,
custom_callbacks
=
custom_callbacks
)
custom_callbacks
=
custom_callbacks
,
custom_metrics
=
custom_metrics
)
if
FLAGS
.
model_export_path
:
if
FLAGS
.
model_export_path
:
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
...
@@ -399,11 +406,12 @@ def run_bert(strategy,
...
@@ -399,11 +406,12 @@ def run_bert(strategy,
return
trained_model
return
trained_model
def
custom_main
(
custom_callbacks
=
None
):
def
custom_main
(
custom_callbacks
=
None
,
custom_metrics
=
None
):
"""Run classification or regression.
"""Run classification or regression.
Args:
Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
custom_metrics: list of metrics passed to the training loop.
"""
"""
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_param
)
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_param
)
...
@@ -474,11 +482,12 @@ def custom_main(custom_callbacks=None):
...
@@ -474,11 +482,12 @@ def custom_main(custom_callbacks=None):
bert_config
,
bert_config
,
train_input_fn
,
train_input_fn
,
eval_input_fn
,
eval_input_fn
,
custom_callbacks
=
custom_callbacks
)
custom_callbacks
=
custom_callbacks
,
custom_metrics
=
custom_metrics
)
def
main
(
_
):
def
main
(
_
):
custom_main
(
custom_callbacks
=
None
)
custom_main
(
custom_callbacks
=
None
,
custom_metrics
=
None
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment