Unverified Commit c3c2386c authored by xinliupitt's avatar xinliupitt Committed by GitHub
Browse files

Add relative positional embedding to KerasBERT (#8617)

* root dir

* zone updated

* print mask

* preview emb

* tf print

* input only

* emb

* tf print

* emb after mask

* masked_softmax print

* print scores

* multi folder

* first pos emb

* check input shape

* add test temp

* import math

* two classes

* prints

* all get_pos replace

* make time scale private

* pos emb comments

* print input

* embedding_inputs

* tf shape

* dimention list

* tf_util

* print tf_util

* concise

* transformer pos change to layer

* keep length var

* length as input

* None as input

* print time signal

* print time signal

* remove print

* test input shape

* double check shape

* double check shape

* double check shape

* more test

* shape check

* shape check

* print 97 info

* print 97 info new

* test if sam

* assert same

* remove assert

* tf print same

* tf print diff

* output example

* output example

* output example

* formal test

* formal test length

* raise valurerror

* test valurerror

* double check

* comments

* remove prints

* rename relative

* delet naive test

* delete docs in xinliu branch

* code reformat

* import order

* indentation fix

* more files

* adjust char number

* disable not callable

* comment to length

* error of length unequal to input_shape

* root dir

* zone updated

* print mask

* preview emb

* tf print

* input only

* emb

* tf print

* emb after mask

* masked_softmax print

* print scores

* multi folder

* remove docs

* remove prints

* root dir

* zone updated

* print mask

* preview emb

* tf print

* input only

* emb

* tf print

* emb after mask

* masked_softmax print

* print scores

* multi folder

* remove docs

* apply revised 3 files

* rm prints
parent cd3c6c57
......@@ -19,6 +19,8 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import math
import tensorflow as tf
from official.modeling import tf_utils
......@@ -118,3 +120,81 @@ class PositionEmbedding(tf.keras.layers.Layer):
position_embeddings = self._position_embeddings
return tf.broadcast_to(position_embeddings, input_shape)
@tf.keras.utils.register_keras_serializable(package="Text")
class RelativePositionEmbedding(tf.keras.layers.Layer):
"""Creates a positional embedding.
This layer calculates the position encoding as a mix of sine and cosine
functions with geometrically increasing wavelengths. Defined and formulized in
"Attention is All You Need", section 3.5.
(https://arxiv.org/abs/1706.03762).
Arguments:
hidden_size: Size of the hidden layer.
min_timescale: Minimum scale that will be applied at each position
max_timescale: Maximum scale that will be applied at each position.
length: Number of positions. Should be specified if `inputs` is None at
`call(self, inputs)`
"""
def __init__(self,
hidden_size,
min_timescale=1.0,
max_timescale=1.0e4,
length=None,
**kwargs):
# We need to have a default dtype of float32, since the inputs (which Keras
# usually uses to infer the dtype) will always be int32.
# We compute the positional encoding in float32 even if the model uses
# float16, as many of the ops used, like log and exp, are numerically
# unstable in float16.
if "dtype" not in kwargs:
kwargs["dtype"] = "float32"
super(RelativePositionEmbedding, self).__init__(**kwargs)
self._hidden_size = hidden_size
self._min_timescale = min_timescale
self._max_timescale = max_timescale
self._length = length
def get_config(self):
config = {
"hidden_size": self._hidden_size,
"min_timescale": self._min_timescale,
"max_timescale": self._max_timescale,
"length": self._length,
}
base_config = super(RelativePositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
"""Implements build() for the layer."""
super(RelativePositionEmbedding, self).build(input_shape)
def call(self, inputs):
"""Implements call() for the layer."""
length = self._length
if inputs is None and length is None:
raise ValueError(
"If inputs is None, `length` must be set in "
"RelativePositionEmbedding().")
if inputs is not None:
input_shape = tf_utils.get_shape_list(inputs)
if length is not None and length != input_shape[1]:
raise ValueError(
"If inputs is not None, `length` must equal to input_shape[1]."
)
length = input_shape[1]
position = tf.cast(tf.range(length), tf.float32)
num_timescales = self._hidden_size // 2
min_timescale, max_timescale = self._min_timescale, self._max_timescale
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(tf.cast(num_timescales, tf.float32) - 1))
inv_timescales = min_timescale * tf.exp(
tf.cast(tf.range(num_timescales), tf.float32) *
-log_timescale_increment)
scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales,
0)
position_embeddings = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)],
axis=1)
return position_embeddings
......@@ -36,7 +36,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
sequence_length = 21
width = 30
input_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(input_tensor)
output_tensor = test_layer(input_tensor) # pylint: disable=not-callable
# When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch.
......@@ -51,7 +51,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
sequence_length = 21
width = 30
input_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(input_tensor)
output_tensor = test_layer(input_tensor) # pylint: disable=not-callable
# When using static positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions save batch.
......@@ -67,7 +67,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
width = 30
input_tensor = tf.keras.Input(shape=(None, width))
output_tensor = test_layer(input_tensor)
output_tensor = test_layer(input_tensor) # pylint: disable=not-callable
# When using dynamic positional embedding shapes, the output is expected
# to be the same as the input shape in all dimensions - but may be None if
......@@ -82,7 +82,7 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
width = 30
input_tensor = tf.keras.Input(shape=(None, width))
output_tensor = test_layer(input_tensor)
output_tensor = test_layer(input_tensor) # pylint: disable=not-callable
model = tf.keras.Model(input_tensor, output_tensor)
......@@ -98,6 +98,34 @@ class PositionEmbeddingLayerTest(keras_parameterized.TestCase):
self.assertAllEqual([1, input_length, width], output_data.shape)
def test_relative_tensor_input(self):
hidden_size = 8
test_layer = position_embedding.RelativePositionEmbedding(
hidden_size=hidden_size)
# create a 3-dimensional input for test_layer to infer length as 1.
input_tensor = tf.constant([[[0] * hidden_size]])
output_tensor = test_layer(input_tensor) # pylint: disable=not-callable
# expected output is the theoretical result of the input based on
# sine cosine relative position embedding formula.
expected_output_tensor = tf.constant([[0, 0, 0, 0, 1, 1, 1, 1]])
self.assertAllEqual(output_tensor, expected_output_tensor)
def test_relative_length_input(self):
hidden_size = 8
# When we do not have tensor as input, we explicitly specify length
# value when initializing test_layer.
test_layer = position_embedding.RelativePositionEmbedding(
hidden_size=hidden_size, length=1)
input_tensor = None
output_tensor = test_layer(input_tensor) # pylint: disable=not-callable
# expected output is the theoretical result of the input based on
# sine cosine relative position embedding formula.
expected_output_tensor = tf.constant([[0, 0, 0, 0, 1, 1, 1, 1]])
self.assertAllEqual(output_tensor, expected_output_tensor)
if __name__ == "__main__":
tf.test.main()
......@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.nlp.modeling.layers import position_embedding
from official.nlp.transformer import attention_layer
from official.nlp.transformer import beam_search
from official.nlp.transformer import embedding_layer
......@@ -170,9 +171,9 @@ class Transformer(tf.keras.Model):
attention_bias = tf.cast(attention_bias, self.params["dtype"])
with tf.name_scope("add_pos_encoding"):
length = tf.shape(embedded_inputs)[1]
pos_encoding = model_utils.get_position_encoding(
length, self.params["hidden_size"])
pos_layer = position_embedding.RelativePositionEmbedding(
hidden_size=self.params["hidden_size"])
pos_encoding = pos_layer(embedded_inputs)
pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
encoder_inputs = embedded_inputs + pos_encoding
......@@ -209,8 +210,9 @@ class Transformer(tf.keras.Model):
[[0, 0], [1, 0], [0, 0]])[:, :-1, :]
with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1]
pos_encoding = model_utils.get_position_encoding(
length, self.params["hidden_size"])
pos_layer = position_embedding.RelativePositionEmbedding(
hidden_size=self.params["hidden_size"])
pos_encoding = pos_layer(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
decoder_inputs += pos_encoding
if training:
......@@ -233,8 +235,10 @@ class Transformer(tf.keras.Model):
def _get_symbols_to_logits_fn(self, max_decode_length, training):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = model_utils.get_position_encoding(
max_decode_length + 1, self.params["hidden_size"])
pos_layer = position_embedding.RelativePositionEmbedding(
hidden_size=self.params["hidden_size"],
length=max_decode_length + 1)
timing_signal = pos_layer(None)
timing_signal = tf.cast(timing_signal, self.params["dtype"])
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length, dtype=self.params["dtype"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment