Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
57c08e2f
Commit
57c08e2f
authored
Jun 15, 2020
by
A. Unique TensorFlower
Browse files
Make function argument names consistent in core.base_task.Task
PiperOrigin-RevId: 316513485
parent
ee3cc115
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
29 deletions
+28
-29
official/core/base_task.py
official/core/base_task.py
+11
-11
official/nlp/tasks/masked_lm.py
official/nlp/tasks/masked_lm.py
+12
-12
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+5
-6
No files found.
official/core/base_task.py
View file @
57c08e2f
...
@@ -114,18 +114,18 @@ class Task(tf.Module):
...
@@ -114,18 +114,18 @@ class Task(tf.Module):
"""
"""
pass
pass
def
build_losses
(
self
,
feature
s
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
def
build_losses
(
self
,
label
s
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
"""Standard interface to compute losses.
"""Standard interface to compute losses.
Args:
Args:
feature
s: optional
feature/
label
s
tensors.
label
s: optional label tensors.
model_outputs: a nested structure of output tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
Returns:
The total loss tensor.
The total loss tensor.
"""
"""
del
model_outputs
,
feature
s
del
model_outputs
,
label
s
if
aux_losses
is
None
:
if
aux_losses
is
None
:
losses
=
[
tf
.
constant
(
0.0
,
dtype
=
tf
.
float32
)]
losses
=
[
tf
.
constant
(
0.0
,
dtype
=
tf
.
float32
)]
...
@@ -139,29 +139,29 @@ class Task(tf.Module):
...
@@ -139,29 +139,29 @@ class Task(tf.Module):
del
training
del
training
return
[]
return
[]
def
process_metrics
(
self
,
metrics
,
labels
,
outputs
):
def
process_metrics
(
self
,
metrics
,
labels
,
model_
outputs
):
"""Process and update metrics. Called when using custom training loop API.
"""Process and update metrics. Called when using custom training loop API.
Args:
Args:
metrics: a nested structure of metrics objects.
metrics: a nested structure of metrics objects.
The return of function self.build_metrics.
The return of function self.build_metrics.
labels: a tensor or a nested structure of tensors.
labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors.
model_
outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
For example, output of the keras model built by self.build_model.
"""
"""
for
metric
in
metrics
:
for
metric
in
metrics
:
metric
.
update_state
(
labels
,
outputs
)
metric
.
update_state
(
labels
,
model_
outputs
)
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
outputs
):
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_
outputs
):
"""Process and update compiled_metrics. call when using compile/fit API.
"""Process and update compiled_metrics. call when using compile/fit API.
Args:
Args:
compiled_metrics: the compiled metrics (model.compiled_metrics).
compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors.
labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors.
model_
outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
For example, output of the keras model built by self.build_model.
"""
"""
compiled_metrics
.
update_state
(
labels
,
outputs
)
compiled_metrics
.
update_state
(
labels
,
model_
outputs
)
def
train_step
(
self
,
def
train_step
(
self
,
inputs
,
inputs
,
...
@@ -187,7 +187,7 @@ class Task(tf.Module):
...
@@ -187,7 +187,7 @@ class Task(tf.Module):
outputs
=
model
(
features
,
training
=
True
)
outputs
=
model
(
features
,
training
=
True
)
# Computes per-replica loss.
# Computes per-replica loss.
loss
=
self
.
build_losses
(
loss
=
self
.
build_losses
(
feature
s
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
label
s
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# optimizer.
scaled_loss
=
loss
/
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
scaled_loss
=
loss
/
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
...
@@ -231,7 +231,7 @@ class Task(tf.Module):
...
@@ -231,7 +231,7 @@ class Task(tf.Module):
features
,
labels
=
inputs
,
inputs
features
,
labels
=
inputs
,
inputs
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
self
.
inference_step
(
features
,
model
)
loss
=
self
.
build_losses
(
loss
=
self
.
build_losses
(
feature
s
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
label
s
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
...
...
official/nlp/tasks/masked_lm.py
View file @
57c08e2f
...
@@ -43,25 +43,25 @@ class MaskedLMTask(base_task.Task):
...
@@ -43,25 +43,25 @@ class MaskedLMTask(base_task.Task):
return
bert
.
instantiate_from_cfg
(
self
.
task_config
.
network
)
return
bert
.
instantiate_from_cfg
(
self
.
task_config
.
network
)
def
build_losses
(
self
,
def
build_losses
(
self
,
feature
s
,
label
s
,
model_outputs
,
model_outputs
,
metrics
,
metrics
,
aux_losses
=
None
)
->
tf
.
Tensor
:
aux_losses
=
None
)
->
tf
.
Tensor
:
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
lm_output
=
tf
.
nn
.
log_softmax
(
model_outputs
[
'lm_output'
],
axis
=-
1
)
lm_output
=
tf
.
nn
.
log_softmax
(
model_outputs
[
'lm_output'
],
axis
=-
1
)
mlm_loss
=
loss_lib
.
weighted_sparse_categorical_crossentropy_loss
(
mlm_loss
=
loss_lib
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
feature
s
[
'masked_lm_ids'
],
labels
=
label
s
[
'masked_lm_ids'
],
predictions
=
lm_output
,
predictions
=
lm_output
,
weights
=
feature
s
[
'masked_lm_weights'
])
weights
=
label
s
[
'masked_lm_weights'
])
metrics
[
'lm_example_loss'
].
update_state
(
mlm_loss
)
metrics
[
'lm_example_loss'
].
update_state
(
mlm_loss
)
if
'next_sentence_labels'
in
feature
s
:
if
'next_sentence_labels'
in
label
s
:
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
if
policy
.
name
==
'mixed_bfloat16'
:
# b/158514794: bf16 is not stable.
if
policy
.
name
==
'mixed_bfloat16'
:
# b/158514794: bf16 is not stable.
policy
=
tf
.
float32
policy
=
tf
.
float32
predictions
=
tf
.
keras
.
layers
.
Activation
(
predictions
=
tf
.
keras
.
layers
.
Activation
(
tf
.
nn
.
log_softmax
,
dtype
=
policy
)(
model_outputs
[
'next_sentence'
])
tf
.
nn
.
log_softmax
,
dtype
=
policy
)(
model_outputs
[
'next_sentence'
])
sentence_labels
=
feature
s
[
'next_sentence_labels'
]
sentence_labels
=
label
s
[
'next_sentence_labels'
]
sentence_loss
=
loss_lib
.
weighted_sparse_categorical_crossentropy_loss
(
sentence_loss
=
loss_lib
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
sentence_labels
,
labels
=
sentence_labels
,
predictions
=
predictions
)
predictions
=
predictions
)
...
@@ -112,15 +112,15 @@ class MaskedLMTask(base_task.Task):
...
@@ -112,15 +112,15 @@ class MaskedLMTask(base_task.Task):
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
=
'next_sentence_loss'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
=
'next_sentence_loss'
))
return
metrics
return
metrics
def
process_metrics
(
self
,
metrics
,
inputs
,
outputs
):
def
process_metrics
(
self
,
metrics
,
labels
,
model_
outputs
):
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
if
'masked_lm_accuracy'
in
metrics
:
if
'masked_lm_accuracy'
in
metrics
:
metrics
[
'masked_lm_accuracy'
].
update_state
(
input
s
[
'masked_lm_ids'
],
metrics
[
'masked_lm_accuracy'
].
update_state
(
label
s
[
'masked_lm_ids'
],
outputs
[
'lm_output'
],
model_
outputs
[
'lm_output'
],
input
s
[
'masked_lm_weights'
])
label
s
[
'masked_lm_weights'
])
if
'next_sentence_accuracy'
in
metrics
:
if
'next_sentence_accuracy'
in
metrics
:
metrics
[
'next_sentence_accuracy'
].
update_state
(
metrics
[
'next_sentence_accuracy'
].
update_state
(
input
s
[
'next_sentence_labels'
],
outputs
[
'next_sentence'
])
label
s
[
'next_sentence_labels'
],
model_
outputs
[
'next_sentence'
])
def
train_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
def
train_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
):
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
):
...
@@ -139,7 +139,7 @@ class MaskedLMTask(base_task.Task):
...
@@ -139,7 +139,7 @@ class MaskedLMTask(base_task.Task):
outputs
=
model
(
inputs
,
training
=
True
)
outputs
=
model
(
inputs
,
training
=
True
)
# Computes per-replica loss.
# Computes per-replica loss.
loss
=
self
.
build_losses
(
loss
=
self
.
build_losses
(
feature
s
=
inputs
,
label
s
=
inputs
,
model_outputs
=
outputs
,
model_outputs
=
outputs
,
metrics
=
metrics
,
metrics
=
metrics
,
aux_losses
=
model
.
losses
)
aux_losses
=
model
.
losses
)
...
@@ -166,7 +166,7 @@ class MaskedLMTask(base_task.Task):
...
@@ -166,7 +166,7 @@ class MaskedLMTask(base_task.Task):
"""
"""
outputs
=
self
.
inference_step
(
inputs
,
model
)
outputs
=
self
.
inference_step
(
inputs
,
model
)
loss
=
self
.
build_losses
(
loss
=
self
.
build_losses
(
feature
s
=
inputs
,
label
s
=
inputs
,
model_outputs
=
outputs
,
model_outputs
=
outputs
,
metrics
=
metrics
,
metrics
=
metrics
,
aux_losses
=
model
.
losses
)
aux_losses
=
model
.
losses
)
...
...
official/nlp/tasks/sentence_prediction.py
View file @
57c08e2f
...
@@ -79,8 +79,7 @@ class SentencePredictionTask(base_task.Task):
...
@@ -79,8 +79,7 @@ class SentencePredictionTask(base_task.Task):
else
:
else
:
return
bert
.
instantiate_from_cfg
(
self
.
task_config
.
network
)
return
bert
.
instantiate_from_cfg
(
self
.
task_config
.
network
)
def
build_losses
(
self
,
features
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
labels
=
features
loss
=
loss_lib
.
weighted_sparse_categorical_crossentropy_loss
(
loss
=
loss_lib
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
labels
,
labels
=
labels
,
predictions
=
tf
.
nn
.
log_softmax
(
model_outputs
[
'sentence_prediction'
],
predictions
=
tf
.
nn
.
log_softmax
(
model_outputs
[
'sentence_prediction'
],
...
@@ -118,12 +117,12 @@ class SentencePredictionTask(base_task.Task):
...
@@ -118,12 +117,12 @@ class SentencePredictionTask(base_task.Task):
]
]
return
metrics
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
outputs
):
def
process_metrics
(
self
,
metrics
,
labels
,
model_
outputs
):
for
metric
in
metrics
:
for
metric
in
metrics
:
metric
.
update_state
(
labels
,
outputs
[
'sentence_prediction'
])
metric
.
update_state
(
labels
,
model_
outputs
[
'sentence_prediction'
])
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
outputs
):
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_
outputs
):
compiled_metrics
.
update_state
(
labels
,
outputs
[
'sentence_prediction'
])
compiled_metrics
.
update_state
(
labels
,
model_
outputs
[
'sentence_prediction'
])
def
initialize
(
self
,
model
):
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment