Commit 79b6de8e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 408977858
parent d872cee2
......@@ -193,6 +193,7 @@ class ReuseEncoderConfig(hyperparams.Config):
reuse_attention: int = -1
use_relative_pe: bool = False
pe_max_seq_length: int = 512
max_reuse_layer_idx: int = 6
@dataclasses.dataclass
......@@ -519,7 +520,8 @@ def build_encoder(config: EncoderConfig,
stddev=encoder_cfg.initializer_range),
reuse_attention=encoder_cfg.reuse_attention,
use_relative_pe=encoder_cfg.use_relative_pe,
pe_max_seq_length=encoder_cfg.pe_max_seq_length)
pe_max_seq_length=encoder_cfg.pe_max_seq_length,
max_reuse_layer_idx=encoder_cfg.max_reuse_layer_idx)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cls=layers.ReuseTransformer,
......
......@@ -50,6 +50,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
use_relative_pe=False,
pe_max_seq_length=512,
layer_idx=None,
max_reuse_layer_idx=None,
**kwargs):
"""Initializes `ReuseTransformer`.
......@@ -90,6 +91,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
use_relative_pe: whether to use relative position bias.
pe_max_seq_length: used to set the size of the relative positin encodings.
layer_idx: the idx of this layer.
max_reuse_layer_idx: layer idx (if passed) greater than this value will
not reuse attention scores from previous layers.
**kwargs: keyword arguments.
"""
super().__init__(**kwargs)
......@@ -118,9 +121,11 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._use_relative_pe = use_relative_pe
self._pe_max_seq_length = pe_max_seq_length
self._layer_idx = layer_idx
# Special handling for the first layer.
# Consider taking a list to config each layer by layer index.
if self._layer_idx is not None and self._layer_idx == 0:
self._max_reuse_layer_idx = max_reuse_layer_idx
# Overwrite for the first layer and layers greater than max_reuse_layer_idx.
if self._layer_idx is not None and (
self._layer_idx == 0 or (self._max_reuse_layer_idx is not None and
self._max_reuse_layer_idx < self._layer_idx)):
self._reuse_attention = 0
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
......@@ -233,6 +238,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._reuse_attention,
"use_relative_pe": self._use_relative_pe,
"pe_max_seq_length": self._pe_max_seq_length,
"max_reuse_layer_idx": self._max_reuse_layer_idx,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment