Commit dab5ff9d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 463166393
parent e096ecf4
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
# limitations under the License. # limitations under the License.
"""Keras-based attention layer with learnable per dim scaling.""" """Keras-based attention layer with learnable per dim scaling."""
import gin
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@gin.configurable
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class PerDimScaleAttention(tf.keras.layers.MultiHeadAttention): class PerDimScaleAttention(tf.keras.layers.MultiHeadAttention):
"""Learn scales for individual dims. """Learn scales for individual dims.
...@@ -27,6 +29,7 @@ class PerDimScaleAttention(tf.keras.layers.MultiHeadAttention): ...@@ -27,6 +29,7 @@ class PerDimScaleAttention(tf.keras.layers.MultiHeadAttention):
def _build_from_signature(self, query, value, key=None): def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key) # pytype: disable=attribute-error super()._build_from_signature(query=query, value=value, key=key) # pytype: disable=attribute-error
self._scale_dim = self._key_dim self._scale_dim = self._key_dim
with tf.init_scope():
self.per_dim_scale = self.add_weight( self.per_dim_scale = self.add_weight(
name='per_dim_scale', name='per_dim_scale',
shape=(self._scale_dim,), shape=(self._scale_dim,),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment