Commit 790e49e5 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into run_superglue

parents 8ab018b0 5bb827c3
...@@ -25,10 +25,10 @@ def _large_compatible_negative(tensor_type): ...@@ -25,10 +25,10 @@ def _large_compatible_negative(tensor_type):
in this module (-1e9) cannot be represented using `tf.float16`. in this module (-1e9) cannot be represented using `tf.float16`.
Args: Args:
tensor_type: a dtype to determine the type. tensor_type: A dtype to determine the type.
Returns: Returns:
a large negative number. A large negative number.
""" """
if tensor_type == tf.float16: if tensor_type == tf.float16:
return tf.float16.min return tf.float16.min
......
...@@ -44,7 +44,7 @@ def _get_norm_layer(normalization_type='no_norm', name=None): ...@@ -44,7 +44,7 @@ def _get_norm_layer(normalization_type='no_norm', name=None):
Args: Args:
normalization_type: String. The type of normalization_type, only normalization_type: String. The type of normalization_type, only
'no_norm' and 'layer_norm' are supported. `no_norm` and `layer_norm` are supported.
name: Name for the norm layer. name: Name for the norm layer.
Returns: Returns:
...@@ -89,7 +89,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer): ...@@ -89,7 +89,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
output_embed_size: Embedding size for the final embedding output. output_embed_size: Embedding size for the final embedding output.
max_sequence_length: Maximum length of input sequence. max_sequence_length: Maximum length of input sequence.
normalization_type: String. The type of normalization_type, only normalization_type: String. The type of normalization_type, only
'no_norm' and 'layer_norm' are supported. `no_norm` and `layer_norm` are supported.
initializer: The initializer to use for the embedding weights and initializer: The initializer to use for the embedding weights and
linear projection weights. linear projection weights.
dropout_rate: Dropout rate. dropout_rate: Dropout rate.
...@@ -208,10 +208,10 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -208,10 +208,10 @@ class MobileBertTransformer(tf.keras.layers.Layer):
key_query_shared_bottleneck: Whether to share linear transformation for key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries. keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks. num_feedforward_networks: Number of stacked feed-forward networks.
normalization_type: The type of normalization_type, only 'no_norm' and normalization_type: The type of normalization_type, only `no_norm` and
'layer_norm' are supported. 'no_norm' represents the element-wise `layer_norm` are supported. `no_norm` represents the element-wise
linear transformation for the student model, as suggested by the linear transformation for the student model, as suggested by the
original MobileBERT paper. 'layer_norm' is used for the teacher model. original MobileBERT paper. `layer_norm` is used for the teacher model.
initializer: The initializer to use for the embedding weights and initializer: The initializer to use for the embedding weights and
linear projection weights. linear projection weights.
**kwargs: keyword arguments. **kwargs: keyword arguments.
...@@ -346,14 +346,16 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -346,14 +346,16 @@ class MobileBertTransformer(tf.keras.layers.Layer):
"""Implementes the forward pass. """Implementes the forward pass.
Args: Args:
input_tensor: Float tensor of shape [batch_size, seq_length, hidden_size]. input_tensor: Float tensor of shape
attention_mask: (optional) int32 tensor of shape [batch_size, seq_length, `(batch_size, seq_length, hidden_size)`.
seq_length], with 1 for positions that can be attended to and 0 in attention_mask: (optional) int32 tensor of shape
positions that should not be. `(batch_size, seq_length, seq_length)`, with 1 for positions that can
be attended to and 0 in positions that should not be.
return_attention_scores: If return attention score. return_attention_scores: If return attention score.
Returns: Returns:
layer_output: Float tensor of shape [batch_size, seq_length, hidden_size]. layer_output: Float tensor of shape
`(batch_size, seq_length, hidden_size)`.
attention_scores (Optional): Only when return_attention_scores is True. attention_scores (Optional): Only when return_attention_scores is True.
Raises: Raises:
...@@ -450,8 +452,8 @@ class MobileBertMaskedLM(tf.keras.layers.Layer): ...@@ -450,8 +452,8 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
activation: The activation, if any, for the dense layer. activation: The activation, if any, for the dense layer.
initializer: The initializer for the dense layer. Defaults to a Glorot initializer: The initializer for the dense layer. Defaults to a Glorot
uniform initializer. uniform initializer.
output: The output style for this layer. Can be either 'logits' or output: The output style for this layer. Can be either `logits` or
'predictions'. `predictions`.
**kwargs: keyword arguments. **kwargs: keyword arguments.
""" """
super(MobileBertMaskedLM, self).__init__(**kwargs) super(MobileBertMaskedLM, self).__init__(**kwargs)
...@@ -527,16 +529,16 @@ class MobileBertMaskedLM(tf.keras.layers.Layer): ...@@ -527,16 +529,16 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
Args: Args:
sequence_tensor: Sequence output of `BertModel` layer of shape sequence_tensor: Sequence output of `BertModel` layer of shape
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of `(batch_size, seq_length, num_hidden)` where `num_hidden` is number of
hidden units of `BertModel` layer. hidden units of `BertModel` layer.
positions: Positions ids of tokens in sequence to mask for pretraining positions: Positions ids of tokens in sequence to mask for pretraining
of with dimension (batch_size, num_predictions) where of with dimension `(batch_size, num_predictions)` where
`num_predictions` is maximum number of tokens to mask out and predict `num_predictions` is maximum number of tokens to mask out and predict
per each sequence. per each sequence.
Returns: Returns:
Masked out sequence tensor of shape (batch_size * num_predictions, Masked out sequence tensor of shape
num_hidden). `(batch_size * num_predictions, num_hidden)`.
""" """
sequence_shape = tf.shape(sequence_tensor) sequence_shape = tf.shape(sequence_tensor)
batch_size, seq_length = sequence_shape[0], sequence_shape[1] batch_size, seq_length = sequence_shape[0], sequence_shape[1]
......
...@@ -26,8 +26,8 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -26,8 +26,8 @@ class VotingAttention(tf.keras.layers.Layer):
"""Voting Attention layer. """Voting Attention layer.
Args: Args:
num_heads: the number of attention heads. num_heads: The number of attention heads.
head_size: per-head hidden size. head_size: Per-head hidden size.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -115,7 +115,7 @@ class MultiChannelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -115,7 +115,7 @@ class MultiChannelAttention(tf.keras.layers.MultiHeadAttention):
context tensors according to the distribution among channels. context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case. `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention attention_mask: A boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions. to certain positions.
""" """
......
...@@ -77,7 +77,7 @@ class RelativePositionEmbedding(tf.keras.layers.Layer): ...@@ -77,7 +77,7 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
dimension of `inputs`. dimension of `inputs`.
Returns: Returns:
A tensor in shape of [length, hidden_size]. A tensor in shape of `(length, hidden_size)`.
""" """
if inputs is None and length is None: if inputs is None and length is None:
raise ValueError("If inputs is None, `length` must be set in " raise ValueError("If inputs is None, `length` must be set in "
...@@ -114,7 +114,7 @@ def _relative_position_bucket(relative_position, ...@@ -114,7 +114,7 @@ def _relative_position_bucket(relative_position,
the distance in tokens from the attending position to the attended-to the distance in tokens from the attending position to the attended-to
position. position.
If bidirectional=False, then positive relative positions are invalid. If `bidirectional=False`, then positive relative positions are invalid.
We use smaller buckets for small absolute relative_position and larger We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions. buckets for larger absolute relative_positions.
...@@ -127,13 +127,13 @@ def _relative_position_bucket(relative_position, ...@@ -127,13 +127,13 @@ def _relative_position_bucket(relative_position,
than the model has been trained on. than the model has been trained on.
Args: Args:
relative_position: an int32 Tensor relative_position: An int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional bidirectional: A boolean - whether the attention is bidirectional
num_buckets: an integer num_buckets: An integer
max_distance: an integer max_distance: An integer
Returns: Returns:
a Tensor with the same shape as relative_position, containing int32 A Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets) values in the range [0, num_buckets)
""" """
ret = 0 ret = 0
......
...@@ -103,10 +103,10 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -103,10 +103,10 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
segment_attention_bias: Optional trainable bias parameter added to the segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in query had when calculating the segment-based attention score used in
XLNet of shape `[num_heads, dim]`. XLNet of shape `[num_heads, dim]`.
state: Optional `Tensor` of shape [B, M, E] where M is the length of the state: Optional `Tensor` of shape `[B, M, E]` where M is the length of the
state or memory. state or memory.
If passed, this is also attended over as in Transformer XL. If passed, this is also attended over as in Transformer XL.
attention_mask: a boolean mask of shape `[B, T, S]` that prevents attention attention_mask: A boolean mask of shape `[B, T, S]` that prevents attention
to certain positions. to certain positions.
""" """
......
...@@ -21,15 +21,15 @@ from official.nlp.keras_nlp import layers ...@@ -21,15 +21,15 @@ from official.nlp.keras_nlp import layers
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class SelfAttentionMask(layers.SelfAttentionMask): class SelfAttentionMask(layers.SelfAttentionMask):
"""Create 3D attention mask from a 2D tensor mask. """Creates 3D attention mask from a 2D tensor mask.
**Warning: Please use the `keras_nlp.layers.SelfAttentionMask`.** **Warning: Please use the `keras_nlp.layers.SelfAttentionMask`.**
inputs[0]: from_tensor: 2D or 3D Tensor of shape inputs[0]: from_tensor: 2D or 3D Tensor of shape
[batch_size, from_seq_length, ...]. `(batch_size, from_seq_length, ...)`.
inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length]. inputs[1]: to_mask: int32 Tensor of shape `(batch_size, to_seq_length)`.
Returns: Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length]. Float Tensor of shape `(batch_size, from_seq_length, to_seq_length)`.
""" """
def call(self, inputs): def call(self, inputs):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization layers.
## References:
[1] Yuichi Yoshida, Takeru Miyato. Spectral Norm Regularization for Improving
the Generalizability of Deep Learning.
_arXiv preprint arXiv:1705.10941_, 2017. https://arxiv.org/abs/1705.10941
[2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida.
Spectral normalization for generative adversarial networks.
In _International Conference on Learning Representations_, 2018.
[3] Henry Gouk, Eibe Frank, Bernhard Pfahringer, Michael Cree.
Regularisation of neural networks by enforcing lipschitz continuity.
_arXiv preprint arXiv:1804.04368_, 2018. https://arxiv.org/abs/1804.04368
"""
import numpy as np
import tensorflow as tf
class SpectralNormalization(tf.keras.layers.Wrapper):
"""Implements spectral normalization for Dense layer."""
def __init__(self,
layer,
iteration=1,
norm_multiplier=0.95,
training=True,
aggregation=tf.VariableAggregation.MEAN,
inhere_layer_name=False,
**kwargs):
"""Initializer.
Args:
layer: (tf.keras.layers.Layer) A TF Keras layer to apply normalization to.
iteration: (int) The number of power iteration to perform to estimate
weight matrix's singular value.
norm_multiplier: (float) Multiplicative constant to threshold the
normalization. Usually under normalization, the singular value will
converge to this value.
training: (bool) Whether to perform power iteration to update the singular
value estimate.
aggregation: (tf.VariableAggregation) Indicates how a distributed variable
will be aggregated. Accepted values are constants defined in the class
tf.VariableAggregation.
inhere_layer_name: (bool) Whether to inhere the name of the input layer.
**kwargs: (dict) Other keyword arguments for the layers.Wrapper class.
"""
self.iteration = iteration
self.do_power_iteration = training
self.aggregation = aggregation
self.norm_multiplier = norm_multiplier
# Set layer name.
wrapper_name = kwargs.pop('name', None)
if inhere_layer_name:
wrapper_name = layer.name
if not isinstance(layer, tf.keras.layers.Layer):
raise ValueError('`layer` must be a `tf.keras.layer.Layer`. '
'Observed `{}`'.format(layer))
super(SpectralNormalization, self).__init__(
layer, name=wrapper_name, **kwargs)
def build(self, input_shape):
super(SpectralNormalization, self).build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight(
shape=(1, np.prod(self.w_shape[:-1])),
initializer=self.uv_initializer,
trainable=False,
name='v',
dtype=self.dtype,
aggregation=self.aggregation)
self.u = self.add_weight(
shape=(1, self.w_shape[-1]),
initializer=self.uv_initializer,
trainable=False,
name='u',
dtype=self.dtype,
aggregation=self.aggregation)
self.update_weights()
def call(self, inputs, *, training=None):
training = self.do_power_iteration if training is None else training
u_update_op, v_update_op, w_update_op = self.update_weights(
training=training)
output = self.layer(inputs)
w_restore_op = self.restore_weights()
# Register update ops.
self.add_update(u_update_op)
self.add_update(v_update_op)
self.add_update(w_update_op)
self.add_update(w_restore_op)
return output
def update_weights(self, *, training=True):
w_reshaped = tf.reshape(self.w, [-1, self.w_shape[-1]])
u_hat = self.u
v_hat = self.v
if training:
for _ in range(self.iteration):
v_hat = tf.nn.l2_normalize(tf.matmul(u_hat, tf.transpose(w_reshaped)))
u_hat = tf.nn.l2_normalize(tf.matmul(v_hat, w_reshaped))
sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
# Convert sigma from a 1x1 matrix to a scalar.
sigma = tf.reshape(sigma, [])
u_update_op = self.u.assign(u_hat)
v_update_op = self.v.assign(v_hat)
# Bound spectral norm to be not larger than self.norm_multiplier.
w_norm = tf.cond((self.norm_multiplier / sigma) < 1, lambda: # pylint:disable=g-long-lambda
(self.norm_multiplier / sigma) * self.w, lambda: self.w)
w_update_op = self.layer.kernel.assign(w_norm)
return u_update_op, v_update_op, w_update_op
def restore_weights(self):
"""Restores layer weights to maintain gradient update (See Alg 1 of [1])."""
return self.layer.kernel.assign(self.w)
class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
"""Implements spectral normalization for Conv2D layer based on [3]."""
def __init__(self,
layer,
iteration=1,
norm_multiplier=0.95,
training=True,
aggregation=tf.VariableAggregation.MEAN,
legacy_mode=False,
**kwargs):
"""Initializer.
Args:
layer: (tf.keras.layers.Layer) A TF Keras layer to apply normalization to.
iteration: (int) The number of power iteration to perform to estimate
weight matrix's singular value.
norm_multiplier: (float) Multiplicative constant to threshold the
normalization. Usually under normalization, the singular value will
converge to this value.
training: (bool) Whether to perform power iteration to update the singular
value estimate.
aggregation: (tf.VariableAggregation) Indicates how a distributed variable
will be aggregated. Accepted values are constants defined in the class
tf.VariableAggregation.
legacy_mode: (bool) Whether to use the legacy implementation where the
dimension of the u and v vectors are set to the batch size. It should
not be enabled unless for backward compatibility reasons.
**kwargs: (dict) Other keyword arguments for the layers.Wrapper class.
"""
self.iteration = iteration
self.do_power_iteration = training
self.aggregation = aggregation
self.norm_multiplier = norm_multiplier
self.legacy_mode = legacy_mode
# Set layer attributes.
layer._name += '_spec_norm'
if not isinstance(layer, tf.keras.layers.Conv2D):
raise ValueError(
'layer must be a `tf.keras.layer.Conv2D` instance. You passed: {input}'
.format(input=layer))
super(SpectralNormalizationConv2D, self).__init__(layer, **kwargs)
def build(self, input_shape):
self.layer.build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype
# Shape (kernel_size_1, kernel_size_2, in_channel, out_channel).
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.strides = self.layer.strides
# Set the dimensions of u and v vectors.
batch_size = input_shape[0]
uv_dim = batch_size if self.legacy_mode else 1
# Resolve shapes.
in_height = input_shape[1]
in_width = input_shape[2]
in_channel = self.w_shape[2]
out_height = in_height // self.strides[0]
out_width = in_width // self.strides[1]
out_channel = self.w_shape[3]
self.in_shape = (uv_dim, in_height, in_width, in_channel)
self.out_shape = (uv_dim, out_height, out_width, out_channel)
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight(
shape=self.in_shape,
initializer=self.uv_initializer,
trainable=False,
name='v',
dtype=self.dtype,
aggregation=self.aggregation)
self.u = self.add_weight(
shape=self.out_shape,
initializer=self.uv_initializer,
trainable=False,
name='u',
dtype=self.dtype,
aggregation=self.aggregation)
super(SpectralNormalizationConv2D, self).build()
def call(self, inputs):
u_update_op, v_update_op, w_update_op = self.update_weights()
output = self.layer(inputs)
w_restore_op = self.restore_weights()
# Register update ops.
self.add_update(u_update_op)
self.add_update(v_update_op)
self.add_update(w_update_op)
self.add_update(w_restore_op)
return output
def update_weights(self):
"""Computes power iteration for convolutional filters based on [3]."""
# Initialize u, v vectors.
u_hat = self.u
v_hat = self.v
if self.do_power_iteration:
for _ in range(self.iteration):
# Updates v.
v_ = tf.nn.conv2d_transpose(
u_hat,
self.w,
output_shape=self.in_shape,
strides=self.strides,
padding='SAME')
v_hat = tf.nn.l2_normalize(tf.reshape(v_, [1, -1]))
v_hat = tf.reshape(v_hat, v_.shape)
# Updates u.
u_ = tf.nn.conv2d(v_hat, self.w, strides=self.strides, padding='SAME')
u_hat = tf.nn.l2_normalize(tf.reshape(u_, [1, -1]))
u_hat = tf.reshape(u_hat, u_.shape)
v_w_hat = tf.nn.conv2d(v_hat, self.w, strides=self.strides, padding='SAME')
sigma = tf.matmul(tf.reshape(v_w_hat, [1, -1]), tf.reshape(u_hat, [-1, 1]))
# Convert sigma from a 1x1 matrix to a scalar.
sigma = tf.reshape(sigma, [])
u_update_op = self.u.assign(u_hat)
v_update_op = self.v.assign(v_hat)
w_norm = tf.cond((self.norm_multiplier / sigma) < 1, lambda: # pylint:disable=g-long-lambda
(self.norm_multiplier / sigma) * self.w, lambda: self.w)
w_update_op = self.layer.kernel.assign(w_norm)
return u_update_op, v_update_op, w_update_op
def restore_weights(self):
"""Restores layer weights to maintain gradient update (See Alg 1 of [1])."""
return self.layer.kernel.assign(self.w)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for normalization layers.
## References:
[1] Hanie Sedghi, Vineet Gupta, Philip M. Long.
The Singular Values of Convolutional Layers.
In _International Conference on Learning Representations_, 2019.
"""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import spectral_normalization
DenseLayer = tf.keras.layers.Dense(10)
Conv2DLayer = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='valid')
def _compute_spectral_norm(weight):
if weight.ndim > 2:
# Computes Conv2D via FFT transform as in [1].
weight = np.fft.fft2(weight, weight.shape[1:3], axes=[0, 1])
return np.max(np.linalg.svd(weight, compute_uv=False))
class NormalizationTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(NormalizationTest, self).setUp()
self.num_iterations = 1000
self.norm_multiplier = 0.95
@parameterized.named_parameters(
('Dense',
(None, 10), DenseLayer, spectral_normalization.SpectralNormalization),
('Conv2D', (None, 32, 32, 3), Conv2DLayer,
spectral_normalization.SpectralNormalizationConv2D))
def test_spec_norm_magnitude(self, input_shape, layer, norm_wrapper):
"""Tests if the weights spectral norm converges to norm_multiplier."""
layer.build(input_shape)
sn_layer = norm_wrapper(
layer,
iteration=self.num_iterations,
norm_multiplier=self.norm_multiplier)
# Perform normalization.
sn_layer.build(input_shape)
sn_layer.update_weights()
normalized_kernel = sn_layer.layer.kernel.numpy()
spectral_norm_computed = _compute_spectral_norm(normalized_kernel)
spectral_norm_expected = self.norm_multiplier
self.assertAllClose(
spectral_norm_computed, spectral_norm_expected, atol=5e-2)
# Test that the normalized layer is K-Lipschitz. In particular, if the layer
# is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K
# is the norm multiplier.
new_input_shape = (16,) + input_shape[1:]
new_input = tf.random.uniform(new_input_shape)
delta_vec = tf.random.uniform(new_input_shape)
output1 = sn_layer(new_input)
output2 = sn_layer(new_input + delta_vec)
delta_input = tf.norm(tf.reshape(delta_vec, (-1,))).numpy()
delta_output = tf.norm(tf.reshape(output2 - output1, (-1,))).numpy()
self.assertLessEqual(delta_output, self.norm_multiplier * delta_input)
if __name__ == '__main__':
tf.test.main()
...@@ -63,7 +63,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention): ...@@ -63,7 +63,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
that will be applied on attention scores before and after softmax. that will be applied on attention scores before and after softmax.
Args: Args:
qkv_rank: the rank of query, key, value tensors after projection. qkv_rank: The rank of query, key, value tensors after projection.
""" """
super(TalkingHeadsAttention, self)._build_attention(qkv_rank) super(TalkingHeadsAttention, self)._build_attention(qkv_rank)
......
...@@ -100,10 +100,10 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -100,10 +100,10 @@ class BertTokenizer(tf.keras.layers.Layer):
tokenize_with_offsets: If true, calls tokenize_with_offsets: If true, calls
`text.BertTokenizer.tokenize_with_offsets()` instead of plain `text.BertTokenizer.tokenize_with_offsets()` instead of plain
`text.BertTokenizer.tokenize()` and outputs a triple of `text.BertTokenizer.tokenize()` and outputs a triple of
(tokens, start_offsets, limit_offsets). `(tokens, start_offsets, limit_offsets)`.
raw_table_access: An object with methods .lookup(keys) and .size() raw_table_access: An object with methods `.lookup(keys) and `.size()`
that operate on the raw lookup table of tokens. It can be used to that operate on the raw lookup table of tokens. It can be used to
look up special token synbols like [MASK]. look up special token synbols like `[MASK]`.
""" """
def __init__(self, *, def __init__(self, *,
...@@ -121,16 +121,16 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -121,16 +121,16 @@ class BertTokenizer(tf.keras.layers.Layer):
lower_case: A Python boolean forwarded to `text.BertTokenizer`. lower_case: A Python boolean forwarded to `text.BertTokenizer`.
If true, input text is converted to lower case (where applicable) If true, input text is converted to lower case (where applicable)
before tokenization. This must be set to match the way in which before tokenization. This must be set to match the way in which
the vocab_file was created. the `vocab_file` was created.
tokenize_with_offsets: A Python boolean. If true, this layer calls tokenize_with_offsets: A Python boolean. If true, this layer calls
`text.BertTokenizer.tokenize_with_offsets()` instead of plain `text.BertTokenizer.tokenize_with_offsets()` instead of plain
`text.BertTokenizer.tokenize()` and outputs a triple of `text.BertTokenizer.tokenize()` and outputs a triple of
(tokens, start_offsets, limit_offsets) `(tokens, start_offsets, limit_offsets)`
insead of just tokens. insead of just tokens.
**kwargs: standard arguments to Layer(). **kwargs: Standard arguments to `Layer()`.
Raises: Raises:
ImportError: if importing `tensorflow_text` failed. ImportError: If importing `tensorflow_text` failed.
""" """
_check_if_tf_text_installed() _check_if_tf_text_installed()
...@@ -167,17 +167,18 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -167,17 +167,18 @@ class BertTokenizer(tf.keras.layers.Layer):
"""Calls `text.BertTokenizer` on inputs. """Calls `text.BertTokenizer` on inputs.
Args: Args:
inputs: A string Tensor of shape [batch_size]. inputs: A string Tensor of shape `(batch_size,)`.
Returns: Returns:
One or three of `RaggedTensors` if `tokenize_with_offsets` is False or One or three of `RaggedTensors` if `tokenize_with_offsets` is False or
True, respectively. These are True, respectively. These are
tokens: A `RaggedTensor` of shape [batch_size, (words), (pieces_per_word)] tokens: A `RaggedTensor` of shape
and type int32. tokens[i,j,k] contains the k-th wordpiece of the `[batch_size, (words), (pieces_per_word)]`
and type int32. `tokens[i,j,k]` contains the k-th wordpiece of the
j-th word in the i-th input. j-th word in the i-th input.
start_offsets, limit_offsets: If `tokenize_with_offsets` is True, start_offsets, limit_offsets: If `tokenize_with_offsets` is True,
RaggedTensors of type int64 with the same indices as tokens. RaggedTensors of type int64 with the same indices as tokens.
Element [i,j,k] contains the byte offset at the start, or past the Element `[i,j,k]` contains the byte offset at the start, or past the
end, resp., for the k-th wordpiece of the j-th word in the i-th input. end, resp., for the k-th wordpiece of the j-th word in the i-th input.
""" """
# Prepare to reshape the result to work around broken shape inference. # Prepare to reshape the result to work around broken shape inference.
...@@ -201,12 +202,7 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -201,12 +202,7 @@ class BertTokenizer(tf.keras.layers.Layer):
def get_config(self): def get_config(self):
# Skip in tf.saved_model.save(); fail if called direcly. # Skip in tf.saved_model.save(); fail if called direcly.
# TODO(arnoegw): Implement when switching to MutableHashTable, which gets raise NotImplementedError("TODO(b/170480226): implement")
# initialized from the checkpoint and not from a vocab file.
# We cannot just put the original, user-supplied vocab file name into
# the config, because the path has to change as the SavedModel is copied
# around.
raise NotImplementedError("Not implemented yet.")
def get_special_tokens_dict(self): def get_special_tokens_dict(self):
"""Returns dict of token ids, keyed by standard names for their purpose. """Returns dict of token ids, keyed by standard names for their purpose.
...@@ -268,13 +264,13 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -268,13 +264,13 @@ class BertTokenizer(tf.keras.layers.Layer):
class SentencepieceTokenizer(tf.keras.layers.Layer): class SentencepieceTokenizer(tf.keras.layers.Layer):
"""Wraps tf_text.SentencepieceTokenizer as a Keras Layer. """Wraps `tf_text.SentencepieceTokenizer` as a Keras Layer.
Attributes: Attributes:
tokenize_with_offsets: If true, calls tokenize_with_offsets: If true, calls
SentencepieceTokenizer.tokenize_with_offsets() `SentencepieceTokenizer.tokenize_with_offsets()`
instead of plain .tokenize() and outputs a triple of instead of plain `.tokenize()` and outputs a triple of
(tokens, start_offsets, limit_offsets). `(tokens, start_offsets, limit_offsets)`.
""" """
def __init__(self, def __init__(self,
...@@ -300,9 +296,9 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -300,9 +296,9 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
store the actual proto (not a filename passed here). store the actual proto (not a filename passed here).
model_serialized_proto: The sentencepiece model serialized proto string. model_serialized_proto: The sentencepiece model serialized proto string.
tokenize_with_offsets: A Python boolean. If true, this layer calls tokenize_with_offsets: A Python boolean. If true, this layer calls
SentencepieceTokenizer.tokenize_with_offsets() instead of `SentencepieceTokenizer.tokenize_with_offsets()` instead of
plain .tokenize() and outputs a triple of plain `.tokenize()` and outputs a triple of
(tokens, start_offsets, limit_offsets) insead of just tokens. `(tokens, start_offsets, limit_offsets)` insead of just tokens.
Note that when following `strip_diacritics` is set to True, returning Note that when following `strip_diacritics` is set to True, returning
offsets is not supported now. offsets is not supported now.
nbest_size: A scalar for sampling: nbest_size: A scalar for sampling:
...@@ -320,7 +316,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -320,7 +316,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
`tokenize_with_offsets`. NOTE: New models are encouraged to put this `tokenize_with_offsets`. NOTE: New models are encouraged to put this
into custom normalization rules for the Sentencepiece model itself to into custom normalization rules for the Sentencepiece model itself to
avoid this extra step and the limitation regarding offsets. avoid this extra step and the limitation regarding offsets.
**kwargs: standard arguments to Layer(). **kwargs: standard arguments to `Layer()`.
Raises: Raises:
ImportError: if importing tensorflow_text failed. ImportError: if importing tensorflow_text failed.
...@@ -360,19 +356,19 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -360,19 +356,19 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
return self._tokenizer.vocab_size() return self._tokenizer.vocab_size()
def call(self, inputs: tf.Tensor): def call(self, inputs: tf.Tensor):
"""Calls text.SentencepieceTokenizer on inputs. """Calls `text.SentencepieceTokenizer` on inputs.
Args: Args:
inputs: A string Tensor of shape [batch_size]. inputs: A string Tensor of shape `(batch_size,)`.
Returns: Returns:
One or three of RaggedTensors if tokenize_with_offsets is False or True, One or three of RaggedTensors if tokenize_with_offsets is False or True,
respectively. These are respectively. These are
tokens: A RaggedTensor of shape [batch_size, (pieces)] and type int32. tokens: A RaggedTensor of shape `[batch_size, (pieces)]` and type `int32`.
tokens[i,j] contains the j-th piece in the i-th input. `tokens[i,j]` contains the j-th piece in the i-th input.
start_offsets, limit_offsets: If tokenize_with_offsets is True, start_offsets, limit_offsets: If `tokenize_with_offsets` is True,
RaggedTensors of type int64 with the same indices as tokens. RaggedTensors of type `int64` with the same indices as tokens.
Element [i,j] contains the byte offset at the start, or past the Element `[i,j]` contains the byte offset at the start, or past the
end, resp., for the j-th piece in the i-th input. end, resp., for the j-th piece in the i-th input.
""" """
if self._strip_diacritics: if self._strip_diacritics:
...@@ -403,19 +399,8 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -403,19 +399,8 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
return _reshape(tokens) return _reshape(tokens)
def get_config(self): def get_config(self):
raise NotImplementedError("b/170480226") # Skip in tf.saved_model.save(); fail if called direcly.
# TODO(b/170480226): Uncomment and improve to fix the bug. raise NotImplementedError("TODO(b/170480226): implement")
# config = {
# "model_serialized_proto": self._model_serialized_proto,
# "lower_case": self._lower_case,
# "tokenize_with_offsets": self.tokenize_with_offsets,
# "nbest_size": self._nbest_size,
# "alpha": self._alpha,
# "strip_diacritics": self._strip_diacritics,
# }
# base_config = super(SentencepieceTokenizer, self).get_config()
# base_config.update(config)
# return base_config
def get_special_tokens_dict(self): def get_special_tokens_dict(self):
"""Returns dict of token ids, keyed by standard names for their purpose. """Returns dict of token ids, keyed by standard names for their purpose.
...@@ -492,7 +477,7 @@ class BertPackInputs(tf.keras.layers.Layer): ...@@ -492,7 +477,7 @@ class BertPackInputs(tf.keras.layers.Layer):
special_tokens_dict=None, special_tokens_dict=None,
truncator="round_robin", truncator="round_robin",
**kwargs): **kwargs):
"""Initializes with a target seq_length, relevant token ids and truncator. """Initializes with a target `seq_length`, relevant token ids and truncator.
Args: Args:
seq_length: The desired output length. Must not exceed the max_seq_length seq_length: The desired output length. Must not exceed the max_seq_length
...@@ -505,13 +490,13 @@ class BertPackInputs(tf.keras.layers.Layer): ...@@ -505,13 +490,13 @@ class BertPackInputs(tf.keras.layers.Layer):
unused positions after the last segment in the sequence unused positions after the last segment in the sequence
(called "[PAD]" for BERT). (called "[PAD]" for BERT).
special_tokens_dict: Optionally, a dict from Python strings to Python special_tokens_dict: Optionally, a dict from Python strings to Python
integers that contains values for start_of_sequence_id, integers that contains values for `start_of_sequence_id`,
end_of_segment_id and padding_id. (Further values in the dict are `end_of_segment_id` and `padding_id`. (Further values in the dict are
silenty ignored.) If this is passed, separate *_id arguments must be silenty ignored.) If this is passed, separate *_id arguments must be
omitted. omitted.
truncator: The algorithm to truncate a list of batched segments to fit a truncator: The algorithm to truncate a list of batched segments to fit a
per-example length limit. The value can be either "round_robin" or per-example length limit. The value can be either `round_robin` or
"waterfall": `waterfall`:
(1) For "round_robin" algorithm, available space is assigned (1) For "round_robin" algorithm, available space is assigned
one token at a time in a round-robin fashion to the inputs that still one token at a time in a round-robin fashion to the inputs that still
need some, until the limit is reached. It currently only supports need some, until the limit is reached. It currently only supports
...@@ -521,10 +506,10 @@ class BertPackInputs(tf.keras.layers.Layer): ...@@ -521,10 +506,10 @@ class BertPackInputs(tf.keras.layers.Layer):
left-to-right manner and fills up the buckets until we run out of left-to-right manner and fills up the buckets until we run out of
budget. It support arbitrary number of segments. budget. It support arbitrary number of segments.
**kwargs: standard arguments to Layer(). **kwargs: standard arguments to `Layer()`.
Raises: Raises:
ImportError: if importing tensorflow_text failed. ImportError: if importing `tensorflow_text` failed.
""" """
_check_if_tf_text_installed() _check_if_tf_text_installed()
super().__init__(**kwargs) super().__init__(**kwargs)
......
...@@ -37,8 +37,8 @@ class TNExpandCondense(Layer): ...@@ -37,8 +37,8 @@ class TNExpandCondense(Layer):
Note the input shape and output shape will be identical. Note the input shape and output shape will be identical.
Args: Args:
proj_multiplier: Positive integer, multiple of input_shape[-1] to project proj_multiplier: Positive integer, multiple of `input_shape[-1]` to project
up to. Must be one of [2, 4, 6, 8]. up to. Must be one of `[2, 4, 6, 8]`.
use_bias: Boolean, whether the layer uses a bias vector. use_bias: Boolean, whether the layer uses a bias vector.
activation: Activation function to use between Expand and Condense. If you activation: Activation function to use between Expand and Condense. If you
don't specify anything, no activation is applied don't specify anything, no activation is applied
......
...@@ -232,7 +232,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -232,7 +232,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", name="self_attention_layer_norm",
axis=-1, axis=-1,
epsilon=self._norm_epsilon)) epsilon=self._norm_epsilon,
dtype="float32"))
# Encoder-decoder attention. # Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls( self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
...@@ -250,7 +251,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -250,7 +251,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", name="attention/encdec_output_layer_norm",
axis=-1, axis=-1,
epsilon=self._norm_epsilon)) epsilon=self._norm_epsilon,
dtype="float32"))
# Feed-forward projection. # Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense( self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
...@@ -273,7 +275,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -273,7 +275,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
**common_kwargs) **common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon) name="output_layer_norm", axis=-1,
epsilon=self._norm_epsilon, dtype="float32")
super().build(input_shape) super().build(input_shape)
def get_config(self): def get_config(self):
......
...@@ -112,8 +112,9 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -112,8 +112,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, input_shape): def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape input_tensor_shape = input_shape[0] if (
input_tensor_shape = tf.TensorShape(input_tensor) len(input_shape) == 2) else input_shape
input_tensor_shape = tf.TensorShape(input_tensor_shape)
if len(input_tensor_shape.as_list()) != 3: if len(input_tensor_shape.as_list()) != 3:
raise ValueError( raise ValueError(
"TransformerScaffold expects a three-dimensional input of " "TransformerScaffold expects a three-dimensional input of "
...@@ -170,6 +171,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -170,6 +171,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else: else:
self._feedforward_block = None self._feedforward_block = None
# self._dropout_rate controls dropout rates at two places:
# after attention, and after FFN.
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet. # It is probably safe in mixed_float16, but we haven't validated this yet.
......
...@@ -12,5 +12,5 @@ ...@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Activations package definition. Subject to change.""" """Losses contains common loss computation used in NLP (subject to change)."""
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss
# Models # Models
Models are combinations of layers and networks that would be trained. Models are combinations of `tf.keras` layers and models that can be trained.
Several pre-built canned models are provided to train encoder networks. These Several pre-built canned models are provided to train encoder networks.
models are intended as both convenience functions and canonical examples. These models are intended as both convenience functions and canonical examples.
* [`BertClassifier`](bert_classifier.py) implements a simple classification * [`BertClassifier`](bert_classifier.py) implements a simple classification
model containing a single classification head using the Classification network. model containing a single classification head using the Classification network.
......
...@@ -12,7 +12,11 @@ ...@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Models package definition.""" """Models are combinations of `tf.keras` layers and models that can be trained.
Several pre-built canned models are provided to train encoder networks.
These models are intended as both convenience functions and canonical examples.
"""
from official.nlp.modeling.models.bert_classifier import BertClassifier from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import * from official.nlp.modeling.models.bert_pretrainer import *
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
......
...@@ -50,8 +50,8 @@ class BertPretrainer(tf.keras.Model): ...@@ -50,8 +50,8 @@ class BertPretrainer(tf.keras.Model):
None, no activation will be used. None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
......
...@@ -37,11 +37,11 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -37,11 +37,11 @@ class BertSpanLabeler(tf.keras.Model):
Args: Args:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a `get_embedding_table` method.
initializer: The initializer (if any) to use in the span labeling network. initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logit`' or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
......
...@@ -36,12 +36,15 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -36,12 +36,15 @@ class BertTokenClassifier(tf.keras.Model):
Args: Args:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a `get_embedding_table` method.
num_classes: Number of classes to predict from the classification network. num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks. initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
dropout_rate: The dropout probability of the token classification head.
output_encoder_outputs: Whether to include intermediate sequence output
in the final output.
""" """
def __init__(self, def __init__(self,
...@@ -50,6 +53,7 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -50,6 +53,7 @@ class BertTokenClassifier(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
dropout_rate=0.1, dropout_rate=0.1,
output_encoder_outputs=False,
**kwargs): **kwargs):
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
...@@ -74,14 +78,19 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -74,14 +78,19 @@ class BertTokenClassifier(tf.keras.Model):
name='predictions/transform/logits') name='predictions/transform/logits')
logits = classifier(sequence_output) logits = classifier(sequence_output)
if output == 'logits': if output == 'logits':
output_tensors = logits output_tensors = {'logits': logits}
elif output == 'predictions': elif output == 'predictions':
output_tensors = tf.keras.layers.Activation(tf.nn.log_softmax)(logits) output_tensors = {
'predictions': tf.keras.layers.Activation(tf.nn.log_softmax)(logits)
}
else: else:
raise ValueError( raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or ' ('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output) '"predictions"') % output)
if output_encoder_outputs:
output_tensors['encoder_outputs'] = sequence_output
# b/164516224 # b/164516224
# Once we've created the network using the Functional API, we call # Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model # super().__init__ as though we were invoking the Functional API Model
...@@ -98,6 +107,7 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -98,6 +107,7 @@ class BertTokenClassifier(tf.keras.Model):
'num_classes': num_classes, 'num_classes': num_classes,
'initializer': initializer, 'initializer': initializer,
'output': output, 'output': output,
'output_encoder_outputs': output_encoder_outputs
} }
# We are storing the config dict as a namedtuple here to ensure checkpoint # We are storing the config dict as a namedtuple here to ensure checkpoint
......
...@@ -27,22 +27,26 @@ from official.nlp.modeling.models import bert_token_classifier ...@@ -27,22 +27,26 @@ from official.nlp.modeling.models import bert_token_classifier
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class BertTokenClassifierTest(keras_parameterized.TestCase): class BertTokenClassifierTest(keras_parameterized.TestCase):
@parameterized.parameters(True, False) @parameterized.parameters((True, True), (False, False))
def test_bert_trainer(self, dict_outputs): def test_bert_trainer(self, dict_outputs, output_encoder_outputs):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
hidden_size = 768
test_network = networks.BertEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
max_sequence_length=sequence_length, max_sequence_length=sequence_length,
dict_outputs=dict_outputs) dict_outputs=dict_outputs,
hidden_size=hidden_size)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_classes = 3 num_classes = 3
bert_trainer_model = bert_token_classifier.BertTokenClassifier( bert_trainer_model = bert_token_classifier.BertTokenClassifier(
test_network, num_classes=num_classes) test_network,
num_classes=num_classes,
output_encoder_outputs=output_encoder_outputs)
# Create a set of 2-dimensional inputs (the first dimension is implicit). # Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -50,12 +54,18 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -50,12 +54,18 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
sequence_outs = bert_trainer_model([word_ids, mask, type_ids]) outputs = bert_trainer_model([word_ids, mask, type_ids])
if output_encoder_outputs:
logits = outputs['logits']
encoder_outputs = outputs['encoder_outputs']
self.assertAllEqual(encoder_outputs.shape.as_list(),
[None, sequence_length, hidden_size])
else:
logits = outputs['logits']
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_classification_shape = [None, sequence_length, num_classes] expected_classification_shape = [None, sequence_length, num_classes]
self.assertAllEqual(expected_classification_shape, self.assertAllEqual(expected_classification_shape, logits.shape.as_list())
sequence_outs.shape.as_list())
def test_bert_trainer_tensor_call(self): def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment