Commit fd72a90e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Adds transformer_layers properties

PiperOrigin-RevId: 293081282
parent 458ea726
......@@ -143,6 +143,7 @@ class TransformerEncoder(network.Network):
if float_dtype == 'float16':
embeddings = tf.cast(embeddings, tf.float16)
self._transformer_layers = []
data = embeddings
attention_mask = layers.SelfAttentionMask()([data, mask])
for i in range(num_layers):
......@@ -155,6 +156,7 @@ class TransformerEncoder(network.Network):
kernel_initializer=initializer,
dtype=float_dtype,
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
data = layer([data, attention_mask])
first_token_tensor = (
......@@ -178,6 +180,11 @@ class TransformerEncoder(network.Network):
def get_config(self):
return self._config_dict
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment