Unverified Commit ab1c1dfc authored by Zhang Xunkai's avatar Zhang Xunkai Committed by GitHub
Browse files

Make max_length and static_batch configurable (#6893)

* Make max_length and static_batch configurable.

* Fix line length.

* Fix incorrect parameters in building eval input.

* Improve comments for readability.
parent e80b385a
......@@ -193,12 +193,12 @@ def _batch_examples(dataset, batch_size, max_length):
def _read_and_batch_from_files(
file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat,
static_batch=False):
static_batch=False, num_replicas=1):
"""Create dataset where each item is a dict of "inputs" and "targets".
Args:
file_pattern: String used to match the input TFRecord files.
batch_size: Maximum number of tokens per batch of examples
batch_size: Maximum number of tokens per global batch of examples.
max_length: Maximum number of tokens per example
num_parallel_calls: Number of cpu cores for parallel input processing.
shuffle: If true, randomizes order of elements.
......@@ -215,6 +215,10 @@ def _read_and_batch_from_files(
to be grouped so that the number of padding tokens is minimized, and helps
model training. In cases where the input shape must be static
(e.g. running on TPU), this setting should be set to True.
num_replicas: Number of GPUs or other workers. We will generate global
batches, and each global batch is equally divisible by number of replicas.
Currently it is only effective when static_batch==True. TODO: make it
effective when static_batch=False.
Returns:
tf.data.Dataset object containing examples loaded from the files.
......@@ -223,10 +227,12 @@ def _read_and_batch_from_files(
# Read files and interleave results. When training, the order of the examples
# will be non-deterministic.
options = tf.data.Options()
options.experimental_deterministic = False
dataset = dataset.interleave(
_load_records,
cycle_length=num_parallel_calls,
num_parallel_calls=num_parallel_calls)
num_parallel_calls=tf.data.experimental.AUTOTUNE).with_options(options)
# Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization.
......@@ -238,10 +244,14 @@ def _read_and_batch_from_files(
if static_batch:
dataset = dataset.padded_batch(
batch_size // max_length, ([max_length], [max_length]),
drop_remainder=True)
# First calculate batch size (token number) per worker, then divide it
# into sentences, and finally expand to a global batch. It could prove
# the global batch divisble for distribution strategy.
((batch_size // num_replicas) // max_length) * num_replicas,
([max_length], [max_length]), drop_remainder=True)
else:
# Group and batch such that each batch has examples of similar length.
# TODO: _batch_examples might need to do something special for num_replicas.
dataset = _batch_examples(dataset, batch_size, max_length)
dataset = dataset.repeat(repeat)
......@@ -272,7 +282,8 @@ def train_input_fn(params):
return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=True,
repeat=params["repeat_dataset"], static_batch=params["static_batch"])
repeat=params["repeat_dataset"], static_batch=params["static_batch"],
num_replicas=params["num_gpus"])
def eval_input_fn(params):
......@@ -283,7 +294,7 @@ def eval_input_fn(params):
return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=False, repeat=1,
static_batch=params["static_batch"])
static_batch=params["static_batch"], num_replicas=params["num_gpus"])
def map_data_for_transformer_fn(x, y):
......
......@@ -82,6 +82,9 @@ def define_transformer_flags():
help=flags_core.help_wrap(
'The Number of training steps to run between evaluations. This is '
'used if --train_steps is defined.'))
flags.DEFINE_boolean(
name='enable_time_history', default=True,
help='Whether to enable TimeHistory callback.')
flags.DEFINE_boolean(
name='enable_tensorboard', default=False,
help='Whether to enable Tensorboard callback.')
......@@ -111,7 +114,7 @@ def define_transformer_flags():
'complete list of parameters, please see model/model_params.py.'))
flags.DEFINE_bool(
name='static_batch', default=False,
name='static_batch', short_name='sb', default=False,
help=flags_core.help_wrap(
'Whether the batches in the dataset should have static shapes. In '
'general, this setting should be False. Dynamic shapes allow the '
......@@ -119,6 +122,12 @@ def define_transformer_flags():
'minimized, and helps model training. In cases where the input shape '
'must be static (e.g. running on TPU), this setting will be ignored '
'and static batching will always be used.'))
flags.DEFINE_integer(
name='max_length', short_name='ml', default=256,
help=flags_core.help_wrap(
'Max sentence length for Transformer. Default is 256. Note: Usually '
'it is more effective to use a smaller max length if static_batch is '
'enabled, e.g. 64.'))
# Flags for training with steps (may be used for debugging)
flags.DEFINE_integer(
......@@ -195,8 +204,9 @@ def define_transformer_flags():
def get_callbacks():
"""Returns common callbacks."""
callbacks = []
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
callbacks.append(time_callback)
if FLAGS.enable_time_history:
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
callbacks.append(time_callback)
if FLAGS.enable_tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
......
......@@ -99,9 +99,11 @@ class TransformerTask(object):
self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)
params["num_gpus"] = num_gpus
params["data_dir"] = flags_obj.data_dir
params["model_dir"] = flags_obj.model_dir
params["static_batch"] = flags_obj.static_batch
params["max_length"] = flags_obj.max_length
params["num_parallel_calls"] = (
flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)
......@@ -148,7 +150,9 @@ class TransformerTask(object):
epochs=i,
steps_per_epoch=flags_obj.steps_between_evals,
callbacks=callbacks,
verbose=2)
# If TimeHistory is enabled, progress bar would be messy. Increase the
# verbose level to get rid of it.
verbose=(2 if flags_obj.enable_time_history else 1))
print("End train iteration:{}/{} global step:{}".format(
i,
iterations,
......@@ -159,6 +163,8 @@ class TransformerTask(object):
if (flags_obj.bleu_source and flags_obj.bleu_ref):
uncased_score, cased_score = self.eval()
print("BLEU: uncased={}, cased={}".format(uncased_score, cased_score))
stats = misc.build_stats(history, callbacks)
if uncased_score and cased_score:
stats["bleu_uncased"] = uncased_score
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment