Commit db6aca44 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Rename arguments of EncoderScaffold:

num_output_classes -> pooled_output_dim
classification_layer_initializer -> pooler_layer_initializer

PiperOrigin-RevId: 305550415
parent 91495d13
...@@ -120,9 +120,12 @@ def get_transformer_encoder(bert_config, ...@@ -120,9 +120,12 @@ def get_transformer_encoder(bert_config,
dropout_rate=bert_config.hidden_dropout_prob, dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob,
) )
kwargs = dict(embedding_cfg=embedding_cfg, hidden_cfg=hidden_cfg, kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_hidden_layers, num_hidden_instances=bert_config.num_hidden_layers,
num_output_classes=bert_config.hidden_size) pooled_output_dim=bert_config.hidden_size,
)
# Relies on gin configuration to define the Transformer encoder arguments. # Relies on gin configuration to define the Transformer encoder arguments.
return transformer_encoder_cls(**kwargs) return transformer_encoder_cls(**kwargs)
......
...@@ -51,8 +51,8 @@ class EncoderScaffold(network.Network): ...@@ -51,8 +51,8 @@ class EncoderScaffold(network.Network):
instantiated. instantiated.
Arguments: Arguments:
num_output_classes: The output size of the classification layer. pooled_output_dim: The dimension of pooled output.
classification_layer_initializer: The initializer for the classification pooler_layer_initializer: The initializer for the classification
layer. layer.
embedding_cls: The class or instance to use to embed the input data. This embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder. If embedding_cls is class or instance defines the inputs to this encoder. If embedding_cls is
...@@ -90,8 +90,8 @@ class EncoderScaffold(network.Network): ...@@ -90,8 +90,8 @@ class EncoderScaffold(network.Network):
def __init__( def __init__(
self, self,
num_output_classes, pooled_output_dim,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
embedding_cls=None, embedding_cls=None,
embedding_cfg=None, embedding_cfg=None,
...@@ -104,8 +104,8 @@ class EncoderScaffold(network.Network): ...@@ -104,8 +104,8 @@ class EncoderScaffold(network.Network):
self._hidden_cls = hidden_cls self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg self._hidden_cfg = hidden_cfg
self._num_hidden_instances = num_hidden_instances self._num_hidden_instances = num_hidden_instances
self._num_output_classes = num_output_classes self._pooled_output_dim = pooled_output_dim
self._classification_layer_initializer = classification_layer_initializer self._pooler_layer_initializer = pooler_layer_initializer
self._embedding_cls = embedding_cls self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data self._embedding_data = embedding_data
...@@ -184,9 +184,9 @@ class EncoderScaffold(network.Network): ...@@ -184,9 +184,9 @@ class EncoderScaffold(network.Network):
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data) tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data)
) )
cls_output = tf.keras.layers.Dense( cls_output = tf.keras.layers.Dense(
units=num_output_classes, units=pooled_output_dim,
activation='tanh', activation='tanh',
kernel_initializer=classification_layer_initializer, kernel_initializer=pooler_layer_initializer,
name='cls_transform')( name='cls_transform')(
first_token_tensor) first_token_tensor)
...@@ -197,10 +197,10 @@ class EncoderScaffold(network.Network): ...@@ -197,10 +197,10 @@ class EncoderScaffold(network.Network):
config_dict = { config_dict = {
'num_hidden_instances': 'num_hidden_instances':
self._num_hidden_instances, self._num_hidden_instances,
'num_output_classes': 'pooled_output_dim':
self._num_output_classes, self._pooled_output_dim,
'classification_layer_initializer': 'pooler_layer_initializer':
self._classification_layer_initializer, self._pooler_layer_initializer,
'embedding_cls': 'embedding_cls':
self._embedding_network, self._embedding_network,
'embedding_cfg': 'embedding_cfg':
......
...@@ -91,8 +91,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -91,8 +91,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=num_hidden_instances, num_hidden_instances=num_hidden_instances,
num_output_classes=hidden_size, pooled_output_dim=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cls=ValidatedTransformerLayer, hidden_cls=ValidatedTransformerLayer,
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
...@@ -147,8 +147,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -147,8 +147,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3, num_hidden_instances=3,
num_output_classes=hidden_size, pooled_output_dim=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg)
...@@ -201,8 +201,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -201,8 +201,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3, num_hidden_instances=3,
num_output_classes=hidden_size, pooled_output_dim=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg)
...@@ -254,8 +254,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -254,8 +254,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3, num_hidden_instances=3,
num_output_classes=hidden_size, pooled_output_dim=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg)
...@@ -293,8 +293,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -293,8 +293,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
network = encoder_scaffold.EncoderScaffold( network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3, num_hidden_instances=3,
num_output_classes=hidden_size, pooled_output_dim=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg)
...@@ -352,8 +352,8 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase): ...@@ -352,8 +352,8 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3, num_hidden_instances=3,
num_output_classes=hidden_size, pooled_output_dim=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cls=network, embedding_cls=network,
...@@ -419,8 +419,8 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase): ...@@ -419,8 +419,8 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3, num_hidden_instances=3,
num_output_classes=hidden_size, pooled_output_dim=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cls=network, embedding_cls=network,
...@@ -509,8 +509,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -509,8 +509,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3, num_hidden_instances=3,
num_output_classes=hidden_size, pooled_output_dim=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cls=xformer, hidden_cls=xformer,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg)
...@@ -579,8 +579,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -579,8 +579,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3, num_hidden_instances=3,
num_output_classes=hidden_size, pooled_output_dim=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cls=xformer, hidden_cls=xformer,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment