Unverified Commit a84e1ef9 authored by Katherine Wu's avatar Katherine Wu Committed by GitHub
Browse files

Add official flag-parsing and benchmarking logging utils to Transformer (#4163)

parent fe1857cd
......@@ -14,7 +14,7 @@ The model also applies embeddings on the input and output tokens, and adds a con
* [Training times](#training-times)
* [Evaluation results](#evaluation-results)
* [Detailed instructions](#detailed-instructions)
* [Export variables (optional)](#export-variables-optional)
* [Environment preparation](#environment-preparation)
* [Download and preprocess datasets](#download-and-preprocess-datasets)
* [Model training and evaluation](#model-training-and-evaluation)
* [Translate using the model](#translate-using-the-model)
......@@ -31,16 +31,23 @@ The model also applies embeddings on the input and output tokens, and adds a con
Below are the commands for running the Transformer model. See the [Detailed instrutions](#detailed-instructions) for more details on running the model.
```
PARAMS=big
cd /path/to/models/official/transformer
# Ensure that PYTHONPATH is correctly defined as described in
# https://github.com/tensorflow/models/tree/master/official#running-the-models
# export PYTHONPATH="$PYTHONPATH:/path/to/models"
# Export variables
PARAM_SET=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAMS
MODEL_DIR=$HOME/transformer/model_$PARAM_SET
# Download training/evaluation datasets
python data_download.py --data_dir=$DATA_DIR
# Train the model for 10 epochs, and evaluate after every epoch.
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de
--param_set=$PARAM_SET --bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de
# Run during training in a separate process to get continuous updates,
# or after training is complete.
......@@ -48,21 +55,21 @@ tensorboard --logdir=$MODEL_DIR
# Translate some text using the trained model
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --text="hello world"
--param_set=$PARAM_SET --text="hello world"
# Compute model's BLEU score using the newstest2014 dataset.
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en
--param_set=$PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de
```
## Benchmarks
### Training times
Currently, both big and base params run on a single GPU. The measurements below
Currently, both big and base parameter sets run on a single GPU. The measurements below
are reported from running the model on a P100 GPU.
Params | batches/sec | batches per epoch | time per epoch
Param Set | batches/sec | batches per epoch | time per epoch
--- | --- | --- | ---
base | 4.8 | 83244 | 4 hr
big | 1.1 | 41365 | 10 hr
......@@ -70,7 +77,7 @@ big | 1.1 | 41365 | 10 hr
### Evaluation results
Below are the case-insensitive BLEU scores after 10 epochs.
Params | Score
Param Set | Score
--- | --- |
base | 27.7
big | 28.9
......@@ -79,13 +86,18 @@ big | 28.9
## Detailed instructions
0. ### Export variables (optional)
0. ### Environment preparation
#### Add models repo to PYTHONPATH
Follow the instructions described in the [Running the models](https://github.com/tensorflow/models/tree/master/official#running-the-models) section to add the models folder to the python path.
#### Export variables (optional)
Export the following variables, or modify the values in each of the snippets below:
```
PARAMS=big
PARAM_SET=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAMS
MODEL_DIR=$HOME/transformer/model_$PARAM_SET
```
1. ### Download and preprocess datasets
......@@ -109,13 +121,13 @@ big | 28.9
Command to run:
```
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --param_set=$PARAM_SET
```
Arguments:
* `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
* `--model_dir`: Directory to save Transformer model training checkpoints.
* `--params`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
* `--param_set`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
* Use the `--help` or `-h` flag to get a full list of possible arguments.
#### Customizing training schedule
......@@ -123,12 +135,12 @@ big | 28.9
By default, the model will train for 10 epochs, and evaluate after every epoch. The training schedule may be defined through the flags:
* Training with epochs (default):
* `--train_epochs`: The total number of complete passes to make through the dataset
* `--epochs_between_eval`: The number of epochs to train between evaluations.
* `--epochs_between_evals`: The number of epochs to train between evaluations.
* Training with steps:
* `--train_steps`: sets the total number of training steps to run.
* `--steps_between_eval`: Number of training steps to run between evaluations.
* `--steps_between_evals`: Number of training steps to run between evaluations.
Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_eval=1000`.
Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_evals=1000`.
Note: At the beginning of each training session, the training dataset is reloaded and shuffled. Stopping the training before completing an epoch may result in worse model quality, due to the chance that some examples may be seen more than others. Therefore, it is recommended to use epochs when the model quality is important.
......@@ -137,7 +149,7 @@ big | 28.9
Use these flags to compute the BLEU when the model evaluates:
* `--bleu_source`: Path to file containing text to translate.
* `--bleu_ref`: Path to file containing the reference translation.
* `--bleu_threshold`: Train until the BLEU score reaches this lower bound. This setting overrides the `--train_steps` and `--train_epochs` flags.
* `--stop_threshold`: Train until the BLEU score reaches this lower bound. This setting overrides the `--train_steps` and `--train_epochs` flags.
The test source and reference files located in the `test_data` directory are extracted from the preprocessed dataset from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data).
......@@ -155,12 +167,12 @@ big | 28.9
Command to run:
```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS --text="hello world"
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --param_set=PARAM_SET --text="hello world"
```
Arguments for initializing the Subtokenizer and trained model:
* `--data_dir`: Used to locate the vocabulary file to create a Subtokenizer, which encodes the input and decodes the model output.
* `--model_dir` and `--params`: These parameters are used to rebuild the trained model
* `--model_dir` and `--param_set`: These parameters are used to rebuild the trained model
Arguments for specifying what to translate:
* `--text`: Text to translate
......@@ -170,7 +182,7 @@ big | 28.9
To translate the newstest2014 data, run:
```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en
--param_set=PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
```
Translating the file takes around 15 minutes on a GTX1080, or 5 minutes on a P100.
......
......@@ -22,17 +22,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import re
import sys
import unicodedata
# pylint: disable=g-bad-import-order
import six
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.utils import metrics
from official.utils.flags import core as flags_core
class UnicodeRegex(object):
......@@ -99,31 +101,37 @@ def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
def main(unused_argv):
if FLAGS.bleu_variant is None or "uncased" in FLAGS.bleu_variant:
if FLAGS.bleu_variant in ("both", "uncased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
print("Case-insensitive results:", score)
tf.logging.info("Case-insensitive results: %f" % score)
if FLAGS.bleu_variant is None or "cased" in FLAGS.bleu_variant:
if FLAGS.bleu_variant in ("both", "cased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
print("Case-sensitive results:", score)
tf.logging.info("Case-sensitive results: %f" % score)
def define_compute_bleu_flags():
"""Add flags for computing BLEU score."""
flags.DEFINE_string(
name="translation", default=None,
help=flags_core.help_wrap("File containing translated text."))
flags.mark_flag_as_required("translation")
flags.DEFINE_string(
name="reference", default=None,
help=flags_core.help_wrap("File containing reference translation."))
flags.mark_flag_as_required("reference")
flags.DEFINE_enum(
name="bleu_variant", short_name="bv", default="both",
enum_values=["both", "uncased", "cased"], case_sensitive=False,
help=flags_core.help_wrap(
"Specify one or more BLEU variants to calculate. Variants: \"cased\""
", \"uncased\", or \"both\"."))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--translation", "-t", type=str, default=None, required=True,
help="[default: %(default)s] File containing translated text.",
metavar="<T>")
parser.add_argument(
"--reference", "-r", type=str, default=None, required=True,
help="[default: %(default)s] File containing reference translation",
metavar="<R>")
parser.add_argument(
"--bleu_variant", "-bv", type=str, choices=["uncased", "cased"],
nargs="*", default=None,
help="Specify one or more BLEU variants to calculate (both are "
"calculated by default. Variants: \"cased\" or \"uncased\".",
metavar="<BV>")
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
tf.logging.set_verbosity(tf.logging.INFO)
define_compute_bleu_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
......@@ -18,19 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import random
import sys
import tarfile
import urllib
# pylint: disable=g-bad-import-order
import six
from six.moves import urllib
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
# Data sources for training/evaluating the transformer translation model.
# If any of the training sources are changed, then either:
......@@ -156,7 +157,7 @@ def download_from_url(path, url):
filename = os.path.join(path, filename)
tf.logging.info("Downloading from %s to %s." % (url, filename))
inprogress_filepath = filename + ".incomplete"
inprogress_filepath, _ = urllib.urlretrieve(
inprogress_filepath, _ = urllib.request.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress.
print()
......@@ -302,7 +303,7 @@ def encode_and_save_files(
for tmp_name, final_name in zip(tmp_filepaths, filepaths):
tf.gfile.Rename(tmp_name, final_name)
tf.logging.info("Saved %d Examples", counter)
tf.logging.info("Saved %d Examples", counter + 1)
return filepaths
......@@ -363,8 +364,6 @@ def make_dir(path):
def main(unused_argv):
"""Obtain training and evaluation data for the Transformer model."""
tf.logging.set_verbosity(tf.logging.INFO)
make_dir(FLAGS.raw_dir)
make_dir(FLAGS.data_dir)
......@@ -398,22 +397,25 @@ def main(unused_argv):
shuffle_records(fname)
def define_data_download_flags():
"""Add flags specifying data download arguments."""
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp/translate_ende",
help=flags_core.help_wrap(
"Directory for where the translate_ende_wmt32k dataset is saved."))
flags.DEFINE_string(
name="raw_dir", short_name="rd", default="/tmp/translate_ende_raw",
help=flags_core.help_wrap(
"Path where the raw data will be downloaded and extracted."))
flags.DEFINE_bool(
name="search", default=False,
help=flags_core.help_wrap(
"If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/translate_ende",
help="[default: %(default)s] Directory for where the "
"translate_ende_wmt32k dataset is saved.",
metavar="<DD>")
parser.add_argument(
"--raw_dir", "-rd", type=str, default="/tmp/translate_ende_raw",
help="[default: %(default)s] Path where the raw data will be downloaded "
"and extracted.",
metavar="<RD>")
parser.add_argument(
"--search", action="store_true",
help="If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE)
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
tf.logging.set_verbosity(tf.logging.INFO)
define_data_download_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
This diff is collapsed.
......@@ -18,18 +18,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
# pylint: disable=g-bad-import-order
from six.moves import xrange # pylint: disable=redefined-builtin
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.data_download import VOCAB_FILE
from official.transformer.model import model_params
from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
_DECODE_BATCH_SIZE = 32
_EXTRA_DECODE_LENGTH = 100
......@@ -106,7 +107,8 @@ def translate_file(
if i % batch_size == 0:
batch_num = (i // batch_size) + 1
print("Decoding batch %d out of %d." % (batch_num, num_decode_batches))
tf.logging.info("Decoding batch %d out of %d." %
(batch_num, num_decode_batches))
yield _encode_and_add_eos(line, subtokenizer)
def input_fn():
......@@ -122,10 +124,8 @@ def translate_file(
translations.append(translation)
if print_all_translations:
print("Translating:")
print("\tInput: %s" % sorted_inputs[i])
print("\tOutput: %s\n" % translation)
print("=" * 100)
tf.logging.info("Translating:\n\tInput: %s\n\tOutput: %s" %
(sorted_inputs[i], translation))
# Write translations in the order they appeared in the original file.
if output_file is not None:
......@@ -150,7 +150,7 @@ def translate_text(estimator, subtokenizer, txt):
predictions = estimator.predict(input_fn)
translation = next(predictions)["outputs"]
translation = _trim_and_decode(translation, subtokenizer)
print("Translation of \"%s\": \"%s\"" % (txt, translation))
tf.logging.info("Translation of \"%s\": \"%s\"" % (txt, translation))
def main(unused_argv):
......@@ -166,15 +166,8 @@ def main(unused_argv):
subtokenizer = tokenizer.Subtokenizer(
os.path.join(FLAGS.data_dir, FLAGS.vocab_file))
if FLAGS.params == "base":
params = model_params.TransformerBaseParams
elif FLAGS.params == "big":
params = model_params.TransformerBigParams
else:
raise ValueError("Invalid parameter set defined: %s."
"Expected 'base' or 'big.'" % FLAGS.params)
# Set up estimator and params
params = transformer_main.PARAMS_MAP[FLAGS.param_set]
params.beam_size = _BEAM_SIZE
params.alpha = _ALPHA
params.extra_decode_length = _EXTRA_DECODE_LENGTH
......@@ -201,45 +194,50 @@ def main(unused_argv):
translate_file(estimator, subtokenizer, input_file, output_file)
def define_translate_flags():
"""Define flags used for translation script."""
# Model and vocab file flags
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp/translate_ende",
help=flags_core.help_wrap(
"Directory for where the translate_ende_wmt32k dataset is saved."))
flags.DEFINE_string(
name="vocab_file", short_name="vf", default=VOCAB_FILE,
help=flags_core.help_wrap(
"Name of vocabulary file containing subtokens for subtokenizing the "
"input text or file. This file is expected to be in the directory "
"defined by --data_dir."))
flags.DEFINE_string(
name="model_dir", short_name="md", default="/tmp/transformer_model",
help=flags_core.help_wrap(
"Directory containing Transformer model checkpoints."))
flags.DEFINE_enum(
name="param_set", short_name="mp", default="big",
enum_values=["base", "big"],
help=flags_core.help_wrap(
"Parameter set to use when creating and training the model. The "
"parameters define the input shape (batch size and max length), "
"model configuration (size of embedding, # of hidden layers, etc.), "
"and various other settings. The big parameter set increases the "
"default batch size, embedding/hidden size, and filter size. For a "
"complete list of parameters, please see model/model_params.py."))
flags.DEFINE_string(
name="text", default=None,
help=flags_core.help_wrap(
"Text to translate. Output will be printed to console."))
flags.DEFINE_string(
name="file", default=None,
help=flags_core.help_wrap(
"File containing text to translate. Translation will be printed to "
"console and, if --file_out is provided, saved to an output file."))
flags.DEFINE_string(
name="file_out", default=None,
help=flags_core.help_wrap(
"If --file flag is specified, save translation to this file."))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Model arguments
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/data/translate_ende",
help="[default: %(default)s] Directory where vocab file is stored.",
metavar="<DD>")
parser.add_argument(
"--vocab_file", "-vf", type=str, default=VOCAB_FILE,
help="[default: %(default)s] Name of vocabulary file.",
metavar="<vf>")
parser.add_argument(
"--model_dir", "-md", type=str, default="/tmp/transformer_model",
help="[default: %(default)s] Directory containing Transformer model "
"checkpoints.",
metavar="<MD>")
parser.add_argument(
"--params", "-p", type=str, default="big", choices=["base", "big"],
help="[default: %(default)s] Parameter used for trained model.",
metavar="<P>")
# Flags for specifying text/file to be translated.
parser.add_argument(
"--text", "-t", type=str, default=None,
help="[default: %(default)s] Text to translate. Output will be printed "
"to console.",
metavar="<T>")
parser.add_argument(
"--file", "-f", type=str, default=None,
help="[default: %(default)s] File containing text to translate. "
"Translation will be printed to console and, if --file_out is "
"provided, saved to an output file.",
metavar="<F>")
parser.add_argument(
"--file_out", "-fo", type=str, default=None,
help="[default: %(default)s] If --file flag is specified, save "
"translation to this file.",
metavar="<FO>")
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
define_translate_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
......@@ -190,14 +190,14 @@ def _batch_examples(dataset, batch_size, max_length):
def _read_and_batch_from_files(
file_pattern, batch_size, max_length, num_cpu_cores, shuffle, repeat):
file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat):
"""Create dataset where each item is a dict of "inputs" and "targets".
Args:
file_pattern: String used to match the input TFRecord files.
batch_size: Maximum number of tokens per batch of examples
max_length: Maximum number of tokens per example
num_cpu_cores: Number of cpu cores for parallel input processing.
num_parallel_calls: Number of cpu cores for parallel input processing.
shuffle: If true, randomizes order of elements.
repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever.
......@@ -215,12 +215,12 @@ def _read_and_batch_from_files(
# will be non-deterministic.
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
_load_records, sloppy=shuffle, cycle_length=num_cpu_cores))
_load_records, sloppy=shuffle, cycle_length=num_parallel_calls))
# Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization.
dataset = dataset.map(_parse_example,
num_parallel_calls=num_cpu_cores)
num_parallel_calls=num_parallel_calls)
# Remove examples where the input or target length exceeds the maximum length,
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
......@@ -238,13 +238,13 @@ def train_input_fn(params):
"""Load and return dataset of batched examples for use during training."""
file_pattern = os.path.join(getattr(params, "data_dir", ""), "*train*")
return _read_and_batch_from_files(
file_pattern, params.batch_size, params.max_length, params.num_cpu_cores,
shuffle=True, repeat=params.repeat_dataset)
file_pattern, params.batch_size, params.max_length,
params.num_parallel_calls, shuffle=True, repeat=params.repeat_dataset)
def eval_input_fn(params):
"""Load and return dataset of batched examples for use during evaluation."""
file_pattern = os.path.join(getattr(params, "data_dir", ""), "*dev*")
return _read_and_batch_from_files(
file_pattern, params.batch_size, params.max_length, params.num_cpu_cores,
shuffle=False, repeat=1)
file_pattern, params.batch_size, params.max_length,
params.num_parallel_calls, shuffle=False, repeat=1)
......@@ -89,8 +89,9 @@ def define_flags():
flags.adopt_key_flags(flags_core)
def main(flags_obj):
pass
def main(_):
flags_obj = flags.FLAGS
print(flags_obj)
if __name__ == "__main__"
......@@ -120,3 +121,32 @@ class BaseTester(unittest.TestCase):
self.AssertEqual(flags.FLAGS.test_flag, "def")
```
## Immutability
Flag values should not be mutated. Instead, use getter functions to return
the desired values. An example getter function is `get_loss_scale` function
below:
```
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
}
def get_loss_scale(flags_obj):
if flags_obj.loss_scale is not None:
return flags_obj.loss_scale
return DTYPE_MAP[flags_obj.dtype][1]
def main(_):
flags_obj = flags.FLAGS()
# Do not mutate flags_obj
# if flags_obj.loss_scale is None:
# flags_obj.loss_scale = DTYPE_MAP[flags_obj.dtype][1] # Don't do this
print(get_loss_scale(flags_obj))
...
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment