Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d5c5cb41
Commit
d5c5cb41
authored
Jul 13, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 321100219
parent
ff27fb50
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
9 deletions
+75
-9
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+6
-1
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+33
-8
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+36
-0
No files found.
official/nlp/data/sentence_prediction_dataloader.py
View file @
d5c5cb41
...
...
@@ -23,6 +23,9 @@ from official.modeling.hyperparams import config_definitions as cfg
from
official.nlp.data
import
data_loader_factory
LABEL_TYPES_MAP
=
{
'int'
:
tf
.
int64
,
'float'
:
tf
.
float32
}
@
dataclasses
.
dataclass
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
...
...
@@ -30,6 +33,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
global_batch_size
:
int
=
32
is_training
:
bool
=
True
seq_length
:
int
=
128
label_type
:
str
=
'int'
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
...
...
@@ -42,11 +46,12 @@ class SentencePredictionDataLoader:
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
label_type
=
LABEL_TYPES_MAP
[
self
.
_params
.
label_type
]
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
label_type
),
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
...
...
official/nlp/tasks/sentence_prediction.py
View file @
d5c5cb41
...
...
@@ -31,6 +31,10 @@ from official.nlp.modeling import models
from
official.nlp.tasks
import
utils
METRIC_TYPES
=
frozenset
(
[
'accuracy'
,
'matthews_corrcoef'
,
'pearson_spearman_corr'
])
@
dataclasses
.
dataclass
class
ModelConfig
(
base_config
.
Config
):
"""A classifier/regressor configuration."""
...
...
@@ -68,6 +72,9 @@ class SentencePredictionTask(base_task.Task):
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
else
:
self
.
_hub_module
=
None
if
params
.
metric_type
not
in
METRIC_TYPES
:
raise
ValueError
(
'Invalid metric_type: {}'
.
format
(
params
.
metric_type
))
self
.
metric_type
=
params
.
metric_type
def
build_model
(
self
):
...
...
@@ -77,7 +84,7 @@ class SentencePredictionTask(base_task.Task):
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
self
.
task_config
.
model
.
encoder
)
# Currently, we only support
s
bert-style sentence prediction finetuning.
# Currently, we only support bert-style sentence prediction finetuning.
return
models
.
BertClassifier
(
network
=
encoder_network
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
...
...
@@ -86,8 +93,11 @@ class SentencePredictionTask(base_task.Task):
use_encoder_pooler
=
self
.
task_config
.
model
.
use_encoder_pooler
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
tf
.
cast
(
model_outputs
,
tf
.
float32
),
from_logits
=
True
)
if
self
.
task_config
.
model
.
num_classes
==
1
:
loss
=
tf
.
keras
.
losses
.
mean_squared_error
(
labels
,
model_outputs
)
else
:
loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
tf
.
cast
(
model_outputs
,
tf
.
float32
),
from_logits
=
True
)
if
aux_losses
:
loss
+=
tf
.
add_n
(
aux_losses
)
...
...
@@ -103,8 +113,12 @@ class SentencePredictionTask(base_task.Task):
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
y
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
)
return
(
x
,
y
)
if
self
.
task_config
.
model
.
num_classes
==
1
:
y
=
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
)
else
:
y
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
)
return
x
,
y
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
...
...
@@ -116,7 +130,11 @@ class SentencePredictionTask(base_task.Task):
def
build_metrics
(
self
,
training
=
None
):
del
training
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)]
if
self
.
task_config
.
model
.
num_classes
==
1
:
metrics
=
[
tf
.
keras
.
metrics
.
MeanSquaredError
()]
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)]
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
...
...
@@ -154,6 +172,7 @@ class SentencePredictionTask(base_task.Task):
return
None
if
state
is
None
:
state
=
{
'sentence_prediction'
:
[],
'labels'
:
[]}
# TODO(b/160712818): Add support for concatenating partial batches.
state
[
'sentence_prediction'
].
append
(
np
.
concatenate
([
v
.
numpy
()
for
v
in
step_outputs
[
'sentence_prediction'
]],
axis
=
0
))
...
...
@@ -162,15 +181,21 @@ class SentencePredictionTask(base_task.Task):
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
if
self
.
metric_type
==
'matthews_corrcoef'
:
if
self
.
metric_type
==
'accuracy'
:
return
None
elif
self
.
metric_type
==
'matthews_corrcoef'
:
preds
=
np
.
concatenate
(
aggregated_logs
[
'sentence_prediction'
],
axis
=
0
)
preds
=
np
.
reshape
(
preds
,
-
1
)
labels
=
np
.
concatenate
(
aggregated_logs
[
'labels'
],
axis
=
0
)
labels
=
np
.
reshape
(
labels
,
-
1
)
return
{
self
.
metric_type
:
sklearn_metrics
.
matthews_corrcoef
(
preds
,
labels
)
}
if
self
.
metric_type
==
'pearson_spearman_corr'
:
el
if
self
.
metric_type
==
'pearson_spearman_corr'
:
preds
=
np
.
concatenate
(
aggregated_logs
[
'sentence_prediction'
],
axis
=
0
)
preds
=
np
.
reshape
(
preds
,
-
1
)
labels
=
np
.
concatenate
(
aggregated_logs
[
'labels'
],
axis
=
0
)
labels
=
np
.
reshape
(
labels
,
-
1
)
pearson_corr
=
stats
.
pearsonr
(
preds
,
labels
)[
0
]
spearman_corr
=
stats
.
spearmanr
(
preds
,
labels
)[
0
]
corr_metric
=
(
pearson_corr
+
spearman_corr
)
/
2
...
...
official/nlp/tasks/sentence_prediction_test.py
View file @
d5c5cb41
...
...
@@ -85,6 +85,42 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
ckpt
.
save
(
config
.
init_checkpoint
)
task
.
initialize
(
model
)
@
parameterized
.
named_parameters
(
{
"testcase_name"
:
"regression"
,
"num_classes"
:
1
,
},
{
"testcase_name"
:
"classification"
,
"num_classes"
:
2
,
},
)
def
test_metrics_and_losses
(
self
,
num_classes
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
init_checkpoint
=
self
.
get_temp_dir
(),
model
=
self
.
get_model_config
(
num_classes
),
train_data
=
self
.
_train_data_config
)
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
if
num_classes
==
1
:
self
.
assertIsInstance
(
metrics
[
0
],
tf
.
keras
.
metrics
.
MeanSquaredError
)
else
:
self
.
assertIsInstance
(
metrics
[
0
],
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
)
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
loss
=
logs
[
"loss"
].
numpy
()
if
num_classes
==
1
:
self
.
assertAlmostEqual
(
loss
,
42.77483
,
places
=
3
)
else
:
self
.
assertAlmostEqual
(
loss
,
3.57627e-6
,
places
=
3
)
@
parameterized
.
parameters
((
"matthews_corrcoef"
,
2
),
(
"pearson_spearman_corr"
,
1
))
def
test_np_metrics
(
self
,
metric_type
,
num_classes
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment