transformer_main.py 26.3 KB
Newer Older
Katherine Wu's avatar
Katherine Wu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
16
17
18
19
"""Train and evaluate the Transformer model.

See README for description of setting the training schedule and evaluating the
BLEU score.
"""
Katherine Wu's avatar
Katherine Wu committed
20
21
22
23
24
25
26
27
28
29

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tempfile

# pylint: disable=g-bad-import-order
from six.moves import xrange  # pylint: disable=redefined-builtin
30
31
from absl import app as absl_app
from absl import flags
Katherine Wu's avatar
Katherine Wu committed
32
33
34
35
36
37
38
39
40
import tensorflow as tf
# pylint: enable=g-bad-import-order

from official.transformer import compute_bleu
from official.transformer import translate
from official.transformer.model import model_params
from official.transformer.model import transformer
from official.transformer.utils import dataset
from official.transformer.utils import metrics
41
from official.transformer.utils import schedule
Katherine Wu's avatar
Katherine Wu committed
42
from official.transformer.utils import tokenizer
43
from official.utils.accelerator import tpu as tpu_util
44
from official.utils.export import export
45
46
47
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
48
from official.utils.misc import distribution_utils
49
50
51
from official.utils.misc import model_helpers

PARAMS_MAP = {
52
53
54
    "tiny": model_params.TINY_PARAMS,
    "base": model_params.BASE_PARAMS,
    "big": model_params.BIG_PARAMS,
55
}
56
57


Katherine Wu's avatar
Katherine Wu committed
58
59
DEFAULT_TRAIN_EPOCHS = 10
INF = int(1e9)
60
BLEU_DIR = "bleu"
Katherine Wu's avatar
Katherine Wu committed
61

62
63
64
65
66
67
# Dictionary containing tensors that are logged by the logging hooks. Each item
# maps a string to the tensor name.
TENSORS_TO_LOG = {
    "learning_rate": "model/get_train_op/learning_rate/learning_rate",
    "cross_entropy_loss": "model/cross_entropy"}

Katherine Wu's avatar
Katherine Wu committed
68
69
70
71
72
73
74
75
76

def model_fn(features, labels, mode, params):
  """Defines how to train, evaluate and predict from the transformer model."""
  with tf.variable_scope("model"):
    inputs, targets = features, labels

    # Create model and get output logits.
    model = transformer.Transformer(params, mode == tf.estimator.ModeKeys.TRAIN)

77
    logits = model(inputs, targets)
Katherine Wu's avatar
Katherine Wu committed
78
79
80
81

    # When in prediction mode, the labels/targets is None. The model output
    # is the prediction
    if mode == tf.estimator.ModeKeys.PREDICT:
82
83
      if params["use_tpu"]:
        raise NotImplementedError("Prediction is not yet supported on TPUs.")
Katherine Wu's avatar
Katherine Wu committed
84
85
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.PREDICT,
86
87
88
89
          predictions=logits,
          export_outputs={
              "translate": tf.estimator.export.PredictOutput(logits)
          })
Katherine Wu's avatar
Katherine Wu committed
90

91
92
93
94
95
96
97
98
    # Explicitly set the shape of the logits for XLA (TPU). This is needed
    # because the logits are passed back to the host VM CPU for metric
    # evaluation, and the shape of [?, ?, vocab_size] is too vague. However
    # it is known from Transformer that the first two dimensions of logits
    # are the dimensions of targets. Note that the ambiguous shape of logits is
    # not a problem when computing xentropy, because padded_cross_entropy_loss
    # resolves the shape on the TPU.
    logits.set_shape(targets.shape.as_list() + logits.shape.as_list()[2:])
Katherine Wu's avatar
Katherine Wu committed
99
100

    # Calculate model loss.
Katherine Wu's avatar
Katherine Wu committed
101
102
    # xentropy contains the cross entropy loss of every nonpadding token in the
    # targets.
Katherine Wu's avatar
Katherine Wu committed
103
    xentropy, weights = metrics.padded_cross_entropy_loss(
104
        logits, targets, params["label_smoothing"], params["vocab_size"])
Katherine Wu's avatar
Katherine Wu committed
105
    loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
Katherine Wu's avatar
Katherine Wu committed
106

107
108
109
    # Save loss as named tensor that will be logged with the logging hook.
    tf.identity(loss, "cross_entropy")

Katherine Wu's avatar
Katherine Wu committed
110
    if mode == tf.estimator.ModeKeys.EVAL:
111
112
      if params["use_tpu"]:
        # host call functions should only have tensors as arguments.
alope107's avatar
alope107 committed
113
        # This lambda pre-populates params so that metric_fn is
114
        # TPUEstimator compliant.
alope107's avatar
alope107 committed
115
116
        metric_fn = lambda logits, labels: (
            metrics.get_eval_metrics(logits, labels, params=params))
117
118
119
120
        eval_metrics = (metric_fn, [logits, labels])
        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode, loss=loss, predictions={"predictions": logits},
            eval_metrics=eval_metrics)
Katherine Wu's avatar
Katherine Wu committed
121
122
123
124
      return tf.estimator.EstimatorSpec(
          mode=mode, loss=loss, predictions={"predictions": logits},
          eval_metric_ops=metrics.get_eval_metrics(logits, labels, params))
    else:
125
126
127
128
129
130
131
132
133
134
135
136
137
      train_op, metric_dict = get_train_op_and_metrics(loss, params)

      # Epochs can be quite long. This gives some intermediate information
      # in TensorBoard.
      metric_dict["minibatch_loss"] = loss
      if params["use_tpu"]:
        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode, loss=loss, train_op=train_op,
            host_call=tpu_util.construct_scalar_host_call(
                metric_dict=metric_dict, model_dir=params["model_dir"],
                prefix="training/")
        )
      record_scalars(metric_dict)
Katherine Wu's avatar
Katherine Wu committed
138
139
140
      return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)


141
142
143
144
145
def record_scalars(metric_dict):
  for key, value in metric_dict.items():
    tf.contrib.summary.scalar(name=key, tensor=value)


Katherine Wu's avatar
Katherine Wu committed
146
147
148
149
150
151
152
153
154
155
156
157
def get_learning_rate(learning_rate, hidden_size, learning_rate_warmup_steps):
  """Calculate learning rate with linear warmup and rsqrt decay."""
  with tf.name_scope("learning_rate"):
    warmup_steps = tf.to_float(learning_rate_warmup_steps)
    step = tf.to_float(tf.train.get_or_create_global_step())

    learning_rate *= (hidden_size ** -0.5)
    # Apply linear warmup
    learning_rate *= tf.minimum(1.0, step / warmup_steps)
    # Apply rsqrt decay
    learning_rate *= tf.rsqrt(tf.maximum(step, warmup_steps))

158
159
160
161
    # Create a named tensor that will be logged using the logging hook.
    # The full name includes variable and names scope. In this case, the name
    # is model/get_train_op/learning_rate/learning_rate
    tf.identity(learning_rate, "learning_rate")
Katherine Wu's avatar
Katherine Wu committed
162
163
164
165

    return learning_rate


166
167
def get_train_op_and_metrics(loss, params):
  """Generate training op and metrics to save in TensorBoard."""
Katherine Wu's avatar
Katherine Wu committed
168
169
  with tf.variable_scope("get_train_op"):
    learning_rate = get_learning_rate(
170
171
172
        learning_rate=params["learning_rate"],
        hidden_size=params["hidden_size"],
        learning_rate_warmup_steps=params["learning_rate_warmup_steps"])
Katherine Wu's avatar
Katherine Wu committed
173
174
175
176
177

    # Create optimizer. Use LazyAdamOptimizer from TF contrib, which is faster
    # than the TF core Adam optimizer.
    optimizer = tf.contrib.opt.LazyAdamOptimizer(
        learning_rate,
178
179
180
181
182
183
        beta1=params["optimizer_adam_beta1"],
        beta2=params["optimizer_adam_beta2"],
        epsilon=params["optimizer_adam_epsilon"])

    if params["use_tpu"] and params["tpu"] != tpu_util.LOCAL:
      optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
Katherine Wu's avatar
Katherine Wu committed
184

185
186
187
188
189
    # Uses automatic mixed precision FP16 training if on GPU.
    if params["dtype"] == "fp16":
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)

Katherine Wu's avatar
Katherine Wu committed
190
191
192
193
194
    # Calculate and apply gradients using LazyAdamOptimizer.
    global_step = tf.train.get_global_step()
    tvars = tf.trainable_variables()
    gradients = optimizer.compute_gradients(
        loss, tvars, colocate_gradients_with_ops=True)
195
    minimize_op = optimizer.apply_gradients(
Katherine Wu's avatar
Katherine Wu committed
196
        gradients, global_step=global_step, name="train")
197
198
199
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    train_op = tf.group(minimize_op, update_ops)

200
    train_metrics = {"learning_rate": learning_rate}
Katherine Wu's avatar
Katherine Wu committed
201

202
203
204
205
    if not params["use_tpu"]:
      # gradient norm is not included as a summary when running on TPU, as
      # it can cause instability between the TPU and the host controller.
      gradient_norm = tf.global_norm(list(zip(*gradients))[0])
206
      train_metrics["global_norm/gradient_norm"] = gradient_norm
207

208
    return train_op, train_metrics
Katherine Wu's avatar
Katherine Wu committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232


def translate_and_compute_bleu(estimator, subtokenizer, bleu_source, bleu_ref):
  """Translate file and report the cased and uncased bleu scores."""
  # Create temporary file to store translation.
  tmp = tempfile.NamedTemporaryFile(delete=False)
  tmp_filename = tmp.name

  translate.translate_file(
      estimator, subtokenizer, bleu_source, output_file=tmp_filename,
      print_all_translations=False)

  # Compute uncased and cased bleu scores.
  uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False)
  cased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, True)
  os.remove(tmp_filename)
  return uncased_score, cased_score


def get_global_step(estimator):
  """Return estimator's last checkpoint."""
  return int(estimator.latest_checkpoint().split("-")[-1])


233
def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file):
Katherine Wu's avatar
Katherine Wu committed
234
  """Calculate and record the BLEU score."""
235
  subtokenizer = tokenizer.Subtokenizer(vocab_file)
Katherine Wu's avatar
Katherine Wu committed
236
237
238
239

  uncased_score, cased_score = translate_and_compute_bleu(
      estimator, subtokenizer, bleu_source, bleu_ref)

240
241
  tf.logging.info("Bleu score (uncased): %f", uncased_score)
  tf.logging.info("Bleu score (cased): %f", cased_score)
Katherine Wu's avatar
Katherine Wu committed
242
243
  return uncased_score, cased_score

244
245
246

def _validate_file(filepath):
  """Make sure that file exists."""
247
  if not tf.io.gfile.exists(filepath):
248
249
250
    raise tf.errors.NotFoundError(None, None, "File %s not found." % filepath)


251
252
def run_loop(
    estimator, schedule_manager, train_hooks=None, benchmark_logger=None,
253
    bleu_source=None, bleu_ref=None, bleu_threshold=None, vocab_file=None):
Katherine Wu's avatar
Katherine Wu committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
  """Train and evaluate model, and optionally compute model's BLEU score.

  **Step vs. Epoch vs. Iteration**

  Steps and epochs are canonical terms used in TensorFlow and general machine
  learning. They are used to describe running a single process (train/eval):
    - Step refers to running the process through a single or batch of examples.
    - Epoch refers to running the process through an entire dataset.

  E.g. training a dataset with 100 examples. The dataset is
  divided into 20 batches with 5 examples per batch. A single training step
  trains the model on one batch. After 20 training steps, the model will have
  trained on every batch in the dataset, or, in other words, one epoch.

  Meanwhile, iteration is used in this implementation to describe running
  multiple processes (training and eval).
    - A single iteration:
      1. trains the model for a specific number of steps or epochs.
      2. evaluates the model.
      3. (if source and ref files are provided) compute BLEU score.

  This function runs through multiple train+eval+bleu iterations.

  Args:
    estimator: tf.Estimator containing model to train.
279
    schedule_manager: A schedule.Manager object to guide the run loop.
280
281
    train_hooks: List of hooks to pass to the estimator during training.
    benchmark_logger: a BenchmarkLogger object that logs evaluation data
Katherine Wu's avatar
Katherine Wu committed
282
283
284
    bleu_source: File containing text to be translated for BLEU calculation.
    bleu_ref: File containing reference translations for BLEU calculation.
    bleu_threshold: minimum BLEU score before training is stopped.
285
286
    vocab_file: Path to vocab file that will be used to subtokenize bleu_source.

287
288
289
290
291
  Returns:
    Dict of results of the run.  Contains the keys `eval_results`,
    `train_hooks`, `bleu_cased`, and `bleu_uncased`. `train_hooks` is a list the
    instances of hooks used during training.

292
293
294
  Raises:
    ValueError: if both or none of single_iteration_train_steps and
      single_iteration_train_epochs were defined.
295
    NotFoundError: if the vocab file or bleu files don't exist.
Katherine Wu's avatar
Katherine Wu committed
296
  """
297
298
299
300
301
302
  if bleu_source:
    _validate_file(bleu_source)
  if bleu_ref:
    _validate_file(bleu_ref)
  if vocab_file:
    _validate_file(vocab_file)
Katherine Wu's avatar
Katherine Wu committed
303
304

  evaluate_bleu = bleu_source is not None and bleu_ref is not None
305
306
307
308
  if evaluate_bleu and schedule_manager.use_tpu:
    raise ValueError("BLEU score can not be computed when training with a TPU, "
                     "as it requires estimator.predict which is not yet "
                     "supported.")
Katherine Wu's avatar
Katherine Wu committed
309

310
311
  # Print details of training schedule.
  tf.logging.info("Training schedule:")
312
313
  tf.logging.info(
      "\t1. Train for {}".format(schedule_manager.train_increment_str))
314
  tf.logging.info("\t2. Evaluate model.")
Katherine Wu's avatar
Katherine Wu committed
315
  if evaluate_bleu:
316
    tf.logging.info("\t3. Compute BLEU score.")
Katherine Wu's avatar
Katherine Wu committed
317
    if bleu_threshold is not None:
318
319
      tf.logging.info("Repeat above steps until the BLEU score reaches %f" %
                      bleu_threshold)
Katherine Wu's avatar
Katherine Wu committed
320
  if not evaluate_bleu or bleu_threshold is None:
321
322
    tf.logging.info("Repeat above steps %d times." %
                    schedule_manager.train_eval_iterations)
Katherine Wu's avatar
Katherine Wu committed
323
324

  if evaluate_bleu:
325
326
    # Create summary writer to log bleu score (values can be displayed in
    # Tensorboard).
Katherine Wu's avatar
Katherine Wu committed
327
328
329
330
    bleu_writer = tf.summary.FileWriter(
        os.path.join(estimator.model_dir, BLEU_DIR))
    if bleu_threshold is not None:
      # Change loop stopping condition if bleu_threshold is defined.
331
      schedule_manager.train_eval_iterations = INF
Katherine Wu's avatar
Katherine Wu committed
332
333

  # Loop training/evaluation/bleu cycles
334
  stats = {}
335
  for i in xrange(schedule_manager.train_eval_iterations):
336
    tf.logging.info("Starting iteration %d" % (i + 1))
Katherine Wu's avatar
Katherine Wu committed
337
338
339

    # Train the model for single_iteration_train_steps or until the input fn
    # runs out of examples (if single_iteration_train_steps is None).
340
    estimator.train(
341
342
        dataset.train_input_fn,
        steps=schedule_manager.single_iteration_train_steps,
343
        hooks=train_hooks)
Katherine Wu's avatar
Katherine Wu committed
344

345
346
347
348
    eval_results = estimator.evaluate(
        input_fn=dataset.eval_input_fn,
        steps=schedule_manager.single_iteration_eval_steps)

349
    tf.logging.info("Evaluation results (iter %d/%d):" %
350
                    (i + 1, schedule_manager.train_eval_iterations))
351
352
353
354
355
356
357
358
    tf.logging.info(eval_results)
    benchmark_logger.log_evaluation_result(eval_results)

    # The results from estimator.evaluate() are measured on an approximate
    # translation, which utilize the target golden values provided. The actual
    # bleu score must be computed using the estimator.predict() path, which
    # outputs translations that are not based on golden values. The translations
    # are compared to reference file to get the actual bleu score.
Katherine Wu's avatar
Katherine Wu committed
359
    if evaluate_bleu:
360
      uncased_score, cased_score = evaluate_and_log_bleu(
361
          estimator, bleu_source, bleu_ref, vocab_file)
362

363
364
365
      stats["bleu_uncased"] = uncased_score
      stats["bleu_cased"] = cased_score

366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
      # Write actual bleu scores using summary writer and benchmark logger
      global_step = get_global_step(estimator)
      summary = tf.Summary(value=[
          tf.Summary.Value(tag="bleu/uncased", simple_value=uncased_score),
          tf.Summary.Value(tag="bleu/cased", simple_value=cased_score),
      ])
      bleu_writer.add_summary(summary, global_step)
      bleu_writer.flush()
      benchmark_logger.log_metric(
          "bleu_uncased", uncased_score, global_step=global_step)
      benchmark_logger.log_metric(
          "bleu_cased", cased_score, global_step=global_step)

      # Stop training if bleu stopping threshold is met.
      if model_helpers.past_stop_threshold(bleu_threshold, uncased_score):
Katherine Wu's avatar
Katherine Wu committed
381
382
383
        bleu_writer.close()
        break

384
385
386
387
388
  stats["eval_results"] = eval_results
  stats["train_hooks"] = train_hooks

  return stats

Katherine Wu's avatar
Katherine Wu committed
389

390
391
392
def define_transformer_flags():
  """Add flags and flag validators for running transformer_main."""
  # Add common flags (data_dir, model_dir, train_epochs, etc.).
393
394
395
396
  flags.DEFINE_integer(
      name="max_length", short_name="ml", default=None,
      help=flags_core.help_wrap("Max length."))

397
  flags_core.define_base()
398
399
400
401
  flags_core.define_performance(
      num_parallel_calls=True,
      inter_op=False,
      intra_op=False,
402
      synthetic_data=True,
403
      max_train_steps=False,
404
      dtype=True,
405
      all_reduce_alg=True
406
407
  )
  flags_core.define_benchmark()
408
  flags_core.define_device(tpu=True)
409
410
411
412
413
414
415
416
417

  # Set flags from the flags_core module as "key flags" so they're listed when
  # the '-h' flag is used. Without this line, the flags defined above are
  # only shown in the full `--helpful` help text.
  flags.adopt_module_key_flags(flags_core)

  # Add transformer-specific flags
  flags.DEFINE_enum(
      name="param_set", short_name="mp", default="big",
418
      enum_values=PARAMS_MAP.keys(),
419
420
421
422
423
424
425
      help=flags_core.help_wrap(
          "Parameter set to use when creating and training the model. The "
          "parameters define the input shape (batch size and max length), "
          "model configuration (size of embedding, # of hidden layers, etc.), "
          "and various other settings. The big parameter set increases the "
          "default batch size, embedding/hidden size, and filter size. For a "
          "complete list of parameters, please see model/model_params.py."))
Katherine Wu's avatar
Katherine Wu committed
426

427
428
429
430
431
432
433
434
435
436
  flags.DEFINE_bool(
      name="static_batch", default=False,
      help=flags_core.help_wrap(
          "Whether the batches in the dataset should have static shapes. In "
          "general, this setting should be False. Dynamic shapes allow the "
          "inputs to be grouped so that the number of padding tokens is "
          "minimized, and helps model training. In cases where the input shape "
          "must be static (e.g. running on TPU), this setting will be ignored "
          "and static batching will always be used."))

437
438
439
440
441
442
443
444
445
  # Flags for training with steps (may be used for debugging)
  flags.DEFINE_integer(
      name="train_steps", short_name="ts", default=None,
      help=flags_core.help_wrap("The number of steps used to train."))
  flags.DEFINE_integer(
      name="steps_between_evals", short_name="sbe", default=1000,
      help=flags_core.help_wrap(
          "The Number of training steps to run between evaluations. This is "
          "used if --train_steps is defined."))
Katherine Wu's avatar
Katherine Wu committed
446

447
448
449
450
451
  # BLEU score computation
  flags.DEFINE_string(
      name="bleu_source", short_name="bls", default=None,
      help=flags_core.help_wrap(
          "Path to source file containing text translate when calculating the "
452
453
454
          "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
          "Use the flag --stop_threshold to stop the script based on the "
          "uncased BLEU score."))
455
456
457
458
  flags.DEFINE_string(
      name="bleu_ref", short_name="blr", default=None,
      help=flags_core.help_wrap(
          "Path to source file containing text translate when calculating the "
459
460
461
          "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
          "Use the flag --stop_threshold to stop the script based on the "
          "uncased BLEU score."))
462
  flags.DEFINE_string(
463
      name="vocab_file", short_name="vf", default=None,
464
      help=flags_core.help_wrap(
465
466
467
          "Path to subtoken vocabulary file. If data_download.py was used to "
          "download and encode the training data, look in the data_dir to find "
          "the vocab file."))
468
469
470
471
472
473
474
475
476
477
478
479
480
481

  flags_core.set_defaults(data_dir="/tmp/translate_ende",
                          model_dir="/tmp/transformer_model",
                          batch_size=None,
                          train_epochs=None)

  @flags.multi_flags_validator(
      ["train_epochs", "train_steps"],
      message="Both --train_steps and --train_epochs were set. Only one may be "
              "defined.")
  def _check_train_limits(flag_dict):
    return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None

  @flags.multi_flags_validator(
482
      ["bleu_source", "bleu_ref"],
483
      message="Both or neither --bleu_source and --bleu_ref must be defined.")
484
  def _check_bleu_files(flags_dict):
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    return (flags_dict["bleu_source"] is None) == (
        flags_dict["bleu_ref"] is None)

  @flags.multi_flags_validator(
      ["bleu_source", "bleu_ref", "vocab_file"],
      message="--vocab_file must be defined if --bleu_source and --bleu_ref "
              "are defined.")
  def _check_bleu_vocab_file(flags_dict):
    if flags_dict["bleu_source"] and flags_dict["bleu_ref"]:
      return flags_dict["vocab_file"] is not None
    return True

  @flags.multi_flags_validator(
      ["export_dir", "vocab_file"],
      message="--vocab_file must be defined if --export_dir is set.")
  def _check_export_vocab_file(flags_dict):
    if flags_dict["export_dir"]:
      return flags_dict["vocab_file"] is not None
    return True

  flags_core.require_cloud_storage(["data_dir", "model_dir", "export_dir"])
506
507
508
509
510
511
512
513
514
515
516
517
518
519


def construct_estimator(flags_obj, params, schedule_manager):
  """Construct an estimator from either Estimator or TPUEstimator.

  Args:
    flags_obj: The FLAGS object parsed from command line.
    params: A dict of run specific parameters.
    schedule_manager: A schedule.Manager object containing the run schedule.

  Returns:
    An estimator object to be used for training and eval.
  """
  if not params["use_tpu"]:
520
    distribution_strategy = distribution_utils.get_distribution_strategy(
521
522
523
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_core.get_num_gpus(flags_obj),
        all_reduce_alg=flags_obj.all_reduce_alg)
524
    return tf.estimator.Estimator(
525
526
        model_fn=model_fn, model_dir=flags_obj.model_dir, params=params,
        config=tf.estimator.RunConfig(train_distribute=distribution_strategy))
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

  tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
      tpu=flags_obj.tpu,
      zone=flags_obj.tpu_zone,
      project=flags_obj.tpu_gcp_project
  )

  tpu_config = tf.contrib.tpu.TPUConfig(
      iterations_per_loop=schedule_manager.single_iteration_train_steps,
      num_shards=flags_obj.num_tpu_shards)

  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=flags_obj.model_dir,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=True),
      tpu_config=tpu_config)

  return tf.contrib.tpu.TPUEstimator(
      model_fn=model_fn,
      use_tpu=params["use_tpu"] and flags_obj.tpu != tpu_util.LOCAL,
      train_batch_size=schedule_manager.batch_size,
      eval_batch_size=schedule_manager.batch_size,
      params={
          # TPUEstimator needs to populate batch_size itself due to sharding.
          key: value for key, value in params.items() if key != "batch_size"},
      config=run_config)

555
556
557
558
559
560

def run_transformer(flags_obj):
  """Create tf.Estimator to train and evaluate transformer model.

  Args:
    flags_obj: Object containing parsed flag values.
561
562
563
564
565

  Returns:
    Dict of results of the run.  Contains the keys `eval_results`,
    `train_hooks`, `bleu_cased`, and `bleu_uncased`. `train_hooks` is a list the
    instances of hooks used during training.
566
  """
567
568
  num_gpus = flags_core.get_num_gpus(flags_obj)

Katherine Wu's avatar
Katherine Wu committed
569
  # Add flag-defined parameters to params object
570
  params = PARAMS_MAP[flags_obj.param_set]
571
572
573
574
575
576
  if num_gpus > 1:
    if flags_obj.param_set == "big":
      params = model_params.BIG_MULTI_GPU_PARAMS
    elif flags_obj.param_set == "base":
      params = model_params.BASE_MULTI_GPU_PARAMS

577
578
579
580
581
582
583
584
585
  params["data_dir"] = flags_obj.data_dir
  params["model_dir"] = flags_obj.model_dir
  params["num_parallel_calls"] = flags_obj.num_parallel_calls

  params["tpu"] = flags_obj.tpu
  params["use_tpu"] = bool(flags_obj.tpu)  # was a tpu specified.
  params["static_batch"] = flags_obj.static_batch or params["use_tpu"]
  params["allow_ffn_pad"] = not params["use_tpu"]

586
587
  params["max_length"] = flags_obj.max_length or params['max_length']

588
589
  params["use_synthetic_data"] = flags_obj.use_synthetic_data

590
591
592
593
594
595
  # Set batch size parameter, which depends on the availability of
  # TPU and GPU, and distribution settings.
  params["batch_size"] = (flags_obj.batch_size or (
      params["default_batch_size_tpu"] if params["use_tpu"]
      else params["default_batch_size"]))

596
  total_batch_size = params["batch_size"]
597
  if not params["use_tpu"]:
598
    params["batch_size"] = distribution_utils.per_replica_batch_size(
599
600
        params["batch_size"], num_gpus)

601
602
603
604
605
606
607
608
609
610
611
612
613
  schedule_manager = schedule.Manager(
      train_steps=flags_obj.train_steps,
      steps_between_evals=flags_obj.steps_between_evals,
      train_epochs=flags_obj.train_epochs,
      epochs_between_evals=flags_obj.epochs_between_evals,
      default_train_epochs=DEFAULT_TRAIN_EPOCHS,
      batch_size=params["batch_size"],
      max_length=params["max_length"],
      use_tpu=params["use_tpu"],
      num_tpu_shards=flags_obj.num_tpu_shards
  )

  params["repeat_dataset"] = schedule_manager.repeat_dataset
614

615
616
  model_helpers.apply_clean(flags.FLAGS)

617
618
619
  # Create hooks that log information about the training and metric values
  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
620
      model_dir=flags_obj.model_dir,
621
      tensors_to_log=TENSORS_TO_LOG,  # used for logging hooks
622
      batch_size=total_batch_size,  # for ExamplesPerSecondHook
623
      use_tpu=params["use_tpu"]  # Not all hooks can run with TPUs
624
  )
625
  benchmark_logger = logger.get_benchmark_logger()
626
627
628
  benchmark_logger.log_run_info(
      model_name="transformer",
      dataset_name="wmt_translate_ende",
629
      run_params=params,
630
      test_id=flags_obj.benchmark_test_id)
631
632

  # Train and evaluate transformer model
633
  estimator = construct_estimator(flags_obj, params, schedule_manager)
634
  stats = run_loop(
635
636
      estimator=estimator,
      # Training arguments
637
      schedule_manager=schedule_manager,
638
639
640
641
642
643
      train_hooks=train_hooks,
      benchmark_logger=benchmark_logger,
      # BLEU calculation arguments
      bleu_source=flags_obj.bleu_source,
      bleu_ref=flags_obj.bleu_ref,
      bleu_threshold=flags_obj.stop_threshold,
644
645
      vocab_file=flags_obj.vocab_file)

646
  if flags_obj.export_dir and not params["use_tpu"]:
647
648
649
650
651
652
653
654
655
656
    serving_input_fn = export.build_tensor_serving_input_receiver_fn(
        shape=[None], dtype=tf.int64, batch_size=None)
    # Export saved model, and save the vocab file as an extra asset. The vocab
    # file is saved to allow consistent input encoding and output decoding.
    # (See the "Export trained model" section in the README for an example of
    # how to use the vocab file.)
    # Since the model itself does not use the vocab file, this file is saved as
    # an extra asset rather than a core asset.
    estimator.export_savedmodel(
        flags_obj.export_dir, serving_input_fn,
657
658
        assets_extra={"vocab.txt": flags_obj.vocab_file},
        strip_default_attrs=True)
659
  return stats
Katherine Wu's avatar
Katherine Wu committed
660
661


662
def main(_):
663
664
  with logger.benchmark_context(flags.FLAGS):
    run_transformer(flags.FLAGS)
Katherine Wu's avatar
Katherine Wu committed
665
666


667
668
669
670
if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)
  define_transformer_flags()
  absl_app.run(main)