Commit 1722b691 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Refactor: move transformer/model to nlp/transformer.

PiperOrigin-RevId: 286325224
parent b6161f67
......@@ -19,7 +19,7 @@ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam
"""
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
from tensorflow.python.util import nest
......
......@@ -14,13 +14,9 @@
# ==============================================================================
"""Test beam search helper methods."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import beam_search
from official.nlp.transformer import beam_search_v1 as beam_search
class BeamSearchHelperTests(tf.test.TestCase):
......
......@@ -18,26 +18,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow as tf
from official.transformer.model import model_utils
from official.utils.misc import keras_utils
from official.nlp.transformer import model_utils
NEG_INF = -1e9
class ModelUtilsTest(tf.test.TestCase):
def setUp(self):
super(ModelUtilsTest, self).setUp()
if keras_utils.is_v2_0:
tf.compat.v1.disable_eager_execution()
def test_get_padding(self):
x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]])
padding = model_utils.get_padding(x, padding_value=0)
with self.session() as sess:
padding = sess.run(padding)
self.assertAllEqual([[0, 1, 1, 1, 0], [0, 0, 1, 1, 1], [1, 0, 0, 1, 0]],
padding)
......@@ -47,8 +39,6 @@ class ModelUtilsTest(tf.test.TestCase):
bias = model_utils.get_padding_bias(x)
bias_shape = tf.shape(bias)
flattened_bias = tf.reshape(bias, [3, 5])
with self.session() as sess:
flattened_bias, bias_shape = sess.run((flattened_bias, bias_shape))
self.assertAllEqual([[0, NEG_INF, NEG_INF, NEG_INF, 0],
[0, 0, NEG_INF, NEG_INF, NEG_INF],
......@@ -59,8 +49,6 @@ class ModelUtilsTest(tf.test.TestCase):
def test_get_decoder_self_attention_bias(self):
length = 5
bias = model_utils.get_decoder_self_attention_bias(length)
with self.session() as sess:
bias = sess.run(bias)
self.assertAllEqual([[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
[0, 0, NEG_INF, NEG_INF, NEG_INF],
......@@ -71,4 +59,5 @@ class ModelUtilsTest(tf.test.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -20,7 +20,6 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import model_utils
from official.r1.utils import tpu as tpu_utils
......
......@@ -24,11 +24,11 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.nlp.transformer import beam_search_v1 as beam_search
from official.nlp.transformer import model_utils
from official.r1.transformer import attention_layer
from official.r1.transformer import embedding_layer
from official.r1.transformer import ffn_layer
from official.transformer.model import beam_search
from official.transformer.model import model_utils
from official.transformer.utils.tokenizer import EOS_ID
_NEG_INF = -1e9
......
......@@ -32,6 +32,7 @@ from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.nlp.transformer import model_params
from official.r1.utils import export
from official.r1.utils import tpu as tpu_util
from official.r1.transformer import translate
......@@ -39,7 +40,6 @@ from official.r1.transformer import transformer
from official.r1.transformer import dataset
from official.r1.transformer import schedule
from official.transformer import compute_bleu
from official.transformer.model import model_params
from official.transformer.utils import metrics
from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared legacy Transformer modules into this module."""
from official.r1.transformer import transformer
from official.r1.transformer import ffn_layer
from official.r1.transformer import embedding_layer
from official.r1.transformer import attention_layer
......@@ -17,7 +17,7 @@
import tensorflow as tf
from official.transformer.model import beam_search as v1
from official.nlp.transformer import beam_search_v1 as v1
from official.transformer.v2 import misc
_StateKeys = v1._StateKeys # pylint: disable=protected-access
......
......@@ -26,7 +26,7 @@ import tensorflow as tf
# different TF versions are fixed.
from tensorflow.python import tf2 as tf2_internal
from official.transformer.model import model_params
from official.nlp.transformer import model_params
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
......
......@@ -23,7 +23,7 @@ from __future__ import print_function
import tensorflow as tf
from official.transformer.model import model_utils
from official.nlp.transformer import model_utils
from official.transformer.utils.tokenizer import EOS_ID
from official.transformer.v2 import attention_layer
from official.transformer.v2 import beam_search
......
......@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf
from official.transformer.model import model_params
from official.nlp.transformer import model_params
from official.transformer.v2 import transformer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment