Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based attention layer with learnable per dim scaling."""
import gin
import numpy as np
import tensorflow as tf
@gin.configurable
@tf.keras.utils.register_keras_serializable(package='Text')
class PerDimScaleAttention(tf.keras.layers.MultiHeadAttention):
"""Learn scales for individual dims.
It can improve quality but might hurt training stability.
"""
def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key) # pytype: disable=attribute-error
self._scale_dim = self._key_dim
with tf.init_scope():
self.per_dim_scale = self.add_weight(
name='per_dim_scale',
shape=(self._scale_dim,),
initializer='zeros',
dtype=self.dtype,
trainable=True)
def _scale_query(self, query):
# 1.0/tf.nn.softplus(0.0) = 1.442695041. Hard code this number so that we
# can avoid unnecessary XLA op fusion mess on TPU.
r_softplus_0 = 1.442695041
scale = tf.constant(
r_softplus_0 / np.sqrt(float(self._scale_dim)), dtype=query.dtype)
scale *= tf.nn.softplus(self.per_dim_scale)
return query * scale
def _compute_attention(self,
query,
key,
value,
attention_mask=None,
training=None):
query = self._scale_query(query)
attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training)
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value)
return attention_output, attention_scores
def call(
self,
query,
value,
key=None,
attention_mask=None,
return_attention_scores=False,
training=None,
):
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key` = [B, S, N, H]
key = self._key_dense(key)
# `value` = [B, S, N, H]
value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask, training)
attention_output = self._output_dense(attention_output)
if return_attention_scores:
return attention_output, attention_scores
return attention_output
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PerDimScaleAttention."""
import tensorflow as tf
from official.nlp.modeling.layers import per_dim_scale_attention as attention
class PerDimScaleAttentionTest(tf.test.TestCase):
def test_attention(self):
num_heads = 12
key_dim = 64
seq_length = 1024
batch_size = 2
test_layer = attention.PerDimScaleAttention(
num_heads=num_heads, key_dim=key_dim)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim * num_heads))
value = query
output = test_layer(query=query, value=value)
self.assertEqual(output.shape,
[batch_size, seq_length, key_dim * num_heads])
def test_config(self):
num_heads = 12
key_dim = 64
test_layer = attention.PerDimScaleAttention(
num_heads=num_heads, key_dim=key_dim)
print(test_layer.get_config())
new_layer = attention.PerDimScaleAttention.from_config(
test_layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -53,7 +53,7 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -53,7 +53,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
seq_axis=1, seq_axis=1,
**kwargs): **kwargs):
super(PositionEmbedding, self).__init__(**kwargs) super().__init__(**kwargs)
if max_length is None: if max_length is None:
raise ValueError( raise ValueError(
"`max_length` must be an Integer, not `None`." "`max_length` must be an Integer, not `None`."
...@@ -81,7 +81,7 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -81,7 +81,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
shape=[weight_sequence_length, width], shape=[weight_sequence_length, width],
initializer=self._initializer) initializer=self._initializer)
super(PositionEmbedding, self).build(input_shape) super().build(input_shape)
def call(self, inputs): def call(self, inputs):
input_shape = tf.shape(inputs) input_shape = tf.shape(inputs)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -98,14 +98,14 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -98,14 +98,14 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
`[B, L, dim]`. `[B, L, dim]`.
segment_matrix: Optional `Tensor` representing segmentation IDs used in segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet of shape `[B, S, S + M]`. XLNet of shape `[B, S, S + M]`.
segment_encoding: Optional `Tensor` representing the segmentation segment_encoding: Optional `Tensor` representing the segmentation encoding
encoding as used in XLNet of shape `[2, num_heads, dim]`. as used in XLNet of shape `[2, num_heads, dim]`.
segment_attention_bias: Optional trainable bias parameter added to the segment_attention_bias: Optional trainable bias parameter added to the query
query had when calculating the segment-based attention score used in had when calculating the segment-based attention score used in XLNet of
XLNet of shape `[num_heads, dim]`. shape `[num_heads, dim]`.
state: Optional `Tensor` of shape `[B, M, E]` where M is the length of the state: Optional `Tensor` of shape `[B, M, E]` where M is the length of the
state or memory. state or memory. If passed, this is also attended over as in Transformer
If passed, this is also attended over as in Transformer XL. XL.
attention_mask: A boolean mask of shape `[B, T, S]` that prevents attention attention_mask: A boolean mask of shape `[B, T, S]` that prevents attention
to certain positions. to certain positions.
""" """
...@@ -144,7 +144,7 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -144,7 +144,7 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
with tf.init_scope(): with tf.init_scope():
einsum_equation, _, output_rank = _build_proj_equation( einsum_equation, _, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2) key_shape.rank - 1, bound_dims=1, output_dims=2)
self._encoding_dense = tf.keras.layers.experimental.EinsumDense( self._encoding_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]), [self._num_heads, self._key_dim]),
...@@ -255,8 +255,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -255,8 +255,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
Args: Args:
query: attention input. query: attention input.
value: attention input. value: attention input.
content_attention_bias: A trainable bias parameter added to the query content_attention_bias: A trainable bias parameter added to the query head
head when calculating the content-based attention score. when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score. head when calculating the position-based attention score.
key: attention input. key: attention input.
...@@ -264,8 +264,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -264,8 +264,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
value. value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet. XLNet.
segment_encoding: Optional `Tensor` representing the segmentation segment_encoding: Optional `Tensor` representing the segmentation encoding
encoding as used in XLNet. as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in query had when calculating the segment-based attention score used in
XLNet. XLNet.
...@@ -394,22 +394,22 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention): ...@@ -394,22 +394,22 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
content_stream: The content representation, commonly referred to as h. content_stream: The content representation, commonly referred to as h.
This serves a similar role to the standard hidden states in This serves a similar role to the standard hidden states in
Transformer-XL. Transformer-XL.
content_attention_bias: A trainable bias parameter added to the query content_attention_bias: A trainable bias parameter added to the query head
head when calculating the content-based attention score. when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score. head when calculating the position-based attention score.
query_stream: The query representation, commonly referred to as g. query_stream: The query representation, commonly referred to as g. This
This only has access to contextual information and position, but not only has access to contextual information and position, but not content.
content. If not provided, then this is MultiHeadRelativeAttention with If not provided, then this is MultiHeadRelativeAttention with
self-attention. self-attention.
relative_position_encoding: relative positional encoding for key and relative_position_encoding: relative positional encoding for key and
value. value.
target_mapping: Optional `Tensor` representing the target mapping used target_mapping: Optional `Tensor` representing the target mapping used in
in partial prediction. partial prediction.
segment_matrix: Optional `Tensor` representing segmentation IDs used in segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet. XLNet.
segment_encoding: Optional `Tensor` representing the segmentation segment_encoding: Optional `Tensor` representing the segmentation encoding
encoding as used in XLNet. as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the segment_attention_bias: Optional trainable bias parameter added to the
query head when calculating the segment-based attention score. query head when calculating the segment-based attention score.
state: (default None) optional state. If passed, this is also attended state: (default None) optional state. If passed, this is also attended
...@@ -417,8 +417,8 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention): ...@@ -417,8 +417,8 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
content_attention_mask: (default None) Optional mask that is added to content_attention_mask: (default None) Optional mask that is added to
content attention logits. If state is not None, the mask source sequence content attention logits. If state is not None, the mask source sequence
dimension should extend M. dimension should extend M.
query_attention_mask: (default None) Optional mask that is added to query_attention_mask: (default None) Optional mask that is added to query
query attention logits. If state is not None, the mask source sequence attention logits. If state is not None, the mask source sequence
dimension should extend M. dimension should extend M.
Returns: Returns:
...@@ -496,4 +496,3 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention): ...@@ -496,4 +496,3 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
query_attention_output = self._output_dense(query_attention_output) query_attention_output = self._output_dense(query_attention_output)
return content_attention_output, query_attention_output return content_attention_output, query_attention_output
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,6 +22,8 @@ import string ...@@ -22,6 +22,8 @@ import string
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
...@@ -221,7 +223,7 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -221,7 +223,7 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
**kwargs): **kwargs):
super(ReuseMultiHeadAttention, self).__init__(**kwargs) super().__init__(**kwargs)
self._num_heads = num_heads self._num_heads = num_heads
self._key_dim = key_dim self._key_dim = key_dim
self._value_dim = value_dim if value_dim else key_dim self._value_dim = value_dim if value_dim else key_dim
...@@ -299,7 +301,7 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -299,7 +301,7 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
"key_shape": self._key_shape, "key_shape": self._key_shape,
"value_shape": self._value_shape, "value_shape": self._value_shape,
} }
base_config = super(ReuseMultiHeadAttention, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@classmethod @classmethod
...@@ -347,8 +349,6 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -347,8 +349,6 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._key_shape = tf.TensorShape(key) self._key_shape = tf.TensorShape(key)
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -362,42 +362,61 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -362,42 +362,61 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
if self._reuse_heads < self._num_heads: if self._reuse_heads < self._num_heads:
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2) free_dims, bound_dims=1, output_dims=2)
self._query_dense = tf.keras.layers.experimental.EinsumDense( self._query_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [ output_shape=_get_output_shape(
self._num_heads - self._reuse_heads, self._key_dim]), output_rank - 1,
[self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="query", name="query",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2) self._key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = tf.keras.layers.experimental.EinsumDense( self._key_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [ output_shape=_get_output_shape(
self._num_heads - self._reuse_heads, self._key_dim]), output_rank - 1,
[self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="key", name="key",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2) self._value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = [] self._value_dense = []
if self._reuse_heads > 0: if self._reuse_heads > 0:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense( self._value_dense.append(
einsum_equation, tf.keras.layers.EinsumDense(
output_shape=_get_output_shape( einsum_equation,
output_rank - 1, [self._reuse_heads, self._value_dim]), output_shape=_get_output_shape(
bias_axes=bias_axes if self._use_bias else None, output_rank - 1, [self._reuse_heads, self._value_dim]),
name="value_reuse", bias_axes=bias_axes if self._use_bias else None,
**common_kwargs)) name="value_reuse",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
if self._reuse_heads < self._num_heads: if self._reuse_heads < self._num_heads:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense( self._value_dense.append(
einsum_equation, tf.keras.layers.EinsumDense(
output_shape=_get_output_shape(output_rank - 1, [ einsum_equation,
self._num_heads - self._reuse_heads, self._value_dim]), output_shape=_get_output_shape(
bias_axes=bias_axes if self._use_bias else None, output_rank - 1,
name="value_new", [self._num_heads - self._reuse_heads, self._value_dim]),
**common_kwargs)) bias_axes=bias_axes if self._use_bias else None,
name="value_new",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
# Builds the attention computations for multi-head dot product attention. # Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once # These computations could be wrapped into the keras attention layer once
...@@ -434,18 +453,20 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -434,18 +453,20 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
output_shape = [self._query_shape[-1]] output_shape = [self._query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape)) free_dims, bound_dims=2, output_dims=len(output_shape))
return tf.keras.layers.experimental.EinsumDense( return tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape), output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if (use_bias and self._use_bias) else None, bias_axes=bias_axes if (use_bias and self._use_bias) else None,
name=name, name=name,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
def _build_attention(self, rank): def _build_attention(self, rank):
"""Builds multi-head dot-product attention computations. """Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to This function builds attributes necessary for `_compute_attention` to
costomize attention computation to replace the default dot-product customize attention computation to replace the default dot-product
attention. attention.
Args: Args:
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
"""Keras-based TransformerEncoder block layer.""" """Keras-based TransformerEncoder block layer."""
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import reuse_attention as attention from official.nlp.modeling.layers import reuse_attention as attention
...@@ -131,7 +133,8 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -131,7 +133,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes self._attention_axes = attention_axes
def build(self, input_shape): def build(self, input_shape):
...@@ -156,7 +159,6 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -156,7 +159,6 @@ class ReuseTransformer(tf.keras.layers.Layer):
else: else:
self._attention_head_size = self._head_size self._attention_head_size = self._head_size
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -168,6 +170,7 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -168,6 +170,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
dropout=self._attention_dropout, dropout=self._attention_dropout,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
attention_axes=self._attention_axes, attention_axes=self._attention_axes,
reuse_attention=self._reuse_attention, reuse_attention=self._reuse_attention,
use_relative_pe=self._use_relative_pe, use_relative_pe=self._use_relative_pe,
...@@ -184,11 +187,12 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -184,11 +187,12 @@ class ReuseTransformer(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
...@@ -201,12 +205,13 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -201,12 +205,13 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._inner_activation, dtype=policy) self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout( self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout) rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout) self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -68,7 +68,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -68,7 +68,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself # Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors. # (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6 batch_size = 6
input_data = 10 * np.random.random_sample( input_data = np.random.random_sample(
(batch_size, sequence_length, width)) (batch_size, sequence_length, width))
_ = model.predict(input_data) _ = model.predict(input_data)
...@@ -89,7 +89,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -89,7 +89,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself # Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors. # (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6 batch_size = 6
input_data = 10 * np.random.random_sample( input_data = np.random.random_sample(
(batch_size, sequence_length, width)) (batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len), # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length) # which here is (batch, sequence_length, sequence_length)
...@@ -104,7 +104,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -104,7 +104,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width = 80 width = 80
batch_size = 6 batch_size = 6
input_data = 10 * np.random.random_sample( input_data = np.random.random_sample(
(batch_size, sequence_length, width)) (batch_size, sequence_length, width))
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length)) 2, size=(batch_size, sequence_length, sequence_length))
...@@ -121,7 +121,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -121,7 +121,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer([input_data, mask_data]) new_output_tensor, _ = new_layer([input_data, mask_data])
self.assertAllClose( self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=0.002, rtol=0.25) new_output_tensor, output_tensor[:, 0:1, :], atol=0.002, rtol=0.01)
def test_layer_output_range_with_relative_pe(self, transformer_cls): def test_layer_output_range_with_relative_pe(self, transformer_cls):
test_layer = transformer_cls( test_layer = transformer_cls(
...@@ -131,7 +131,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -131,7 +131,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width = 80 width = 80
batch_size = 6 batch_size = 6
input_data = 10 * np.random.random_sample( input_data = np.random.random_sample(
(batch_size, sequence_length, width)) (batch_size, sequence_length, width))
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length)) 2, size=(batch_size, sequence_length, sequence_length))
...@@ -149,7 +149,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -149,7 +149,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer([input_data, mask_data]) new_output_tensor, _ = new_layer([input_data, mask_data])
self.assertAllClose( self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003) new_output_tensor, output_tensor[:, 0:1, :], atol=0.002, rtol=0.01)
def test_layer_output_range_without_mask(self, transformer_cls): def test_layer_output_range_without_mask(self, transformer_cls):
test_layer = transformer_cls( test_layer = transformer_cls(
...@@ -159,7 +159,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -159,7 +159,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width = 80 width = 80
batch_size = 6 batch_size = 6
input_data = 10 * np.random.random_sample( input_data = np.random.random_sample(
(batch_size, sequence_length, width)) (batch_size, sequence_length, width))
output_tensor, _ = test_layer(input_data) output_tensor, _ = test_layer(input_data)
...@@ -175,7 +175,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -175,7 +175,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer(input_data) new_output_tensor, _ = new_layer(input_data)
self.assertAllClose( self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003) new_output_tensor, output_tensor[:, 0:1, :], atol=0.002, rtol=0.01)
def test_layer_output_range_with_pre_norm(self, transformer_cls): def test_layer_output_range_with_pre_norm(self, transformer_cls):
test_layer = transformer_cls( test_layer = transformer_cls(
...@@ -185,7 +185,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -185,7 +185,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
width = 80 width = 80
batch_size = 6 batch_size = 6
input_data = 10 * np.random.random_sample( input_data = np.random.random_sample(
(batch_size, sequence_length, width)) (batch_size, sequence_length, width))
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, sequence_length, sequence_length)) 2, size=(batch_size, sequence_length, sequence_length))
...@@ -203,7 +203,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -203,7 +203,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor, _ = new_layer([input_data, mask_data]) new_output_tensor, _ = new_layer([input_data, mask_data])
self.assertAllClose( self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003) new_output_tensor, output_tensor[:, 0:1, :], atol=0.002, rtol=0.01)
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.set_global_policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
...@@ -223,7 +223,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -223,7 +223,7 @@ class ReuseTransformerLayerTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself # Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors. # (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6 batch_size = 6
input_data = (10 * np.random.random_sample( input_data = (np.random.random_sample(
(batch_size, sequence_length, width))) (batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len), # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length) # which here is (batch, sequence_length, sequence_length)
...@@ -368,7 +368,7 @@ class ReuseTransformerArgumentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -368,7 +368,7 @@ class ReuseTransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself # Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors. # (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6 batch_size = 6
input_data = 10 * np.random.random_sample( input_data = np.random.random_sample(
(batch_size, sequence_length, width)) (batch_size, sequence_length, width))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len), # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length) # which here is (batch, sequence_length, sequence_length)
...@@ -404,7 +404,7 @@ class ReuseTransformerArgumentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -404,7 +404,7 @@ class ReuseTransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
# Invoke the model on test data. We can't validate the output data itself # Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors. # (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 6 batch_size = 6
input_data = (10 * np.random.random_sample( input_data = (np.random.random_sample(
(batch_size, sequence_length, width))) (batch_size, sequence_length, width)))
# The attention mask should be of shape (batch, from_seq_len, to_seq_len), # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
# which here is (batch, sequence_length, sequence_length) # which here is (batch, sequence_length, sequence_length)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,10 +14,13 @@ ...@@ -14,10 +14,13 @@
"""Keras-based rezero-transformer block layer (Transformer with ReZero).""" """Keras-based rezero-transformer block layer (Transformer with ReZero)."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from typing import Optional
from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util from official.nlp.modeling.layers import util
...@@ -33,8 +36,10 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -33,8 +36,10 @@ class ReZeroTransformer(tf.keras.layers.Layer):
Args: Args:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer. inner_dim: The output dimension of the first Dense layer in a two-layer
intermediate_activation: Activation for the intermediate layer. feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
dropout_rate: Dropout probability for the post-attention and output dropout. dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer. attention_dropout_rate: Dropout probability for within the attention layer.
output_range: the sequence output range, [0, output_range) by slicing the output_range: the sequence output range, [0, output_range) by slicing the
...@@ -52,8 +57,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -52,8 +57,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
num_attention_heads, num_attention_heads,
intermediate_size, inner_dim=768,
intermediate_activation, inner_activation=tf_utils.get_activation("gelu"),
dropout_rate=0.0, dropout_rate=0.0,
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
output_range=None, output_range=None,
...@@ -72,12 +77,19 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -72,12 +77,19 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_dropout_rate = kwargs.pop("attention_dropout", attention_dropout_rate = kwargs.pop("attention_dropout",
attention_dropout_rate) attention_dropout_rate)
dropout_rate = kwargs.pop("output_dropout", dropout_rate) dropout_rate = kwargs.pop("output_dropout", dropout_rate)
inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("intermediate_activation", inner_activation)
util.filter_kwargs(kwargs) util.filter_kwargs(kwargs)
super(ReZeroTransformer, self).__init__(**kwargs) super().__init__(**kwargs)
# Deprecation warning.
if output_range is not None:
logging.warning("`output_range` is avaliable as an argument for `call()`."
"The `output_range` as __init__ argument is deprecated.")
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
self._intermediate_size = intermediate_size self._inner_dim = inner_dim
self._intermediate_activation = intermediate_activation self._inner_activation = inner_activation
self._attention_dropout_rate = attention_dropout_rate self._attention_dropout_rate = attention_dropout_rate
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._output_range = output_range self._output_range = output_range
...@@ -121,8 +133,6 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -121,8 +133,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -133,6 +143,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -133,6 +143,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
key_dim=self._attention_head_size, key_dim=self._attention_head_size,
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
name="self_attention", name="self_attention",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
...@@ -144,11 +156,13 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -144,11 +156,13 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
name="intermediate", name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
...@@ -156,13 +170,15 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -156,13 +170,15 @@ class ReZeroTransformer(tf.keras.layers.Layer):
# as well, so we use float32. # as well, so we use float32.
# TODO(b/154538392): Investigate this. # TODO(b/154538392): Investigate this.
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._inner_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._inner_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
...@@ -185,16 +201,16 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -185,16 +201,16 @@ class ReZeroTransformer(tf.keras.layers.Layer):
trainable=True, trainable=True,
dtype=tf.float32) dtype=tf.float32)
super(ReZeroTransformer, self).build(input_shape) super().build(input_shape)
def get_config(self): def get_config(self):
config = { config = {
"num_attention_heads": "num_attention_heads":
self._num_heads, self._num_heads,
"intermediate_size": "inner_dim":
self._intermediate_size, self._inner_dim,
"intermediate_activation": "inner_activation":
self._intermediate_activation, self._inner_activation,
"dropout_rate": "dropout_rate":
self._dropout_rate, self._dropout_rate,
"attention_dropout_rate": "attention_dropout_rate":
...@@ -220,7 +236,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -220,7 +236,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint), tf.keras.constraints.serialize(self._bias_constraint),
} }
base_config = super(ReZeroTransformer, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def reset_rezero(self): def reset_rezero(self):
...@@ -228,7 +244,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -228,7 +244,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
if not self._share_rezero: if not self._share_rezero:
self._rezero_a_ffn.assign(0.) self._rezero_a_ffn.assign(0.)
def call(self, inputs): def call(self, inputs, output_range: Optional[tf.Tensor] = None) -> tf.Tensor:
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
if len(inputs) == 2: if len(inputs) == 2:
input_tensor, attention_mask = inputs input_tensor, attention_mask = inputs
...@@ -241,10 +257,12 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -241,10 +257,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
else: else:
input_tensor, key_value, attention_mask = (inputs, None, None) input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range: if output_range is None:
target_tensor = input_tensor[:, 0:self._output_range, :] output_range = self._output_range
if output_range:
target_tensor = input_tensor[:, 0:output_range, :]
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:output_range, :]
else: else:
target_tensor = input_tensor target_tensor = input_tensor
...@@ -261,8 +279,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -261,8 +279,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_output = tf.cast(attention_output, tf.float32) attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer( intermediate_output = self._inner_activation_layer(intermediate_output)
intermediate_output)
layer_output = self._output_dense(intermediate_output) layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and # During mixed precision training, attention_output is from layer norm and
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -128,6 +128,9 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase): ...@@ -128,6 +128,9 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :]) self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
output_tensor = test_layer([input_data, mask_data], output_range=1)
self.assertAllClose(new_output_tensor, output_tensor, atol=5e-5, rtol=0.003)
def test_separate_qkv(self): def test_separate_qkv(self):
test_layer = rezero_transformer.ReZeroTransformer( test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=2, num_attention_heads=2,
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for Mixture of Experts (MoE) routing.
For MoE routing, we need to separate a set of tokens to sets of tokens.
Later on, different sets of tokens can potentially go to different experts.
"""
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package="Text")
class TokenImportanceWithMovingAvg(tf.keras.layers.Layer):
"""Routing based on per-token importance value."""
def __init__(self,
vocab_size,
init_importance,
moving_average_beta=0.995,
**kwargs):
self._vocab_size = vocab_size
self._init_importance = init_importance
self._moving_average_beta = moving_average_beta
super().__init__(**kwargs)
def build(self, input_shape):
self._importance_embedding = self.add_weight(
name="importance_embed",
shape=(self._vocab_size),
initializer=tf.keras.initializers.Constant(self._init_importance),
trainable=False)
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"init_importance":
self._init_importance,
"moving_average_beta":
self._moving_average_beta,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def update_token_importance(self, token_ids, importance):
token_ids = tf.reshape(token_ids, shape=[-1])
importance = tf.reshape(importance, shape=[-1])
beta = self._moving_average_beta
old_importance = tf.gather(self._importance_embedding, token_ids)
self._importance_embedding.assign(tf.tensor_scatter_nd_update(
self._importance_embedding,
tf.expand_dims(token_ids, axis=1),
old_importance * beta + tf.cast(importance * (1.0 - beta),
dtype=tf.float32)))
def call(self, inputs):
return tf.gather(self._importance_embedding, inputs)
@tf.keras.utils.register_keras_serializable(package="Text")
class SelectTopK(tf.keras.layers.Layer):
"""Select top-k + random-k tokens according to importance."""
def __init__(self,
top_k=None,
random_k=None,
**kwargs):
self._top_k = top_k
self._random_k = random_k
super().__init__(**kwargs)
def get_config(self):
config = {
"top_k":
self._top_k,
"random_k":
self._random_k,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
if self._random_k is None:
# Pure top-k, not randomness.
pos = tf.argsort(inputs, direction="DESCENDING")
selected = tf.slice(pos, [0, 0], [-1, self._top_k])
not_selected = tf.slice(pos, [0, self._top_k], [-1, -1])
elif self._top_k is None:
# Pure randomness, no top-k.
pos = tf.argsort(tf.random.uniform(shape=tf.shape(inputs)),
direction="DESCENDING")
selected = tf.slice(pos, [0, 0], [-1, self._random_k])
not_selected = tf.slice(pos, [0, self._random_k], [-1, -1])
else:
# Top-k plus randomness.
pos = tf.argsort(inputs, direction="DESCENDING")
selected_top_k = tf.slice(pos, [0, 0], [-1, self._top_k])
pos_left = tf.slice(pos, [0, self._top_k], [-1, -1])
# Randomly shuffle pos_left
sort_index = tf.argsort(
tf.random.uniform(shape=tf.shape(pos_left)),
direction="DESCENDING")
pos_left = tf.gather(pos_left, sort_index, batch_dims=1, axis=1)
selected_rand = tf.slice(pos_left, [0, 0], [-1, self._random_k])
not_selected = tf.slice(pos_left, [0, self._random_k], [-1, -1])
selected = tf.concat([selected_top_k, selected_rand], axis=1)
# Return the indices of selected and not-selected tokens.
return selected, not_selected
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for routing."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import routing
class TokenImportanceTest(tf.test.TestCase, parameterized.TestCase):
def test_token_importance(self):
token_importance_embed = routing.TokenImportanceWithMovingAvg(
vocab_size=4,
init_importance=10.0,
moving_average_beta=0.995)
importance = token_importance_embed(np.array([[0, 1], [2, 3]]))
self.assertAllClose(importance, np.array([[10.0, 10.0], [10.0, 10.0]]))
token_importance_embed.update_token_importance(
token_ids=np.array([[0, 1]]),
importance=np.array([[0.0, 0.0]]))
importance = token_importance_embed(np.array([[0, 1], [2, 3]]))
self.assertAllClose(importance, np.array([[9.95, 9.95], [10.0, 10.0]]))
class TopKSelectionTest(tf.test.TestCase, parameterized.TestCase):
def test_top_k_selection(self):
token_selection = routing.SelectTopK(top_k=2)
selected, _ = token_selection(np.array([[0, 1, 2, 3], [4, 3, 2, 1]]))
self.assertAllClose(selected, np.array([[3, 2], [0, 1]]))
def test_random_k_selection(self):
token_selection = routing.SelectTopK(random_k=2)
selected, _ = token_selection(np.array([[0, 1, 2, 3], [4, 3, 2, 1]]))
self.assertAllClose(selected.shape, (2, 2))
def test_top_k_random_k(self):
token_selection = routing.SelectTopK(top_k=1, random_k=1)
selected, _ = token_selection(np.array([[0, 1, 2, 3], [4, 3, 2, 1]]))
self.assertAllClose(selected.shape, (2, 2))
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,10 +13,38 @@ ...@@ -13,10 +13,38 @@
# limitations under the License. # limitations under the License.
"""Keras layer that creates a self-attention mask.""" """Keras layer that creates a self-attention mask."""
from typing import Optional
import tensorflow as tf import tensorflow as tf
def get_mask(inputs: tf.Tensor,
to_mask: tf.Tensor,
dtype: Optional[tf.DType] = None) -> tf.Tensor:
"""Gets a 3D self-attention mask.
Args:
inputs: from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length,
...].
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
dtype: the output Tensor dtype.
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
from_shape = tf.shape(inputs)
batch_size = from_shape[0]
from_seq_length = from_shape[1]
dtype = inputs.dtype if dtype is None else dtype
to_shape = tf.shape(to_mask)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=dtype)
return tf.broadcast_to(to_mask, [batch_size, from_seq_length, to_seq_length])
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class SelfAttentionMask(tf.keras.layers.Layer): class SelfAttentionMask(tf.keras.layers.Layer):
"""Create 3D attention mask from a 2D tensor mask. """Create 3D attention mask from a 2D tensor mask.
...@@ -33,26 +61,4 @@ class SelfAttentionMask(tf.keras.layers.Layer): ...@@ -33,26 +61,4 @@ class SelfAttentionMask(tf.keras.layers.Layer):
if isinstance(inputs, list) and to_mask is None: if isinstance(inputs, list) and to_mask is None:
to_mask = inputs[1] to_mask = inputs[1]
inputs = inputs[0] inputs = inputs[0]
from_shape = tf.shape(inputs) return get_mask(inputs, to_mask)
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_shape = tf.shape(to_mask)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
dtype=inputs.dtype)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=inputs.dtype)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -74,21 +74,20 @@ class SpectralNormalization(tf.keras.layers.Wrapper): ...@@ -74,21 +74,20 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
if not isinstance(layer, tf.keras.layers.Layer): if not isinstance(layer, tf.keras.layers.Layer):
raise ValueError('`layer` must be a `tf.keras.layer.Layer`. ' raise ValueError('`layer` must be a `tf.keras.layer.Layer`. '
'Observed `{}`'.format(layer)) 'Observed `{}`'.format(layer))
super(SpectralNormalization, self).__init__( super().__init__(
layer, name=wrapper_name, **kwargs) layer, name=wrapper_name, **kwargs)
def build(self, input_shape): def build(self, input_shape):
super(SpectralNormalization, self).build(input_shape) super().build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype self._dtype = self.layer.kernel.dtype
self.w = self.layer.kernel self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list() self.w_shape = self.w.shape.as_list()
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight( self.v = self.add_weight(
shape=(1, np.prod(self.w_shape[:-1])), shape=(1, np.prod(self.w_shape[:-1])),
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='v', name='v',
dtype=self.dtype, dtype=self.dtype,
...@@ -96,7 +95,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper): ...@@ -96,7 +95,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self.u = self.add_weight( self.u = self.add_weight(
shape=(1, self.w_shape[-1]), shape=(1, self.w_shape[-1]),
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='u', name='u',
dtype=self.dtype, dtype=self.dtype,
...@@ -194,10 +193,11 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper): ...@@ -194,10 +193,11 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
raise ValueError( raise ValueError(
'layer must be a `tf.keras.layer.Conv2D` instance. You passed: {input}' 'layer must be a `tf.keras.layer.Conv2D` instance. You passed: {input}'
.format(input=layer)) .format(input=layer))
super(SpectralNormalizationConv2D, self).__init__(layer, **kwargs) super().__init__(layer, **kwargs)
def build(self, input_shape): def build(self, input_shape):
self.layer.build(input_shape) if not self.layer.built:
self.layer.build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype self._dtype = self.layer.kernel.dtype
...@@ -221,11 +221,10 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper): ...@@ -221,11 +221,10 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.in_shape = (uv_dim, in_height, in_width, in_channel) self.in_shape = (uv_dim, in_height, in_width, in_channel)
self.out_shape = (uv_dim, out_height, out_width, out_channel) self.out_shape = (uv_dim, out_height, out_width, out_channel)
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight( self.v = self.add_weight(
shape=self.in_shape, shape=self.in_shape,
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='v', name='v',
dtype=self.dtype, dtype=self.dtype,
...@@ -233,13 +232,13 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper): ...@@ -233,13 +232,13 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.u = self.add_weight( self.u = self.add_weight(
shape=self.out_shape, shape=self.out_shape,
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='u', name='u',
dtype=self.dtype, dtype=self.dtype,
aggregation=self.aggregation) aggregation=self.aggregation)
super(SpectralNormalizationConv2D, self).build() super().build()
def call(self, inputs): def call(self, inputs):
u_update_op, v_update_op, w_update_op = self.update_weights() u_update_op, v_update_op, w_update_op = self.update_weights()
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -66,7 +66,7 @@ class NormalizationTest(tf.test.TestCase, parameterized.TestCase): ...@@ -66,7 +66,7 @@ class NormalizationTest(tf.test.TestCase, parameterized.TestCase):
spectral_norm_computed = _compute_spectral_norm(normalized_kernel) spectral_norm_computed = _compute_spectral_norm(normalized_kernel)
spectral_norm_expected = self.norm_multiplier spectral_norm_expected = self.norm_multiplier
self.assertAllClose( self.assertAllClose(
spectral_norm_computed, spectral_norm_expected, atol=5e-2) spectral_norm_computed, spectral_norm_expected, atol=1e-1)
# Test that the normalized layer is K-Lipschitz. In particular, if the layer # Test that the normalized layer is K-Lipschitz. In particular, if the layer
# is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K # is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,6 +20,8 @@ import string ...@@ -20,6 +20,8 @@ import string
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
...@@ -87,7 +89,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention): ...@@ -87,7 +89,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._pre_softmax_weight = self.add_weight( self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight", "pre_softmax_weight",
shape=(self._num_heads, self._num_heads), shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer, initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer, regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint, constraint=self._kernel_constraint,
dtype=self.dtype, dtype=self.dtype,
...@@ -95,7 +97,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention): ...@@ -95,7 +97,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._post_softmax_weight = self.add_weight( self._post_softmax_weight = self.add_weight(
"post_softmax_weight", "post_softmax_weight",
shape=(self._num_heads, self._num_heads), shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer, initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer, regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint, constraint=self._kernel_constraint,
dtype=self.dtype, dtype=self.dtype,
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,14 +14,16 @@ ...@@ -14,14 +14,16 @@
"""Keras Layers for BERT-specific preprocessing.""" """Keras Layers for BERT-specific preprocessing."""
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Mapping, Optional, Text, Union
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
try: try:
# pytype: disable=import-error
import tensorflow_text as text import tensorflow_text as text
from tensorflow_text.python.ops import bert_tokenizer from tensorflow_text.python.ops import bert_tokenizer
# pytype: enable=import-error
except ImportError: except ImportError:
text = None text = None
bert_tokenizer = None bert_tokenizer = None
...@@ -57,7 +59,7 @@ def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor, ...@@ -57,7 +59,7 @@ def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
class BertTokenizer(tf.keras.layers.Layer): class BertTokenizer(tf.keras.layers.Layer):
"""Wraps BertTokenizer with pre-defined vocab as a Keras Layer. """Wraps TF.Text's BertTokenizer with pre-defined vocab as a Keras Layer.
Attributes: Attributes:
tokenize_with_offsets: If true, calls tokenize_with_offsets: If true, calls
...@@ -71,8 +73,9 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -71,8 +73,9 @@ class BertTokenizer(tf.keras.layers.Layer):
def __init__(self, *, def __init__(self, *,
vocab_file: str, vocab_file: str,
lower_case: bool, lower_case: Optional[bool] = None,
tokenize_with_offsets: bool = False, tokenize_with_offsets: bool = False,
tokenizer_kwargs: Optional[Mapping[Text, Any]] = None,
**kwargs): **kwargs):
"""Initialize a `BertTokenizer` layer. """Initialize a `BertTokenizer` layer.
...@@ -81,15 +84,18 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -81,15 +84,18 @@ class BertTokenizer(tf.keras.layers.Layer):
This is a text file with newline-separated wordpiece tokens. This is a text file with newline-separated wordpiece tokens.
This layer initializes a lookup table from it that gets used with This layer initializes a lookup table from it that gets used with
`text.BertTokenizer`. `text.BertTokenizer`.
lower_case: A Python boolean forwarded to `text.BertTokenizer`. lower_case: Optional boolean forwarded to `text.BertTokenizer`.
If true, input text is converted to lower case (where applicable) If true, input text is converted to lower case (where applicable)
before tokenization. This must be set to match the way in which before tokenization. This must be set to match the way in which
the `vocab_file` was created. the `vocab_file` was created. If passed, this overrides whatever value
may have been passed in `tokenizer_kwargs`.
tokenize_with_offsets: A Python boolean. If true, this layer calls tokenize_with_offsets: A Python boolean. If true, this layer calls
`text.BertTokenizer.tokenize_with_offsets()` instead of plain `text.BertTokenizer.tokenize_with_offsets()` instead of plain
`text.BertTokenizer.tokenize()` and outputs a triple of `text.BertTokenizer.tokenize()` and outputs a triple of
`(tokens, start_offsets, limit_offsets)` `(tokens, start_offsets, limit_offsets)`
insead of just tokens. insead of just tokens.
tokenizer_kwargs: Optional mapping with keyword arguments to forward to
`text.BertTokenizer`'s constructor.
**kwargs: Standard arguments to `Layer()`. **kwargs: Standard arguments to `Layer()`.
Raises: Raises:
...@@ -111,8 +117,11 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -111,8 +117,11 @@ class BertTokenizer(tf.keras.layers.Layer):
self._special_tokens_dict = self._create_special_tokens_dict( self._special_tokens_dict = self._create_special_tokens_dict(
self._vocab_table, vocab_file) self._vocab_table, vocab_file)
super().__init__(**kwargs) super().__init__(**kwargs)
self._bert_tokenizer = text.BertTokenizer( tokenizer_kwargs = dict(tokenizer_kwargs or {})
self._vocab_table, lower_case=lower_case) if lower_case is not None:
tokenizer_kwargs["lower_case"] = lower_case
self._bert_tokenizer = text.BertTokenizer(self._vocab_table,
**tokenizer_kwargs)
@property @property
def vocab_size(self): def vocab_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment