Commit 6d259f7f authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Fix mobile_bert_encoder_test.py

PiperOrigin-RevId: 329003089
parent 02f8d387
...@@ -18,7 +18,18 @@ import numpy as np ...@@ -18,7 +18,18 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.modeling.networks import mobile_bert_encoder from official.nlp.modeling.networks import mobile_bert_encoder
from official.nlp.projects.mobilebert import utils
def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
"""Generate consisitant fake integer input sequences."""
np.random.seed(seed)
fake_input = []
for _ in range(batch_size):
fake_input.append([])
for _ in range(seq_len):
fake_input[-1].append(np.random.randint(0, vocab_size))
fake_input = np.asarray(fake_input)
return fake_input
class ModelingTest(parameterized.TestCase, tf.test.TestCase): class ModelingTest(parameterized.TestCase, tf.test.TestCase):
...@@ -48,8 +59,7 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase): ...@@ -48,8 +59,7 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
expected_shape = [2, 3, 4] expected_shape = [2, 3, 4]
self.assertListEqual(output_shape, expected_shape, msg=None) self.assertListEqual(output_shape, expected_shape, msg=None)
@parameterized.named_parameters( @parameterized.named_parameters(('with_kq_shared_bottleneck', False),
('with_kq_shared_bottleneck', False),
('without_kq_shared_bottleneck', True)) ('without_kq_shared_bottleneck', True))
def test_transfomer_kq_shared_bottleneck(self, is_kq_shared): def test_transfomer_kq_shared_bottleneck(self, is_kq_shared):
feature = tf.random.uniform([2, 3, 512]) feature = tf.random.uniform([2, 3, 512])
...@@ -62,12 +72,8 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase): ...@@ -62,12 +72,8 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
def test_transfomer_with_mask(self): def test_transfomer_with_mask(self):
feature = tf.random.uniform([2, 3, 512]) feature = tf.random.uniform([2, 3, 512])
input_mask = [[[0., 0., 1.], input_mask = [[[0., 0., 1.], [0., 0., 1.], [0., 0., 1.]],
[0., 0., 1.], [[0., 1., 1.], [0., 1., 1.], [0., 1., 1.]]]
[0., 0., 1.]],
[[0., 1., 1.],
[0., 1., 1.],
[0., 1., 1.]]]
input_mask = np.asarray(input_mask) input_mask = np.asarray(input_mask)
layer = mobile_bert_encoder.TransformerLayer() layer = mobile_bert_encoder.TransformerLayer()
output = layer(feature, input_mask) output = layer(feature, input_mask)
...@@ -83,8 +89,8 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase): ...@@ -83,8 +89,8 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
num_attention_heads=num_attention_heads) num_attention_heads=num_attention_heads)
_, attention_score = layer(feature, return_attention_scores=True) _, attention_score = layer(feature, return_attention_scores=True)
expected_shape = [2, num_attention_heads, sequence_length, sequence_length] expected_shape = [2, num_attention_heads, sequence_length, sequence_length]
self.assertListEqual(attention_score.shape.as_list(), expected_shape, self.assertListEqual(
msg=None) attention_score.shape.as_list(), expected_shape, msg=None)
@parameterized.named_parameters( @parameterized.named_parameters(
('default_setting', 'relu', True, 'no_norm', False), ('default_setting', 'relu', True, 'no_norm', False),
...@@ -153,22 +159,19 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase): ...@@ -153,22 +159,19 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
layer_out_tensor, pooler_out_tensor = test_network([word_ids, layer_out_tensor, pooler_out_tensor = test_network(
mask, type_ids]) [word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids], model = tf.keras.Model([word_ids, mask, type_ids],
[layer_out_tensor, pooler_out_tensor]) [layer_out_tensor, pooler_out_tensor])
input_seq = utils.generate_fake_input(batch_size=1, input_seq = generate_fake_input(
seq_len=sequence_length, batch_size=1, seq_len=sequence_length, vocab_size=vocab_size)
vocab_size=vocab_size) input_mask = generate_fake_input(
input_mask = utils.generate_fake_input(batch_size=1, batch_size=1, seq_len=sequence_length, vocab_size=2)
seq_len=sequence_length, token_type = generate_fake_input(
vocab_size=2) batch_size=1, seq_len=sequence_length, vocab_size=2)
token_type = utils.generate_fake_input(batch_size=1, layer_output, pooler_output = model.predict(
seq_len=sequence_length, [input_seq, input_mask, token_type])
vocab_size=2)
layer_output, pooler_output = model.predict([input_seq, input_mask,
token_type])
layer_output_shape = [1, sequence_length, hidden_size] layer_output_shape = [1, sequence_length, hidden_size]
self.assertAllEqual(layer_output.shape, layer_output_shape) self.assertAllEqual(layer_output.shape, layer_output_shape)
...@@ -192,21 +195,18 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase): ...@@ -192,21 +195,18 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
layer_out_tensor, pooler_out_tensor, attention_out_tensor = test_network( layer_out_tensor, pooler_out_tensor, attention_out_tensor = test_network(
[word_ids, mask, type_ids]) [word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids], model = tf.keras.Model(
[layer_out_tensor, pooler_out_tensor, [word_ids, mask, type_ids],
attention_out_tensor]) [layer_out_tensor, pooler_out_tensor, attention_out_tensor])
input_seq = utils.generate_fake_input(batch_size=1, input_seq = generate_fake_input(
seq_len=sequence_length, batch_size=1, seq_len=sequence_length, vocab_size=vocab_size)
vocab_size=vocab_size) input_mask = generate_fake_input(
input_mask = utils.generate_fake_input(batch_size=1, batch_size=1, seq_len=sequence_length, vocab_size=2)
seq_len=sequence_length, token_type = generate_fake_input(
vocab_size=2) batch_size=1, seq_len=sequence_length, vocab_size=2)
token_type = utils.generate_fake_input(batch_size=1, _, _, attention_score_output = model.predict(
seq_len=sequence_length, [input_seq, input_mask, token_type])
vocab_size=2)
_, _, attention_score_output = model.predict([input_seq, input_mask,
token_type])
self.assertLen(attention_score_output, num_blocks) self.assertLen(attention_score_output, num_blocks)
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -218,8 +218,7 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase): ...@@ -218,8 +218,7 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
mobilebert_encoder = mobile_bert_encoder.MobileBERTEncoder( mobilebert_encoder = mobile_bert_encoder.MobileBERTEncoder(
word_vocab_size=100, hidden_size=hidden_size) word_vocab_size=100, hidden_size=hidden_size)
num_classes = 5 num_classes = 5
classifier = task(network=mobilebert_encoder, classifier = task(network=mobilebert_encoder, num_classes=num_classes)
num_classes=num_classes)
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -227,5 +226,6 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase): ...@@ -227,5 +226,6 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
prediction = classifier([word_ids, mask, type_ids]) prediction = classifier([word_ids, mask, type_ids])
self.assertAllEqual(prediction.shape.as_list(), prediction_shape) self.assertAllEqual(prediction.shape.as_list(), prediction_shape)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment