Commit e6ffa057 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 332092620
parent 4ab0b381
...@@ -115,18 +115,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -115,18 +115,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
raise ValueError( raise ValueError(
"TransformerScaffold expects a three-dimensional input of " "TransformerScaffold expects a three-dimensional input of "
"shape [batch, sequence, width].") "shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape hidden_size = input_tensor_shape[-1]
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError("When passing a mask tensor to TransformerLayer, the "
"mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0: if hidden_size % self._num_heads != 0:
raise ValueError( raise ValueError(
"The input size (%d) is not a multiple of the number of attention " "The input size (%d) is not a multiple of the number of attention "
......
...@@ -182,30 +182,6 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -182,30 +182,6 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertNotEmpty(call_list) self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.") self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_layer_creation_with_incorrect_mask_fails(self):
sequence_length = 21
width = 80
call_list = []
attention_layer_cfg = {
'num_heads': 10,
'key_dim': 8,
'call_list': call_list,
}
test_layer = transformer_scaffold.TransformerScaffold(
attention_cls=ValidatedAttentionLayer,
attention_cfg=attention_layer_cfg,
num_attention_heads=10,
intermediate_size=2048,
intermediate_activation='relu')
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length - 3))
with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
_ = test_layer([data_tensor, mask_tensor])
def test_layer_invocation(self): def test_layer_invocation(self):
sequence_length = 21 sequence_length = 21
width = 80 width = 80
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment