Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b3e4fefd
Commit
b3e4fefd
authored
Jun 19, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jun 19, 2020
Browse files
Support numpy-based metrics through Orbit.
PiperOrigin-RevId: 317432167
parent
b708fd68
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
102 additions
and
10 deletions
+102
-10
official/core/base_task.py
official/core/base_task.py
+8
-0
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+57
-1
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+37
-9
No files found.
official/core/base_task.py
View file @
b3e4fefd
...
...
@@ -247,6 +247,14 @@ class Task(tf.Module):
"""Performs the forward step."""
return
model
(
inputs
,
training
=
False
)
def
aggregate_logs
(
self
,
state
,
step_logs
):
"""Optional aggregation over logs returned from a validation step."""
pass
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
"""Optional reduce of aggregated logs over validation steps."""
return
{}
_REGISTERED_TASK_CLS
=
{}
...
...
official/nlp/tasks/sentence_prediction.py
View file @
b3e4fefd
...
...
@@ -14,8 +14,11 @@
# limitations under the License.
# ==============================================================================
"""Sentence prediction (classification) task."""
import
logging
from
absl
import
logging
import
dataclasses
import
numpy
as
np
from
scipy
import
stats
from
sklearn
import
metrics
as
sklearn_metrics
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
...
...
@@ -33,6 +36,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
# be specified.
init_checkpoint
:
str
=
''
hub_module_url
:
str
=
''
metric_type
:
str
=
'accuracy'
network
:
bert
.
BertPretrainerConfig
=
bert
.
BertPretrainerConfig
(
num_masked_tokens
=
0
,
# No masked language modeling head.
cls_heads
=
[
...
...
@@ -59,6 +63,7 @@ class SentencePredictionTask(base_task.Task):
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
else
:
self
.
_hub_module
=
None
self
.
metric_type
=
params
.
metric_type
def
build_model
(
self
):
if
self
.
_hub_module
:
...
...
@@ -123,6 +128,57 @@ class SentencePredictionTask(base_task.Task):
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
compiled_metrics
.
update_state
(
labels
,
model_outputs
[
'sentence_prediction'
])
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
if
self
.
metric_type
==
'accuracy'
:
return
super
(
SentencePredictionTask
,
self
).
validation_step
(
inputs
,
model
,
metrics
)
features
,
labels
=
inputs
outputs
=
self
.
inference_step
(
features
,
model
)
loss
=
self
.
build_losses
(
labels
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
if
self
.
metric_type
==
'matthews_corrcoef'
:
return
{
self
.
loss
:
loss
,
'sentence_prediction'
:
tf
.
expand_dims
(
tf
.
math
.
argmax
(
outputs
[
'sentence_prediction'
],
axis
=
1
),
axis
=
0
),
'labels'
:
labels
,
}
if
self
.
metric_type
==
'pearson_spearman_corr'
:
return
{
self
.
loss
:
loss
,
'sentence_prediction'
:
outputs
[
'sentence_prediction'
],
'labels'
:
labels
,
}
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
if
state
is
None
:
state
=
{
'sentence_prediction'
:
[],
'labels'
:
[]}
state
[
'sentence_prediction'
].
append
(
np
.
concatenate
([
v
.
numpy
()
for
v
in
step_outputs
[
'sentence_prediction'
]],
axis
=
0
))
state
[
'labels'
].
append
(
np
.
concatenate
([
v
.
numpy
()
for
v
in
step_outputs
[
'labels'
]],
axis
=
0
))
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
if
self
.
metric_type
==
'matthews_corrcoef'
:
preds
=
np
.
concatenate
(
aggregated_logs
[
'sentence_prediction'
],
axis
=
0
)
labels
=
np
.
concatenate
(
aggregated_logs
[
'labels'
],
axis
=
0
)
return
{
self
.
metric_type
:
sklearn_metrics
.
matthews_corrcoef
(
preds
,
labels
)
}
if
self
.
metric_type
==
'pearson_spearman_corr'
:
preds
=
np
.
concatenate
(
aggregated_logs
[
'sentence_prediction'
],
axis
=
0
)
labels
=
np
.
concatenate
(
aggregated_logs
[
'labels'
],
axis
=
0
)
pearson_corr
=
stats
.
pearsonr
(
preds
,
labels
)[
0
]
spearman_corr
=
stats
.
spearmanr
(
preds
,
labels
)[
0
]
corr_metric
=
(
pearson_corr
+
spearman_corr
)
/
2
return
{
self
.
metric_type
:
corr_metric
}
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
...
...
official/nlp/tasks/sentence_prediction_test.py
View file @
b3e4fefd
...
...
@@ -16,6 +16,8 @@
"""Tests for official.nlp.tasks.sentence_prediction."""
import
functools
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.bert
import
configs
...
...
@@ -25,20 +27,24 @@ from official.nlp.configs import encoders
from
official.nlp.tasks
import
sentence_prediction
class
SentencePredictionTaskTest
(
tf
.
test
.
TestCase
):
class
SentencePredictionTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
SentencePredictionTaskTest
,
self
).
setUp
()
self
.
_network_config
=
bert
.
BertPretrainerConfig
(
self
.
_train_data_config
=
bert
.
SentencePredictionDataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
def
get_network_config
(
self
,
num_classes
):
return
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
num_masked_tokens
=
0
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"sentence_prediction"
)
inner_dim
=
10
,
num_classes
=
num_classes
,
name
=
"sentence_prediction"
)
])
self
.
_train_data_config
=
bert
.
SentencePredictionDataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
def
_run_task
(
self
,
config
):
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
...
...
@@ -57,7 +63,7 @@ class SentencePredictionTaskTest(tf.test.TestCase):
def
test_task
(
self
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
init_checkpoint
=
self
.
get_temp_dir
(),
network
=
self
.
_network_config
,
network
=
self
.
get
_network_config
(
2
)
,
train_data
=
self
.
_train_data_config
)
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
...
...
@@ -84,12 +90,34 @@ class SentencePredictionTaskTest(tf.test.TestCase):
ckpt
.
save
(
config
.
init_checkpoint
)
task
.
initialize
(
model
)
def
test_task_with_fit
(
self
):
@
parameterized
.
parameters
((
"matthews_corrcoef"
,
2
),
(
"pearson_spearman_corr"
,
1
))
def
test_np_metrics
(
self
,
metric_type
,
num_classes
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
network
=
self
.
_network_config
,
metric_type
=
metric_type
,
init_checkpoint
=
self
.
get_temp_dir
(),
network
=
self
.
get_network_config
(
num_classes
),
train_data
=
self
.
_train_data_config
)
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
strategy
=
tf
.
distribute
.
get_strategy
()
distributed_outputs
=
strategy
.
run
(
functools
.
partial
(
task
.
validation_step
,
model
=
model
),
args
=
(
next
(
iterator
),))
outputs
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
distributed_outputs
)
aggregated
=
task
.
aggregate_logs
(
step_outputs
=
outputs
)
aggregated
=
task
.
aggregate_logs
(
state
=
aggregated
,
step_outputs
=
outputs
)
self
.
assertIn
(
metric_type
,
task
.
reduce_aggregated_logs
(
aggregated
))
def
test_task_with_fit
(
self
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
network
=
self
.
get_network_config
(
2
),
train_data
=
self
.
_train_data_config
)
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
model
=
task
.
compile_model
(
model
,
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
),
...
...
@@ -126,7 +154,7 @@ class SentencePredictionTaskTest(tf.test.TestCase):
hub_module_url
=
self
.
_export_bert_tfhub
()
config
=
sentence_prediction
.
SentencePredictionConfig
(
hub_module_url
=
hub_module_url
,
network
=
self
.
_network_config
,
network
=
self
.
get
_network_config
(
2
)
,
train_data
=
self
.
_train_data_config
)
self
.
_run_task
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment