"convert/convert_nomicbert.go" did not exist on "77903ab8b4fb8075faad7bde5bde2eee3173e407"
Commit 5a2cf36f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into newavarecords

parents 258ddfc3 a829e648
...@@ -21,6 +21,8 @@ from __future__ import print_function ...@@ -21,6 +21,8 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import deprecation
_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"] _CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
...@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer): ...@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer):
`(batch_size, units)`. `(batch_size, units)`.
""" """
@deprecation.deprecated(
None, "DenseEinsum is deprecated. Please use "
"tf.keras.experimental.EinsumDense layer instead.")
def __init__(self, def __init__(self,
output_shape, output_shape,
num_summed_dimensions=1, num_summed_dimensions=1,
......
...@@ -26,7 +26,6 @@ import math ...@@ -26,7 +26,6 @@ import math
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import masked_softmax from official.nlp.modeling.layers import masked_softmax
...@@ -67,28 +66,26 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -67,28 +66,26 @@ class VotingAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
self._query_dense = dense_einsum.DenseEinsum( common_kwargs = dict(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
dtype=self.dtype, self._query_dense = tf.keras.layers.experimental.EinsumDense(
name="encdocatt_query") "BAE,ENH->BANH",
self._key_dense = dense_einsum.DenseEinsum( output_shape=(None, self._num_heads, self._head_size),
output_shape=(self._num_heads, self._head_size), bias_axes="NH",
kernel_initializer=self._kernel_initializer, name="query",
bias_initializer=self._bias_initializer, **common_kwargs)
kernel_regularizer=self._kernel_regularizer, self._key_dense = tf.keras.layers.experimental.EinsumDense(
bias_regularizer=self._bias_regularizer, "BAE,ENH->BANH",
activity_regularizer=self._activity_regularizer, output_shape=(None, self._num_heads, self._head_size),
kernel_constraint=self._kernel_constraint, bias_axes="NH",
bias_constraint=self._bias_constraint, name="key",
dtype=self.dtype, **common_kwargs)
name="encdocatt_key")
super(VotingAttention, self).build(unused_input_shapes) super(VotingAttention, self).build(unused_input_shapes)
def call(self, encoder_outputs, doc_attention_mask): def call(self, encoder_outputs, doc_attention_mask):
...@@ -113,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -113,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer):
class MultiChannelAttention(attention.MultiHeadAttention): class MultiChannelAttention(attention.MultiHeadAttention):
"""Multi-channel Attention layer. """Multi-channel Attention layer.
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple Introduced in, [Generating Representative Headlines for News Stories
cross-attention target sequences. ](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
target sequences.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
""" """
def _build_attention(self, qkv_rank): def build_attention(self, rank):
super(MultiChannelAttention, self)._build_attention(qkv_rank) super(MultiChannelAttention, self).build_attention(rank)
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2]) self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, inputs, attention_mask=None): def call(self,
from_tensor = inputs[0] query,
to_tensor = inputs[1] value,
doc_attention_probs = inputs[2] key=None,
context_attention_weights=None,
attention_mask=None):
if not self._built_from_signature:
self._build_from_signature(query, value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of stories) # B = batch size (number of stories)
# A = num_docs (number of docs) # A = num_docs (number of docs)
# F = `from_tensor` sequence length # F = target sequence length
# T = `to_tensor` sequence length # T = source sequence length
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, F, N ,H] # `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor) query_tensor = self._query_dense(query)
# `key_tensor` = [B, A, T, N, H] # `key_tensor` = [B, A, T, N, H]
key_tensor = self._key_dense(to_tensor) key_tensor = self._key_dense(key)
# `value_tensor` = [B, A, T, N, H] # `value_tensor` = [B, A, T, N, H]
value_tensor = self._value_dense(to_tensor) value_tensor = self._value_dense(value)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
...@@ -159,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention): ...@@ -159,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs, context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
value_tensor) value_tensor)
attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs, attention_output = tf.einsum("BNFA,BAFNH->BFNH", context_attention_weights,
context_layer) context_layer)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
return attention_output return attention_output
...@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase): ...@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2)) mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint( doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float) 2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, doc_probs], mask_data) outputs = attention_layer(
query=from_data,
value=to_data,
context_attention_weights=doc_probs,
attention_mask=mask_data)
self.assertEqual(outputs.shape, (3, 4, 8)) self.assertEqual(outputs.shape, (3, 4, 8))
......
...@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer): ...@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"hidden_size": self._hidden_size, "hidden_size": self._hidden_size,
"min_timescale": self._min_timescale, "min_timescale": self._min_timescale,
"max_timescale": self._max_timescale, "max_timescale": self._max_timescale,
"length": self._length,
} }
base_config = super(RelativePositionEmbedding, self).get_config() base_config = super(RelativePositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
...@@ -23,7 +23,6 @@ import gin ...@@ -23,7 +23,6 @@ import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention " "The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
name="self_attention") self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
...@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=self._intermediate_size, "abc,cd->abd",
activation=None, output_shape=(None, self._intermediate_size),
kernel_initializer=self._kernel_initializer, bias_axes="d",
bias_initializer=self._bias_initializer, name="intermediate",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
...@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=hidden_size, "abc,cd->abd",
kernel_initializer=self._kernel_initializer, output_shape=(None, hidden_size),
bias_initializer=self._bias_initializer, bias_axes="d",
kernel_regularizer=self._kernel_regularizer, name="output",
bias_regularizer=self._bias_regularizer, **common_kwargs)
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
...@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
target_tensor = input_tensor target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask) attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm: if self._use_layer_norm:
......
...@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
""" """
def _build_attention(self, qkv_rank): def build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations. """Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection This function overrides base class to create additional linear projection
...@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args: Args:
qkv_rank: the rank of query, key, value tensors after projection. qkv_rank: the rank of query, key, value tensors after projection.
""" """
super(TalkingHeadsAttention, self)._build_attention(qkv_rank) super(TalkingHeadsAttention, self).build_attention(qkv_rank)
# Build an equation: # Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) -> # (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
...@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype=self.dtype, dtype=self.dtype,
trainable=True) trainable=True)
def _compute_attention(self, def compute_attention(self,
query_tensor, query_tensor,
key_tensor, key_tensor,
value_tensor, value_tensor,
attention_mask=None): attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors. """Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection This function overrides base class to apply additional linear projection
......
...@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value]) output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
...@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64) num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self): def test_attention_scores(self):
...@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True) num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query]) output, coef = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40]) self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
...@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor) output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
...@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V. # Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8)) key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor) output = test_layer(
query=query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output) model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
...@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool") mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data) output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape) null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data) unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
# same. # same.
self.assertNotAllClose(output, unmasked_output) self.assertNotAllClose(output, unmasked_output)
......
...@@ -23,7 +23,6 @@ import gin ...@@ -23,7 +23,6 @@ import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import multi_channel_attention from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers.util import tf_function_if_eager from official.nlp.modeling.layers.util import tf_function_if_eager
...@@ -106,21 +105,24 @@ class Transformer(tf.keras.layers.Layer): ...@@ -106,21 +105,24 @@ class Transformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention " "The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
name="self_attention") self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
# pylint: disable=protected-access # pylint: disable=protected-access
self._attention_layer.build([input_tensor_shape] * 3) # Temporarily handling for checkpoint compatible changes.
self._attention_layer._build_from_signature(
query=input_tensor_shape, value=input_tensor_shape)
self._attention_output_dense = self._attention_layer._output_dense self._attention_output_dense = self._attention_layer._output_dense
# pylint: enable=protected-access # pylint: enable=protected-access
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
...@@ -132,17 +134,12 @@ class Transformer(tf.keras.layers.Layer): ...@@ -132,17 +134,12 @@ class Transformer(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=self._intermediate_size, "abc,cd->abd",
activation=None, output_shape=(None, self._intermediate_size),
kernel_initializer=self._kernel_initializer, bias_axes="d",
bias_initializer=self._bias_initializer, name="intermediate",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
...@@ -151,16 +148,12 @@ class Transformer(tf.keras.layers.Layer): ...@@ -151,16 +148,12 @@ class Transformer(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=hidden_size, "abc,cd->abd",
kernel_initializer=self._kernel_initializer, output_shape=(None, hidden_size),
bias_initializer=self._bias_initializer, bias_axes="d",
kernel_regularizer=self._kernel_regularizer, name="output",
bias_regularizer=self._bias_regularizer, **common_kwargs)
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
...@@ -211,9 +204,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -211,9 +204,9 @@ class Transformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
target_tensor = input_tensor target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask) attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor + attention_output = self._attention_layer_norm(target_tensor +
attention_output) attention_output)
...@@ -312,30 +305,27 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -312,30 +305,27 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self.num_attention_heads)) "heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size / self.num_attention_heads) self.attention_head_size = int(hidden_size / self.num_attention_heads)
# Self attention. common_kwargs = dict(
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
name="self_attention") # Self attention.
self.self_attention_output_dense = dense_einsum.DenseEinsum( self.self_attention = attention.CachedAttention(
output_shape=hidden_size, num_heads=self.num_attention_heads,
num_summed_dimensions=2, key_size=self.attention_head_size,
kernel_initializer=self._kernel_initializer, dropout=self.attention_dropout_rate,
bias_initializer=self._bias_initializer, name="self_attention",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer, self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
activity_regularizer=self._activity_regularizer, "abc,cd->abd",
kernel_constraint=self._kernel_constraint, output_shape=(None, hidden_size),
bias_constraint=self._bias_constraint, bias_axes="d",
name="self_attention_output") name="output",
**common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout( self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate) rate=self.dropout_rate)
self.self_attention_layer_norm = ( self.self_attention_layer_norm = (
...@@ -347,14 +337,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -347,14 +337,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
output_shape=hidden_size, output_shape=hidden_size,
kernel_initializer=self._kernel_initializer, name="attention/encdec",
bias_initializer=self._bias_initializer, **common_kwargs)
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention/encdec")
self.encdec_attention_dropout = tf.keras.layers.Dropout( self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate) rate=self.dropout_rate)
...@@ -363,29 +347,20 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -363,29 +347,20 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12)) name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
# Feed-forward projection. # Feed-forward projection.
self.intermediate_dense = dense_einsum.DenseEinsum( self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=self.intermediate_size, "abc,cd->abd",
activation=None, output_shape=(None, self.intermediate_size),
kernel_initializer=self._kernel_initializer, bias_axes="d",
bias_initializer=self._bias_initializer, name="intermediate",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self.intermediate_activation_layer = tf.keras.layers.Activation( self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation) self.intermediate_activation)
self.output_dense = dense_einsum.DenseEinsum( self.output_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=hidden_size, "abc,cd->abd",
kernel_initializer=self._kernel_initializer, output_shape=(None, hidden_size),
bias_initializer=self._bias_initializer, bias_axes="d",
kernel_regularizer=self._kernel_regularizer, name="output",
bias_regularizer=self._bias_regularizer, **common_kwargs)
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12) name="output_layer_norm", axis=-1, epsilon=1e-12)
...@@ -409,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -409,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d" % "TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs)) len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
self_attention_inputs = [input_tensor, input_tensor]
self_attention_output, cache = self.self_attention( self_attention_output, cache = self.self_attention(
self_attention_inputs, query=input_tensor,
value=input_tensor,
attention_mask=self_attention_mask, attention_mask=self_attention_mask,
cache=cache, cache=cache,
decode_loop_step=decode_loop_step) decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output) self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm( self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output) input_tensor + self_attention_output)
cross_attn_inputs = dict(
cross_attn_inputs = [self_attention_output, memory] query=self_attention_output,
value=memory,
attention_mask=attention_mask)
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities. # Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1]) cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(cross_attn_inputs, attention_mask) attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output) attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output + attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output) attention_output)
......
...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else: else:
input_tensor, attention_mask = (inputs, None) input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor] attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output)
......
...@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention): ...@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super(ValidatedAttentionLayer, self).__init__(**kwargs) super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list self.list = call_list
def call(self, inputs, attention_mask=None): def call(self, query, value, attention_mask=None):
self.list.append(True) self.list.append(True)
return super(ValidatedAttentionLayer, self).call( return super(ValidatedAttentionLayer, self).call(
inputs, attention_mask=attention_mask) query, value, attention_mask=attention_mask)
def get_config(self): def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config() config = super(ValidatedAttentionLayer, self).get_config()
......
...@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_ = new_layer([input_data, mask_data]) _ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :]) self.assertAllClose(new_output_tensor,
output_tensor[:, 0:1, :],
atol=5e-5,
rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -36,6 +37,9 @@ class BertClassifier(tf.keras.Model): ...@@ -36,6 +37,9 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes` instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated. argument. If `num_classes` is set to 1, a regression network is instantiated.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
...@@ -43,23 +47,25 @@ class BertClassifier(tf.keras.Model): ...@@ -43,23 +47,25 @@ class BertClassifier(tf.keras.Model):
num_classes: Number of classes to predict from the classification network. num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks. initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or dropout_rate: The dropout probability of the cls head.
'predictions'. use_encoder_pooler: Whether to use the pooler layer pre-defined inside
the encoder.
""" """
def __init__(self, def __init__(self,
network, network,
num_classes, num_classes,
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits',
dropout_rate=0.1, dropout_rate=0.1,
use_encoder_pooler=True,
**kwargs): **kwargs):
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._network = network
self._config = { self._config = {
'network': network, 'network': network,
'num_classes': num_classes, 'num_classes': num_classes,
'initializer': initializer, 'initializer': initializer,
'output': output, 'use_encoder_pooler': use_encoder_pooler,
} }
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
...@@ -67,22 +73,36 @@ class BertClassifier(tf.keras.Model): ...@@ -67,22 +73,36 @@ class BertClassifier(tf.keras.Model):
# when we construct the Model object at the end of init. # when we construct the Model object at the end of init.
inputs = network.inputs inputs = network.inputs
# Because we have a copy of inputs to create this Model object, we can if use_encoder_pooler:
# invoke the Network object with its own input tensors to start the Model. # Because we have a copy of inputs to create this Model object, we can
_, cls_output = network(inputs) # invoke the Network object with its own input tensors to start the Model.
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output) _, cls_output = network(inputs)
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
self.classifier = networks.Classification( self.classifier = networks.Classification(
input_width=cls_output.shape[-1], input_width=cls_output.shape[-1],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
output=output, output='logits',
name='classification') name='sentence_prediction')
predictions = self.classifier(cls_output) predictions = self.classifier(cls_output)
else:
sequence_output, _ = network(inputs)
self.classifier = layers.ClassificationHead(
inner_dim=sequence_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
name='sentence_prediction')
predictions = self.classifier(sequence_output)
super(BertClassifier, self).__init__( super(BertClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs) inputs=inputs, outputs=predictions, **kwargs)
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self): def get_config(self):
return self._config return self._config
......
...@@ -42,8 +42,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -42,8 +42,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
test_network, test_network, num_classes=num_classes)
num_classes=num_classes)
# Create a set of 2-dimensional inputs (the first dimension is implicit). # Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -89,7 +88,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -89,7 +88,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=4, initializer='zeros', output='predictions') test_network, num_classes=4, initializer='zeros')
# Create another BERT trainer via serialization and deserialization. # Create another BERT trainer via serialization and deserialization.
config = bert_trainer_model.get_config() config = bert_trainer_model.get_config()
......
...@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model): ...@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model):
instantiates the masked language model and classification networks that are instantiates the masked language model and classification networks that are
used to create the training objectives. used to create the training objectives.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. and a classification output.
...@@ -147,11 +150,9 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -147,11 +150,9 @@ class BertPretrainerV2(tf.keras.Model):
(Experimental). (Experimental).
Adds the masked language model head and optional classification heads upon the Adds the masked language model head and optional classification heads upon the
transformer encoder. When num_masked_tokens == 0, there won't be MaskedLM transformer encoder.
head.
Arguments: Arguments:
num_masked_tokens: Number of tokens to predict from the masked LM.
encoder_network: A transformer network. This network should output a encoder_network: A transformer network. This network should output a
sequence output and a classification output. sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network. If mlm_activation: The activation (if any) to use in the masked LM network. If
...@@ -169,7 +170,6 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -169,7 +170,6 @@ class BertPretrainerV2(tf.keras.Model):
def __init__( def __init__(
self, self,
num_masked_tokens: int,
encoder_network: tf.keras.Model, encoder_network: tf.keras.Model,
mlm_activation=None, mlm_activation=None,
mlm_initializer='glorot_uniform', mlm_initializer='glorot_uniform',
...@@ -179,7 +179,6 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -179,7 +179,6 @@ class BertPretrainerV2(tf.keras.Model):
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._config = { self._config = {
'encoder_network': encoder_network, 'encoder_network': encoder_network,
'num_masked_tokens': num_masked_tokens,
'mlm_initializer': mlm_initializer, 'mlm_initializer': mlm_initializer,
'classification_heads': classification_heads, 'classification_heads': classification_heads,
'name': name, 'name': name,
...@@ -195,19 +194,16 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -195,19 +194,16 @@ class BertPretrainerV2(tf.keras.Model):
raise ValueError('Classification heads should have unique names.') raise ValueError('Classification heads should have unique names.')
outputs = dict() outputs = dict()
if num_masked_tokens > 0: self.masked_lm = layers.MaskedLM(
self.masked_lm = layers.MaskedLM( embedding_table=self.encoder_network.get_embedding_table(),
embedding_table=self.encoder_network.get_embedding_table(), activation=mlm_activation,
activation=mlm_activation, initializer=mlm_initializer,
initializer=mlm_initializer, name='cls/predictions')
name='cls/predictions') masked_lm_positions = tf.keras.layers.Input(
masked_lm_positions = tf.keras.layers.Input( shape=(None,), name='masked_lm_positions', dtype=tf.int32)
shape=(num_masked_tokens,), inputs.append(masked_lm_positions)
name='masked_lm_positions', outputs['lm_output'] = self.masked_lm(
dtype=tf.int32) sequence_output, masked_positions=masked_lm_positions)
inputs.append(masked_lm_positions)
outputs['lm_output'] = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads: for cls_head in self.classification_heads:
outputs[cls_head.name] = cls_head(sequence_output) outputs[cls_head.name] = cls_head(sequence_output)
...@@ -217,7 +213,7 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -217,7 +213,7 @@ class BertPretrainerV2(tf.keras.Model):
@property @property
def checkpoint_items(self): def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed.""" """Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.encoder_network) items = dict(encoder=self.encoder_network, masked_lm=self.masked_lm)
for head in self.classification_heads: for head in self.classification_heads:
for key, item in head.checkpoint_items.items(): for key, item in head.checkpoint_items.items():
items['.'.join([head.name, key])] = item items['.'.join([head.name, key])] = item
......
...@@ -118,10 +118,9 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -118,10 +118,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_token_predictions = 2
bert_trainer_model = bert_pretrainer.BertPretrainerV2( bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network, num_masked_tokens=num_token_predictions) encoder_network=test_network)
num_token_predictions = 20
# Create a set of 2-dimensional inputs (the first dimension is implicit). # Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -145,7 +144,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -145,7 +144,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
bert_trainer_model = bert_pretrainer.BertPretrainerV2( bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network, num_masked_tokens=2) encoder_network=test_network)
# Create another BERT trainer via serialization and deserialization. # Create another BERT trainer via serialization and deserialization.
config = bert_trainer_model.get_config() config = bert_trainer_model.get_config()
......
...@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model):
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805). for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertSpanLabeler allows a user to pass in a transformer stack, and The BertSpanLabeler allows a user to pass in a transformer encoder, and
instantiates a span labeling network based on a single dense layer. instantiates a span labeling network based on a single dense layer.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
......
...@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model):
instantiates a token classification network based on the passed `num_classes` instantiates a token classification network based on the passed `num_classes`
argument. argument.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
......
...@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model):
model (at generator side) and classification networks (at discriminator side) model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives. that are used to create the training objectives.
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside __init__ and call() implements the computation.
Arguments: Arguments:
generator_network: A transformer network for generator, this network should generator_network: A transformer network for generator, this network should
output a sequence output and an optional classification output. output a sequence output and an optional classification output.
...@@ -48,7 +51,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -48,7 +51,6 @@ class ElectraPretrainer(tf.keras.Model):
num_classes: Number of classes to predict from the classification network num_classes: Number of classes to predict from the classification network
for the generator network (not used now) for the generator network (not used now)
sequence_length: Input sequence length sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM. num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used. classification networks. If None, no activation will be used.
...@@ -66,7 +68,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -66,7 +68,6 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size, vocab_size,
num_classes, num_classes,
sequence_length, sequence_length,
last_hidden_dim,
num_token_predictions, num_token_predictions,
mlm_activation=None, mlm_activation=None,
mlm_initializer='glorot_uniform', mlm_initializer='glorot_uniform',
...@@ -80,7 +81,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -80,7 +81,6 @@ class ElectraPretrainer(tf.keras.Model):
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'num_classes': num_classes, 'num_classes': num_classes,
'sequence_length': sequence_length, 'sequence_length': sequence_length,
'last_hidden_dim': last_hidden_dim,
'num_token_predictions': num_token_predictions, 'num_token_predictions': num_token_predictions,
'mlm_activation': mlm_activation, 'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer, 'mlm_initializer': mlm_initializer,
...@@ -95,7 +95,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -95,7 +95,6 @@ class ElectraPretrainer(tf.keras.Model):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.num_classes = num_classes self.num_classes = num_classes
self.sequence_length = sequence_length self.sequence_length = sequence_length
self.last_hidden_dim = last_hidden_dim
self.num_token_predictions = num_token_predictions self.num_token_predictions = num_token_predictions
self.mlm_activation = mlm_activation self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer self.mlm_initializer = mlm_initializer
...@@ -108,10 +107,15 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -108,10 +107,15 @@ class ElectraPretrainer(tf.keras.Model):
output=output_type, output=output_type,
name='generator_masked_lm') name='generator_masked_lm')
self.classification = layers.ClassificationHead( self.classification = layers.ClassificationHead(
inner_dim=last_hidden_dim, inner_dim=generator_network._config_dict['hidden_size'],
num_classes=num_classes, num_classes=num_classes,
initializer=mlm_initializer, initializer=mlm_initializer,
name='generator_classification_head') name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network._config_dict['hidden_size'],
activation=mlm_activation,
kernel_initializer=mlm_initializer,
name='discriminator_projection_head')
self.discriminator_head = tf.keras.layers.Dense( self.discriminator_head = tf.keras.layers.Dense(
units=1, kernel_initializer=mlm_initializer) units=1, kernel_initializer=mlm_initializer)
...@@ -165,7 +169,8 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -165,7 +169,8 @@ class ElectraPretrainer(tf.keras.Model):
if isinstance(disc_sequence_output, list): if isinstance(disc_sequence_output, list):
disc_sequence_output = disc_sequence_output[-1] disc_sequence_output = disc_sequence_output[-1]
disc_logits = self.discriminator_head(disc_sequence_output) disc_logits = self.discriminator_head(
self.discriminator_projection(disc_sequence_output))
disc_logits = tf.squeeze(disc_logits, axis=-1) disc_logits = tf.squeeze(disc_logits, axis=-1)
outputs = { outputs = {
...@@ -214,6 +219,12 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -214,6 +219,12 @@ class ElectraPretrainer(tf.keras.Model):
'sampled_tokens': sampled_tokens 'sampled_tokens': sampled_tokens
} }
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.discriminator_network)
return items
def get_config(self): def get_config(self):
return self._config return self._config
......
...@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=vocab_size, vocab_size=vocab_size,
num_classes=num_classes, num_classes=num_classes,
sequence_length=sequence_length, sequence_length=sequence_length,
last_hidden_dim=768,
num_token_predictions=num_token_predictions, num_token_predictions=num_token_predictions,
disallow_correct=True) disallow_correct=True)
...@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100, vocab_size=100,
num_classes=2, num_classes=2,
sequence_length=3, sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2) num_token_predictions=2)
# Create a set of 2-dimensional data tensors to feed into the model. # Create a set of 2-dimensional data tensors to feed into the model.
...@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100, vocab_size=100,
num_classes=2, num_classes=2,
sequence_length=3, sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2) num_token_predictions=2)
# Create another BERT trainer via serialization and deserialization. # Create another BERT trainer via serialization and deserialization.
......
...@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
The default values for this object are taken from the ALBERT-Base The default values for this object are taken from the ALBERT-Base
implementation described in the paper. implementation described in the paper.
*Note* that the network is constructed by Keras Functional API.
Arguments: Arguments:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width is embedding_width: The width of the word embeddings. If the embedding width is
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment