".github/git@developer.sourcefind.cn:change/sglang.git" did not exist on "428710c2ec057058a482e7723b0b46db3052660a"
Commit 29aff36a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 328989927
parent a94e2737
......@@ -67,15 +67,24 @@ class DualEncoder(tf.keras.Model):
self.network = network
left_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_word_ids')
left_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_mask')
left_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_type_ids')
if output == 'logits':
left_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_word_ids')
left_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_mask')
left_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_type_ids')
else:
# Keep the consistant with legacy BERT hub module input names.
left_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
left_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
left_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
left_inputs = [left_word_ids, left_mask, left_type_ids]
_, left_encoded = network(left_inputs)
left_sequence_output, left_encoded = network(left_inputs)
if normalize:
left_encoded = tf.keras.layers.Lambda(
......@@ -108,13 +117,19 @@ class DualEncoder(tf.keras.Model):
elif output == 'predictions':
inputs = [left_word_ids, left_mask, left_type_ids]
outputs = left_encoded
# To keep consistent with legacy BERT hub modules, the outputs are
# "pooled_output" and "sequence_output".
outputs = [left_encoded, left_sequence_output]
else:
raise ValueError('output type %s is not supported' % output)
super(DualEncoder, self).__init__(
inputs=inputs, outputs=outputs, **kwargs)
# Set _self_setattr_tracking to True so it can be exported with assets.
self._self_setattr_tracking = True
def get_config(self):
return self._config
......
......@@ -64,13 +64,17 @@ class DualEncoderTest(keras_parameterized.TestCase):
left_encoded, _ = outputs
elif output == 'predictions':
left_encoded = dual_encoder_model([
left_encoded, left_sequence_output = dual_encoder_model([
left_word_ids, left_mask, left_type_ids])
# Validate that the outputs are of the expected shape.
expected_encoding_shape = [None, 768]
self.assertAllEqual(expected_encoding_shape, left_encoded.shape.as_list())
expected_sequence_shape = [None, sequence_length, 768]
self.assertAllEqual(expected_sequence_shape,
left_sequence_output.shape.as_list())
@parameterized.parameters((192, 'logits'), (768, 'predictions'))
def test_dual_encoder_tensor_call(self, hidden_size, output):
"""Validate that the Keras object can be invoked."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment