Commit 7e370ba7 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 370159521
parent 76145d74
...@@ -257,11 +257,20 @@ class EncoderScaffold(tf.keras.Model): ...@@ -257,11 +257,20 @@ class EncoderScaffold(tf.keras.Model):
'pooler_layer_initializer': self._pooler_layer_initializer, 'pooler_layer_initializer': self._pooler_layer_initializer,
'embedding_cls': self._embedding_network, 'embedding_cls': self._embedding_network,
'embedding_cfg': self._embedding_cfg, 'embedding_cfg': self._embedding_cfg,
'hidden_cfg': self._hidden_cfg,
'layer_norm_before_pooling': self._layer_norm_before_pooling, 'layer_norm_before_pooling': self._layer_norm_before_pooling,
'return_all_layer_outputs': self._return_all_layer_outputs, 'return_all_layer_outputs': self._return_all_layer_outputs,
'dict_outputs': self._dict_outputs, 'dict_outputs': self._dict_outputs,
} }
if self._hidden_cfg:
config_dict['hidden_cfg'] = {}
for k, v in self._hidden_cfg.items():
# `self._hidden_cfg` may contain `class`, e.g., when `hidden_cfg` is
# `TransformerScaffold`, its `attention_cls` argument can be a `class`.
if inspect.isclass(v):
config_dict['hidden_cfg'][k] = tf.keras.utils.get_registered_name(v)
else:
config_dict['hidden_cfg'][k] = v
if inspect.isclass(self._hidden_cls): if inspect.isclass(self._hidden_cls):
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name( config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
self._hidden_cls) self._hidden_cls)
......
...@@ -31,9 +31,10 @@ from official.nlp.modeling.networks import encoder_scaffold ...@@ -31,9 +31,10 @@ from official.nlp.modeling.networks import encoder_scaffold
@tf.keras.utils.register_keras_serializable(package="TestOnly") @tf.keras.utils.register_keras_serializable(package="TestOnly")
class ValidatedTransformerLayer(layers.Transformer): class ValidatedTransformerLayer(layers.Transformer):
def __init__(self, call_list, **kwargs): def __init__(self, call_list, call_class=None, **kwargs):
super(ValidatedTransformerLayer, self).__init__(**kwargs) super(ValidatedTransformerLayer, self).__init__(**kwargs)
self.list = call_list self.list = call_list
self.call_class = call_class
def call(self, inputs): def call(self, inputs):
self.list.append(True) self.list.append(True)
...@@ -41,10 +42,16 @@ class ValidatedTransformerLayer(layers.Transformer): ...@@ -41,10 +42,16 @@ class ValidatedTransformerLayer(layers.Transformer):
def get_config(self): def get_config(self):
config = super(ValidatedTransformerLayer, self).get_config() config = super(ValidatedTransformerLayer, self).get_config()
config["call_list"] = [] config["call_list"] = self.list
config["call_class"] = tf.keras.utils.get_registered_name(self.call_class)
return config return config
@tf.keras.utils.register_keras_serializable(package="TestLayerOnly")
class TestLayer(tf.keras.layers.Layer):
pass
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover. # guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
...@@ -560,7 +567,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -560,7 +567,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
self.assertNotEmpty(call_list) self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.") self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_serialize_deserialize(self): @parameterized.parameters(True, False)
def test_serialize_deserialize(self, use_hidden_cls_instance):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
vocab_size = 57 vocab_size = 57
...@@ -591,21 +599,27 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -591,21 +599,27 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02), tf.keras.initializers.TruncatedNormal(stddev=0.02),
"call_list": "call_list":
call_list call_list,
"call_class":
TestLayer
} }
# Create a small EncoderScaffold for testing. This time, we pass an already- # Create a small EncoderScaffold for testing. This time, we pass an already-
# instantiated layer object. # instantiated layer object.
kwargs = dict(
xformer = ValidatedTransformerLayer(**hidden_cfg)
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3, num_hidden_instances=3,
pooled_output_dim=hidden_size, pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cls=xformer,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg)
if use_hidden_cls_instance:
xformer = ValidatedTransformerLayer(**hidden_cfg)
test_network = encoder_scaffold.EncoderScaffold(
hidden_cls=xformer, **kwargs)
else:
test_network = encoder_scaffold.EncoderScaffold(
hidden_cls=ValidatedTransformerLayer, hidden_cfg=hidden_cfg, **kwargs)
# Create another network object from the first object's config. # Create another network object from the first object's config.
new_network = encoder_scaffold.EncoderScaffold.from_config( new_network = encoder_scaffold.EncoderScaffold.from_config(
test_network.get_config()) test_network.get_config())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment