Commit b6780b59 authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Address a few legacy issues in the test.

PiperOrigin-RevId: 458268967
parent 08a9f1f8
......@@ -19,8 +19,6 @@ import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras.testing_utils import layer_test
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
......@@ -45,13 +43,9 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters((768, 6), (1024, 2))
def test_keras_layer(self, input_dim, proj_multiple):
self.skipTest('Disable the test for now since it imports '
'keras.testing_utils, will reenable this test after we '
'fix the b/184578869')
# TODO(scottzhu): Reenable after fix b/184578869
data = np.random.normal(size=(100, input_dim))
data = data.astype(np.float32)
layer_test(
tf.keras.__internal__.utils.layer_test(
TNExpandCondense,
kwargs={
'proj_multiplier': proj_multiple,
......@@ -64,9 +58,9 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters((768, 6), (1024, 2))
def test_train(self, input_dim, proj_multiple):
tf.keras.utils.set_random_seed(0)
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
tf.keras.utils.set_random_seed(0)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment