Commit 37e76715 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 475997523
parent 60568599
...@@ -196,9 +196,9 @@ class BertEncoderV2Test(keras_parameterized.TestCase): ...@@ -196,9 +196,9 @@ class BertEncoderV2Test(keras_parameterized.TestCase):
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types, type_vocab_size=num_types,
output_range=output_range,
dict_outputs=True, dict_outputs=True,
with_dense_inputs=True) with_dense_inputs=True,
output_range=output_range)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......
...@@ -116,6 +116,8 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -116,6 +116,8 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_dropout = kwargs.pop('attention_dropout_rate') attention_dropout = kwargs.pop('attention_dropout_rate')
super().__init__(**kwargs) super().__init__(**kwargs)
self._output_range = output_range
activation = tf.keras.activations.get(inner_activation) activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
...@@ -163,6 +165,7 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -163,6 +165,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
self._transformer_layers = [] self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask( self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask') name='self_attention_mask')
self._num_layers = num_layers
for i in range(num_layers): for i in range(num_layers):
layer = layers.TransformerEncoderBlock( layer = layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
...@@ -172,7 +175,6 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -172,7 +175,6 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
return_attention_scores=return_attention_scores, return_attention_scores=return_attention_scores,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=tf_utils.clone_initializer(initializer), kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
...@@ -257,8 +259,11 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -257,8 +259,11 @@ class BertEncoderV2(tf.keras.layers.Layer):
encoder_outputs = [] encoder_outputs = []
attention_outputs = [] attention_outputs = []
x = embeddings x = embeddings
for layer in self._transformer_layers: for i, layer in enumerate(self._transformer_layers):
x = layer([x, attention_mask]) transformer_output_range = None
if i == self._num_layers - 1:
transformer_output_range = self._output_range
x = layer([x, attention_mask], output_range=transformer_output_range)
if self._config['return_attention_scores']: if self._config['return_attention_scores']:
x, attention_scores = x x, attention_scores = x
attention_outputs.append(attention_scores) attention_outputs.append(attention_scores)
...@@ -475,10 +480,9 @@ class BertEncoder(tf.keras.Model): ...@@ -475,10 +480,9 @@ class BertEncoder(tf.keras.Model):
encoder_outputs = [] encoder_outputs = []
attention_outputs = [] attention_outputs = []
for i in range(num_layers): for i in range(num_layers):
if i == num_layers - 1 and output_range is not None: transformer_output_range = None
if i == num_layers - 1:
transformer_output_range = output_range transformer_output_range = output_range
else:
transformer_output_range = None
layer = layers.TransformerEncoderBlock( layer = layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
inner_dim=inner_dim, inner_dim=inner_dim,
...@@ -487,11 +491,11 @@ class BertEncoder(tf.keras.Model): ...@@ -487,11 +491,11 @@ class BertEncoder(tf.keras.Model):
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
return_attention_scores=return_attention_scores, return_attention_scores=return_attention_scores,
output_range=transformer_output_range,
kernel_initializer=tf_utils.clone_initializer(initializer), kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
transformer_layers.append(layer) transformer_layers.append(layer)
data = layer([data, attention_mask]) data = layer([data, attention_mask],
output_range=transformer_output_range)
if return_attention_scores: if return_attention_scores:
data, attention_scores = data data, attention_scores = data
attention_outputs.append(attention_scores) attention_outputs.append(attention_scores)
...@@ -600,3 +604,4 @@ class BertEncoder(tf.keras.Model): ...@@ -600,3 +604,4 @@ class BertEncoder(tf.keras.Model):
logging.warn(warn_string) logging.warn(warn_string)
return cls(**config) return cls(**config)
...@@ -545,8 +545,7 @@ class BertEncoderV2CompatibilityTest(tf.test.TestCase): ...@@ -545,8 +545,7 @@ class BertEncoderV2CompatibilityTest(tf.test.TestCase):
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types, type_vocab_size=num_types)
output_range=None)
word_id_data = np.random.randint( word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length)) vocab_size, size=(batch_size, sequence_length))
...@@ -605,8 +604,7 @@ class BertEncoderV2CompatibilityTest(tf.test.TestCase): ...@@ -605,8 +604,7 @@ class BertEncoderV2CompatibilityTest(tf.test.TestCase):
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types, type_vocab_size=num_types)
output_range=None)
word_id_data = np.random.randint( word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length)) vocab_size, size=(batch_size, sequence_length))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment