Commit a12cec09 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 300477605
parent e8afeaee
......@@ -200,5 +200,4 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -20,7 +20,7 @@ from __future__ import print_function
import re
import tensorflow as tf
import tensorflow.compat.v2 as tf
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
......
......@@ -59,5 +59,4 @@ class ModelUtilsTest(tf.test.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -94,5 +94,4 @@ class TransformerLayersTest(tf.test.TestCase):
if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
tf.test.main()
......@@ -474,7 +474,6 @@ def main(_):
if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
logging.set_verbosity(logging.INFO)
misc.define_transformer_flags()
app.run(main)
......@@ -187,5 +187,4 @@ class TransformerTaskTest(tf.test.TestCase):
if __name__ == '__main__':
tf.compat.v1.enable_v2_behavior()
tf.test.main()
......@@ -65,5 +65,4 @@ class TransformerV2Test(tf.test.TestCase):
if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
tf.test.main()
......@@ -26,7 +26,7 @@ import unicodedata
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
import tensorflow.compat.v2 as tf
PAD = "<pad>"
PAD_ID = 0
......
......@@ -17,7 +17,7 @@
import collections
import tempfile
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow.compat.v2 as tf # pylint: disable=g-bad-import-order
from official.nlp.transformer.utils import tokenizer
......
......@@ -455,5 +455,5 @@ def main(_):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.enable_v2_behavior()
app.run(main)
......@@ -193,5 +193,4 @@ def main(unused_argv):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main)
......@@ -153,5 +153,4 @@ def main(unused_argv):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main)
......@@ -301,5 +301,4 @@ def main(unused_argv):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main)
......@@ -49,5 +49,4 @@ class PositionalEmbeddingLayerTest(tf.test.TestCase):
self.assertAllClose(pos_emb, target)
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment