Commit db6aca44 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Rename arguments of EncoderScaffold:

num_output_classes -> pooled_output_dim
classification_layer_initializer -> pooler_layer_initializer

PiperOrigin-RevId: 305550415
parent 91495d13
......@@ -120,9 +120,12 @@ def get_transformer_encoder(bert_config,
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
)
kwargs = dict(embedding_cfg=embedding_cfg, hidden_cfg=hidden_cfg,
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_hidden_layers,
num_output_classes=bert_config.hidden_size)
pooled_output_dim=bert_config.hidden_size,
)
# Relies on gin configuration to define the Transformer encoder arguments.
return transformer_encoder_cls(**kwargs)
......
......@@ -51,8 +51,8 @@ class EncoderScaffold(network.Network):
instantiated.
Arguments:
num_output_classes: The output size of the classification layer.
classification_layer_initializer: The initializer for the classification
pooled_output_dim: The dimension of pooled output.
pooler_layer_initializer: The initializer for the classification
layer.
embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder. If embedding_cls is
......@@ -90,8 +90,8 @@ class EncoderScaffold(network.Network):
def __init__(
self,
num_output_classes,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
pooled_output_dim,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
embedding_cls=None,
embedding_cfg=None,
......@@ -104,8 +104,8 @@ class EncoderScaffold(network.Network):
self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg
self._num_hidden_instances = num_hidden_instances
self._num_output_classes = num_output_classes
self._classification_layer_initializer = classification_layer_initializer
self._pooled_output_dim = pooled_output_dim
self._pooler_layer_initializer = pooler_layer_initializer
self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data
......@@ -184,9 +184,9 @@ class EncoderScaffold(network.Network):
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data)
)
cls_output = tf.keras.layers.Dense(
units=num_output_classes,
units=pooled_output_dim,
activation='tanh',
kernel_initializer=classification_layer_initializer,
kernel_initializer=pooler_layer_initializer,
name='cls_transform')(
first_token_tensor)
......@@ -197,10 +197,10 @@ class EncoderScaffold(network.Network):
config_dict = {
'num_hidden_instances':
self._num_hidden_instances,
'num_output_classes':
self._num_output_classes,
'classification_layer_initializer':
self._classification_layer_initializer,
'pooled_output_dim':
self._pooled_output_dim,
'pooler_layer_initializer':
self._pooler_layer_initializer,
'embedding_cls':
self._embedding_network,
'embedding_cfg':
......
......@@ -91,8 +91,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=num_hidden_instances,
num_output_classes=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cls=ValidatedTransformerLayer,
hidden_cfg=hidden_cfg,
......@@ -147,8 +147,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
num_output_classes=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg)
......@@ -201,8 +201,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
num_output_classes=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg)
......@@ -254,8 +254,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
num_output_classes=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg)
......@@ -293,8 +293,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing.
network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
num_output_classes=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg)
......@@ -352,8 +352,8 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
num_output_classes=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cfg=hidden_cfg,
embedding_cls=network,
......@@ -419,8 +419,8 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
# Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
num_output_classes=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cfg=hidden_cfg,
embedding_cls=network,
......@@ -509,8 +509,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
num_output_classes=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cls=xformer,
embedding_cfg=embedding_cfg)
......@@ -579,8 +579,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
num_output_classes=hidden_size,
classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cls=xformer,
embedding_cfg=embedding_cfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment