Unverified Commit a84e1ef9 authored by Katherine Wu's avatar Katherine Wu Committed by GitHub
Browse files

Add official flag-parsing and benchmarking logging utils to Transformer (#4163)

parent fe1857cd
......@@ -14,7 +14,7 @@ The model also applies embeddings on the input and output tokens, and adds a con
* [Training times](#training-times)
* [Evaluation results](#evaluation-results)
* [Detailed instructions](#detailed-instructions)
* [Export variables (optional)](#export-variables-optional)
* [Environment preparation](#environment-preparation)
* [Download and preprocess datasets](#download-and-preprocess-datasets)
* [Model training and evaluation](#model-training-and-evaluation)
* [Translate using the model](#translate-using-the-model)
......@@ -31,16 +31,23 @@ The model also applies embeddings on the input and output tokens, and adds a con
Below are the commands for running the Transformer model. See the [Detailed instrutions](#detailed-instructions) for more details on running the model.
```
PARAMS=big
cd /path/to/models/official/transformer
# Ensure that PYTHONPATH is correctly defined as described in
# https://github.com/tensorflow/models/tree/master/official#running-the-models
# export PYTHONPATH="$PYTHONPATH:/path/to/models"
# Export variables
PARAM_SET=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAMS
MODEL_DIR=$HOME/transformer/model_$PARAM_SET
# Download training/evaluation datasets
python data_download.py --data_dir=$DATA_DIR
# Train the model for 10 epochs, and evaluate after every epoch.
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de
--param_set=$PARAM_SET --bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de
# Run during training in a separate process to get continuous updates,
# or after training is complete.
......@@ -48,21 +55,21 @@ tensorboard --logdir=$MODEL_DIR
# Translate some text using the trained model
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --text="hello world"
--param_set=$PARAM_SET --text="hello world"
# Compute model's BLEU score using the newstest2014 dataset.
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en
--param_set=$PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de
```
## Benchmarks
### Training times
Currently, both big and base params run on a single GPU. The measurements below
Currently, both big and base parameter sets run on a single GPU. The measurements below
are reported from running the model on a P100 GPU.
Params | batches/sec | batches per epoch | time per epoch
Param Set | batches/sec | batches per epoch | time per epoch
--- | --- | --- | ---
base | 4.8 | 83244 | 4 hr
big | 1.1 | 41365 | 10 hr
......@@ -70,7 +77,7 @@ big | 1.1 | 41365 | 10 hr
### Evaluation results
Below are the case-insensitive BLEU scores after 10 epochs.
Params | Score
Param Set | Score
--- | --- |
base | 27.7
big | 28.9
......@@ -79,13 +86,18 @@ big | 28.9
## Detailed instructions
0. ### Export variables (optional)
0. ### Environment preparation
#### Add models repo to PYTHONPATH
Follow the instructions described in the [Running the models](https://github.com/tensorflow/models/tree/master/official#running-the-models) section to add the models folder to the python path.
#### Export variables (optional)
Export the following variables, or modify the values in each of the snippets below:
```
PARAMS=big
PARAM_SET=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAMS
MODEL_DIR=$HOME/transformer/model_$PARAM_SET
```
1. ### Download and preprocess datasets
......@@ -109,13 +121,13 @@ big | 28.9
Command to run:
```
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --param_set=$PARAM_SET
```
Arguments:
* `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
* `--model_dir`: Directory to save Transformer model training checkpoints.
* `--params`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
* `--param_set`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
* Use the `--help` or `-h` flag to get a full list of possible arguments.
#### Customizing training schedule
......@@ -123,12 +135,12 @@ big | 28.9
By default, the model will train for 10 epochs, and evaluate after every epoch. The training schedule may be defined through the flags:
* Training with epochs (default):
* `--train_epochs`: The total number of complete passes to make through the dataset
* `--epochs_between_eval`: The number of epochs to train between evaluations.
* `--epochs_between_evals`: The number of epochs to train between evaluations.
* Training with steps:
* `--train_steps`: sets the total number of training steps to run.
* `--steps_between_eval`: Number of training steps to run between evaluations.
* `--steps_between_evals`: Number of training steps to run between evaluations.
Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_eval=1000`.
Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_evals=1000`.
Note: At the beginning of each training session, the training dataset is reloaded and shuffled. Stopping the training before completing an epoch may result in worse model quality, due to the chance that some examples may be seen more than others. Therefore, it is recommended to use epochs when the model quality is important.
......@@ -137,7 +149,7 @@ big | 28.9
Use these flags to compute the BLEU when the model evaluates:
* `--bleu_source`: Path to file containing text to translate.
* `--bleu_ref`: Path to file containing the reference translation.
* `--bleu_threshold`: Train until the BLEU score reaches this lower bound. This setting overrides the `--train_steps` and `--train_epochs` flags.
* `--stop_threshold`: Train until the BLEU score reaches this lower bound. This setting overrides the `--train_steps` and `--train_epochs` flags.
The test source and reference files located in the `test_data` directory are extracted from the preprocessed dataset from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data).
......@@ -155,12 +167,12 @@ big | 28.9
Command to run:
```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS --text="hello world"
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --param_set=PARAM_SET --text="hello world"
```
Arguments for initializing the Subtokenizer and trained model:
* `--data_dir`: Used to locate the vocabulary file to create a Subtokenizer, which encodes the input and decodes the model output.
* `--model_dir` and `--params`: These parameters are used to rebuild the trained model
* `--model_dir` and `--param_set`: These parameters are used to rebuild the trained model
Arguments for specifying what to translate:
* `--text`: Text to translate
......@@ -170,7 +182,7 @@ big | 28.9
To translate the newstest2014 data, run:
```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en
--param_set=PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
```
Translating the file takes around 15 minutes on a GTX1080, or 5 minutes on a P100.
......
......@@ -22,17 +22,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import re
import sys
import unicodedata
# pylint: disable=g-bad-import-order
import six
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.utils import metrics
from official.utils.flags import core as flags_core
class UnicodeRegex(object):
......@@ -99,31 +101,37 @@ def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
def main(unused_argv):
if FLAGS.bleu_variant is None or "uncased" in FLAGS.bleu_variant:
if FLAGS.bleu_variant in ("both", "uncased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
print("Case-insensitive results:", score)
tf.logging.info("Case-insensitive results: %f" % score)
if FLAGS.bleu_variant is None or "cased" in FLAGS.bleu_variant:
if FLAGS.bleu_variant in ("both", "cased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
print("Case-sensitive results:", score)
tf.logging.info("Case-sensitive results: %f" % score)
def define_compute_bleu_flags():
"""Add flags for computing BLEU score."""
flags.DEFINE_string(
name="translation", default=None,
help=flags_core.help_wrap("File containing translated text."))
flags.mark_flag_as_required("translation")
flags.DEFINE_string(
name="reference", default=None,
help=flags_core.help_wrap("File containing reference translation."))
flags.mark_flag_as_required("reference")
flags.DEFINE_enum(
name="bleu_variant", short_name="bv", default="both",
enum_values=["both", "uncased", "cased"], case_sensitive=False,
help=flags_core.help_wrap(
"Specify one or more BLEU variants to calculate. Variants: \"cased\""
", \"uncased\", or \"both\"."))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--translation", "-t", type=str, default=None, required=True,
help="[default: %(default)s] File containing translated text.",
metavar="<T>")
parser.add_argument(
"--reference", "-r", type=str, default=None, required=True,
help="[default: %(default)s] File containing reference translation",
metavar="<R>")
parser.add_argument(
"--bleu_variant", "-bv", type=str, choices=["uncased", "cased"],
nargs="*", default=None,
help="Specify one or more BLEU variants to calculate (both are "
"calculated by default. Variants: \"cased\" or \"uncased\".",
metavar="<BV>")
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
tf.logging.set_verbosity(tf.logging.INFO)
define_compute_bleu_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
......@@ -18,19 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import random
import sys
import tarfile
import urllib
# pylint: disable=g-bad-import-order
import six
from six.moves import urllib
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
# Data sources for training/evaluating the transformer translation model.
# If any of the training sources are changed, then either:
......@@ -156,7 +157,7 @@ def download_from_url(path, url):
filename = os.path.join(path, filename)
tf.logging.info("Downloading from %s to %s." % (url, filename))
inprogress_filepath = filename + ".incomplete"
inprogress_filepath, _ = urllib.urlretrieve(
inprogress_filepath, _ = urllib.request.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress.
print()
......@@ -302,7 +303,7 @@ def encode_and_save_files(
for tmp_name, final_name in zip(tmp_filepaths, filepaths):
tf.gfile.Rename(tmp_name, final_name)
tf.logging.info("Saved %d Examples", counter)
tf.logging.info("Saved %d Examples", counter + 1)
return filepaths
......@@ -363,8 +364,6 @@ def make_dir(path):
def main(unused_argv):
"""Obtain training and evaluation data for the Transformer model."""
tf.logging.set_verbosity(tf.logging.INFO)
make_dir(FLAGS.raw_dir)
make_dir(FLAGS.data_dir)
......@@ -398,22 +397,25 @@ def main(unused_argv):
shuffle_records(fname)
def define_data_download_flags():
"""Add flags specifying data download arguments."""
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp/translate_ende",
help=flags_core.help_wrap(
"Directory for where the translate_ende_wmt32k dataset is saved."))
flags.DEFINE_string(
name="raw_dir", short_name="rd", default="/tmp/translate_ende_raw",
help=flags_core.help_wrap(
"Path where the raw data will be downloaded and extracted."))
flags.DEFINE_bool(
name="search", default=False,
help=flags_core.help_wrap(
"If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/translate_ende",
help="[default: %(default)s] Directory for where the "
"translate_ende_wmt32k dataset is saved.",
metavar="<DD>")
parser.add_argument(
"--raw_dir", "-rd", type=str, default="/tmp/translate_ende_raw",
help="[default: %(default)s] Path where the raw data will be downloaded "
"and extracted.",
metavar="<RD>")
parser.add_argument(
"--search", action="store_true",
help="If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE)
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
tf.logging.set_verbosity(tf.logging.INFO)
define_data_download_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
......@@ -12,19 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Creates an estimator to train the Transformer model."""
"""Train and evaluate the Transformer model.
See README for description of setting the training schedule and evaluating the
BLEU score.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tempfile
# pylint: disable=g-bad-import-order
from six.moves import xrange # pylint: disable=redefined-builtin
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
......@@ -36,11 +40,26 @@ from official.transformer.model import transformer
from official.transformer.utils import dataset
from official.transformer.utils import metrics
from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import model_helpers
PARAMS_MAP = {
"base": model_params.TransformerBaseParams,
"big": model_params.TransformerBigParams,
}
DEFAULT_TRAIN_EPOCHS = 10
BLEU_DIR = "bleu"
INF = int(1e9)
# Dictionary containing tensors that are logged by the logging hooks. Each item
# maps a string to the tensor name.
TENSORS_TO_LOG = {
"learning_rate": "model/get_train_op/learning_rate/learning_rate",
"cross_entropy_loss": "model/cross_entropy"}
def model_fn(features, labels, mode, params):
"""Defines how to train, evaluate and predict from the transformer model."""
......@@ -66,6 +85,9 @@ def model_fn(features, labels, mode, params):
logits, targets, params.label_smoothing, params.vocab_size)
loss = tf.reduce_sum(xentropy * weights) / tf.reduce_sum(weights)
# Save loss as named tensor that will be logged with the logging hook.
tf.identity(loss, "cross_entropy")
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, predictions={"predictions": logits},
......@@ -87,6 +109,10 @@ def get_learning_rate(learning_rate, hidden_size, learning_rate_warmup_steps):
# Apply rsqrt decay
learning_rate *= tf.rsqrt(tf.maximum(step, warmup_steps))
# Create a named tensor that will be logged using the logging hook.
# The full name includes variable and names scope. In this case, the name
# is model/get_train_op/learning_rate/learning_rate
tf.identity(learning_rate, "learning_rate")
# Save learning rate value to TensorBoard summary.
tf.summary.scalar("learning_rate", learning_rate)
......@@ -145,31 +171,22 @@ def get_global_step(estimator):
return int(estimator.latest_checkpoint().split("-")[-1])
def evaluate_and_log_bleu(estimator, bleu_writer, bleu_source, bleu_ref):
def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file_path):
"""Calculate and record the BLEU score."""
subtokenizer = tokenizer.Subtokenizer(
os.path.join(FLAGS.data_dir, FLAGS.vocab_file))
subtokenizer = tokenizer.Subtokenizer(vocab_file_path)
uncased_score, cased_score = translate_and_compute_bleu(
estimator, subtokenizer, bleu_source, bleu_ref)
print("Bleu score (uncased):", uncased_score)
print("Bleu score (cased):", cased_score)
summary = tf.Summary(value=[
tf.Summary.Value(tag="bleu/uncased", simple_value=uncased_score),
tf.Summary.Value(tag="bleu/cased", simple_value=cased_score),
])
bleu_writer.add_summary(summary, get_global_step(estimator))
bleu_writer.flush()
tf.logging.info("Bleu score (uncased):", uncased_score)
tf.logging.info("Bleu score (cased):", cased_score)
return uncased_score, cased_score
def train_schedule(
estimator, train_eval_iterations, single_iteration_train_steps=None,
single_iteration_train_epochs=None, bleu_source=None, bleu_ref=None,
bleu_threshold=None):
single_iteration_train_epochs=None, train_hooks=None, benchmark_logger=None,
bleu_source=None, bleu_ref=None, bleu_threshold=None, vocab_file_path=None):
"""Train and evaluate model, and optionally compute model's BLEU score.
**Step vs. Epoch vs. Iteration**
......@@ -198,9 +215,12 @@ def train_schedule(
train_eval_iterations: Number of times to repeat the train+eval iteration.
single_iteration_train_steps: Number of steps to train in one iteration.
single_iteration_train_epochs: Number of epochs to train in one iteration.
train_hooks: List of hooks to pass to the estimator during training.
benchmark_logger: a BenchmarkLogger object that logs evaluation data
bleu_source: File containing text to be translated for BLEU calculation.
bleu_ref: File containing reference translations for BLEU calculation.
bleu_threshold: minimum BLEU score before training is stopped.
vocab_file_path: Path to vocabulary file used to subtokenize bleu_source.
Raises:
ValueError: if both or none of single_iteration_train_steps and
......@@ -221,22 +241,24 @@ def train_schedule(
evaluate_bleu = bleu_source is not None and bleu_ref is not None
# Print out training schedule
print("Training schedule:")
# Print details of training schedule.
tf.logging.info("Training schedule:")
if single_iteration_train_epochs is not None:
print("\t1. Train for %d epochs." % single_iteration_train_epochs)
tf.logging.info("\t1. Train for %d epochs." % single_iteration_train_epochs)
else:
print("\t1. Train for %d steps." % single_iteration_train_steps)
print("\t2. Evaluate model.")
tf.logging.info("\t1. Train for %d steps." % single_iteration_train_steps)
tf.logging.info("\t2. Evaluate model.")
if evaluate_bleu:
print("\t3. Compute BLEU score.")
tf.logging.info("\t3. Compute BLEU score.")
if bleu_threshold is not None:
print("Repeat above steps until the BLEU score reaches", bleu_threshold)
tf.logging.info("Repeat above steps until the BLEU score reaches %f" %
bleu_threshold)
if not evaluate_bleu or bleu_threshold is None:
print("Repeat above steps %d times." % train_eval_iterations)
tf.logging.info("Repeat above steps %d times." % train_eval_iterations)
if evaluate_bleu:
# Set summary writer to log bleu score.
# Create summary writer to log bleu score (values can be displayed in
# Tensorboard).
bleu_writer = tf.summary.FileWriter(
os.path.join(estimator.model_dir, BLEU_DIR))
if bleu_threshold is not None:
......@@ -245,145 +267,201 @@ def train_schedule(
# Loop training/evaluation/bleu cycles
for i in xrange(train_eval_iterations):
print("Starting iteration", i + 1)
tf.logging.info("Starting iteration %d" % (i + 1))
# Train the model for single_iteration_train_steps or until the input fn
# runs out of examples (if single_iteration_train_steps is None).
estimator.train(dataset.train_input_fn, steps=single_iteration_train_steps)
estimator.train(
dataset.train_input_fn, steps=single_iteration_train_steps,
hooks=train_hooks)
eval_results = estimator.evaluate(dataset.eval_input_fn)
print("Evaluation results (iter %d/%d):" % (i + 1, train_eval_iterations),
eval_results)
tf.logging.info("Evaluation results (iter %d/%d):" %
(i + 1, train_eval_iterations))
tf.logging.info(eval_results)
benchmark_logger.log_evaluation_result(eval_results)
# The results from estimator.evaluate() are measured on an approximate
# translation, which utilize the target golden values provided. The actual
# bleu score must be computed using the estimator.predict() path, which
# outputs translations that are not based on golden values. The translations
# are compared to reference file to get the actual bleu score.
if evaluate_bleu:
uncased_score, _ = evaluate_and_log_bleu(
estimator, bleu_writer, bleu_source, bleu_ref)
if bleu_threshold is not None and uncased_score > bleu_threshold:
uncased_score, cased_score = evaluate_and_log_bleu(
estimator, bleu_source, bleu_ref, vocab_file_path)
# Write actual bleu scores using summary writer and benchmark logger
global_step = get_global_step(estimator)
summary = tf.Summary(value=[
tf.Summary.Value(tag="bleu/uncased", simple_value=uncased_score),
tf.Summary.Value(tag="bleu/cased", simple_value=cased_score),
])
bleu_writer.add_summary(summary, global_step)
bleu_writer.flush()
benchmark_logger.log_metric(
"bleu_uncased", uncased_score, global_step=global_step)
benchmark_logger.log_metric(
"bleu_cased", cased_score, global_step=global_step)
# Stop training if bleu stopping threshold is met.
if model_helpers.past_stop_threshold(bleu_threshold, uncased_score):
bleu_writer.close()
break
def main(_):
# Set logging level to INFO to display training progress (logged by the
# estimator)
tf.logging.set_verbosity(tf.logging.INFO)
def define_transformer_flags():
"""Add flags and flag validators for running transformer_main."""
# Add common flags (data_dir, model_dir, train_epochs, etc.).
flags_core.define_base(multi_gpu=False, num_gpu=False, export_dir=False)
flags_core.define_performance(
num_parallel_calls=True,
inter_op=False,
intra_op=False,
synthetic_data=False,
max_train_steps=False,
dtype=False
)
flags_core.define_benchmark()
# Set flags from the flags_core module as "key flags" so they're listed when
# the '-h' flag is used. Without this line, the flags defined above are
# only shown in the full `--helpful` help text.
flags.adopt_module_key_flags(flags_core)
# Add transformer-specific flags
flags.DEFINE_enum(
name="param_set", short_name="mp", default="big",
enum_values=["base", "big"],
help=flags_core.help_wrap(
"Parameter set to use when creating and training the model. The "
"parameters define the input shape (batch size and max length), "
"model configuration (size of embedding, # of hidden layers, etc.), "
"and various other settings. The big parameter set increases the "
"default batch size, embedding/hidden size, and filter size. For a "
"complete list of parameters, please see model/model_params.py."))
if FLAGS.params == "base":
params = model_params.TransformerBaseParams
elif FLAGS.params == "big":
params = model_params.TransformerBigParams
else:
raise ValueError("Invalid parameter set defined: %s."
"Expected 'base' or 'big.'" % FLAGS.params)
# Flags for training with steps (may be used for debugging)
flags.DEFINE_integer(
name="train_steps", short_name="ts", default=None,
help=flags_core.help_wrap("The number of steps used to train."))
flags.DEFINE_integer(
name="steps_between_evals", short_name="sbe", default=1000,
help=flags_core.help_wrap(
"The Number of training steps to run between evaluations. This is "
"used if --train_steps is defined."))
# BLEU score computation
flags.DEFINE_string(
name="bleu_source", short_name="bls", default=None,
help=flags_core.help_wrap(
"Path to source file containing text translate when calculating the "
"official BLEU score. --bleu_source, --bleu_ref, and --vocab_file "
"must be set. Use the flag --stop_threshold to stop the script based "
"on the uncased BLEU score."))
flags.DEFINE_string(
name="bleu_ref", short_name="blr", default=None,
help=flags_core.help_wrap(
"Path to source file containing text translate when calculating the "
"official BLEU score. --bleu_source, --bleu_ref, and --vocab_file "
"must be set. Use the flag --stop_threshold to stop the script based "
"on the uncased BLEU score."))
flags.DEFINE_string(
name="vocab_file", short_name="vf", default=VOCAB_FILE,
help=flags_core.help_wrap(
"Name of vocabulary file containing subtokens for subtokenizing the "
"bleu_source file. This file is expected to be in the directory "
"defined by --data_dir."))
flags_core.set_defaults(data_dir="/tmp/translate_ende",
model_dir="/tmp/transformer_model",
batch_size=None,
train_epochs=None)
@flags.multi_flags_validator(
["train_epochs", "train_steps"],
message="Both --train_steps and --train_epochs were set. Only one may be "
"defined.")
def _check_train_limits(flag_dict):
return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None
@flags.multi_flags_validator(
["data_dir", "bleu_source", "bleu_ref", "vocab_file"],
message="--bleu_source, --bleu_ref, and/or --vocab_file don't exist. "
"Please ensure that the file paths are correct.")
def _check_bleu_files(flags_dict):
"""Validate files when bleu_source and bleu_ref are defined."""
if flags_dict["bleu_source"] is None or flags_dict["bleu_ref"] is None:
return True
# Ensure that bleu_source, bleu_ref, and vocab files exist.
vocab_file_path = os.path.join(
flags_dict["data_dir"], flags_dict["vocab_file"])
return all([
tf.gfile.Exists(flags_dict["bleu_source"]),
tf.gfile.Exists(flags_dict["bleu_ref"]),
tf.gfile.Exists(vocab_file_path)])
def run_transformer(flags_obj):
"""Create tf.Estimator to train and evaluate transformer model.
Args:
flags_obj: Object containing parsed flag values.
"""
# Determine training schedule based on flags.
if FLAGS.train_steps is not None and FLAGS.train_epochs is not None:
raise ValueError("Both --train_steps and --train_epochs were set. Only one "
"may be defined.")
if FLAGS.train_steps is not None:
train_eval_iterations = FLAGS.train_steps // FLAGS.steps_between_eval
single_iteration_train_steps = FLAGS.steps_between_eval
if flags_obj.train_steps is not None:
train_eval_iterations = (
flags_obj.train_steps // flags_obj.steps_between_evals)
single_iteration_train_steps = flags_obj.steps_between_evals
single_iteration_train_epochs = None
else:
if FLAGS.train_epochs is None:
FLAGS.train_epochs = DEFAULT_TRAIN_EPOCHS
train_eval_iterations = FLAGS.train_epochs // FLAGS.epochs_between_eval
train_epochs = flags_obj.train_epochs or DEFAULT_TRAIN_EPOCHS
train_eval_iterations = train_epochs // flags_obj.epochs_between_evals
single_iteration_train_steps = None
single_iteration_train_epochs = FLAGS.epochs_between_eval
# Make sure that the BLEU source and ref files if set
if FLAGS.bleu_source is not None and FLAGS.bleu_ref is not None:
if not tf.gfile.Exists(FLAGS.bleu_source):
raise ValueError("BLEU source file %s does not exist" % FLAGS.bleu_source)
if not tf.gfile.Exists(FLAGS.bleu_ref):
raise ValueError("BLEU source file %s does not exist" % FLAGS.bleu_ref)
single_iteration_train_epochs = flags_obj.epochs_between_evals
# Add flag-defined parameters to params object
params.data_dir = FLAGS.data_dir
params.num_cpu_cores = FLAGS.num_cpu_cores
params.epochs_between_eval = FLAGS.epochs_between_eval
params = PARAMS_MAP[flags_obj.param_set]
params.data_dir = flags_obj.data_dir
params.num_parallel_calls = flags_obj.num_parallel_calls
params.epochs_between_evals = flags_obj.epochs_between_evals
params.repeat_dataset = single_iteration_train_epochs
params.batch_size = flags_obj.batch_size or params.batch_size
# Create hooks that log information about the training and metric values
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
tensors_to_log=TENSORS_TO_LOG, # used for logging hooks
batch_size=params.batch_size # for ExamplesPerSecondHook
)
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
benchmark_logger.log_run_info(
model_name="transformer",
dataset_name="wmt_translate_ende",
run_params=params.__dict__)
# Train and evaluate transformer model
estimator = tf.estimator.Estimator(
model_fn=model_fn, model_dir=FLAGS.model_dir, params=params)
model_fn=model_fn, model_dir=flags_obj.model_dir, params=params)
train_schedule(
estimator, train_eval_iterations, single_iteration_train_steps,
single_iteration_train_epochs, FLAGS.bleu_source, FLAGS.bleu_ref,
FLAGS.bleu_threshold)
estimator=estimator,
# Training arguments
train_eval_iterations=train_eval_iterations,
single_iteration_train_steps=single_iteration_train_steps,
single_iteration_train_epochs=single_iteration_train_epochs,
train_hooks=train_hooks,
benchmark_logger=benchmark_logger,
# BLEU calculation arguments
bleu_source=flags_obj.bleu_source,
bleu_ref=flags_obj.bleu_ref,
bleu_threshold=flags_obj.stop_threshold,
vocab_file_path=os.path.join(flags_obj.data_dir, flags_obj.vocab_file))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/translate_ende",
help="[default: %(default)s] Directory containing training and "
"evaluation data, and vocab file used for encoding.",
metavar="<DD>")
parser.add_argument(
"--vocab_file", "-vf", type=str, default=VOCAB_FILE,
help="[default: %(default)s] Name of vocabulary file.",
metavar="<vf>")
parser.add_argument(
"--model_dir", "-md", type=str, default="/tmp/transformer_model",
help="[default: %(default)s] Directory to save Transformer model "
"training checkpoints",
metavar="<MD>")
parser.add_argument(
"--params", "-p", type=str, default="big", choices=["base", "big"],
help="[default: %(default)s] Parameter set to use when creating and "
"training the model.",
metavar="<P>")
parser.add_argument(
"--num_cpu_cores", "-nc", type=int, default=4,
help="[default: %(default)s] Number of CPU cores to use in the input "
"pipeline.",
metavar="<NC>")
# Flags for training with epochs. (default)
parser.add_argument(
"--train_epochs", "-te", type=int, default=None,
help="The number of epochs used to train. If both --train_epochs and "
"--train_steps are not set, the model will train for %d epochs." %
DEFAULT_TRAIN_EPOCHS,
metavar="<TE>")
parser.add_argument(
"--epochs_between_eval", "-ebe", type=int, default=1,
help="[default: %(default)s] The number of training epochs to run "
"between evaluations.",
metavar="<TE>")
def main(_):
run_transformer(flags.FLAGS)
# Flags for training with steps (may be used for debugging)
parser.add_argument(
"--train_steps", "-ts", type=int, default=None,
help="Total number of training steps. If both --train_epochs and "
"--train_steps are not set, the model will train for %d epochs." %
DEFAULT_TRAIN_EPOCHS,
metavar="<TS>")
parser.add_argument(
"--steps_between_eval", "-sbe", type=int, default=1000,
help="[default: %(default)s] Number of training steps to run between "
"evaluations.",
metavar="<SBE>")
# BLEU score computation
parser.add_argument(
"--bleu_source", "-bs", type=str, default=None,
help="Path to source file containing text translate when calculating the "
"official BLEU score. Both --bleu_source and --bleu_ref must be "
"set. The BLEU score will be calculated during model evaluation.",
metavar="<BS>")
parser.add_argument(
"--bleu_ref", "-br", type=str, default=None,
help="Path to file containing the reference translation for calculating "
"the official BLEU score. Both --bleu_source and --bleu_ref must be "
"set. The BLEU score will be calculated during model evaluation.",
metavar="<BR>")
parser.add_argument(
"--bleu_threshold", "-bt", type=float, default=None,
help="Stop training when the uncased BLEU score reaches this value. "
"Setting this overrides the total number of steps or epochs set by "
"--train_steps or --train_epochs.",
metavar="<BT>")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
define_transformer_flags()
absl_app.run(main)
......@@ -18,18 +18,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
# pylint: disable=g-bad-import-order
from six.moves import xrange # pylint: disable=redefined-builtin
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.data_download import VOCAB_FILE
from official.transformer.model import model_params
from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
_DECODE_BATCH_SIZE = 32
_EXTRA_DECODE_LENGTH = 100
......@@ -106,7 +107,8 @@ def translate_file(
if i % batch_size == 0:
batch_num = (i // batch_size) + 1
print("Decoding batch %d out of %d." % (batch_num, num_decode_batches))
tf.logging.info("Decoding batch %d out of %d." %
(batch_num, num_decode_batches))
yield _encode_and_add_eos(line, subtokenizer)
def input_fn():
......@@ -122,10 +124,8 @@ def translate_file(
translations.append(translation)
if print_all_translations:
print("Translating:")
print("\tInput: %s" % sorted_inputs[i])
print("\tOutput: %s\n" % translation)
print("=" * 100)
tf.logging.info("Translating:\n\tInput: %s\n\tOutput: %s" %
(sorted_inputs[i], translation))
# Write translations in the order they appeared in the original file.
if output_file is not None:
......@@ -150,7 +150,7 @@ def translate_text(estimator, subtokenizer, txt):
predictions = estimator.predict(input_fn)
translation = next(predictions)["outputs"]
translation = _trim_and_decode(translation, subtokenizer)
print("Translation of \"%s\": \"%s\"" % (txt, translation))
tf.logging.info("Translation of \"%s\": \"%s\"" % (txt, translation))
def main(unused_argv):
......@@ -166,15 +166,8 @@ def main(unused_argv):
subtokenizer = tokenizer.Subtokenizer(
os.path.join(FLAGS.data_dir, FLAGS.vocab_file))
if FLAGS.params == "base":
params = model_params.TransformerBaseParams
elif FLAGS.params == "big":
params = model_params.TransformerBigParams
else:
raise ValueError("Invalid parameter set defined: %s."
"Expected 'base' or 'big.'" % FLAGS.params)
# Set up estimator and params
params = transformer_main.PARAMS_MAP[FLAGS.param_set]
params.beam_size = _BEAM_SIZE
params.alpha = _ALPHA
params.extra_decode_length = _EXTRA_DECODE_LENGTH
......@@ -201,45 +194,50 @@ def main(unused_argv):
translate_file(estimator, subtokenizer, input_file, output_file)
def define_translate_flags():
"""Define flags used for translation script."""
# Model and vocab file flags
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp/translate_ende",
help=flags_core.help_wrap(
"Directory for where the translate_ende_wmt32k dataset is saved."))
flags.DEFINE_string(
name="vocab_file", short_name="vf", default=VOCAB_FILE,
help=flags_core.help_wrap(
"Name of vocabulary file containing subtokens for subtokenizing the "
"input text or file. This file is expected to be in the directory "
"defined by --data_dir."))
flags.DEFINE_string(
name="model_dir", short_name="md", default="/tmp/transformer_model",
help=flags_core.help_wrap(
"Directory containing Transformer model checkpoints."))
flags.DEFINE_enum(
name="param_set", short_name="mp", default="big",
enum_values=["base", "big"],
help=flags_core.help_wrap(
"Parameter set to use when creating and training the model. The "
"parameters define the input shape (batch size and max length), "
"model configuration (size of embedding, # of hidden layers, etc.), "
"and various other settings. The big parameter set increases the "
"default batch size, embedding/hidden size, and filter size. For a "
"complete list of parameters, please see model/model_params.py."))
flags.DEFINE_string(
name="text", default=None,
help=flags_core.help_wrap(
"Text to translate. Output will be printed to console."))
flags.DEFINE_string(
name="file", default=None,
help=flags_core.help_wrap(
"File containing text to translate. Translation will be printed to "
"console and, if --file_out is provided, saved to an output file."))
flags.DEFINE_string(
name="file_out", default=None,
help=flags_core.help_wrap(
"If --file flag is specified, save translation to this file."))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Model arguments
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/data/translate_ende",
help="[default: %(default)s] Directory where vocab file is stored.",
metavar="<DD>")
parser.add_argument(
"--vocab_file", "-vf", type=str, default=VOCAB_FILE,
help="[default: %(default)s] Name of vocabulary file.",
metavar="<vf>")
parser.add_argument(
"--model_dir", "-md", type=str, default="/tmp/transformer_model",
help="[default: %(default)s] Directory containing Transformer model "
"checkpoints.",
metavar="<MD>")
parser.add_argument(
"--params", "-p", type=str, default="big", choices=["base", "big"],
help="[default: %(default)s] Parameter used for trained model.",
metavar="<P>")
# Flags for specifying text/file to be translated.
parser.add_argument(
"--text", "-t", type=str, default=None,
help="[default: %(default)s] Text to translate. Output will be printed "
"to console.",
metavar="<T>")
parser.add_argument(
"--file", "-f", type=str, default=None,
help="[default: %(default)s] File containing text to translate. "
"Translation will be printed to console and, if --file_out is "
"provided, saved to an output file.",
metavar="<F>")
parser.add_argument(
"--file_out", "-fo", type=str, default=None,
help="[default: %(default)s] If --file flag is specified, save "
"translation to this file.",
metavar="<FO>")
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
define_translate_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
......@@ -190,14 +190,14 @@ def _batch_examples(dataset, batch_size, max_length):
def _read_and_batch_from_files(
file_pattern, batch_size, max_length, num_cpu_cores, shuffle, repeat):
file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat):
"""Create dataset where each item is a dict of "inputs" and "targets".
Args:
file_pattern: String used to match the input TFRecord files.
batch_size: Maximum number of tokens per batch of examples
max_length: Maximum number of tokens per example
num_cpu_cores: Number of cpu cores for parallel input processing.
num_parallel_calls: Number of cpu cores for parallel input processing.
shuffle: If true, randomizes order of elements.
repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever.
......@@ -215,12 +215,12 @@ def _read_and_batch_from_files(
# will be non-deterministic.
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
_load_records, sloppy=shuffle, cycle_length=num_cpu_cores))
_load_records, sloppy=shuffle, cycle_length=num_parallel_calls))
# Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization.
dataset = dataset.map(_parse_example,
num_parallel_calls=num_cpu_cores)
num_parallel_calls=num_parallel_calls)
# Remove examples where the input or target length exceeds the maximum length,
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
......@@ -238,13 +238,13 @@ def train_input_fn(params):
"""Load and return dataset of batched examples for use during training."""
file_pattern = os.path.join(getattr(params, "data_dir", ""), "*train*")
return _read_and_batch_from_files(
file_pattern, params.batch_size, params.max_length, params.num_cpu_cores,
shuffle=True, repeat=params.repeat_dataset)
file_pattern, params.batch_size, params.max_length,
params.num_parallel_calls, shuffle=True, repeat=params.repeat_dataset)
def eval_input_fn(params):
"""Load and return dataset of batched examples for use during evaluation."""
file_pattern = os.path.join(getattr(params, "data_dir", ""), "*dev*")
return _read_and_batch_from_files(
file_pattern, params.batch_size, params.max_length, params.num_cpu_cores,
shuffle=False, repeat=1)
file_pattern, params.batch_size, params.max_length,
params.num_parallel_calls, shuffle=False, repeat=1)
......@@ -89,8 +89,9 @@ def define_flags():
flags.adopt_key_flags(flags_core)
def main(flags_obj):
pass
def main(_):
flags_obj = flags.FLAGS
print(flags_obj)
if __name__ == "__main__"
......@@ -120,3 +121,32 @@ class BaseTester(unittest.TestCase):
self.AssertEqual(flags.FLAGS.test_flag, "def")
```
## Immutability
Flag values should not be mutated. Instead, use getter functions to return
the desired values. An example getter function is `get_loss_scale` function
below:
```
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
}
def get_loss_scale(flags_obj):
if flags_obj.loss_scale is not None:
return flags_obj.loss_scale
return DTYPE_MAP[flags_obj.dtype][1]
def main(_):
flags_obj = flags.FLAGS()
# Do not mutate flags_obj
# if flags_obj.loss_scale is None:
# flags_obj.loss_scale = DTYPE_MAP[flags_obj.dtype][1] # Don't do this
print(get_loss_scale(flags_obj))
...
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment