Commit 2f9f2479 authored by George Karpenkov's avatar George Karpenkov Committed by A. Unique TensorFlower
Browse files

Test the Transformer layer with dynamic sequence

PiperOrigin-RevId: 297412889
parent ce83a9db
...@@ -166,6 +166,24 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -166,6 +166,24 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input. # The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list()) self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())
def test_dynamic_layer_sequence(self):
test_layer = transformer.Transformer(
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
width = 30
input_tensor = tf.keras.Input(shape=(None, width))
output_tensor = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output_tensor)
input_length = 17
input_data = np.ones((1, input_length, width))
output_data = model.predict(input_data)
self.assertAllEqual([1, input_length, width], output_data.shape)
if __name__ == '__main__': if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment