Commit 8c408bbe authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 311597242
parent 7cdb82e3
...@@ -54,9 +54,11 @@ class EncoderScaffold(tf.keras.Model): ...@@ -54,9 +54,11 @@ class EncoderScaffold(tf.keras.Model):
pooler_layer_initializer: The initializer for the classification pooler_layer_initializer: The initializer for the classification
layer. layer.
embedding_cls: The class or instance to use to embed the input data. This embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder. If embedding_cls is class or instance defines the inputs to this encoder and outputs
not set, a default embedding network (from the original BERT paper) will (1) embeddings tensor with shape [batch_size, seq_length, hidden_size] and
be created. (2) attention masking with tensor [batch_size, seq_length, seq_length].
If embedding_cls is not set, a default embedding network
(from the original BERT paper) will be created.
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
be instantiated. If embedding_cls is not set, a config dict must be be instantiated. If embedding_cls is not set, a config dict must be
passed to 'embedding_cfg' with the following values: passed to 'embedding_cfg' with the following values:
...@@ -121,7 +123,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -121,7 +123,7 @@ class EncoderScaffold(tf.keras.Model):
else: else:
self._embedding_network = embedding_cls self._embedding_network = embedding_cls
inputs = self._embedding_network.inputs inputs = self._embedding_network.inputs
embeddings, mask = self._embedding_network(inputs) embeddings, attention_mask = self._embedding_network(inputs)
else: else:
self._embedding_network = None self._embedding_network = None
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
...@@ -174,7 +176,8 @@ class EncoderScaffold(tf.keras.Model): ...@@ -174,7 +176,8 @@ class EncoderScaffold(tf.keras.Model):
tf.keras.layers.Dropout( tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings)) rate=embedding_cfg['dropout_rate'])(embeddings))
attention_mask = layers.SelfAttentionMask()([embeddings, mask]) attention_mask = layers.SelfAttentionMask()([embeddings, mask])
data = embeddings data = embeddings
layer_output_data = [] layer_output_data = []
......
...@@ -211,8 +211,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -211,8 +211,6 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02), tf.keras.initializers.TruncatedNormal(stddev=0.02),
} }
print(hidden_cfg)
print(embedding_cfg)
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3, num_hidden_instances=3,
...@@ -347,7 +345,9 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase): ...@@ -347,7 +345,9 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
name="word_embeddings") name="word_embeddings")
word_embeddings = embedding_layer(word_ids) word_embeddings = embedding_layer(word_ids)
network = tf.keras.Model([word_ids, mask], [word_embeddings, mask]) attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])
network = tf.keras.Model([word_ids, mask],
[word_embeddings, attention_mask])
hidden_cfg = { hidden_cfg = {
"num_attention_heads": "num_attention_heads":
...@@ -414,7 +414,9 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase): ...@@ -414,7 +414,9 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
name="word_embeddings") name="word_embeddings")
word_embeddings = embedding_layer(word_ids) word_embeddings = embedding_layer(word_ids)
network = tf.keras.Model([word_ids, mask], [word_embeddings, mask]) attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])
network = tf.keras.Model([word_ids, mask],
[word_embeddings, attention_mask])
hidden_cfg = { hidden_cfg = {
"num_attention_heads": "num_attention_heads":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment