Commit 7f926353 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Refactor] TF models: move all contents of transformer to nlp/transformer

PiperOrigin-RevId: 294997928
parent 91c681af
...@@ -23,8 +23,8 @@ import time ...@@ -23,8 +23,8 @@ import time
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.transformer.v2 import misc from official.nlp.transformer import misc
from official.transformer.v2 import transformer_main as transformer_main from official.nlp.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.testing import benchmark_wrappers from official.utils.testing import benchmark_wrappers
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
......
...@@ -30,7 +30,7 @@ model. ...@@ -30,7 +30,7 @@ model.
# https://github.com/tensorflow/models/tree/master/official#requirements # https://github.com/tensorflow/models/tree/master/official#requirements
export PYTHONPATH="$PYTHONPATH:/path/to/models" export PYTHONPATH="$PYTHONPATH:/path/to/models"
cd /path/to/models/official/transformer/v2 cd /path/to/models/official/nlp/transformer
# Export variables # Export variables
PARAM_SET=big PARAM_SET=big
...@@ -94,7 +94,7 @@ tensorboard --logdir=$MODEL_DIR ...@@ -94,7 +94,7 @@ tensorboard --logdir=$MODEL_DIR
2. ### Model training and evaluation 2. ### Model training and evaluation
[transformer_main.py](v2/transformer_main.py) creates a Transformer keras model, [transformer_main.py](transformer_main.py) creates a Transformer keras model,
and trains it uses keras model.fit(). and trains it uses keras model.fit().
Users need to adjust `batch_size` and `num_gpus` to get good performance Users need to adjust `batch_size` and `num_gpus` to get good performance
...@@ -199,16 +199,16 @@ tensorboard --logdir=$MODEL_DIR ...@@ -199,16 +199,16 @@ tensorboard --logdir=$MODEL_DIR
A brief look at each component in the code: A brief look at each component in the code:
### Model Definition ### Model Definition
* [transformer.py](v2/transformer.py): Defines a tf.keras.Model: `Transformer`. * [transformer.py](transformer.py): Defines a tf.keras.Model: `Transformer`.
* [embedding_layer.py](v2/embedding_layer.py): Contains the layer that calculates the embeddings. The embedding weights are also used to calculate the pre-softmax probabilities from the decoder output. * [embedding_layer.py](embedding_layer.py): Contains the layer that calculates the embeddings. The embedding weights are also used to calculate the pre-softmax probabilities from the decoder output.
* [attention_layer.py](v2/attention_layer.py): Defines the multi-headed and self attention layers that are used in the encoder/decoder stacks. * [attention_layer.py](attention_layer.py): Defines the multi-headed and self attention layers that are used in the encoder/decoder stacks.
* [ffn_layer.py](v2/ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers. * [ffn_layer.py](ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers.
Other files: Other files:
* [beam_search.py](v2/beam_search.py) contains the beam search implementation, which is used during model inference to find high scoring translations. * [beam_search.py](beam_search.py) contains the beam search implementation, which is used during model inference to find high scoring translations.
### Model Trainer ### Model Trainer
[transformer_main.py](v2/transformer_main.py) creates an `TransformerTask` to train and evaluate the model using tf.keras. [transformer_main.py](transformer_main.py) creates an `TransformerTask` to train and evaluate the model using tf.keras.
### Test dataset ### Test dataset
The [newstest2014 files](https://storage.googleapis.com/tf-perf-public/official_transformer/test_data/newstest2014.tgz) The [newstest2014 files](https://storage.googleapis.com/tf-perf-public/official_transformer/test_data/newstest2014.tgz)
......
...@@ -12,13 +12,12 @@ ...@@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Beam search in TF v2. """Beam search in TF v2."""
"""
import tensorflow as tf import tensorflow as tf
from official.nlp.transformer import beam_search_v1 as v1 from official.nlp.transformer import beam_search_v1 as v1
from official.transformer.v2 import misc from official.nlp.transformer import misc
_StateKeys = v1._StateKeys # pylint: disable=protected-access _StateKeys = v1._StateKeys # pylint: disable=protected-access
......
...@@ -33,8 +33,8 @@ from absl import flags ...@@ -33,8 +33,8 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.transformer.utils import metrics from official.nlp.transformer.utils import metrics
from official.transformer.utils import tokenizer from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
import tempfile import tempfile
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
from official.transformer import compute_bleu from official.nlp.transformer import compute_bleu
class ComputeBleuTest(tf.test.TestCase): class ComputeBleuTest(tf.test.TestCase):
......
...@@ -31,7 +31,7 @@ from absl import logging ...@@ -31,7 +31,7 @@ from absl import logging
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.transformer.utils import tokenizer from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
# Data sources for training/evaluating the transformer translation model. # Data sources for training/evaluating the transformer translation model.
...@@ -88,7 +88,7 @@ VOCAB_FILE = "vocab.ende.%d" % _TARGET_VOCAB_SIZE ...@@ -88,7 +88,7 @@ VOCAB_FILE = "vocab.ende.%d" % _TARGET_VOCAB_SIZE
_PREFIX = "wmt32k" _PREFIX = "wmt32k"
_TRAIN_TAG = "train" _TRAIN_TAG = "train"
_EVAL_TAG = "dev" # Following WMT and Tensor2Tensor conventions, in which the _EVAL_TAG = "dev" # Following WMT and Tensor2Tensor conventions, in which the
# evaluation datasets are tagged as "dev" for development. # evaluation datasets are tagged as "dev" for development.
# Number of files to split train and evaluation data # Number of files to split train and evaluation data
_TRAIN_SHARDS = 100 _TRAIN_SHARDS = 100
......
...@@ -57,7 +57,7 @@ import os ...@@ -57,7 +57,7 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.transformer.v2 import misc from official.nlp.transformer import misc
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
# Buffer size for reading records from a TFRecord file. Each training file is # Buffer size for reading records from a TFRecord file. Each training file is
......
...@@ -22,14 +22,13 @@ from __future__ import division ...@@ -22,14 +22,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.nlp.transformer import attention_layer
from official.nlp.transformer import beam_search
from official.nlp.transformer import embedding_layer
from official.nlp.transformer import ffn_layer
from official.nlp.transformer import metrics
from official.nlp.transformer import model_utils from official.nlp.transformer import model_utils
from official.transformer.utils.tokenizer import EOS_ID from official.nlp.transformer.utils.tokenizer import EOS_ID
from official.transformer.v2 import attention_layer
from official.transformer.v2 import beam_search
from official.transformer.v2 import embedding_layer
from official.transformer.v2 import ffn_layer
from official.transformer.v2 import metrics
# Disable the not-callable lint error, since it claims many objects are not # Disable the not-callable lint error, since it claims many objects are not
......
...@@ -20,10 +20,10 @@ from __future__ import print_function ...@@ -20,10 +20,10 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.transformer.v2 import attention_layer from official.nlp.transformer import attention_layer
from official.transformer.v2 import embedding_layer from official.nlp.transformer import embedding_layer
from official.transformer.v2 import ffn_layer from official.nlp.transformer import ffn_layer
from official.transformer.v2 import metrics from official.nlp.transformer import metrics
class TransformerLayersTest(tf.test.TestCase): class TransformerLayersTest(tf.test.TestCase):
......
...@@ -31,14 +31,14 @@ from absl import logging ...@@ -31,14 +31,14 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from official.transformer import compute_bleu from official.nlp.transformer import compute_bleu
from official.transformer.utils import tokenizer from official.nlp.transformer.utils import tokenizer
from official.transformer.v2 import data_pipeline from official.nlp.transformer import data_pipeline
from official.transformer.v2 import metrics from official.nlp.transformer import metrics
from official.transformer.v2 import misc from official.nlp.transformer import misc
from official.transformer.v2 import optimizer from official.nlp.transformer import optimizer
from official.transformer.v2 import transformer from official.nlp.transformer import transformer
from official.transformer.v2 import translate from official.nlp.transformer import translate
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
......
...@@ -26,12 +26,10 @@ import unittest ...@@ -26,12 +26,10 @@ import unittest
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main
from official.utils.misc import keras_utils
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
FIXED_TIMESTAMP = 'my_time_stamp' FIXED_TIMESTAMP = 'my_time_stamp'
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.nlp.transformer import model_params from official.nlp.transformer import model_params
from official.transformer.v2 import transformer from official.nlp.transformer import transformer
class TransformerV2Test(tf.test.TestCase): class TransformerV2Test(tf.test.TestCase):
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.transformer.utils import tokenizer from official.nlp.transformer.utils import tokenizer
_EXTRA_DECODE_LENGTH = 100 _EXTRA_DECODE_LENGTH = 100
_BEAM_SIZE = 4 _BEAM_SIZE = 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment