Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1357ce19
Commit
1357ce19
authored
Jun 22, 2020
by
Jeremiah Harmsen
Committed by
A. Unique TensorFlower
Jun 22, 2020
Browse files
Internal change
PiperOrigin-RevId: 317638173
parent
8aa44501
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
12 deletions
+16
-12
official/nlp/bert/bert_models.py
official/nlp/bert/bert_models.py
+16
-12
No files found.
official/nlp/bert/bert_models.py
View file @
1357ce19
...
@@ -25,7 +25,6 @@ import tensorflow_hub as hub
...
@@ -25,7 +25,6 @@ import tensorflow_hub as hub
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.nlp.albert
import
configs
as
albert_configs
from
official.nlp.albert
import
configs
as
albert_configs
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
configs
from
official.nlp.modeling
import
losses
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
...
@@ -67,22 +66,27 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
...
@@ -67,22 +66,27 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
next_sentence_loss
,
name
=
'next_sentence_loss'
,
aggregation
=
'mean'
)
next_sentence_loss
,
name
=
'next_sentence_loss'
,
aggregation
=
'mean'
)
def
call
(
self
,
def
call
(
self
,
lm_output
,
lm_output
_logits
,
sentence_output
,
sentence_output
_logits
,
lm_label_ids
,
lm_label_ids
,
lm_label_weights
,
lm_label_weights
,
sentence_labels
=
None
):
sentence_labels
=
None
):
"""Implements call() for the layer."""
"""Implements call() for the layer."""
lm_label_weights
=
tf
.
cast
(
lm_label_weights
,
tf
.
float32
)
lm_label_weights
=
tf
.
cast
(
lm_label_weights
,
tf
.
float32
)
lm_output
=
tf
.
cast
(
lm_output
,
tf
.
float32
)
lm_output
_logits
=
tf
.
cast
(
lm_output
_logits
,
tf
.
float32
)
mask_label_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
lm_prediction_losses
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
=
lm_label_ids
,
predictions
=
lm_output
,
weights
=
lm_label_weights
)
lm_label_ids
,
lm_output_logits
,
from_logits
=
True
)
lm_numerator_loss
=
tf
.
reduce_sum
(
lm_prediction_losses
*
lm_label_weights
)
lm_denominator_loss
=
tf
.
reduce_sum
(
lm_label_weights
)
mask_label_loss
=
tf
.
math
.
divide_no_nan
(
lm_numerator_loss
,
lm_denominator_loss
)
if
sentence_labels
is
not
None
:
if
sentence_labels
is
not
None
:
sentence_output
=
tf
.
cast
(
sentence_output
,
tf
.
float32
)
sentence_output_logits
=
tf
.
cast
(
sentence_output_logits
,
tf
.
float32
)
sentence_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
sentence_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
=
sentence_labels
,
predictions
=
sentence_output
)
sentence_labels
,
sentence_output_logits
,
from_logits
=
True
)
sentence_loss
=
tf
.
reduce_mean
(
sentence_loss
)
loss
=
mask_label_loss
+
sentence_loss
loss
=
mask_label_loss
+
sentence_loss
else
:
else
:
sentence_loss
=
None
sentence_loss
=
None
...
@@ -92,8 +96,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
...
@@ -92,8 +96,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
# TODO(hongkuny): Avoids the hack and switches add_loss.
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss
=
tf
.
fill
(
batch_shape
,
loss
)
final_loss
=
tf
.
fill
(
batch_shape
,
loss
)
self
.
_add_metrics
(
lm_output
,
lm_label_ids
,
lm_label_weights
,
self
.
_add_metrics
(
lm_output
_logits
,
lm_label_ids
,
lm_label_weights
,
mask_label_loss
,
sentence_output
,
sentence_labels
,
mask_label_loss
,
sentence_output
_logits
,
sentence_labels
,
sentence_loss
)
sentence_loss
)
return
final_loss
return
final_loss
...
@@ -228,7 +232,7 @@ def pretrain_model(bert_config,
...
@@ -228,7 +232,7 @@ def pretrain_model(bert_config,
activation
=
tf_utils
.
get_activation
(
bert_config
.
hidden_act
),
activation
=
tf_utils
.
get_activation
(
bert_config
.
hidden_act
),
num_token_predictions
=
max_predictions_per_seq
,
num_token_predictions
=
max_predictions_per_seq
,
initializer
=
initializer
,
initializer
=
initializer
,
output
=
'
prediction
s'
)
output
=
'
logit
s'
)
outputs
=
pretrainer_model
(
outputs
=
pretrainer_model
(
[
input_word_ids
,
input_mask
,
input_type_ids
,
masked_lm_positions
])
[
input_word_ids
,
input_mask
,
input_type_ids
,
masked_lm_positions
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment