"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "c895a7d13f74c66aee4c68aed75aaeddb7fbcf18"
Commit c813d85f authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Remove pack_inputs/unpack_inputs in bert

PiperOrigin-RevId: 291082425
parent df15a276
...@@ -72,20 +72,6 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -72,20 +72,6 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
'vocab_size': vocab_size, 'vocab_size': vocab_size,
} }
def __call__(self,
lm_output,
sentence_output=None,
lm_label_ids=None,
lm_label_weights=None,
sentence_labels=None,
**kwargs):
inputs = tf_utils.pack_inputs([
lm_output, sentence_output, lm_label_ids, lm_label_weights,
sentence_labels
])
return super(BertPretrainLossAndMetricLayer,
self).__call__(inputs, **kwargs)
def _add_metrics(self, lm_output, lm_labels, lm_label_weights, def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
lm_example_loss, sentence_output, sentence_labels, lm_example_loss, sentence_output, sentence_labels,
next_sentence_loss): next_sentence_loss):
...@@ -110,14 +96,10 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -110,14 +96,10 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
self.add_metric( self.add_metric(
next_sentence_loss, name='next_sentence_loss', aggregation='mean') next_sentence_loss, name='next_sentence_loss', aggregation='mean')
def call(self, inputs): def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights,
sentence_labels):
"""Implements call() for the layer.""" """Implements call() for the layer."""
unpacked_inputs = tf_utils.unpack_inputs(inputs) lm_label_weights = tf.keras.backend.cast(lm_label_weights, tf.float32)
lm_output = unpacked_inputs[0]
sentence_output = unpacked_inputs[1]
lm_label_ids = unpacked_inputs[2]
lm_label_weights = tf.keras.backend.cast(unpacked_inputs[3], tf.float32)
sentence_labels = unpacked_inputs[4]
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment