Commit 1722b691 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Refactor: move transformer/model to nlp/transformer.

PiperOrigin-RevId: 286325224
parent b6161f67
...@@ -19,7 +19,7 @@ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam ...@@ -19,7 +19,7 @@ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam
""" """
import numpy as np import numpy as np
import tensorflow as tf import tensorflow.compat.v1 as tf
from tensorflow.python.util import nest from tensorflow.python.util import nest
......
...@@ -14,13 +14,9 @@ ...@@ -14,13 +14,9 @@
# ============================================================================== # ==============================================================================
"""Test beam search helper methods.""" """Test beam search helper methods."""
from __future__ import absolute_import import tensorflow.compat.v1 as tf
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order from official.nlp.transformer import beam_search_v1 as beam_search
from official.transformer.model import beam_search
class BeamSearchHelperTests(tf.test.TestCase): class BeamSearchHelperTests(tf.test.TestCase):
......
...@@ -18,26 +18,18 @@ from __future__ import absolute_import ...@@ -18,26 +18,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
from official.transformer.model import model_utils from official.nlp.transformer import model_utils
from official.utils.misc import keras_utils
NEG_INF = -1e9 NEG_INF = -1e9
class ModelUtilsTest(tf.test.TestCase): class ModelUtilsTest(tf.test.TestCase):
def setUp(self):
super(ModelUtilsTest, self).setUp()
if keras_utils.is_v2_0:
tf.compat.v1.disable_eager_execution()
def test_get_padding(self): def test_get_padding(self):
x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]]) x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]])
padding = model_utils.get_padding(x, padding_value=0) padding = model_utils.get_padding(x, padding_value=0)
with self.session() as sess:
padding = sess.run(padding)
self.assertAllEqual([[0, 1, 1, 1, 0], [0, 0, 1, 1, 1], [1, 0, 0, 1, 0]], self.assertAllEqual([[0, 1, 1, 1, 0], [0, 0, 1, 1, 1], [1, 0, 0, 1, 0]],
padding) padding)
...@@ -47,8 +39,6 @@ class ModelUtilsTest(tf.test.TestCase): ...@@ -47,8 +39,6 @@ class ModelUtilsTest(tf.test.TestCase):
bias = model_utils.get_padding_bias(x) bias = model_utils.get_padding_bias(x)
bias_shape = tf.shape(bias) bias_shape = tf.shape(bias)
flattened_bias = tf.reshape(bias, [3, 5]) flattened_bias = tf.reshape(bias, [3, 5])
with self.session() as sess:
flattened_bias, bias_shape = sess.run((flattened_bias, bias_shape))
self.assertAllEqual([[0, NEG_INF, NEG_INF, NEG_INF, 0], self.assertAllEqual([[0, NEG_INF, NEG_INF, NEG_INF, 0],
[0, 0, NEG_INF, NEG_INF, NEG_INF], [0, 0, NEG_INF, NEG_INF, NEG_INF],
...@@ -59,8 +49,6 @@ class ModelUtilsTest(tf.test.TestCase): ...@@ -59,8 +49,6 @@ class ModelUtilsTest(tf.test.TestCase):
def test_get_decoder_self_attention_bias(self): def test_get_decoder_self_attention_bias(self):
length = 5 length = 5
bias = model_utils.get_decoder_self_attention_bias(length) bias = model_utils.get_decoder_self_attention_bias(length)
with self.session() as sess:
bias = sess.run(bias)
self.assertAllEqual([[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF], self.assertAllEqual([[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
[0, 0, NEG_INF, NEG_INF, NEG_INF], [0, 0, NEG_INF, NEG_INF, NEG_INF],
...@@ -71,4 +59,5 @@ class ModelUtilsTest(tf.test.TestCase): ...@@ -71,4 +59,5 @@ class ModelUtilsTest(tf.test.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -20,7 +20,6 @@ from __future__ import print_function ...@@ -20,7 +20,6 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import model_utils
from official.r1.utils import tpu as tpu_utils from official.r1.utils import tpu as tpu_utils
......
...@@ -24,11 +24,11 @@ from __future__ import print_function ...@@ -24,11 +24,11 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.nlp.transformer import beam_search_v1 as beam_search
from official.nlp.transformer import model_utils
from official.r1.transformer import attention_layer from official.r1.transformer import attention_layer
from official.r1.transformer import embedding_layer from official.r1.transformer import embedding_layer
from official.r1.transformer import ffn_layer from official.r1.transformer import ffn_layer
from official.transformer.model import beam_search
from official.transformer.model import model_utils
from official.transformer.utils.tokenizer import EOS_ID from official.transformer.utils.tokenizer import EOS_ID
_NEG_INF = -1e9 _NEG_INF = -1e9
......
...@@ -32,6 +32,7 @@ from absl import flags ...@@ -32,6 +32,7 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.nlp.transformer import model_params
from official.r1.utils import export from official.r1.utils import export
from official.r1.utils import tpu as tpu_util from official.r1.utils import tpu as tpu_util
from official.r1.transformer import translate from official.r1.transformer import translate
...@@ -39,7 +40,6 @@ from official.r1.transformer import transformer ...@@ -39,7 +40,6 @@ from official.r1.transformer import transformer
from official.r1.transformer import dataset from official.r1.transformer import dataset
from official.r1.transformer import schedule from official.r1.transformer import schedule
from official.transformer import compute_bleu from official.transformer import compute_bleu
from official.transformer.model import model_params
from official.transformer.utils import metrics from official.transformer.utils import metrics
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared legacy Transformer modules into this module."""
from official.r1.transformer import transformer
from official.r1.transformer import ffn_layer
from official.r1.transformer import embedding_layer
from official.r1.transformer import attention_layer
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import tensorflow as tf import tensorflow as tf
from official.transformer.model import beam_search as v1 from official.nlp.transformer import beam_search_v1 as v1
from official.transformer.v2 import misc from official.transformer.v2 import misc
_StateKeys = v1._StateKeys # pylint: disable=protected-access _StateKeys = v1._StateKeys # pylint: disable=protected-access
......
...@@ -26,7 +26,7 @@ import tensorflow as tf ...@@ -26,7 +26,7 @@ import tensorflow as tf
# different TF versions are fixed. # different TF versions are fixed.
from tensorflow.python import tf2 as tf2_internal from tensorflow.python import tf2 as tf2_internal
from official.transformer.model import model_params from official.nlp.transformer import model_params
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
......
...@@ -23,7 +23,7 @@ from __future__ import print_function ...@@ -23,7 +23,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.transformer.model import model_utils from official.nlp.transformer import model_utils
from official.transformer.utils.tokenizer import EOS_ID from official.transformer.utils.tokenizer import EOS_ID
from official.transformer.v2 import attention_layer from official.transformer.v2 import attention_layer
from official.transformer.v2 import beam_search from official.transformer.v2 import beam_search
......
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.transformer.model import model_params from official.nlp.transformer import model_params
from official.transformer.v2 import transformer from official.transformer.v2 import transformer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment