Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
eee5ca5f
Commit
eee5ca5f
authored
Aug 28, 2020
by
A. Unique TensorFlower
Browse files
Fix issue with processing partial batches for CoLa dataset.
PiperOrigin-RevId: 329012198
parent
6d259f7f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
2 deletions
+31
-2
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+2
-1
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+29
-1
No files found.
official/nlp/tasks/sentence_prediction.py
View file @
eee5ca5f
...
...
@@ -159,7 +159,8 @@ class SentencePredictionTask(base_task.Task):
if
self
.
metric_type
==
'matthews_corrcoef'
:
logs
.
update
({
'sentence_prediction'
:
tf
.
expand_dims
(
tf
.
math
.
argmax
(
outputs
,
axis
=
1
),
axis
=
0
),
# Ensure one prediction along batch dimension.
tf
.
expand_dims
(
tf
.
math
.
argmax
(
outputs
,
axis
=
1
),
axis
=
1
),
'labels'
:
labels
,
})
...
...
official/nlp/tasks/sentence_prediction_test.py
View file @
eee5ca5f
...
...
@@ -86,7 +86,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
return
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
@
parameterized
.
named_parameters
(
(
"init_cls_pooler"
,
True
),
...
...
@@ -182,6 +182,34 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
aggregated
=
task
.
aggregate_logs
(
state
=
aggregated
,
step_outputs
=
outputs
)
self
.
assertIn
(
metric_type
,
task
.
reduce_aggregated_logs
(
aggregated
))
def
test_np_metrics_cola_partial_batch
(
self
):
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"train.tf_record"
)
num_examples
=
5
global_batch_size
=
8
seq_length
=
16
_create_fake_dataset
(
train_data_path
,
seq_length
=
seq_length
,
num_classes
=
2
,
num_examples
=
num_examples
)
train_data_config
=
(
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(
input_path
=
train_data_path
,
seq_length
=
seq_length
,
is_training
=
True
,
label_type
=
"int"
,
global_batch_size
=
global_batch_size
,
drop_remainder
=
False
,
include_example_id
=
True
))
config
=
sentence_prediction
.
SentencePredictionConfig
(
metric_type
=
"matthews_corrcoef"
,
model
=
self
.
get_model_config
(
2
),
train_data
=
train_data_config
)
outputs
=
self
.
_run_task
(
config
)
self
.
assertEqual
(
outputs
[
"sentence_prediction"
].
shape
.
as_list
(),
[
8
,
1
])
def
test_task_with_fit
(
self
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
self
.
get_model_config
(
2
),
train_data
=
self
.
_train_data_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment