Commit 88253ce5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 52371ffe
...@@ -95,9 +95,9 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase): ...@@ -95,9 +95,9 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
input_data = np.random.rand(2, input_length, width) + 2.0 input_data = np.random.rand(2, input_length, width) + 2.0
output_data = model.predict(input_data) output_data = model.predict(input_data)
input_data_normed = ( input_data_normed = (input_data -
input_data - np.mean(input_data, axis=-1, keepdims=True)) / ( np.mean(input_data, axis=-1, keepdims=True)) / (
np.std(input_data, axis=-1, keepdims=True)) np.std(input_data, axis=-1, keepdims=True))
self.assertAllClose(input_data_normed, output_data) self.assertAllClose(input_data_normed, output_data)
......
...@@ -46,28 +46,25 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -46,28 +46,25 @@ class TransformerScaffold(tf.keras.layers.Layer):
attention_cfg: The config with which to instantiate `attention_cls`. Ignored attention_cfg: The config with which to instantiate `attention_cls`. Ignored
if attention_cls is a layer instance or None. If `attention_cls` is a if attention_cls is a layer instance or None. If `attention_cls` is a
class, but `attention_cfg` is None, following kwargs will be used to class, but `attention_cfg` is None, following kwargs will be used to
instantiate the attention instance: instantiate the attention instance: {
{
"num_heads": num_attention_heads, "num_heads": num_attention_heads,
"key_size": int(hidden_size // num_attention_heads), "key_size": int(hidden_size // num_attention_heads),
"dropout": attention_dropout_rate, "dropout": attention_dropout_rate,
"name": "self_attention" "name": "self_attention" }, where `hidden_size` is the input tensor's
}, where `hidden_size` is the input tensor's last dimension. last dimension.
feedforward_cls: A class to instantiate feedforward layer, or a layer feedforward_cls: A class to instantiate feedforward layer, or a layer
instance. If None, will use the standard feedforward layer as described instance. If None, will use the standard feedforward layer as described in
in "Attention Is All You Need" paper. If not None, the instantiated "Attention Is All You Need" paper. If not None, the instantiated
feedforward layer is expected to take the output of attention as input feedforward layer is expected to take the output of attention as input and
and its output is this transformer layer's output. its output is this transformer layer's output.
feedforward_cfg: The config with which to instantiate `feedforward_cls`. feedforward_cfg: The config with which to instantiate `feedforward_cls`.
Ignored if feedforward_cls is a layer instance or is None. Ignored if feedforward_cls is a layer instance or is None. If
If `feedforward_cls` is a class, but `feedforward_cfg` is None, following `feedforward_cls` is a class, but `feedforward_cfg` is None, following
kwargs will be used to instantiate the feedforward instance: kwargs will be used to instantiate the feedforward instance: {
{
"intermediate_size": intermediate_size, "intermediate_size": intermediate_size,
"intermediate_activation": intermediate_activation, "intermediate_activation": intermediate_activation,
"dropout": dropout_rate, "dropout": dropout_rate,
"name": "feedforward" "name": "feedforward" }.
}.
dropout_rate: Dropout probability for the post-attention and output dropout. dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer. attention_dropout_rate: Dropout probability for within the attention layer.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
...@@ -190,7 +187,9 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -190,7 +187,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
# It is probably safe in mixed_float16, but we haven't validated this yet. # It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = ( self._attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12, name="self_attention_layer_norm",
axis=-1,
epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
if self._feedforward_block is None: if self._feedforward_block is None:
......
...@@ -233,8 +233,8 @@ class TransformerArgumentTest(keras_parameterized.TestCase): ...@@ -233,8 +233,8 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
norm_first=True, norm_first=True,
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.1, intermediate_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(minval=0., attention_initializer=tf.keras.initializers.RandomUniform(
maxval=1.)) minval=0., maxval=1.))
# Forward path. # Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32) dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32) dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
...@@ -254,8 +254,8 @@ class TransformerArgumentTest(keras_parameterized.TestCase): ...@@ -254,8 +254,8 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
norm_first=True, norm_first=True,
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.1, intermediate_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(minval=0., attention_initializer=tf.keras.initializers.RandomUniform(
maxval=1.)) minval=0., maxval=1.))
encoder_block_config = encoder_block.get_config() encoder_block_config = encoder_block.get_config()
new_encoder_block = transformer.Transformer.from_config( new_encoder_block = transformer.Transformer.from_config(
encoder_block_config) encoder_block_config)
...@@ -308,8 +308,8 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase): ...@@ -308,8 +308,8 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
norm_first=True, norm_first=True,
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.1, intermediate_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(minval=0., attention_initializer=tf.keras.initializers.RandomUniform(
maxval=1.)) minval=0., maxval=1.))
# Forward path. # Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32) dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32) dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
...@@ -329,8 +329,8 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase): ...@@ -329,8 +329,8 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
norm_first=True, norm_first=True,
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.1, intermediate_dropout=0.1,
attention_initializer=tf.keras.initializers.RandomUniform(minval=0., attention_initializer=tf.keras.initializers.RandomUniform(
maxval=1.)) minval=0., maxval=1.))
decoder_block_config = decoder_block.get_config() decoder_block_config = decoder_block.get_config()
new_decoder_block = transformer.TransformerDecoderLayer.from_config( new_decoder_block = transformer.TransformerDecoderLayer.from_config(
decoder_block_config) decoder_block_config)
......
...@@ -204,5 +204,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -204,5 +204,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
expected_loss_data = 6.4222 expected_loss_data = 6.4222
self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3) self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -44,8 +44,8 @@ class BertClassifier(tf.keras.Model): ...@@ -44,8 +44,8 @@ class BertClassifier(tf.keras.Model):
initializer: The initializer (if any) to use in the classification networks. initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
dropout_rate: The dropout probability of the cls head. dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
the encoder. encoder.
""" """
def __init__(self, def __init__(self,
......
...@@ -61,8 +61,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -61,8 +61,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(vocab_size=100, num_layers=2)
vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
...@@ -82,8 +81,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -82,8 +81,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(vocab_size=100, num_layers=2)
vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -79,8 +79,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -79,8 +79,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(vocab_size=100, num_layers=2)
vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -99,8 +98,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -99,8 +98,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(vocab_size=100, num_layers=2)
vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -68,8 +68,8 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -68,8 +68,8 @@ class BertTokenClassifier(tf.keras.Model):
# Because we have a copy of inputs to create this Model object, we can # Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model. # invoke the Network object with its own input tensors to start the Model.
sequence_output, _ = network(inputs) sequence_output, _ = network(inputs)
sequence_output = tf.keras.layers.Dropout( sequence_output = tf.keras.layers.Dropout(rate=dropout_rate)(
rate=dropout_rate)(sequence_output) sequence_output)
self.classifier = tf.keras.layers.Dense( self.classifier = tf.keras.layers.Dense(
num_classes, num_classes,
......
...@@ -41,8 +41,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -41,8 +41,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_classes = 3 num_classes = 3
bert_trainer_model = bert_token_classifier.BertTokenClassifier( bert_trainer_model = bert_token_classifier.BertTokenClassifier(
test_network, test_network, num_classes=num_classes)
num_classes=num_classes)
# Create a set of 2-dimensional inputs (the first dimension is implicit). # Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......
...@@ -20,6 +20,7 @@ from __future__ import division ...@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy import copy
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
......
...@@ -37,9 +37,7 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -37,9 +37,7 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
dict(testcase_name="default", expected_dtype=tf.float32), dict(testcase_name="default", expected_dtype=tf.float32),
dict( dict(testcase_name="with_float16_dtype", expected_dtype=tf.float16),
testcase_name="with_float16_dtype",
expected_dtype=tf.float16),
) )
def test_network_creation(self, expected_dtype): def test_network_creation(self, expected_dtype):
hidden_size = 32 hidden_size = 32
...@@ -94,8 +92,7 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -94,8 +92,7 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types) type_vocab_size=num_types)
self.assertTrue( self.assertTrue(test_network._position_embedding_layer._use_dynamic_slicing)
test_network._position_embedding_layer._use_dynamic_slicing)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......
...@@ -71,8 +71,9 @@ class Classification(tf.keras.Model): ...@@ -71,8 +71,9 @@ class Classification(tf.keras.Model):
if policy.name == 'mixed_bfloat16': if policy.name == 'mixed_bfloat16':
# b/158514794: bf16 is not stable with post-softmax cross-entropy. # b/158514794: bf16 is not stable with post-softmax cross-entropy.
policy = tf.float32 policy = tf.float32
predictions = tf.keras.layers.Activation(tf.nn.log_softmax, predictions = tf.keras.layers.Activation(
dtype=policy)(self.logits) tf.nn.log_softmax, dtype=policy)(
self.logits)
if output == 'logits': if output == 'logits':
output_tensors = self.logits output_tensors = self.logits
......
...@@ -92,8 +92,8 @@ class ClassificationTest(keras_parameterized.TestCase): ...@@ -92,8 +92,8 @@ class ClassificationTest(keras_parameterized.TestCase):
self.assertAllClose(outputs, calculated_softmax) self.assertAllClose(outputs, calculated_softmax)
@parameterized.parameters(1, 10) @parameterized.parameters(1, 10)
def test_network_invocation_with_internal_and_external_logits(self, def test_network_invocation_with_internal_and_external_logits(
num_classes): self, num_classes):
"""Validate that the logit outputs are correct.""" """Validate that the logit outputs are correct."""
input_width = 512 input_width = 512
test_object = classification.Classification( test_object = classification.Classification(
......
...@@ -54,14 +54,13 @@ class EncoderScaffold(tf.keras.Model): ...@@ -54,14 +54,13 @@ class EncoderScaffold(tf.keras.Model):
Arguments: Arguments:
pooled_output_dim: The dimension of pooled output. pooled_output_dim: The dimension of pooled output.
pooler_layer_initializer: The initializer for the classification pooler_layer_initializer: The initializer for the classification layer.
layer.
embedding_cls: The class or instance to use to embed the input data. This embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder and outputs class or instance defines the inputs to this encoder and outputs (1)
(1) embeddings tensor with shape [batch_size, seq_length, hidden_size] and embeddings tensor with shape [batch_size, seq_length, hidden_size] and (2)
(2) attention masking with tensor [batch_size, seq_length, seq_length]. attention masking with tensor [batch_size, seq_length, seq_length]. If
If embedding_cls is not set, a default embedding network embedding_cls is not set, a default embedding network (from the original
(from the original BERT paper) will be created. BERT paper) will be created.
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
be instantiated. If embedding_cls is not set, a config dict must be be instantiated. If embedding_cls is not set, a config dict must be
passed to 'embedding_cfg' with the following values: passed to 'embedding_cfg' with the following values:
...@@ -94,19 +93,18 @@ class EncoderScaffold(tf.keras.Model): ...@@ -94,19 +93,18 @@ class EncoderScaffold(tf.keras.Model):
all encoder transformer layers. all encoder transformer layers.
""" """
def __init__( def __init__(self,
self, pooled_output_dim,
pooled_output_dim, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=0.02),
stddev=0.02), embedding_cls=None,
embedding_cls=None, embedding_cfg=None,
embedding_cfg=None, embedding_data=None,
embedding_data=None, num_hidden_instances=1,
num_hidden_instances=1, hidden_cls=layers.Transformer,
hidden_cls=layers.Transformer, hidden_cfg=None,
hidden_cfg=None, return_all_layer_outputs=False,
return_all_layer_outputs=False, **kwargs):
**kwargs):
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._hidden_cls = hidden_cls self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg self._hidden_cfg = hidden_cfg
...@@ -131,17 +129,11 @@ class EncoderScaffold(tf.keras.Model): ...@@ -131,17 +129,11 @@ class EncoderScaffold(tf.keras.Model):
self._embedding_network = None self._embedding_network = None
seq_length = embedding_cfg.get('seq_length', None) seq_length = embedding_cfg.get('seq_length', None)
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
shape=(seq_length,), shape=(seq_length,), dtype=tf.int32, name='input_word_ids')
dtype=tf.int32,
name='input_word_ids')
mask = tf.keras.layers.Input( mask = tf.keras.layers.Input(
shape=(seq_length,), shape=(seq_length,), dtype=tf.int32, name='input_mask')
dtype=tf.int32,
name='input_mask')
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(seq_length,), shape=(seq_length,), dtype=tf.int32, name='input_type_ids')
dtype=tf.int32,
name='input_type_ids')
inputs = [word_ids, mask, type_ids] inputs = [word_ids, mask, type_ids]
self._embedding_layer = layers.OnDeviceEmbedding( self._embedding_layer = layers.OnDeviceEmbedding(
...@@ -215,20 +207,13 @@ class EncoderScaffold(tf.keras.Model): ...@@ -215,20 +207,13 @@ class EncoderScaffold(tf.keras.Model):
def get_config(self): def get_config(self):
config_dict = { config_dict = {
'num_hidden_instances': 'num_hidden_instances': self._num_hidden_instances,
self._num_hidden_instances, 'pooled_output_dim': self._pooled_output_dim,
'pooled_output_dim': 'pooler_layer_initializer': self._pooler_layer_initializer,
self._pooled_output_dim, 'embedding_cls': self._embedding_network,
'pooler_layer_initializer': 'embedding_cfg': self._embedding_cfg,
self._pooler_layer_initializer, 'hidden_cfg': self._hidden_cfg,
'embedding_cls': 'return_all_layer_outputs': self._return_all_layer_outputs,
self._embedding_network,
'embedding_cfg':
self._embedding_cfg,
'hidden_cfg':
self._hidden_cfg,
'return_all_layer_outputs':
self._return_all_layer_outputs,
} }
if inspect.isclass(self._hidden_cls): if inspect.isclass(self._hidden_cls):
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name( config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
......
...@@ -66,9 +66,9 @@ class TransformerEncoder(tf.keras.Model): ...@@ -66,9 +66,9 @@ class TransformerEncoder(tf.keras.Model):
target sequence of the last transformer layer. `None` means the entire target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yeilds the full target sequence will attend to the source sequence, which yeilds the full
output. output.
embedding_width: The width of the word embeddings. If the embedding width embedding_width: The width of the word embeddings. If the embedding width is
is not equal to hidden size, embedding parameters will be factorized into not equal to hidden size, embedding parameters will be factorized into two
two matrices in the shape of ['vocab_size', 'embedding_width'] and matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much ['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size'). smaller than 'hidden_size').
embedding_layer: The word embedding layer. `None` means we will create a new embedding_layer: The word embedding layer. `None` means we will create a new
...@@ -159,8 +159,7 @@ class TransformerEncoder(tf.keras.Model): ...@@ -159,8 +159,7 @@ class TransformerEncoder(tf.keras.Model):
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)(embeddings)) dtype=tf.float32)(embeddings))
embeddings = ( embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
# We project the 'embedding' output to 'hidden_size' if it is not already # We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'. # 'hidden_size'.
......
...@@ -133,8 +133,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -133,8 +133,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
num_layers=3, num_layers=3,
type_vocab_size=num_types, type_vocab_size=num_types,
output_range=output_range) output_range=output_range)
self.assertTrue( self.assertTrue(test_network._position_embedding_layer._use_dynamic_slicing)
test_network._position_embedding_layer._use_dynamic_slicing)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......
...@@ -31,8 +31,7 @@ class BeamSearchHelperTests(tf.test.TestCase): ...@@ -31,8 +31,7 @@ class BeamSearchHelperTests(tf.test.TestCase):
y = tf.constant(4.0) y = tf.constant(4.0)
x = tf.ones([7, tf.cast(tf.sqrt(y), tf.int32), 2, 5]) x = tf.ones([7, tf.cast(tf.sqrt(y), tf.int32), 2, 5])
shape = beam_search._get_shape_keep_last_dim(x) shape = beam_search._get_shape_keep_last_dim(x)
self.assertAllEqual([None, None, None, 5], self.assertAllEqual([None, None, None, 5], shape.as_list())
shape.as_list())
def test_flatten_beam_dim(self): def test_flatten_beam_dim(self):
x = tf.ones([7, 4, 2, 5]) x = tf.ones([7, 4, 2, 5])
...@@ -55,22 +54,18 @@ class BeamSearchHelperTests(tf.test.TestCase): ...@@ -55,22 +54,18 @@ class BeamSearchHelperTests(tf.test.TestCase):
# [20 21 22 23]]] # [20 21 22 23]]]
y = beam_search._gather_beams(x, [[1, 2], [0, 2]], 2, 2) y = beam_search._gather_beams(x, [[1, 2], [0, 2]], 2, 2)
self.assertAllEqual([[[4, 5, 6, 7], self.assertAllEqual(
[8, 9, 10, 11]], [[[4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [20, 21, 22, 23]]],
[[12, 13, 14, 15], y)
[20, 21, 22, 23]]],
y)
def test_gather_topk_beams(self): def test_gather_topk_beams(self):
x = tf.reshape(tf.range(24), [2, 3, 4]) x = tf.reshape(tf.range(24), [2, 3, 4])
x_scores = [[0, 1, 1], [1, 0, 1]] x_scores = [[0, 1, 1], [1, 0, 1]]
y = beam_search._gather_topk_beams(x, x_scores, 2, 2) y = beam_search._gather_topk_beams(x, x_scores, 2, 2)
self.assertAllEqual([[[4, 5, 6, 7], self.assertAllEqual(
[8, 9, 10, 11]], [[[4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [20, 21, 22, 23]]],
[[12, 13, 14, 15], y)
[20, 21, 22, 23]]],
y)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -99,7 +99,6 @@ NHNET_CONFIG = { ...@@ -99,7 +99,6 @@ NHNET_CONFIG = {
"pad_token_id": 0, "pad_token_id": 0,
"end_token_id": 102, "end_token_id": 102,
"start_token_id": 101, "start_token_id": 101,
"init_from_bert2bert": True, "init_from_bert2bert": True,
} }
......
...@@ -21,6 +21,7 @@ from __future__ import division ...@@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
# Import libraries
from absl import logging from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
......
...@@ -583,7 +583,6 @@ def create_model(model_type: Text, ...@@ -583,7 +583,6 @@ def create_model(model_type: Text,
elif model_type == "nhnet": elif model_type == "nhnet":
return create_nhnet_model(params, init_checkpoint=init_checkpoint) return create_nhnet_model(params, init_checkpoint=init_checkpoint)
elif "transformer" in model_type: elif "transformer" in model_type:
return create_transformer_model( return create_transformer_model(params, init_checkpoint=init_checkpoint)
params, init_checkpoint=init_checkpoint)
else: else:
raise KeyError("The model type is not defined: %s" % model_type) raise KeyError("The model type is not defined: %s" % model_type)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment