Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based attention layer with learnable per dim scaling."""
import gin
import numpy as np
import tensorflow as tf
@gin.configurable
@tf.keras.utils.register_keras_serializable(package='Text')
class PerDimScaleAttention(tf.keras.layers.MultiHeadAttention):
"""Learn scales for individual dims.
It can improve quality but might hurt training stability.
"""
def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key) # pytype: disable=attribute-error
self._scale_dim = self._key_dim
with tf.init_scope():
self.per_dim_scale = self.add_weight(
name='per_dim_scale',
shape=(self._scale_dim,),
initializer='zeros',
dtype=self.dtype,
trainable=True)
def _scale_query(self, query):
# 1.0/tf.nn.softplus(0.0) = 1.442695041. Hard code this number so that we
# can avoid unnecessary XLA op fusion mess on TPU.
r_softplus_0 = 1.442695041
scale = tf.constant(
r_softplus_0 / np.sqrt(float(self._scale_dim)), dtype=query.dtype)
scale *= tf.nn.softplus(self.per_dim_scale)
return query * scale
def _compute_attention(self,
query,
key,
value,
attention_mask=None,
training=None):
query = self._scale_query(query)
attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training)
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value)
return attention_output, attention_scores
def call(
self,
query,
value,
key=None,
attention_mask=None,
return_attention_scores=False,
training=None,
):
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key` = [B, S, N, H]
key = self._key_dense(key)
# `value` = [B, S, N, H]
value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask, training)
attention_output = self._output_dense(attention_output)
if return_attention_scores:
return attention_output, attention_scores
return attention_output
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PerDimScaleAttention."""
import tensorflow as tf
from official.nlp.modeling.layers import per_dim_scale_attention as attention
class PerDimScaleAttentionTest(tf.test.TestCase):
def test_attention(self):
num_heads = 12
key_dim = 64
seq_length = 1024
batch_size = 2
test_layer = attention.PerDimScaleAttention(
num_heads=num_heads, key_dim=key_dim)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim * num_heads))
value = query
output = test_layer(query=query, value=value)
self.assertEqual(output.shape,
[batch_size, seq_length, key_dim * num_heads])
def test_config(self):
num_heads = 12
key_dim = 64
test_layer = attention.PerDimScaleAttention(
num_heads=num_heads, key_dim=key_dim)
print(test_layer.get_config())
new_layer = attention.PerDimScaleAttention.from_config(
test_layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -53,7 +53,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
seq_axis=1,
**kwargs):
super(PositionEmbedding, self).__init__(**kwargs)
super().__init__(**kwargs)
if max_length is None:
raise ValueError(
"`max_length` must be an Integer, not `None`."
......@@ -81,7 +81,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
shape=[weight_sequence_length, width],
initializer=self._initializer)
super(PositionEmbedding, self).build(input_shape)
super().build(input_shape)
def call(self, inputs):
input_shape = tf.shape(inputs)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -98,14 +98,14 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
`[B, L, dim]`.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet of shape `[B, S, S + M]`.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet of shape `[2, num_heads, dim]`.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet of shape `[num_heads, dim]`.
segment_encoding: Optional `Tensor` representing the segmentation encoding
as used in XLNet of shape `[2, num_heads, dim]`.
segment_attention_bias: Optional trainable bias parameter added to the query
had when calculating the segment-based attention score used in XLNet of
shape `[num_heads, dim]`.
state: Optional `Tensor` of shape `[B, M, E]` where M is the length of the
state or memory.
If passed, this is also attended over as in Transformer XL.
state or memory. If passed, this is also attended over as in Transformer
XL.
attention_mask: A boolean mask of shape `[B, T, S]` that prevents attention
to certain positions.
"""
......@@ -144,7 +144,7 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
with tf.init_scope():
einsum_equation, _, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2)
self._encoding_dense = tf.keras.layers.experimental.EinsumDense(
self._encoding_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
......@@ -255,8 +255,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
Args:
query: attention input.
value: attention input.
content_attention_bias: A trainable bias parameter added to the query
head when calculating the content-based attention score.
content_attention_bias: A trainable bias parameter added to the query head
when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
key: attention input.
......@@ -264,8 +264,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_encoding: Optional `Tensor` representing the segmentation encoding
as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
......@@ -394,22 +394,22 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
content_stream: The content representation, commonly referred to as h.
This serves a similar role to the standard hidden states in
Transformer-XL.
content_attention_bias: A trainable bias parameter added to the query
head when calculating the content-based attention score.
content_attention_bias: A trainable bias parameter added to the query head
when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
query_stream: The query representation, commonly referred to as g.
This only has access to contextual information and position, but not
content. If not provided, then this is MultiHeadRelativeAttention with
query_stream: The query representation, commonly referred to as g. This
only has access to contextual information and position, but not content.
If not provided, then this is MultiHeadRelativeAttention with
self-attention.
relative_position_encoding: relative positional encoding for key and
value.
target_mapping: Optional `Tensor` representing the target mapping used
in partial prediction.
target_mapping: Optional `Tensor` representing the target mapping used in
partial prediction.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_encoding: Optional `Tensor` representing the segmentation encoding
as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query head when calculating the segment-based attention score.
state: (default None) optional state. If passed, this is also attended
......@@ -417,8 +417,8 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
content_attention_mask: (default None) Optional mask that is added to
content attention logits. If state is not None, the mask source sequence
dimension should extend M.
query_attention_mask: (default None) Optional mask that is added to
query attention logits. If state is not None, the mask source sequence
query_attention_mask: (default None) Optional mask that is added to query
attention logits. If state is not None, the mask source sequence
dimension should extend M.
Returns:
......@@ -496,4 +496,3 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
query_attention_output = self._output_dense(query_attention_output)
return content_attention_output, query_attention_output
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -22,6 +22,8 @@ import string
import numpy as np
import tensorflow as tf
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase
......@@ -221,7 +223,7 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(ReuseMultiHeadAttention, self).__init__(**kwargs)
super().__init__(**kwargs)
self._num_heads = num_heads
self._key_dim = key_dim
self._value_dim = value_dim if value_dim else key_dim
......@@ -299,7 +301,7 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
"key_shape": self._key_shape,
"value_shape": self._value_shape,
}
base_config = super(ReuseMultiHeadAttention, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
......@@ -347,8 +349,6 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._key_shape = tf.TensorShape(key)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -362,42 +362,61 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
if self._reuse_heads < self._num_heads:
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
self._query_dense = tf.keras.layers.experimental.EinsumDense(
self._query_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [
self._num_heads - self._reuse_heads, self._key_dim]),
output_shape=_get_output_shape(
output_rank - 1,
[self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
self._key_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [
self._num_heads - self._reuse_heads, self._key_dim]),
output_shape=_get_output_shape(
output_rank - 1,
[self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = []
if self._reuse_heads > 0:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value_reuse",
**common_kwargs))
self._value_dense.append(
tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value_reuse",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
if self._reuse_heads < self._num_heads:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [
self._num_heads - self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value_new",
**common_kwargs))
self._value_dense.append(
tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1,
[self._num_heads - self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value_new",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
......@@ -434,18 +453,20 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
output_shape = [self._query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
return tf.keras.layers.experimental.EinsumDense(
return tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if (use_bias and self._use_bias) else None,
name=name,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
def _build_attention(self, rank):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
costomize attention computation to replace the default dot-product
customize attention computation to replace the default dot-product
attention.
Args:
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,6 +14,8 @@
"""Keras-based TransformerEncoder block layer."""
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import reuse_attention as attention
......@@ -131,7 +133,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes
def build(self, input_shape):
......@@ -156,7 +159,6 @@ class ReuseTransformer(tf.keras.layers.Layer):
else:
self._attention_head_size = self._head_size
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -168,6 +170,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
attention_axes=self._attention_axes,
reuse_attention=self._reuse_attention,
use_relative_pe=self._use_relative_pe,
......@@ -184,11 +187,12 @@ class ReuseTransformer(tf.keras.layers.Layer):
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
......@@ -201,12 +205,13 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -68,7 +68,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
input_data = np.random.random_sample(
(batch_size, sequence_length, width))
_ = model.predict(input_data)
......@@ -89,7 +89,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
input_data = np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
......@@ -104,7 +104,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
input_data = np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
......@@ -121,7 +121,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer([input_data, mask_data])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=0.002, rtol=0.25)
new_output_tensor, output_tensor[:, 0:1, :], atol=0.002, rtol=0.01)
def test_layer_output_range_with_relative_pe(self, transformer_cls):
test_layer = transformer_cls(
......@@ -131,7 +131,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
input_data = np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
......@@ -149,7 +149,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer([input_data, mask_data])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
new_output_tensor, output_tensor[:, 0:1, :], atol=0.002, rtol=0.01)
def test_layer_output_range_without_mask(self, transformer_cls):
test_layer = transformer_cls(
......@@ -159,7 +159,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
input_data = np.random.random_sample(
(batch_size, sequence_length, width))
output_tensor, _ = test_layer(input_data)
......@@ -175,7 +175,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer(input_data)
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
new_output_tensor, output_tensor[:, 0:1, :], atol=0.002, rtol=0.01)
def test_layer_output_range_with_pre_norm(self, transformer_cls):
test_layer = transformer_cls(
......@@ -185,7 +185,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width = 80
batch_size = 6
input_data = 10 * np.random.random_sample(
input_data = np.random.random_sample(
(batch_size, sequence_length, width))
mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length))
......@@ -203,7 +203,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer([input_data, mask_data])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
new_output_tensor, output_tensor[:, 0:1, :], atol=0.002, rtol=0.01)
def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.set_global_policy('mixed_float16')
......@@ -223,7 +223,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (10 * np.random.random_sample(
input_data = (np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
......@@ -368,7 +368,7 @@ class ReuseTransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = 10 * np.random.random_sample(
input_data = np.random.random_sample(
(batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
......@@ -404,7 +404,7 @@ class ReuseTransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6
input_data = (10 * np.random.random_sample(
input_data = (np.random.random_sample(
(batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,10 +14,13 @@
"""Keras-based rezero-transformer block layer (Transformer with ReZero)."""
# pylint: disable=g-classes-have-attributes
from typing import Optional
from absl import logging
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util
......@@ -33,8 +36,10 @@ class ReZeroTransformer(tf.keras.layers.Layer):
Args:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
output_range: the sequence output range, [0, output_range) by slicing the
......@@ -52,8 +57,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
inner_dim=768,
inner_activation=tf_utils.get_activation("gelu"),
dropout_rate=0.0,
attention_dropout_rate=0.0,
output_range=None,
......@@ -72,12 +77,19 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_dropout_rate = kwargs.pop("attention_dropout",
attention_dropout_rate)
dropout_rate = kwargs.pop("output_dropout", dropout_rate)
inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("intermediate_activation", inner_activation)
util.filter_kwargs(kwargs)
super(ReZeroTransformer, self).__init__(**kwargs)
super().__init__(**kwargs)
# Deprecation warning.
if output_range is not None:
logging.warning("`output_range` is avaliable as an argument for `call()`."
"The `output_range` as __init__ argument is deprecated.")
self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate
self._output_range = output_range
......@@ -121,8 +133,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -133,6 +143,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
key_dim=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
......@@ -144,11 +156,13 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
output_shape=(None, self._inner_dim),
bias_axes="d",
name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
......@@ -156,13 +170,15 @@ class ReZeroTransformer(tf.keras.layers.Layer):
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
self._inner_activation_layer = tf.keras.layers.Activation(
self._inner_activation, dtype=policy)
self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
......@@ -185,16 +201,16 @@ class ReZeroTransformer(tf.keras.layers.Layer):
trainable=True,
dtype=tf.float32)
super(ReZeroTransformer, self).build(input_shape)
super().build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
......@@ -220,7 +236,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
}
base_config = super(ReZeroTransformer, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def reset_rezero(self):
......@@ -228,7 +244,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
if not self._share_rezero:
self._rezero_a_ffn.assign(0.)
def call(self, inputs):
def call(self, inputs, output_range: Optional[tf.Tensor] = None) -> tf.Tensor:
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
......@@ -241,10 +257,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range:
target_tensor = input_tensor[:, 0:self._output_range, :]
if output_range is None:
output_range = self._output_range
if output_range:
target_tensor = input_tensor[:, 0:output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:output_range, :]
else:
target_tensor = input_tensor
......@@ -261,8 +279,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
intermediate_output = self._inner_activation_layer(intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -128,6 +128,9 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
output_tensor = test_layer([input_data, mask_data], output_range=1)
self.assertAllClose(new_output_tensor, output_tensor, atol=5e-5, rtol=0.003)
def test_separate_qkv(self):
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=2,
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for Mixture of Experts (MoE) routing.
For MoE routing, we need to separate a set of tokens to sets of tokens.
Later on, different sets of tokens can potentially go to different experts.
"""
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package="Text")
class TokenImportanceWithMovingAvg(tf.keras.layers.Layer):
"""Routing based on per-token importance value."""
def __init__(self,
vocab_size,
init_importance,
moving_average_beta=0.995,
**kwargs):
self._vocab_size = vocab_size
self._init_importance = init_importance
self._moving_average_beta = moving_average_beta
super().__init__(**kwargs)
def build(self, input_shape):
self._importance_embedding = self.add_weight(
name="importance_embed",
shape=(self._vocab_size),
initializer=tf.keras.initializers.Constant(self._init_importance),
trainable=False)
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"init_importance":
self._init_importance,
"moving_average_beta":
self._moving_average_beta,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def update_token_importance(self, token_ids, importance):
token_ids = tf.reshape(token_ids, shape=[-1])
importance = tf.reshape(importance, shape=[-1])
beta = self._moving_average_beta
old_importance = tf.gather(self._importance_embedding, token_ids)
self._importance_embedding.assign(tf.tensor_scatter_nd_update(
self._importance_embedding,
tf.expand_dims(token_ids, axis=1),
old_importance * beta + tf.cast(importance * (1.0 - beta),
dtype=tf.float32)))
def call(self, inputs):
return tf.gather(self._importance_embedding, inputs)
@tf.keras.utils.register_keras_serializable(package="Text")
class SelectTopK(tf.keras.layers.Layer):
"""Select top-k + random-k tokens according to importance."""
def __init__(self,
top_k=None,
random_k=None,
**kwargs):
self._top_k = top_k
self._random_k = random_k
super().__init__(**kwargs)
def get_config(self):
config = {
"top_k":
self._top_k,
"random_k":
self._random_k,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
if self._random_k is None:
# Pure top-k, not randomness.
pos = tf.argsort(inputs, direction="DESCENDING")
selected = tf.slice(pos, [0, 0], [-1, self._top_k])
not_selected = tf.slice(pos, [0, self._top_k], [-1, -1])
elif self._top_k is None:
# Pure randomness, no top-k.
pos = tf.argsort(tf.random.uniform(shape=tf.shape(inputs)),
direction="DESCENDING")
selected = tf.slice(pos, [0, 0], [-1, self._random_k])
not_selected = tf.slice(pos, [0, self._random_k], [-1, -1])
else:
# Top-k plus randomness.
pos = tf.argsort(inputs, direction="DESCENDING")
selected_top_k = tf.slice(pos, [0, 0], [-1, self._top_k])
pos_left = tf.slice(pos, [0, self._top_k], [-1, -1])
# Randomly shuffle pos_left
sort_index = tf.argsort(
tf.random.uniform(shape=tf.shape(pos_left)),
direction="DESCENDING")
pos_left = tf.gather(pos_left, sort_index, batch_dims=1, axis=1)
selected_rand = tf.slice(pos_left, [0, 0], [-1, self._random_k])
not_selected = tf.slice(pos_left, [0, self._random_k], [-1, -1])
selected = tf.concat([selected_top_k, selected_rand], axis=1)
# Return the indices of selected and not-selected tokens.
return selected, not_selected
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for routing."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import routing
class TokenImportanceTest(tf.test.TestCase, parameterized.TestCase):
def test_token_importance(self):
token_importance_embed = routing.TokenImportanceWithMovingAvg(
vocab_size=4,
init_importance=10.0,
moving_average_beta=0.995)
importance = token_importance_embed(np.array([[0, 1], [2, 3]]))
self.assertAllClose(importance, np.array([[10.0, 10.0], [10.0, 10.0]]))
token_importance_embed.update_token_importance(
token_ids=np.array([[0, 1]]),
importance=np.array([[0.0, 0.0]]))
importance = token_importance_embed(np.array([[0, 1], [2, 3]]))
self.assertAllClose(importance, np.array([[9.95, 9.95], [10.0, 10.0]]))
class TopKSelectionTest(tf.test.TestCase, parameterized.TestCase):
def test_top_k_selection(self):
token_selection = routing.SelectTopK(top_k=2)
selected, _ = token_selection(np.array([[0, 1, 2, 3], [4, 3, 2, 1]]))
self.assertAllClose(selected, np.array([[3, 2], [0, 1]]))
def test_random_k_selection(self):
token_selection = routing.SelectTopK(random_k=2)
selected, _ = token_selection(np.array([[0, 1, 2, 3], [4, 3, 2, 1]]))
self.assertAllClose(selected.shape, (2, 2))
def test_top_k_random_k(self):
token_selection = routing.SelectTopK(top_k=1, random_k=1)
selected, _ = token_selection(np.array([[0, 1, 2, 3], [4, 3, 2, 1]]))
self.assertAllClose(selected.shape, (2, 2))
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,10 +13,38 @@
# limitations under the License.
"""Keras layer that creates a self-attention mask."""
from typing import Optional
import tensorflow as tf
def get_mask(inputs: tf.Tensor,
to_mask: tf.Tensor,
dtype: Optional[tf.DType] = None) -> tf.Tensor:
"""Gets a 3D self-attention mask.
Args:
inputs: from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length,
...].
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
dtype: the output Tensor dtype.
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
from_shape = tf.shape(inputs)
batch_size = from_shape[0]
from_seq_length = from_shape[1]
dtype = inputs.dtype if dtype is None else dtype
to_shape = tf.shape(to_mask)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=dtype)
return tf.broadcast_to(to_mask, [batch_size, from_seq_length, to_seq_length])
@tf.keras.utils.register_keras_serializable(package='Text')
class SelfAttentionMask(tf.keras.layers.Layer):
"""Create 3D attention mask from a 2D tensor mask.
......@@ -33,26 +61,4 @@ class SelfAttentionMask(tf.keras.layers.Layer):
if isinstance(inputs, list) and to_mask is None:
to_mask = inputs[1]
inputs = inputs[0]
from_shape = tf.shape(inputs)
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_shape = tf.shape(to_mask)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
dtype=inputs.dtype)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=inputs.dtype)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
return get_mask(inputs, to_mask)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -74,21 +74,20 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
if not isinstance(layer, tf.keras.layers.Layer):
raise ValueError('`layer` must be a `tf.keras.layer.Layer`. '
'Observed `{}`'.format(layer))
super(SpectralNormalization, self).__init__(
super().__init__(
layer, name=wrapper_name, **kwargs)
def build(self, input_shape):
super(SpectralNormalization, self).build(input_shape)
super().build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight(
shape=(1, np.prod(self.w_shape[:-1])),
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='v',
dtype=self.dtype,
......@@ -96,7 +95,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self.u = self.add_weight(
shape=(1, self.w_shape[-1]),
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='u',
dtype=self.dtype,
......@@ -194,10 +193,11 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
raise ValueError(
'layer must be a `tf.keras.layer.Conv2D` instance. You passed: {input}'
.format(input=layer))
super(SpectralNormalizationConv2D, self).__init__(layer, **kwargs)
super().__init__(layer, **kwargs)
def build(self, input_shape):
self.layer.build(input_shape)
if not self.layer.built:
self.layer.build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype
......@@ -221,11 +221,10 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.in_shape = (uv_dim, in_height, in_width, in_channel)
self.out_shape = (uv_dim, out_height, out_width, out_channel)
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight(
shape=self.in_shape,
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='v',
dtype=self.dtype,
......@@ -233,13 +232,13 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.u = self.add_weight(
shape=self.out_shape,
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='u',
dtype=self.dtype,
aggregation=self.aggregation)
super(SpectralNormalizationConv2D, self).build()
super().build()
def call(self, inputs):
u_update_op, v_update_op, w_update_op = self.update_weights()
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -66,7 +66,7 @@ class NormalizationTest(tf.test.TestCase, parameterized.TestCase):
spectral_norm_computed = _compute_spectral_norm(normalized_kernel)
spectral_norm_expected = self.norm_multiplier
self.assertAllClose(
spectral_norm_computed, spectral_norm_expected, atol=5e-2)
spectral_norm_computed, spectral_norm_expected, atol=1e-1)
# Test that the normalized layer is K-Lipschitz. In particular, if the layer
# is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,6 +20,8 @@ import string
import gin
import tensorflow as tf
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase
......@@ -87,7 +89,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight",
shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer,
initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
......@@ -95,7 +97,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._post_softmax_weight = self.add_weight(
"post_softmax_weight",
shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer,
initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,14 +14,16 @@
"""Keras Layers for BERT-specific preprocessing."""
# pylint: disable=g-import-not-at-top
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Text, Union
from absl import logging
import tensorflow as tf
try:
# pytype: disable=import-error
import tensorflow_text as text
from tensorflow_text.python.ops import bert_tokenizer
# pytype: enable=import-error
except ImportError:
text = None
bert_tokenizer = None
......@@ -57,7 +59,7 @@ def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
class BertTokenizer(tf.keras.layers.Layer):
"""Wraps BertTokenizer with pre-defined vocab as a Keras Layer.
"""Wraps TF.Text's BertTokenizer with pre-defined vocab as a Keras Layer.
Attributes:
tokenize_with_offsets: If true, calls
......@@ -71,8 +73,9 @@ class BertTokenizer(tf.keras.layers.Layer):
def __init__(self, *,
vocab_file: str,
lower_case: bool,
lower_case: Optional[bool] = None,
tokenize_with_offsets: bool = False,
tokenizer_kwargs: Optional[Mapping[Text, Any]] = None,
**kwargs):
"""Initialize a `BertTokenizer` layer.
......@@ -81,15 +84,18 @@ class BertTokenizer(tf.keras.layers.Layer):
This is a text file with newline-separated wordpiece tokens.
This layer initializes a lookup table from it that gets used with
`text.BertTokenizer`.
lower_case: A Python boolean forwarded to `text.BertTokenizer`.
lower_case: Optional boolean forwarded to `text.BertTokenizer`.
If true, input text is converted to lower case (where applicable)
before tokenization. This must be set to match the way in which
the `vocab_file` was created.
the `vocab_file` was created. If passed, this overrides whatever value
may have been passed in `tokenizer_kwargs`.
tokenize_with_offsets: A Python boolean. If true, this layer calls
`text.BertTokenizer.tokenize_with_offsets()` instead of plain
`text.BertTokenizer.tokenize()` and outputs a triple of
`(tokens, start_offsets, limit_offsets)`
insead of just tokens.
tokenizer_kwargs: Optional mapping with keyword arguments to forward to
`text.BertTokenizer`'s constructor.
**kwargs: Standard arguments to `Layer()`.
Raises:
......@@ -111,8 +117,11 @@ class BertTokenizer(tf.keras.layers.Layer):
self._special_tokens_dict = self._create_special_tokens_dict(
self._vocab_table, vocab_file)
super().__init__(**kwargs)
self._bert_tokenizer = text.BertTokenizer(
self._vocab_table, lower_case=lower_case)
tokenizer_kwargs = dict(tokenizer_kwargs or {})
if lower_case is not None:
tokenizer_kwargs["lower_case"] = lower_case
self._bert_tokenizer = text.BertTokenizer(self._vocab_table,
**tokenizer_kwargs)
@property
def vocab_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment