Commit 79b6de8e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 408977858
parent d872cee2
...@@ -193,6 +193,7 @@ class ReuseEncoderConfig(hyperparams.Config): ...@@ -193,6 +193,7 @@ class ReuseEncoderConfig(hyperparams.Config):
reuse_attention: int = -1 reuse_attention: int = -1
use_relative_pe: bool = False use_relative_pe: bool = False
pe_max_seq_length: int = 512 pe_max_seq_length: int = 512
max_reuse_layer_idx: int = 6
@dataclasses.dataclass @dataclasses.dataclass
...@@ -519,7 +520,8 @@ def build_encoder(config: EncoderConfig, ...@@ -519,7 +520,8 @@ def build_encoder(config: EncoderConfig,
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
reuse_attention=encoder_cfg.reuse_attention, reuse_attention=encoder_cfg.reuse_attention,
use_relative_pe=encoder_cfg.use_relative_pe, use_relative_pe=encoder_cfg.use_relative_pe,
pe_max_seq_length=encoder_cfg.pe_max_seq_length) pe_max_seq_length=encoder_cfg.pe_max_seq_length,
max_reuse_layer_idx=encoder_cfg.max_reuse_layer_idx)
kwargs = dict( kwargs = dict(
embedding_cfg=embedding_cfg, embedding_cfg=embedding_cfg,
hidden_cls=layers.ReuseTransformer, hidden_cls=layers.ReuseTransformer,
......
...@@ -50,6 +50,7 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -50,6 +50,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
use_relative_pe=False, use_relative_pe=False,
pe_max_seq_length=512, pe_max_seq_length=512,
layer_idx=None, layer_idx=None,
max_reuse_layer_idx=None,
**kwargs): **kwargs):
"""Initializes `ReuseTransformer`. """Initializes `ReuseTransformer`.
...@@ -90,6 +91,8 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -90,6 +91,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
use_relative_pe: whether to use relative position bias. use_relative_pe: whether to use relative position bias.
pe_max_seq_length: used to set the size of the relative positin encodings. pe_max_seq_length: used to set the size of the relative positin encodings.
layer_idx: the idx of this layer. layer_idx: the idx of this layer.
max_reuse_layer_idx: layer idx (if passed) greater than this value will
not reuse attention scores from previous layers.
**kwargs: keyword arguments. **kwargs: keyword arguments.
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -118,9 +121,11 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -118,9 +121,11 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._use_relative_pe = use_relative_pe self._use_relative_pe = use_relative_pe
self._pe_max_seq_length = pe_max_seq_length self._pe_max_seq_length = pe_max_seq_length
self._layer_idx = layer_idx self._layer_idx = layer_idx
# Special handling for the first layer. self._max_reuse_layer_idx = max_reuse_layer_idx
# Consider taking a list to config each layer by layer index. # Overwrite for the first layer and layers greater than max_reuse_layer_idx.
if self._layer_idx is not None and self._layer_idx == 0: if self._layer_idx is not None and (
self._layer_idx == 0 or (self._max_reuse_layer_idx is not None and
self._max_reuse_layer_idx < self._layer_idx)):
self._reuse_attention = 0 self._reuse_attention = 0
if attention_initializer: if attention_initializer:
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
...@@ -233,6 +238,7 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -233,6 +238,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._reuse_attention, self._reuse_attention,
"use_relative_pe": self._use_relative_pe, "use_relative_pe": self._use_relative_pe,
"pe_max_seq_length": self._pe_max_seq_length, "pe_max_seq_length": self._pe_max_seq_length,
"max_reuse_layer_idx": self._max_reuse_layer_idx,
"kernel_initializer": "kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer), tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer": "bias_initializer":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment