Commit 0b098c60 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 392901555
parent 5b5073d2
......@@ -74,9 +74,12 @@ class EncoderScaffold(tf.keras.Model):
standard pretraining.
num_hidden_instances: The number of times to instantiate and/or invoke the
hidden_cls.
hidden_cls: The class or instance to encode the input data. If `hidden_cls`
is not set, a KerasBERT transformer layer will be used as the encoder
class.
hidden_cls: Three types of input are supported: (1) class (2) instance
(3) list of classes or instances, to encode the input data. If
`hidden_cls` is not set, a KerasBERT transformer layer will be used as the
encoder class. If `hidden_cls` is a list of classes or instances, these
classes (instances) are sequentially instantiated (invoked) on top of
embedding layer. Mixing classes and instances in the list is allowed.
hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be
instantiated. If hidden_cls is not set, a config dict must be passed to
`hidden_cfg` with the following values:
......@@ -192,15 +195,26 @@ class EncoderScaffold(tf.keras.Model):
layer_output_data = []
hidden_layers = []
hidden_cfg = hidden_cfg if hidden_cfg else {}
if isinstance(hidden_cls, list) and len(hidden_cls) != num_hidden_instances:
raise RuntimeError(
('When input hidden_cls to EncoderScaffold %s is a list, it must '
'contain classes or instances with size specified by '
'num_hidden_instances, got %d vs %d.') % self.name, len(hidden_cls),
num_hidden_instances)
for i in range(num_hidden_instances):
if inspect.isclass(hidden_cls):
if isinstance(hidden_cls, list):
cur_hidden_cls = hidden_cls[i]
else:
cur_hidden_cls = hidden_cls
if inspect.isclass(cur_hidden_cls):
if hidden_cfg and 'attention_cfg' in hidden_cfg and (
layer_idx_as_attention_seed):
hidden_cfg = copy.deepcopy(hidden_cfg)
hidden_cfg['attention_cfg']['seed'] = i
layer = hidden_cls(**hidden_cfg)
layer = cur_hidden_cls(**hidden_cfg)
else:
layer = hidden_cls
layer = cur_hidden_cls
data = layer([data, attention_mask])
layer_output_data.append(data)
hidden_layers.append(layer)
......@@ -347,6 +361,15 @@ class EncoderScaffold(tf.keras.Model):
else:
return self._embedding_data
@property
def embedding_network(self):
if self._embedding_network is None:
raise RuntimeError(
('The EncoderScaffold %s does not have a reference '
'to the embedding network. This is required when you '
'pass a custom embedding network to the scaffold.') % self.name)
return self._embedding_network
@property
def hidden_layers(self):
"""List of hidden layers in the encoder."""
......
......@@ -605,6 +605,98 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_hidden_cls_list(self):
hidden_size = 32
sequence_length = 10
vocab_size = 57
embedding_network = Embeddings(vocab_size, hidden_size)
call_list = []
hidden_cfg = {
"num_attention_heads":
2,
"intermediate_size":
3072,
"intermediate_activation":
activations.gelu,
"dropout_rate":
0.1,
"attention_dropout_rate":
0.1,
"kernel_initializer":
tf.keras.initializers.TruncatedNormal(stddev=0.02),
"call_list":
call_list
}
mask_call_list = []
mask_cfg = {
"call_list": mask_call_list
}
# Create a small EncoderScaffold for testing. This time, we pass an already-
# instantiated layer object.
xformer = ValidatedTransformerLayer(**hidden_cfg)
xmask = ValidatedMaskLayer(**mask_cfg)
test_network_a = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cls=xformer,
mask_cls=xmask,
embedding_cls=embedding_network)
# Create a network b with same embedding and hidden layers as network a.
test_network_b = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
mask_cls=xmask,
embedding_cls=test_network_a.embedding_network,
hidden_cls=test_network_a.hidden_layers)
# Create a network c with same embedding but fewer hidden layers compared to
# network a and b.
hidden_layers = test_network_a.hidden_layers
hidden_layers.pop()
test_network_c = encoder_scaffold.EncoderScaffold(
num_hidden_instances=2,
pooled_output_dim=hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
mask_cls=xmask,
embedding_cls=test_network_a.embedding_network,
hidden_cls=hidden_layers)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
# Create model based off of network a:
data_a, pooled_a = test_network_a([word_ids, mask])
model_a = tf.keras.Model([word_ids, mask], [data_a, pooled_a])
# Create model based off of network b:
data_b, pooled_b = test_network_b([word_ids, mask])
model_b = tf.keras.Model([word_ids, mask], [data_b, pooled_b])
# Create model based off of network b:
data_c, pooled_c = test_network_c([word_ids, mask])
model_c = tf.keras.Model([word_ids, mask], [data_c, pooled_c])
batch_size = 3
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
output_a, _ = model_a.predict([word_id_data, mask_data])
output_b, _ = model_b.predict([word_id_data, mask_data])
output_c, _ = model_c.predict([word_id_data, mask_data])
# Outputs from model a and b should be the same since they share the same
# embedding and hidden layers.
self.assertAllEqual(output_a, output_b)
# Outputs from model a and c shouldn't be the same since they share the same
# embedding layer but different number of hidden layers.
self.assertNotAllEqual(output_a, output_c)
@parameterized.parameters(True, False)
def test_serialize_deserialize(self, use_hidden_cls_instance):
hidden_size = 32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment