Commit 37e76715 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 475997523
parent 60568599
......@@ -196,9 +196,9 @@ class BertEncoderV2Test(keras_parameterized.TestCase):
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
output_range=output_range,
dict_outputs=True,
with_dense_inputs=True)
with_dense_inputs=True,
output_range=output_range)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......
......@@ -116,6 +116,8 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_dropout = kwargs.pop('attention_dropout_rate')
super().__init__(**kwargs)
self._output_range = output_range
activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer)
......@@ -163,6 +165,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask')
self._num_layers = num_layers
for i in range(num_layers):
layer = layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
......@@ -172,7 +175,6 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_dropout=attention_dropout,
norm_first=norm_first,
return_attention_scores=return_attention_scores,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
......@@ -257,8 +259,11 @@ class BertEncoderV2(tf.keras.layers.Layer):
encoder_outputs = []
attention_outputs = []
x = embeddings
for layer in self._transformer_layers:
x = layer([x, attention_mask])
for i, layer in enumerate(self._transformer_layers):
transformer_output_range = None
if i == self._num_layers - 1:
transformer_output_range = self._output_range
x = layer([x, attention_mask], output_range=transformer_output_range)
if self._config['return_attention_scores']:
x, attention_scores = x
attention_outputs.append(attention_scores)
......@@ -475,10 +480,9 @@ class BertEncoder(tf.keras.Model):
encoder_outputs = []
attention_outputs = []
for i in range(num_layers):
if i == num_layers - 1 and output_range is not None:
transformer_output_range = None
if i == num_layers - 1:
transformer_output_range = output_range
else:
transformer_output_range = None
layer = layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
inner_dim=inner_dim,
......@@ -487,11 +491,11 @@ class BertEncoder(tf.keras.Model):
attention_dropout=attention_dropout,
norm_first=norm_first,
return_attention_scores=return_attention_scores,
output_range=transformer_output_range,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i)
transformer_layers.append(layer)
data = layer([data, attention_mask])
data = layer([data, attention_mask],
output_range=transformer_output_range)
if return_attention_scores:
data, attention_scores = data
attention_outputs.append(attention_scores)
......@@ -600,3 +604,4 @@ class BertEncoder(tf.keras.Model):
logging.warn(warn_string)
return cls(**config)
......@@ -545,8 +545,7 @@ class BertEncoderV2CompatibilityTest(tf.test.TestCase):
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
output_range=None)
type_vocab_size=num_types)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
......@@ -605,8 +604,7 @@ class BertEncoderV2CompatibilityTest(tf.test.TestCase):
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
output_range=None)
type_vocab_size=num_types)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment