Commit 7f926353 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Refactor] TF models: move all contents of transformer to nlp/transformer

PiperOrigin-RevId: 294997928
parent 91c681af
......@@ -23,8 +23,8 @@ import time
from absl import flags
import tensorflow as tf
from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main as transformer_main
from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core
from official.utils.testing import benchmark_wrappers
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
......
......@@ -30,7 +30,7 @@ model.
# https://github.com/tensorflow/models/tree/master/official#requirements
export PYTHONPATH="$PYTHONPATH:/path/to/models"
cd /path/to/models/official/transformer/v2
cd /path/to/models/official/nlp/transformer
# Export variables
PARAM_SET=big
......@@ -94,7 +94,7 @@ tensorboard --logdir=$MODEL_DIR
2. ### Model training and evaluation
[transformer_main.py](v2/transformer_main.py) creates a Transformer keras model,
[transformer_main.py](transformer_main.py) creates a Transformer keras model,
and trains it uses keras model.fit().
Users need to adjust `batch_size` and `num_gpus` to get good performance
......@@ -199,16 +199,16 @@ tensorboard --logdir=$MODEL_DIR
A brief look at each component in the code:
### Model Definition
* [transformer.py](v2/transformer.py): Defines a tf.keras.Model: `Transformer`.
* [embedding_layer.py](v2/embedding_layer.py): Contains the layer that calculates the embeddings. The embedding weights are also used to calculate the pre-softmax probabilities from the decoder output.
* [attention_layer.py](v2/attention_layer.py): Defines the multi-headed and self attention layers that are used in the encoder/decoder stacks.
* [ffn_layer.py](v2/ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers.
* [transformer.py](transformer.py): Defines a tf.keras.Model: `Transformer`.
* [embedding_layer.py](embedding_layer.py): Contains the layer that calculates the embeddings. The embedding weights are also used to calculate the pre-softmax probabilities from the decoder output.
* [attention_layer.py](attention_layer.py): Defines the multi-headed and self attention layers that are used in the encoder/decoder stacks.
* [ffn_layer.py](ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers.
Other files:
* [beam_search.py](v2/beam_search.py) contains the beam search implementation, which is used during model inference to find high scoring translations.
* [beam_search.py](beam_search.py) contains the beam search implementation, which is used during model inference to find high scoring translations.
### Model Trainer
[transformer_main.py](v2/transformer_main.py) creates an `TransformerTask` to train and evaluate the model using tf.keras.
[transformer_main.py](transformer_main.py) creates an `TransformerTask` to train and evaluate the model using tf.keras.
### Test dataset
The [newstest2014 files](https://storage.googleapis.com/tf-perf-public/official_transformer/test_data/newstest2014.tgz)
......
......@@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam search in TF v2.
"""
"""Beam search in TF v2."""
import tensorflow as tf
from official.nlp.transformer import beam_search_v1 as v1
from official.transformer.v2 import misc
from official.nlp.transformer import misc
_StateKeys = v1._StateKeys # pylint: disable=protected-access
......
......@@ -33,8 +33,8 @@ from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.utils import metrics
from official.transformer.utils import tokenizer
from official.nlp.transformer.utils import metrics
from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
......
......@@ -16,9 +16,9 @@
import tempfile
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow as tf
from official.transformer import compute_bleu
from official.nlp.transformer import compute_bleu
class ComputeBleuTest(tf.test.TestCase):
......
......@@ -31,7 +31,7 @@ from absl import logging
import tensorflow.compat.v1 as tf
# pylint: enable=g-bad-import-order
from official.transformer.utils import tokenizer
from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
# Data sources for training/evaluating the transformer translation model.
......@@ -88,7 +88,7 @@ VOCAB_FILE = "vocab.ende.%d" % _TARGET_VOCAB_SIZE
_PREFIX = "wmt32k"
_TRAIN_TAG = "train"
_EVAL_TAG = "dev" # Following WMT and Tensor2Tensor conventions, in which the
# evaluation datasets are tagged as "dev" for development.
# evaluation datasets are tagged as "dev" for development.
# Number of files to split train and evaluation data
_TRAIN_SHARDS = 100
......
......@@ -57,7 +57,7 @@ import os
from absl import logging
import tensorflow as tf
from official.transformer.v2 import misc
from official.nlp.transformer import misc
from official.utils.misc import model_helpers
# Buffer size for reading records from a TFRecord file. Each training file is
......
......@@ -22,14 +22,13 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.nlp.transformer import attention_layer
from official.nlp.transformer import beam_search
from official.nlp.transformer import embedding_layer
from official.nlp.transformer import ffn_layer
from official.nlp.transformer import metrics
from official.nlp.transformer import model_utils
from official.transformer.utils.tokenizer import EOS_ID
from official.transformer.v2 import attention_layer
from official.transformer.v2 import beam_search
from official.transformer.v2 import embedding_layer
from official.transformer.v2 import ffn_layer
from official.transformer.v2 import metrics
from official.nlp.transformer.utils.tokenizer import EOS_ID
# Disable the not-callable lint error, since it claims many objects are not
......
......@@ -20,10 +20,10 @@ from __future__ import print_function
import tensorflow as tf
from official.transformer.v2 import attention_layer
from official.transformer.v2 import embedding_layer
from official.transformer.v2 import ffn_layer
from official.transformer.v2 import metrics
from official.nlp.transformer import attention_layer
from official.nlp.transformer import embedding_layer
from official.nlp.transformer import ffn_layer
from official.nlp.transformer import metrics
class TransformerLayersTest(tf.test.TestCase):
......
......@@ -31,14 +31,14 @@ from absl import logging
import tensorflow as tf
# pylint: disable=g-bad-import-order
from official.transformer import compute_bleu
from official.transformer.utils import tokenizer
from official.transformer.v2 import data_pipeline
from official.transformer.v2 import metrics
from official.transformer.v2 import misc
from official.transformer.v2 import optimizer
from official.transformer.v2 import transformer
from official.transformer.v2 import translate
from official.nlp.transformer import compute_bleu
from official.nlp.transformer.utils import tokenizer
from official.nlp.transformer import data_pipeline
from official.nlp.transformer import metrics
from official.nlp.transformer import misc
from official.nlp.transformer import optimizer
from official.nlp.transformer import transformer
from official.nlp.transformer import translate
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import keras_utils
......
......@@ -26,12 +26,10 @@ import unittest
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main
from official.utils.misc import keras_utils
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
FIXED_TIMESTAMP = 'my_time_stamp'
......
......@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf
from official.nlp.transformer import model_params
from official.transformer.v2 import transformer
from official.nlp.transformer import transformer
class TransformerV2Test(tf.test.TestCase):
......
......@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from official.transformer.utils import tokenizer
from official.nlp.transformer.utils import tokenizer
_EXTRA_DECODE_LENGTH = 100
_BEAM_SIZE = 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment