Unverified Commit a84e1ef9 authored by Katherine Wu's avatar Katherine Wu Committed by GitHub
Browse files

Add official flag-parsing and benchmarking logging utils to Transformer (#4163)

parent fe1857cd
...@@ -14,7 +14,7 @@ The model also applies embeddings on the input and output tokens, and adds a con ...@@ -14,7 +14,7 @@ The model also applies embeddings on the input and output tokens, and adds a con
* [Training times](#training-times) * [Training times](#training-times)
* [Evaluation results](#evaluation-results) * [Evaluation results](#evaluation-results)
* [Detailed instructions](#detailed-instructions) * [Detailed instructions](#detailed-instructions)
* [Export variables (optional)](#export-variables-optional) * [Environment preparation](#environment-preparation)
* [Download and preprocess datasets](#download-and-preprocess-datasets) * [Download and preprocess datasets](#download-and-preprocess-datasets)
* [Model training and evaluation](#model-training-and-evaluation) * [Model training and evaluation](#model-training-and-evaluation)
* [Translate using the model](#translate-using-the-model) * [Translate using the model](#translate-using-the-model)
...@@ -31,16 +31,23 @@ The model also applies embeddings on the input and output tokens, and adds a con ...@@ -31,16 +31,23 @@ The model also applies embeddings on the input and output tokens, and adds a con
Below are the commands for running the Transformer model. See the [Detailed instrutions](#detailed-instructions) for more details on running the model. Below are the commands for running the Transformer model. See the [Detailed instrutions](#detailed-instructions) for more details on running the model.
``` ```
PARAMS=big cd /path/to/models/official/transformer
# Ensure that PYTHONPATH is correctly defined as described in
# https://github.com/tensorflow/models/tree/master/official#running-the-models
# export PYTHONPATH="$PYTHONPATH:/path/to/models"
# Export variables
PARAM_SET=big
DATA_DIR=$HOME/transformer/data DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAMS MODEL_DIR=$HOME/transformer/model_$PARAM_SET
# Download training/evaluation datasets # Download training/evaluation datasets
python data_download.py --data_dir=$DATA_DIR python data_download.py --data_dir=$DATA_DIR
# Train the model for 10 epochs, and evaluate after every epoch. # Train the model for 10 epochs, and evaluate after every epoch.
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \ python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de --param_set=$PARAM_SET --bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de
# Run during training in a separate process to get continuous updates, # Run during training in a separate process to get continuous updates,
# or after training is complete. # or after training is complete.
...@@ -48,21 +55,21 @@ tensorboard --logdir=$MODEL_DIR ...@@ -48,21 +55,21 @@ tensorboard --logdir=$MODEL_DIR
# Translate some text using the trained model # Translate some text using the trained model
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \ python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --text="hello world" --param_set=$PARAM_SET --text="hello world"
# Compute model's BLEU score using the newstest2014 dataset. # Compute model's BLEU score using the newstest2014 dataset.
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \ python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en --param_set=$PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de
``` ```
## Benchmarks ## Benchmarks
### Training times ### Training times
Currently, both big and base params run on a single GPU. The measurements below Currently, both big and base parameter sets run on a single GPU. The measurements below
are reported from running the model on a P100 GPU. are reported from running the model on a P100 GPU.
Params | batches/sec | batches per epoch | time per epoch Param Set | batches/sec | batches per epoch | time per epoch
--- | --- | --- | --- --- | --- | --- | ---
base | 4.8 | 83244 | 4 hr base | 4.8 | 83244 | 4 hr
big | 1.1 | 41365 | 10 hr big | 1.1 | 41365 | 10 hr
...@@ -70,7 +77,7 @@ big | 1.1 | 41365 | 10 hr ...@@ -70,7 +77,7 @@ big | 1.1 | 41365 | 10 hr
### Evaluation results ### Evaluation results
Below are the case-insensitive BLEU scores after 10 epochs. Below are the case-insensitive BLEU scores after 10 epochs.
Params | Score Param Set | Score
--- | --- | --- | --- |
base | 27.7 base | 27.7
big | 28.9 big | 28.9
...@@ -79,13 +86,18 @@ big | 28.9 ...@@ -79,13 +86,18 @@ big | 28.9
## Detailed instructions ## Detailed instructions
0. ### Export variables (optional) 0. ### Environment preparation
#### Add models repo to PYTHONPATH
Follow the instructions described in the [Running the models](https://github.com/tensorflow/models/tree/master/official#running-the-models) section to add the models folder to the python path.
#### Export variables (optional)
Export the following variables, or modify the values in each of the snippets below: Export the following variables, or modify the values in each of the snippets below:
``` ```
PARAMS=big PARAM_SET=big
DATA_DIR=$HOME/transformer/data DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAMS MODEL_DIR=$HOME/transformer/model_$PARAM_SET
``` ```
1. ### Download and preprocess datasets 1. ### Download and preprocess datasets
...@@ -109,13 +121,13 @@ big | 28.9 ...@@ -109,13 +121,13 @@ big | 28.9
Command to run: Command to run:
``` ```
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --param_set=$PARAM_SET
``` ```
Arguments: Arguments:
* `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument. * `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
* `--model_dir`: Directory to save Transformer model training checkpoints. * `--model_dir`: Directory to save Transformer model training checkpoints.
* `--params`: Parameter set to use when creating and training the model. Options are `base` and `big` (default). * `--param_set`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
* Use the `--help` or `-h` flag to get a full list of possible arguments. * Use the `--help` or `-h` flag to get a full list of possible arguments.
#### Customizing training schedule #### Customizing training schedule
...@@ -123,12 +135,12 @@ big | 28.9 ...@@ -123,12 +135,12 @@ big | 28.9
By default, the model will train for 10 epochs, and evaluate after every epoch. The training schedule may be defined through the flags: By default, the model will train for 10 epochs, and evaluate after every epoch. The training schedule may be defined through the flags:
* Training with epochs (default): * Training with epochs (default):
* `--train_epochs`: The total number of complete passes to make through the dataset * `--train_epochs`: The total number of complete passes to make through the dataset
* `--epochs_between_eval`: The number of epochs to train between evaluations. * `--epochs_between_evals`: The number of epochs to train between evaluations.
* Training with steps: * Training with steps:
* `--train_steps`: sets the total number of training steps to run. * `--train_steps`: sets the total number of training steps to run.
* `--steps_between_eval`: Number of training steps to run between evaluations. * `--steps_between_evals`: Number of training steps to run between evaluations.
Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_eval=1000`. Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_evals=1000`.
Note: At the beginning of each training session, the training dataset is reloaded and shuffled. Stopping the training before completing an epoch may result in worse model quality, due to the chance that some examples may be seen more than others. Therefore, it is recommended to use epochs when the model quality is important. Note: At the beginning of each training session, the training dataset is reloaded and shuffled. Stopping the training before completing an epoch may result in worse model quality, due to the chance that some examples may be seen more than others. Therefore, it is recommended to use epochs when the model quality is important.
...@@ -137,7 +149,7 @@ big | 28.9 ...@@ -137,7 +149,7 @@ big | 28.9
Use these flags to compute the BLEU when the model evaluates: Use these flags to compute the BLEU when the model evaluates:
* `--bleu_source`: Path to file containing text to translate. * `--bleu_source`: Path to file containing text to translate.
* `--bleu_ref`: Path to file containing the reference translation. * `--bleu_ref`: Path to file containing the reference translation.
* `--bleu_threshold`: Train until the BLEU score reaches this lower bound. This setting overrides the `--train_steps` and `--train_epochs` flags. * `--stop_threshold`: Train until the BLEU score reaches this lower bound. This setting overrides the `--train_steps` and `--train_epochs` flags.
The test source and reference files located in the `test_data` directory are extracted from the preprocessed dataset from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data). The test source and reference files located in the `test_data` directory are extracted from the preprocessed dataset from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data).
...@@ -155,12 +167,12 @@ big | 28.9 ...@@ -155,12 +167,12 @@ big | 28.9
Command to run: Command to run:
``` ```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS --text="hello world" python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --param_set=PARAM_SET --text="hello world"
``` ```
Arguments for initializing the Subtokenizer and trained model: Arguments for initializing the Subtokenizer and trained model:
* `--data_dir`: Used to locate the vocabulary file to create a Subtokenizer, which encodes the input and decodes the model output. * `--data_dir`: Used to locate the vocabulary file to create a Subtokenizer, which encodes the input and decodes the model output.
* `--model_dir` and `--params`: These parameters are used to rebuild the trained model * `--model_dir` and `--param_set`: These parameters are used to rebuild the trained model
Arguments for specifying what to translate: Arguments for specifying what to translate:
* `--text`: Text to translate * `--text`: Text to translate
...@@ -170,7 +182,7 @@ big | 28.9 ...@@ -170,7 +182,7 @@ big | 28.9
To translate the newstest2014 data, run: To translate the newstest2014 data, run:
``` ```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \ python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en --param_set=PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
``` ```
Translating the file takes around 15 minutes on a GTX1080, or 5 minutes on a P100. Translating the file takes around 15 minutes on a GTX1080, or 5 minutes on a P100.
......
...@@ -22,17 +22,19 @@ from __future__ import absolute_import ...@@ -22,17 +22,19 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import re import re
import sys import sys
import unicodedata import unicodedata
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
import six import six
from absl import app as absl_app
from absl import flags
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.transformer.utils import metrics from official.transformer.utils import metrics
from official.utils.flags import core as flags_core
class UnicodeRegex(object): class UnicodeRegex(object):
...@@ -99,31 +101,37 @@ def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False): ...@@ -99,31 +101,37 @@ def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
def main(unused_argv): def main(unused_argv):
if FLAGS.bleu_variant is None or "uncased" in FLAGS.bleu_variant: if FLAGS.bleu_variant in ("both", "uncased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False) score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
print("Case-insensitive results:", score) tf.logging.info("Case-insensitive results: %f" % score)
if FLAGS.bleu_variant is None or "cased" in FLAGS.bleu_variant: if FLAGS.bleu_variant in ("both", "cased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True) score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
print("Case-sensitive results:", score) tf.logging.info("Case-sensitive results: %f" % score)
def define_compute_bleu_flags():
"""Add flags for computing BLEU score."""
flags.DEFINE_string(
name="translation", default=None,
help=flags_core.help_wrap("File containing translated text."))
flags.mark_flag_as_required("translation")
flags.DEFINE_string(
name="reference", default=None,
help=flags_core.help_wrap("File containing reference translation."))
flags.mark_flag_as_required("reference")
flags.DEFINE_enum(
name="bleu_variant", short_name="bv", default="both",
enum_values=["both", "uncased", "cased"], case_sensitive=False,
help=flags_core.help_wrap(
"Specify one or more BLEU variants to calculate. Variants: \"cased\""
", \"uncased\", or \"both\"."))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() tf.logging.set_verbosity(tf.logging.INFO)
parser.add_argument( define_compute_bleu_flags()
"--translation", "-t", type=str, default=None, required=True, FLAGS = flags.FLAGS
help="[default: %(default)s] File containing translated text.", absl_app.run(main)
metavar="<T>")
parser.add_argument(
"--reference", "-r", type=str, default=None, required=True,
help="[default: %(default)s] File containing reference translation",
metavar="<R>")
parser.add_argument(
"--bleu_variant", "-bv", type=str, choices=["uncased", "cased"],
nargs="*", default=None,
help="Specify one or more BLEU variants to calculate (both are "
"calculated by default. Variants: \"cased\" or \"uncased\".",
metavar="<BV>")
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
...@@ -18,19 +18,20 @@ from __future__ import absolute_import ...@@ -18,19 +18,20 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
import random import random
import sys
import tarfile import tarfile
import urllib
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
import six import six
from six.moves import urllib
from absl import app as absl_app
from absl import flags
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
# Data sources for training/evaluating the transformer translation model. # Data sources for training/evaluating the transformer translation model.
# If any of the training sources are changed, then either: # If any of the training sources are changed, then either:
...@@ -156,7 +157,7 @@ def download_from_url(path, url): ...@@ -156,7 +157,7 @@ def download_from_url(path, url):
filename = os.path.join(path, filename) filename = os.path.join(path, filename)
tf.logging.info("Downloading from %s to %s." % (url, filename)) tf.logging.info("Downloading from %s to %s." % (url, filename))
inprogress_filepath = filename + ".incomplete" inprogress_filepath = filename + ".incomplete"
inprogress_filepath, _ = urllib.urlretrieve( inprogress_filepath, _ = urllib.request.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook) url, inprogress_filepath, reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress. # Print newline to clear the carriage return from the download progress.
print() print()
...@@ -302,7 +303,7 @@ def encode_and_save_files( ...@@ -302,7 +303,7 @@ def encode_and_save_files(
for tmp_name, final_name in zip(tmp_filepaths, filepaths): for tmp_name, final_name in zip(tmp_filepaths, filepaths):
tf.gfile.Rename(tmp_name, final_name) tf.gfile.Rename(tmp_name, final_name)
tf.logging.info("Saved %d Examples", counter) tf.logging.info("Saved %d Examples", counter + 1)
return filepaths return filepaths
...@@ -363,8 +364,6 @@ def make_dir(path): ...@@ -363,8 +364,6 @@ def make_dir(path):
def main(unused_argv): def main(unused_argv):
"""Obtain training and evaluation data for the Transformer model.""" """Obtain training and evaluation data for the Transformer model."""
tf.logging.set_verbosity(tf.logging.INFO)
make_dir(FLAGS.raw_dir) make_dir(FLAGS.raw_dir)
make_dir(FLAGS.data_dir) make_dir(FLAGS.data_dir)
...@@ -398,22 +397,25 @@ def main(unused_argv): ...@@ -398,22 +397,25 @@ def main(unused_argv):
shuffle_records(fname) shuffle_records(fname)
def define_data_download_flags():
"""Add flags specifying data download arguments."""
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp/translate_ende",
help=flags_core.help_wrap(
"Directory for where the translate_ende_wmt32k dataset is saved."))
flags.DEFINE_string(
name="raw_dir", short_name="rd", default="/tmp/translate_ende_raw",
help=flags_core.help_wrap(
"Path where the raw data will be downloaded and extracted."))
flags.DEFINE_bool(
name="search", default=False,
help=flags_core.help_wrap(
"If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() tf.logging.set_verbosity(tf.logging.INFO)
parser.add_argument( define_data_download_flags()
"--data_dir", "-dd", type=str, default="/tmp/translate_ende", FLAGS = flags.FLAGS
help="[default: %(default)s] Directory for where the " absl_app.run(main)
"translate_ende_wmt32k dataset is saved.",
metavar="<DD>")
parser.add_argument(
"--raw_dir", "-rd", type=str, default="/tmp/translate_ende_raw",
help="[default: %(default)s] Path where the raw data will be downloaded "
"and extracted.",
metavar="<RD>")
parser.add_argument(
"--search", action="store_true",
help="If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE)
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
...@@ -12,19 +12,23 @@ ...@@ -12,19 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Creates an estimator to train the Transformer model.""" """Train and evaluate the Transformer model.
See README for description of setting the training schedule and evaluating the
BLEU score.
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
import sys
import tempfile import tempfile
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from absl import app as absl_app
from absl import flags
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
...@@ -36,11 +40,26 @@ from official.transformer.model import transformer ...@@ -36,11 +40,26 @@ from official.transformer.model import transformer
from official.transformer.utils import dataset from official.transformer.utils import dataset
from official.transformer.utils import metrics from official.transformer.utils import metrics
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import model_helpers
PARAMS_MAP = {
"base": model_params.TransformerBaseParams,
"big": model_params.TransformerBigParams,
}
DEFAULT_TRAIN_EPOCHS = 10 DEFAULT_TRAIN_EPOCHS = 10
BLEU_DIR = "bleu" BLEU_DIR = "bleu"
INF = int(1e9) INF = int(1e9)
# Dictionary containing tensors that are logged by the logging hooks. Each item
# maps a string to the tensor name.
TENSORS_TO_LOG = {
"learning_rate": "model/get_train_op/learning_rate/learning_rate",
"cross_entropy_loss": "model/cross_entropy"}
def model_fn(features, labels, mode, params): def model_fn(features, labels, mode, params):
"""Defines how to train, evaluate and predict from the transformer model.""" """Defines how to train, evaluate and predict from the transformer model."""
...@@ -66,6 +85,9 @@ def model_fn(features, labels, mode, params): ...@@ -66,6 +85,9 @@ def model_fn(features, labels, mode, params):
logits, targets, params.label_smoothing, params.vocab_size) logits, targets, params.label_smoothing, params.vocab_size)
loss = tf.reduce_sum(xentropy * weights) / tf.reduce_sum(weights) loss = tf.reduce_sum(xentropy * weights) / tf.reduce_sum(weights)
# Save loss as named tensor that will be logged with the logging hook.
tf.identity(loss, "cross_entropy")
if mode == tf.estimator.ModeKeys.EVAL: if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, predictions={"predictions": logits}, mode=mode, loss=loss, predictions={"predictions": logits},
...@@ -87,6 +109,10 @@ def get_learning_rate(learning_rate, hidden_size, learning_rate_warmup_steps): ...@@ -87,6 +109,10 @@ def get_learning_rate(learning_rate, hidden_size, learning_rate_warmup_steps):
# Apply rsqrt decay # Apply rsqrt decay
learning_rate *= tf.rsqrt(tf.maximum(step, warmup_steps)) learning_rate *= tf.rsqrt(tf.maximum(step, warmup_steps))
# Create a named tensor that will be logged using the logging hook.
# The full name includes variable and names scope. In this case, the name
# is model/get_train_op/learning_rate/learning_rate
tf.identity(learning_rate, "learning_rate")
# Save learning rate value to TensorBoard summary. # Save learning rate value to TensorBoard summary.
tf.summary.scalar("learning_rate", learning_rate) tf.summary.scalar("learning_rate", learning_rate)
...@@ -145,31 +171,22 @@ def get_global_step(estimator): ...@@ -145,31 +171,22 @@ def get_global_step(estimator):
return int(estimator.latest_checkpoint().split("-")[-1]) return int(estimator.latest_checkpoint().split("-")[-1])
def evaluate_and_log_bleu(estimator, bleu_writer, bleu_source, bleu_ref): def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file_path):
"""Calculate and record the BLEU score.""" """Calculate and record the BLEU score."""
subtokenizer = tokenizer.Subtokenizer( subtokenizer = tokenizer.Subtokenizer(vocab_file_path)
os.path.join(FLAGS.data_dir, FLAGS.vocab_file))
uncased_score, cased_score = translate_and_compute_bleu( uncased_score, cased_score = translate_and_compute_bleu(
estimator, subtokenizer, bleu_source, bleu_ref) estimator, subtokenizer, bleu_source, bleu_ref)
print("Bleu score (uncased):", uncased_score) tf.logging.info("Bleu score (uncased):", uncased_score)
print("Bleu score (cased):", cased_score) tf.logging.info("Bleu score (cased):", cased_score)
summary = tf.Summary(value=[
tf.Summary.Value(tag="bleu/uncased", simple_value=uncased_score),
tf.Summary.Value(tag="bleu/cased", simple_value=cased_score),
])
bleu_writer.add_summary(summary, get_global_step(estimator))
bleu_writer.flush()
return uncased_score, cased_score return uncased_score, cased_score
def train_schedule( def train_schedule(
estimator, train_eval_iterations, single_iteration_train_steps=None, estimator, train_eval_iterations, single_iteration_train_steps=None,
single_iteration_train_epochs=None, bleu_source=None, bleu_ref=None, single_iteration_train_epochs=None, train_hooks=None, benchmark_logger=None,
bleu_threshold=None): bleu_source=None, bleu_ref=None, bleu_threshold=None, vocab_file_path=None):
"""Train and evaluate model, and optionally compute model's BLEU score. """Train and evaluate model, and optionally compute model's BLEU score.
**Step vs. Epoch vs. Iteration** **Step vs. Epoch vs. Iteration**
...@@ -198,9 +215,12 @@ def train_schedule( ...@@ -198,9 +215,12 @@ def train_schedule(
train_eval_iterations: Number of times to repeat the train+eval iteration. train_eval_iterations: Number of times to repeat the train+eval iteration.
single_iteration_train_steps: Number of steps to train in one iteration. single_iteration_train_steps: Number of steps to train in one iteration.
single_iteration_train_epochs: Number of epochs to train in one iteration. single_iteration_train_epochs: Number of epochs to train in one iteration.
train_hooks: List of hooks to pass to the estimator during training.
benchmark_logger: a BenchmarkLogger object that logs evaluation data
bleu_source: File containing text to be translated for BLEU calculation. bleu_source: File containing text to be translated for BLEU calculation.
bleu_ref: File containing reference translations for BLEU calculation. bleu_ref: File containing reference translations for BLEU calculation.
bleu_threshold: minimum BLEU score before training is stopped. bleu_threshold: minimum BLEU score before training is stopped.
vocab_file_path: Path to vocabulary file used to subtokenize bleu_source.
Raises: Raises:
ValueError: if both or none of single_iteration_train_steps and ValueError: if both or none of single_iteration_train_steps and
...@@ -221,22 +241,24 @@ def train_schedule( ...@@ -221,22 +241,24 @@ def train_schedule(
evaluate_bleu = bleu_source is not None and bleu_ref is not None evaluate_bleu = bleu_source is not None and bleu_ref is not None
# Print out training schedule # Print details of training schedule.
print("Training schedule:") tf.logging.info("Training schedule:")
if single_iteration_train_epochs is not None: if single_iteration_train_epochs is not None:
print("\t1. Train for %d epochs." % single_iteration_train_epochs) tf.logging.info("\t1. Train for %d epochs." % single_iteration_train_epochs)
else: else:
print("\t1. Train for %d steps." % single_iteration_train_steps) tf.logging.info("\t1. Train for %d steps." % single_iteration_train_steps)
print("\t2. Evaluate model.") tf.logging.info("\t2. Evaluate model.")
if evaluate_bleu: if evaluate_bleu:
print("\t3. Compute BLEU score.") tf.logging.info("\t3. Compute BLEU score.")
if bleu_threshold is not None: if bleu_threshold is not None:
print("Repeat above steps until the BLEU score reaches", bleu_threshold) tf.logging.info("Repeat above steps until the BLEU score reaches %f" %
bleu_threshold)
if not evaluate_bleu or bleu_threshold is None: if not evaluate_bleu or bleu_threshold is None:
print("Repeat above steps %d times." % train_eval_iterations) tf.logging.info("Repeat above steps %d times." % train_eval_iterations)
if evaluate_bleu: if evaluate_bleu:
# Set summary writer to log bleu score. # Create summary writer to log bleu score (values can be displayed in
# Tensorboard).
bleu_writer = tf.summary.FileWriter( bleu_writer = tf.summary.FileWriter(
os.path.join(estimator.model_dir, BLEU_DIR)) os.path.join(estimator.model_dir, BLEU_DIR))
if bleu_threshold is not None: if bleu_threshold is not None:
...@@ -245,145 +267,201 @@ def train_schedule( ...@@ -245,145 +267,201 @@ def train_schedule(
# Loop training/evaluation/bleu cycles # Loop training/evaluation/bleu cycles
for i in xrange(train_eval_iterations): for i in xrange(train_eval_iterations):
print("Starting iteration", i + 1) tf.logging.info("Starting iteration %d" % (i + 1))
# Train the model for single_iteration_train_steps or until the input fn # Train the model for single_iteration_train_steps or until the input fn
# runs out of examples (if single_iteration_train_steps is None). # runs out of examples (if single_iteration_train_steps is None).
estimator.train(dataset.train_input_fn, steps=single_iteration_train_steps) estimator.train(
dataset.train_input_fn, steps=single_iteration_train_steps,
hooks=train_hooks)
eval_results = estimator.evaluate(dataset.eval_input_fn) eval_results = estimator.evaluate(dataset.eval_input_fn)
print("Evaluation results (iter %d/%d):" % (i + 1, train_eval_iterations), tf.logging.info("Evaluation results (iter %d/%d):" %
eval_results) (i + 1, train_eval_iterations))
tf.logging.info(eval_results)
benchmark_logger.log_evaluation_result(eval_results)
# The results from estimator.evaluate() are measured on an approximate
# translation, which utilize the target golden values provided. The actual
# bleu score must be computed using the estimator.predict() path, which
# outputs translations that are not based on golden values. The translations
# are compared to reference file to get the actual bleu score.
if evaluate_bleu: if evaluate_bleu:
uncased_score, _ = evaluate_and_log_bleu( uncased_score, cased_score = evaluate_and_log_bleu(
estimator, bleu_writer, bleu_source, bleu_ref) estimator, bleu_source, bleu_ref, vocab_file_path)
if bleu_threshold is not None and uncased_score > bleu_threshold:
# Write actual bleu scores using summary writer and benchmark logger
global_step = get_global_step(estimator)
summary = tf.Summary(value=[
tf.Summary.Value(tag="bleu/uncased", simple_value=uncased_score),
tf.Summary.Value(tag="bleu/cased", simple_value=cased_score),
])
bleu_writer.add_summary(summary, global_step)
bleu_writer.flush()
benchmark_logger.log_metric(
"bleu_uncased", uncased_score, global_step=global_step)
benchmark_logger.log_metric(
"bleu_cased", cased_score, global_step=global_step)
# Stop training if bleu stopping threshold is met.
if model_helpers.past_stop_threshold(bleu_threshold, uncased_score):
bleu_writer.close() bleu_writer.close()
break break
def main(_): def define_transformer_flags():
# Set logging level to INFO to display training progress (logged by the """Add flags and flag validators for running transformer_main."""
# estimator) # Add common flags (data_dir, model_dir, train_epochs, etc.).
tf.logging.set_verbosity(tf.logging.INFO) flags_core.define_base(multi_gpu=False, num_gpu=False, export_dir=False)
flags_core.define_performance(
num_parallel_calls=True,
inter_op=False,
intra_op=False,
synthetic_data=False,
max_train_steps=False,
dtype=False
)
flags_core.define_benchmark()
# Set flags from the flags_core module as "key flags" so they're listed when
# the '-h' flag is used. Without this line, the flags defined above are
# only shown in the full `--helpful` help text.
flags.adopt_module_key_flags(flags_core)
# Add transformer-specific flags
flags.DEFINE_enum(
name="param_set", short_name="mp", default="big",
enum_values=["base", "big"],
help=flags_core.help_wrap(
"Parameter set to use when creating and training the model. The "
"parameters define the input shape (batch size and max length), "
"model configuration (size of embedding, # of hidden layers, etc.), "
"and various other settings. The big parameter set increases the "
"default batch size, embedding/hidden size, and filter size. For a "
"complete list of parameters, please see model/model_params.py."))
if FLAGS.params == "base": # Flags for training with steps (may be used for debugging)
params = model_params.TransformerBaseParams flags.DEFINE_integer(
elif FLAGS.params == "big": name="train_steps", short_name="ts", default=None,
params = model_params.TransformerBigParams help=flags_core.help_wrap("The number of steps used to train."))
else: flags.DEFINE_integer(
raise ValueError("Invalid parameter set defined: %s." name="steps_between_evals", short_name="sbe", default=1000,
"Expected 'base' or 'big.'" % FLAGS.params) help=flags_core.help_wrap(
"The Number of training steps to run between evaluations. This is "
"used if --train_steps is defined."))
# BLEU score computation
flags.DEFINE_string(
name="bleu_source", short_name="bls", default=None,
help=flags_core.help_wrap(
"Path to source file containing text translate when calculating the "
"official BLEU score. --bleu_source, --bleu_ref, and --vocab_file "
"must be set. Use the flag --stop_threshold to stop the script based "
"on the uncased BLEU score."))
flags.DEFINE_string(
name="bleu_ref", short_name="blr", default=None,
help=flags_core.help_wrap(
"Path to source file containing text translate when calculating the "
"official BLEU score. --bleu_source, --bleu_ref, and --vocab_file "
"must be set. Use the flag --stop_threshold to stop the script based "
"on the uncased BLEU score."))
flags.DEFINE_string(
name="vocab_file", short_name="vf", default=VOCAB_FILE,
help=flags_core.help_wrap(
"Name of vocabulary file containing subtokens for subtokenizing the "
"bleu_source file. This file is expected to be in the directory "
"defined by --data_dir."))
flags_core.set_defaults(data_dir="/tmp/translate_ende",
model_dir="/tmp/transformer_model",
batch_size=None,
train_epochs=None)
@flags.multi_flags_validator(
["train_epochs", "train_steps"],
message="Both --train_steps and --train_epochs were set. Only one may be "
"defined.")
def _check_train_limits(flag_dict):
return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None
@flags.multi_flags_validator(
["data_dir", "bleu_source", "bleu_ref", "vocab_file"],
message="--bleu_source, --bleu_ref, and/or --vocab_file don't exist. "
"Please ensure that the file paths are correct.")
def _check_bleu_files(flags_dict):
"""Validate files when bleu_source and bleu_ref are defined."""
if flags_dict["bleu_source"] is None or flags_dict["bleu_ref"] is None:
return True
# Ensure that bleu_source, bleu_ref, and vocab files exist.
vocab_file_path = os.path.join(
flags_dict["data_dir"], flags_dict["vocab_file"])
return all([
tf.gfile.Exists(flags_dict["bleu_source"]),
tf.gfile.Exists(flags_dict["bleu_ref"]),
tf.gfile.Exists(vocab_file_path)])
def run_transformer(flags_obj):
"""Create tf.Estimator to train and evaluate transformer model.
Args:
flags_obj: Object containing parsed flag values.
"""
# Determine training schedule based on flags. # Determine training schedule based on flags.
if FLAGS.train_steps is not None and FLAGS.train_epochs is not None: if flags_obj.train_steps is not None:
raise ValueError("Both --train_steps and --train_epochs were set. Only one " train_eval_iterations = (
"may be defined.") flags_obj.train_steps // flags_obj.steps_between_evals)
if FLAGS.train_steps is not None: single_iteration_train_steps = flags_obj.steps_between_evals
train_eval_iterations = FLAGS.train_steps // FLAGS.steps_between_eval
single_iteration_train_steps = FLAGS.steps_between_eval
single_iteration_train_epochs = None single_iteration_train_epochs = None
else: else:
if FLAGS.train_epochs is None: train_epochs = flags_obj.train_epochs or DEFAULT_TRAIN_EPOCHS
FLAGS.train_epochs = DEFAULT_TRAIN_EPOCHS train_eval_iterations = train_epochs // flags_obj.epochs_between_evals
train_eval_iterations = FLAGS.train_epochs // FLAGS.epochs_between_eval
single_iteration_train_steps = None single_iteration_train_steps = None
single_iteration_train_epochs = FLAGS.epochs_between_eval single_iteration_train_epochs = flags_obj.epochs_between_evals
# Make sure that the BLEU source and ref files if set
if FLAGS.bleu_source is not None and FLAGS.bleu_ref is not None:
if not tf.gfile.Exists(FLAGS.bleu_source):
raise ValueError("BLEU source file %s does not exist" % FLAGS.bleu_source)
if not tf.gfile.Exists(FLAGS.bleu_ref):
raise ValueError("BLEU source file %s does not exist" % FLAGS.bleu_ref)
# Add flag-defined parameters to params object # Add flag-defined parameters to params object
params.data_dir = FLAGS.data_dir params = PARAMS_MAP[flags_obj.param_set]
params.num_cpu_cores = FLAGS.num_cpu_cores params.data_dir = flags_obj.data_dir
params.epochs_between_eval = FLAGS.epochs_between_eval params.num_parallel_calls = flags_obj.num_parallel_calls
params.epochs_between_evals = flags_obj.epochs_between_evals
params.repeat_dataset = single_iteration_train_epochs params.repeat_dataset = single_iteration_train_epochs
params.batch_size = flags_obj.batch_size or params.batch_size
# Create hooks that log information about the training and metric values
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
tensors_to_log=TENSORS_TO_LOG, # used for logging hooks
batch_size=params.batch_size # for ExamplesPerSecondHook
)
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
benchmark_logger.log_run_info(
model_name="transformer",
dataset_name="wmt_translate_ende",
run_params=params.__dict__)
# Train and evaluate transformer model
estimator = tf.estimator.Estimator( estimator = tf.estimator.Estimator(
model_fn=model_fn, model_dir=FLAGS.model_dir, params=params) model_fn=model_fn, model_dir=flags_obj.model_dir, params=params)
train_schedule( train_schedule(
estimator, train_eval_iterations, single_iteration_train_steps, estimator=estimator,
single_iteration_train_epochs, FLAGS.bleu_source, FLAGS.bleu_ref, # Training arguments
FLAGS.bleu_threshold) train_eval_iterations=train_eval_iterations,
single_iteration_train_steps=single_iteration_train_steps,
single_iteration_train_epochs=single_iteration_train_epochs,
train_hooks=train_hooks,
benchmark_logger=benchmark_logger,
# BLEU calculation arguments
bleu_source=flags_obj.bleu_source,
bleu_ref=flags_obj.bleu_ref,
bleu_threshold=flags_obj.stop_threshold,
vocab_file_path=os.path.join(flags_obj.data_dir, flags_obj.vocab_file))
if __name__ == "__main__": def main(_):
parser = argparse.ArgumentParser() run_transformer(flags.FLAGS)
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/translate_ende",
help="[default: %(default)s] Directory containing training and "
"evaluation data, and vocab file used for encoding.",
metavar="<DD>")
parser.add_argument(
"--vocab_file", "-vf", type=str, default=VOCAB_FILE,
help="[default: %(default)s] Name of vocabulary file.",
metavar="<vf>")
parser.add_argument(
"--model_dir", "-md", type=str, default="/tmp/transformer_model",
help="[default: %(default)s] Directory to save Transformer model "
"training checkpoints",
metavar="<MD>")
parser.add_argument(
"--params", "-p", type=str, default="big", choices=["base", "big"],
help="[default: %(default)s] Parameter set to use when creating and "
"training the model.",
metavar="<P>")
parser.add_argument(
"--num_cpu_cores", "-nc", type=int, default=4,
help="[default: %(default)s] Number of CPU cores to use in the input "
"pipeline.",
metavar="<NC>")
# Flags for training with epochs. (default)
parser.add_argument(
"--train_epochs", "-te", type=int, default=None,
help="The number of epochs used to train. If both --train_epochs and "
"--train_steps are not set, the model will train for %d epochs." %
DEFAULT_TRAIN_EPOCHS,
metavar="<TE>")
parser.add_argument(
"--epochs_between_eval", "-ebe", type=int, default=1,
help="[default: %(default)s] The number of training epochs to run "
"between evaluations.",
metavar="<TE>")
# Flags for training with steps (may be used for debugging)
parser.add_argument(
"--train_steps", "-ts", type=int, default=None,
help="Total number of training steps. If both --train_epochs and "
"--train_steps are not set, the model will train for %d epochs." %
DEFAULT_TRAIN_EPOCHS,
metavar="<TS>")
parser.add_argument(
"--steps_between_eval", "-sbe", type=int, default=1000,
help="[default: %(default)s] Number of training steps to run between "
"evaluations.",
metavar="<SBE>")
# BLEU score computation if __name__ == "__main__":
parser.add_argument( tf.logging.set_verbosity(tf.logging.INFO)
"--bleu_source", "-bs", type=str, default=None, define_transformer_flags()
help="Path to source file containing text translate when calculating the " absl_app.run(main)
"official BLEU score. Both --bleu_source and --bleu_ref must be "
"set. The BLEU score will be calculated during model evaluation.",
metavar="<BS>")
parser.add_argument(
"--bleu_ref", "-br", type=str, default=None,
help="Path to file containing the reference translation for calculating "
"the official BLEU score. Both --bleu_source and --bleu_ref must be "
"set. The BLEU score will be calculated during model evaluation.",
metavar="<BR>")
parser.add_argument(
"--bleu_threshold", "-bt", type=float, default=None,
help="Stop training when the uncased BLEU score reaches this value. "
"Setting this overrides the total number of steps or epochs set by "
"--train_steps or --train_epochs.",
metavar="<BT>")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
...@@ -18,18 +18,19 @@ from __future__ import absolute_import ...@@ -18,18 +18,19 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
import sys
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from absl import app as absl_app
from absl import flags
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.transformer.data_download import VOCAB_FILE from official.transformer.data_download import VOCAB_FILE
from official.transformer.model import model_params from official.transformer.model import model_params
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
_DECODE_BATCH_SIZE = 32 _DECODE_BATCH_SIZE = 32
_EXTRA_DECODE_LENGTH = 100 _EXTRA_DECODE_LENGTH = 100
...@@ -106,7 +107,8 @@ def translate_file( ...@@ -106,7 +107,8 @@ def translate_file(
if i % batch_size == 0: if i % batch_size == 0:
batch_num = (i // batch_size) + 1 batch_num = (i // batch_size) + 1
print("Decoding batch %d out of %d." % (batch_num, num_decode_batches)) tf.logging.info("Decoding batch %d out of %d." %
(batch_num, num_decode_batches))
yield _encode_and_add_eos(line, subtokenizer) yield _encode_and_add_eos(line, subtokenizer)
def input_fn(): def input_fn():
...@@ -122,10 +124,8 @@ def translate_file( ...@@ -122,10 +124,8 @@ def translate_file(
translations.append(translation) translations.append(translation)
if print_all_translations: if print_all_translations:
print("Translating:") tf.logging.info("Translating:\n\tInput: %s\n\tOutput: %s" %
print("\tInput: %s" % sorted_inputs[i]) (sorted_inputs[i], translation))
print("\tOutput: %s\n" % translation)
print("=" * 100)
# Write translations in the order they appeared in the original file. # Write translations in the order they appeared in the original file.
if output_file is not None: if output_file is not None:
...@@ -150,7 +150,7 @@ def translate_text(estimator, subtokenizer, txt): ...@@ -150,7 +150,7 @@ def translate_text(estimator, subtokenizer, txt):
predictions = estimator.predict(input_fn) predictions = estimator.predict(input_fn)
translation = next(predictions)["outputs"] translation = next(predictions)["outputs"]
translation = _trim_and_decode(translation, subtokenizer) translation = _trim_and_decode(translation, subtokenizer)
print("Translation of \"%s\": \"%s\"" % (txt, translation)) tf.logging.info("Translation of \"%s\": \"%s\"" % (txt, translation))
def main(unused_argv): def main(unused_argv):
...@@ -166,15 +166,8 @@ def main(unused_argv): ...@@ -166,15 +166,8 @@ def main(unused_argv):
subtokenizer = tokenizer.Subtokenizer( subtokenizer = tokenizer.Subtokenizer(
os.path.join(FLAGS.data_dir, FLAGS.vocab_file)) os.path.join(FLAGS.data_dir, FLAGS.vocab_file))
if FLAGS.params == "base":
params = model_params.TransformerBaseParams
elif FLAGS.params == "big":
params = model_params.TransformerBigParams
else:
raise ValueError("Invalid parameter set defined: %s."
"Expected 'base' or 'big.'" % FLAGS.params)
# Set up estimator and params # Set up estimator and params
params = transformer_main.PARAMS_MAP[FLAGS.param_set]
params.beam_size = _BEAM_SIZE params.beam_size = _BEAM_SIZE
params.alpha = _ALPHA params.alpha = _ALPHA
params.extra_decode_length = _EXTRA_DECODE_LENGTH params.extra_decode_length = _EXTRA_DECODE_LENGTH
...@@ -201,45 +194,50 @@ def main(unused_argv): ...@@ -201,45 +194,50 @@ def main(unused_argv):
translate_file(estimator, subtokenizer, input_file, output_file) translate_file(estimator, subtokenizer, input_file, output_file)
def define_translate_flags():
"""Define flags used for translation script."""
# Model and vocab file flags
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp/translate_ende",
help=flags_core.help_wrap(
"Directory for where the translate_ende_wmt32k dataset is saved."))
flags.DEFINE_string(
name="vocab_file", short_name="vf", default=VOCAB_FILE,
help=flags_core.help_wrap(
"Name of vocabulary file containing subtokens for subtokenizing the "
"input text or file. This file is expected to be in the directory "
"defined by --data_dir."))
flags.DEFINE_string(
name="model_dir", short_name="md", default="/tmp/transformer_model",
help=flags_core.help_wrap(
"Directory containing Transformer model checkpoints."))
flags.DEFINE_enum(
name="param_set", short_name="mp", default="big",
enum_values=["base", "big"],
help=flags_core.help_wrap(
"Parameter set to use when creating and training the model. The "
"parameters define the input shape (batch size and max length), "
"model configuration (size of embedding, # of hidden layers, etc.), "
"and various other settings. The big parameter set increases the "
"default batch size, embedding/hidden size, and filter size. For a "
"complete list of parameters, please see model/model_params.py."))
flags.DEFINE_string(
name="text", default=None,
help=flags_core.help_wrap(
"Text to translate. Output will be printed to console."))
flags.DEFINE_string(
name="file", default=None,
help=flags_core.help_wrap(
"File containing text to translate. Translation will be printed to "
"console and, if --file_out is provided, saved to an output file."))
flags.DEFINE_string(
name="file_out", default=None,
help=flags_core.help_wrap(
"If --file flag is specified, save translation to this file."))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() define_translate_flags()
FLAGS = flags.FLAGS
# Model arguments absl_app.run(main)
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/data/translate_ende",
help="[default: %(default)s] Directory where vocab file is stored.",
metavar="<DD>")
parser.add_argument(
"--vocab_file", "-vf", type=str, default=VOCAB_FILE,
help="[default: %(default)s] Name of vocabulary file.",
metavar="<vf>")
parser.add_argument(
"--model_dir", "-md", type=str, default="/tmp/transformer_model",
help="[default: %(default)s] Directory containing Transformer model "
"checkpoints.",
metavar="<MD>")
parser.add_argument(
"--params", "-p", type=str, default="big", choices=["base", "big"],
help="[default: %(default)s] Parameter used for trained model.",
metavar="<P>")
# Flags for specifying text/file to be translated.
parser.add_argument(
"--text", "-t", type=str, default=None,
help="[default: %(default)s] Text to translate. Output will be printed "
"to console.",
metavar="<T>")
parser.add_argument(
"--file", "-f", type=str, default=None,
help="[default: %(default)s] File containing text to translate. "
"Translation will be printed to console and, if --file_out is "
"provided, saved to an output file.",
metavar="<F>")
parser.add_argument(
"--file_out", "-fo", type=str, default=None,
help="[default: %(default)s] If --file flag is specified, save "
"translation to this file.",
metavar="<FO>")
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
...@@ -190,14 +190,14 @@ def _batch_examples(dataset, batch_size, max_length): ...@@ -190,14 +190,14 @@ def _batch_examples(dataset, batch_size, max_length):
def _read_and_batch_from_files( def _read_and_batch_from_files(
file_pattern, batch_size, max_length, num_cpu_cores, shuffle, repeat): file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat):
"""Create dataset where each item is a dict of "inputs" and "targets". """Create dataset where each item is a dict of "inputs" and "targets".
Args: Args:
file_pattern: String used to match the input TFRecord files. file_pattern: String used to match the input TFRecord files.
batch_size: Maximum number of tokens per batch of examples batch_size: Maximum number of tokens per batch of examples
max_length: Maximum number of tokens per example max_length: Maximum number of tokens per example
num_cpu_cores: Number of cpu cores for parallel input processing. num_parallel_calls: Number of cpu cores for parallel input processing.
shuffle: If true, randomizes order of elements. shuffle: If true, randomizes order of elements.
repeat: Number of times to repeat the dataset. If None, the dataset is repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever. repeated forever.
...@@ -215,12 +215,12 @@ def _read_and_batch_from_files( ...@@ -215,12 +215,12 @@ def _read_and_batch_from_files(
# will be non-deterministic. # will be non-deterministic.
dataset = dataset.apply( dataset = dataset.apply(
tf.contrib.data.parallel_interleave( tf.contrib.data.parallel_interleave(
_load_records, sloppy=shuffle, cycle_length=num_cpu_cores)) _load_records, sloppy=shuffle, cycle_length=num_parallel_calls))
# Parse each tf.Example into a dictionary # Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization. # TODO: Look into prefetch_input_elements for performance optimization.
dataset = dataset.map(_parse_example, dataset = dataset.map(_parse_example,
num_parallel_calls=num_cpu_cores) num_parallel_calls=num_parallel_calls)
# Remove examples where the input or target length exceeds the maximum length, # Remove examples where the input or target length exceeds the maximum length,
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length)) dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
...@@ -238,13 +238,13 @@ def train_input_fn(params): ...@@ -238,13 +238,13 @@ def train_input_fn(params):
"""Load and return dataset of batched examples for use during training.""" """Load and return dataset of batched examples for use during training."""
file_pattern = os.path.join(getattr(params, "data_dir", ""), "*train*") file_pattern = os.path.join(getattr(params, "data_dir", ""), "*train*")
return _read_and_batch_from_files( return _read_and_batch_from_files(
file_pattern, params.batch_size, params.max_length, params.num_cpu_cores, file_pattern, params.batch_size, params.max_length,
shuffle=True, repeat=params.repeat_dataset) params.num_parallel_calls, shuffle=True, repeat=params.repeat_dataset)
def eval_input_fn(params): def eval_input_fn(params):
"""Load and return dataset of batched examples for use during evaluation.""" """Load and return dataset of batched examples for use during evaluation."""
file_pattern = os.path.join(getattr(params, "data_dir", ""), "*dev*") file_pattern = os.path.join(getattr(params, "data_dir", ""), "*dev*")
return _read_and_batch_from_files( return _read_and_batch_from_files(
file_pattern, params.batch_size, params.max_length, params.num_cpu_cores, file_pattern, params.batch_size, params.max_length,
shuffle=False, repeat=1) params.num_parallel_calls, shuffle=False, repeat=1)
...@@ -89,8 +89,9 @@ def define_flags(): ...@@ -89,8 +89,9 @@ def define_flags():
flags.adopt_key_flags(flags_core) flags.adopt_key_flags(flags_core)
def main(flags_obj): def main(_):
pass flags_obj = flags.FLAGS
print(flags_obj)
if __name__ == "__main__" if __name__ == "__main__"
...@@ -120,3 +121,32 @@ class BaseTester(unittest.TestCase): ...@@ -120,3 +121,32 @@ class BaseTester(unittest.TestCase):
self.AssertEqual(flags.FLAGS.test_flag, "def") self.AssertEqual(flags.FLAGS.test_flag, "def")
``` ```
## Immutability
Flag values should not be mutated. Instead, use getter functions to return
the desired values. An example getter function is `get_loss_scale` function
below:
```
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
}
def get_loss_scale(flags_obj):
if flags_obj.loss_scale is not None:
return flags_obj.loss_scale
return DTYPE_MAP[flags_obj.dtype][1]
def main(_):
flags_obj = flags.FLAGS()
# Do not mutate flags_obj
# if flags_obj.loss_scale is None:
# flags_obj.loss_scale = DTYPE_MAP[flags_obj.dtype][1] # Don't do this
print(get_loss_scale(flags_obj))
...
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment