Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
40013f67
Commit
40013f67
authored
Aug 11, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 390235315
parent
d4bc6160
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
3 deletions
+14
-3
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+14
-3
No files found.
official/nlp/tasks/sentence_prediction.py
View file @
40013f67
...
...
@@ -13,10 +13,10 @@
# limitations under the License.
"""Sentence prediction (classification) task."""
import
dataclasses
from
typing
import
List
,
Union
,
Optional
from
absl
import
logging
import
dataclasses
import
numpy
as
np
import
orbit
from
scipy
import
stats
...
...
@@ -140,15 +140,26 @@ class SentencePredictionTask(base_task.Task):
del
training
if
self
.
task_config
.
model
.
num_classes
==
1
:
metrics
=
[
tf
.
keras
.
metrics
.
MeanSquaredError
()]
elif
self
.
task_config
.
model
.
num_classes
==
2
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
),
tf
.
keras
.
metrics
.
AUC
(
name
=
'auc'
,
curve
=
'PR'
),
]
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)
,
]
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
for
metric
in
metrics
:
metric
.
update_state
(
labels
[
self
.
label_field
],
model_outputs
)
if
metric
.
name
==
'auc'
:
# Convert the logit to probability and extract the probability of True..
metric
.
update_state
(
labels
[
self
.
label_field
],
tf
.
expand_dims
(
tf
.
nn
.
softmax
(
model_outputs
)[:,
1
],
axis
=
1
))
if
metric
.
name
==
'cls_accuracy'
:
metric
.
update_state
(
labels
[
self
.
label_field
],
model_outputs
)
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
compiled_metrics
.
update_state
(
labels
[
self
.
label_field
],
model_outputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment