"...resnet50_tensorflow.git" did not exist on "f962ce389f73267b5b7549601a139edf04fd1b1e"
Commit 02f8d387 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 328989927
parent bc4ccd2f
...@@ -67,15 +67,24 @@ class DualEncoder(tf.keras.Model): ...@@ -67,15 +67,24 @@ class DualEncoder(tf.keras.Model):
self.network = network self.network = network
left_word_ids = tf.keras.layers.Input( if output == 'logits':
shape=(max_seq_length,), dtype=tf.int32, name='left_word_ids') left_word_ids = tf.keras.layers.Input(
left_mask = tf.keras.layers.Input( shape=(max_seq_length,), dtype=tf.int32, name='left_word_ids')
shape=(max_seq_length,), dtype=tf.int32, name='left_mask') left_mask = tf.keras.layers.Input(
left_type_ids = tf.keras.layers.Input( shape=(max_seq_length,), dtype=tf.int32, name='left_mask')
shape=(max_seq_length,), dtype=tf.int32, name='left_type_ids') left_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_type_ids')
else:
# Keep the consistant with legacy BERT hub module input names.
left_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
left_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
left_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
left_inputs = [left_word_ids, left_mask, left_type_ids] left_inputs = [left_word_ids, left_mask, left_type_ids]
_, left_encoded = network(left_inputs) left_sequence_output, left_encoded = network(left_inputs)
if normalize: if normalize:
left_encoded = tf.keras.layers.Lambda( left_encoded = tf.keras.layers.Lambda(
...@@ -108,13 +117,19 @@ class DualEncoder(tf.keras.Model): ...@@ -108,13 +117,19 @@ class DualEncoder(tf.keras.Model):
elif output == 'predictions': elif output == 'predictions':
inputs = [left_word_ids, left_mask, left_type_ids] inputs = [left_word_ids, left_mask, left_type_ids]
outputs = left_encoded
# To keep consistent with legacy BERT hub modules, the outputs are
# "pooled_output" and "sequence_output".
outputs = [left_encoded, left_sequence_output]
else: else:
raise ValueError('output type %s is not supported' % output) raise ValueError('output type %s is not supported' % output)
super(DualEncoder, self).__init__( super(DualEncoder, self).__init__(
inputs=inputs, outputs=outputs, **kwargs) inputs=inputs, outputs=outputs, **kwargs)
# Set _self_setattr_tracking to True so it can be exported with assets.
self._self_setattr_tracking = True
def get_config(self): def get_config(self):
return self._config return self._config
......
...@@ -64,13 +64,17 @@ class DualEncoderTest(keras_parameterized.TestCase): ...@@ -64,13 +64,17 @@ class DualEncoderTest(keras_parameterized.TestCase):
left_encoded, _ = outputs left_encoded, _ = outputs
elif output == 'predictions': elif output == 'predictions':
left_encoded = dual_encoder_model([ left_encoded, left_sequence_output = dual_encoder_model([
left_word_ids, left_mask, left_type_ids]) left_word_ids, left_mask, left_type_ids])
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_encoding_shape = [None, 768] expected_encoding_shape = [None, 768]
self.assertAllEqual(expected_encoding_shape, left_encoded.shape.as_list()) self.assertAllEqual(expected_encoding_shape, left_encoded.shape.as_list())
expected_sequence_shape = [None, sequence_length, 768]
self.assertAllEqual(expected_sequence_shape,
left_sequence_output.shape.as_list())
@parameterized.parameters((192, 'logits'), (768, 'predictions')) @parameterized.parameters((192, 'logits'), (768, 'predictions'))
def test_dual_encoder_tensor_call(self, hidden_size, output): def test_dual_encoder_tensor_call(self, hidden_size, output):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment