transformer_main.py 22.8 KB
Newer Older
Katherine Wu's avatar
Katherine Wu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
16
17
18
19
"""Train and evaluate the Transformer model.

See README for description of setting the training schedule and evaluating the
BLEU score.
"""
Katherine Wu's avatar
Katherine Wu committed
20
21
22
23
24

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

25
import functools
Katherine Wu's avatar
Katherine Wu committed
26
27
28
29
30
import os
import tempfile

# pylint: disable=g-bad-import-order
from six.moves import xrange  # pylint: disable=redefined-builtin
31
32
from absl import app as absl_app
from absl import flags
Katherine Wu's avatar
Katherine Wu committed
33
34
35
36
37
38
39
40
41
42
import tensorflow as tf
# pylint: enable=g-bad-import-order

from official.transformer import compute_bleu
from official.transformer import translate
from official.transformer.data_download import VOCAB_FILE
from official.transformer.model import model_params
from official.transformer.model import transformer
from official.transformer.utils import dataset
from official.transformer.utils import metrics
43
from official.transformer.utils import schedule
Katherine Wu's avatar
Katherine Wu committed
44
from official.transformer.utils import tokenizer
45
from official.utils.accelerator import tpu as tpu_util
46
47
48
49
50
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import model_helpers

Katherine Wu's avatar
Katherine Wu committed
51

52
PARAMS_MAP = {
53
54
55
    "tiny": model_params.TINY_PARAMS,
    "base": model_params.BASE_PARAMS,
    "big": model_params.BIG_PARAMS,
56
}
Katherine Wu's avatar
Katherine Wu committed
57
58
59
60
DEFAULT_TRAIN_EPOCHS = 10
BLEU_DIR = "bleu"
INF = int(1e9)

61
62
63
64
65
66
# Dictionary containing tensors that are logged by the logging hooks. Each item
# maps a string to the tensor name.
TENSORS_TO_LOG = {
    "learning_rate": "model/get_train_op/learning_rate/learning_rate",
    "cross_entropy_loss": "model/cross_entropy"}

Katherine Wu's avatar
Katherine Wu committed
67
68
69
70
71
72
73
74
75

def model_fn(features, labels, mode, params):
  """Defines how to train, evaluate and predict from the transformer model."""
  with tf.variable_scope("model"):
    inputs, targets = features, labels

    # Create model and get output logits.
    model = transformer.Transformer(params, mode == tf.estimator.ModeKeys.TRAIN)

76
    logits = model(inputs, targets)
Katherine Wu's avatar
Katherine Wu committed
77
78
79
80

    # When in prediction mode, the labels/targets is None. The model output
    # is the prediction
    if mode == tf.estimator.ModeKeys.PREDICT:
81
82
      if params["use_tpu"]:
        raise NotImplementedError("Prediction is not yet supported on TPUs.")
Katherine Wu's avatar
Katherine Wu committed
83
84
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.PREDICT,
85
          predictions=logits)
Katherine Wu's avatar
Katherine Wu committed
86

87
88
89
90
91
92
93
94
    # Explicitly set the shape of the logits for XLA (TPU). This is needed
    # because the logits are passed back to the host VM CPU for metric
    # evaluation, and the shape of [?, ?, vocab_size] is too vague. However
    # it is known from Transformer that the first two dimensions of logits
    # are the dimensions of targets. Note that the ambiguous shape of logits is
    # not a problem when computing xentropy, because padded_cross_entropy_loss
    # resolves the shape on the TPU.
    logits.set_shape(targets.shape.as_list() + logits.shape.as_list()[2:])
Katherine Wu's avatar
Katherine Wu committed
95
96

    # Calculate model loss.
Katherine Wu's avatar
Katherine Wu committed
97
98
    # xentropy contains the cross entropy loss of every nonpadding token in the
    # targets.
Katherine Wu's avatar
Katherine Wu committed
99
    xentropy, weights = metrics.padded_cross_entropy_loss(
100
        logits, targets, params["label_smoothing"], params["vocab_size"])
Katherine Wu's avatar
Katherine Wu committed
101
    loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
Katherine Wu's avatar
Katherine Wu committed
102

103
104
105
    # Save loss as named tensor that will be logged with the logging hook.
    tf.identity(loss, "cross_entropy")

Katherine Wu's avatar
Katherine Wu committed
106
    if mode == tf.estimator.ModeKeys.EVAL:
107
108
109
110
111
112
113
114
115
      if params["use_tpu"]:
        # host call functions should only have tensors as arguments.
        # functools.partial() pre-populates params so that metric_fn is
        # TPUEstimator compliant.
        metric_fn = functools.partial(metrics.get_eval_metrics, params=params)
        eval_metrics = (metric_fn, [logits, labels])
        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode, loss=loss, predictions={"predictions": logits},
            eval_metrics=eval_metrics)
Katherine Wu's avatar
Katherine Wu committed
116
117
118
119
      return tf.estimator.EstimatorSpec(
          mode=mode, loss=loss, predictions={"predictions": logits},
          eval_metric_ops=metrics.get_eval_metrics(logits, labels, params))
    else:
120
121
122
123
124
125
126
127
128
129
130
131
132
      train_op, metric_dict = get_train_op_and_metrics(loss, params)

      # Epochs can be quite long. This gives some intermediate information
      # in TensorBoard.
      metric_dict["minibatch_loss"] = loss
      if params["use_tpu"]:
        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode, loss=loss, train_op=train_op,
            host_call=tpu_util.construct_scalar_host_call(
                metric_dict=metric_dict, model_dir=params["model_dir"],
                prefix="training/")
        )
      record_scalars(metric_dict)
Katherine Wu's avatar
Katherine Wu committed
133
134
135
      return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)


136
137
138
139
140
def record_scalars(metric_dict):
  for key, value in metric_dict.items():
    tf.contrib.summary.scalar(name=key, tensor=value)


Katherine Wu's avatar
Katherine Wu committed
141
142
143
144
145
146
147
148
149
150
151
152
def get_learning_rate(learning_rate, hidden_size, learning_rate_warmup_steps):
  """Calculate learning rate with linear warmup and rsqrt decay."""
  with tf.name_scope("learning_rate"):
    warmup_steps = tf.to_float(learning_rate_warmup_steps)
    step = tf.to_float(tf.train.get_or_create_global_step())

    learning_rate *= (hidden_size ** -0.5)
    # Apply linear warmup
    learning_rate *= tf.minimum(1.0, step / warmup_steps)
    # Apply rsqrt decay
    learning_rate *= tf.rsqrt(tf.maximum(step, warmup_steps))

153
154
155
156
    # Create a named tensor that will be logged using the logging hook.
    # The full name includes variable and names scope. In this case, the name
    # is model/get_train_op/learning_rate/learning_rate
    tf.identity(learning_rate, "learning_rate")
Katherine Wu's avatar
Katherine Wu committed
157
158
159
160

    return learning_rate


161
162
def get_train_op_and_metrics(loss, params):
  """Generate training op and metrics to save in TensorBoard."""
Katherine Wu's avatar
Katherine Wu committed
163
164
  with tf.variable_scope("get_train_op"):
    learning_rate = get_learning_rate(
165
166
167
        learning_rate=params["learning_rate"],
        hidden_size=params["hidden_size"],
        learning_rate_warmup_steps=params["learning_rate_warmup_steps"])
Katherine Wu's avatar
Katherine Wu committed
168
169
170
171
172

    # Create optimizer. Use LazyAdamOptimizer from TF contrib, which is faster
    # than the TF core Adam optimizer.
    optimizer = tf.contrib.opt.LazyAdamOptimizer(
        learning_rate,
173
174
175
176
177
178
        beta1=params["optimizer_adam_beta1"],
        beta2=params["optimizer_adam_beta2"],
        epsilon=params["optimizer_adam_epsilon"])

    if params["use_tpu"] and params["tpu"] != tpu_util.LOCAL:
      optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
Katherine Wu's avatar
Katherine Wu committed
179
180
181
182
183
184
185
186
187

    # Calculate and apply gradients using LazyAdamOptimizer.
    global_step = tf.train.get_global_step()
    tvars = tf.trainable_variables()
    gradients = optimizer.compute_gradients(
        loss, tvars, colocate_gradients_with_ops=True)
    train_op = optimizer.apply_gradients(
        gradients, global_step=global_step, name="train")

188
    metrics = {"learning_rate": learning_rate}
Katherine Wu's avatar
Katherine Wu committed
189

190
191
192
193
194
195
196
    if not params["use_tpu"]:
      # gradient norm is not included as a summary when running on TPU, as
      # it can cause instability between the TPU and the host controller.
      gradient_norm = tf.global_norm(list(zip(*gradients))[0])
      metrics["global_norm/gradient_norm"] = gradient_norm

    return train_op, metrics
Katherine Wu's avatar
Katherine Wu committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220


def translate_and_compute_bleu(estimator, subtokenizer, bleu_source, bleu_ref):
  """Translate file and report the cased and uncased bleu scores."""
  # Create temporary file to store translation.
  tmp = tempfile.NamedTemporaryFile(delete=False)
  tmp_filename = tmp.name

  translate.translate_file(
      estimator, subtokenizer, bleu_source, output_file=tmp_filename,
      print_all_translations=False)

  # Compute uncased and cased bleu scores.
  uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False)
  cased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, True)
  os.remove(tmp_filename)
  return uncased_score, cased_score


def get_global_step(estimator):
  """Return estimator's last checkpoint."""
  return int(estimator.latest_checkpoint().split("-")[-1])


221
def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file_path):
Katherine Wu's avatar
Katherine Wu committed
222
  """Calculate and record the BLEU score."""
223
  subtokenizer = tokenizer.Subtokenizer(vocab_file_path)
Katherine Wu's avatar
Katherine Wu committed
224
225
226
227

  uncased_score, cased_score = translate_and_compute_bleu(
      estimator, subtokenizer, bleu_source, bleu_ref)

228
229
  tf.logging.info("Bleu score (uncased):", uncased_score)
  tf.logging.info("Bleu score (cased):", cased_score)
Katherine Wu's avatar
Katherine Wu committed
230
231
232
  return uncased_score, cased_score


233
234
def run_loop(
    estimator, schedule_manager, train_hooks=None, benchmark_logger=None,
235
    bleu_source=None, bleu_ref=None, bleu_threshold=None, vocab_file_path=None):
Katherine Wu's avatar
Katherine Wu committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
  """Train and evaluate model, and optionally compute model's BLEU score.

  **Step vs. Epoch vs. Iteration**

  Steps and epochs are canonical terms used in TensorFlow and general machine
  learning. They are used to describe running a single process (train/eval):
    - Step refers to running the process through a single or batch of examples.
    - Epoch refers to running the process through an entire dataset.

  E.g. training a dataset with 100 examples. The dataset is
  divided into 20 batches with 5 examples per batch. A single training step
  trains the model on one batch. After 20 training steps, the model will have
  trained on every batch in the dataset, or, in other words, one epoch.

  Meanwhile, iteration is used in this implementation to describe running
  multiple processes (training and eval).
    - A single iteration:
      1. trains the model for a specific number of steps or epochs.
      2. evaluates the model.
      3. (if source and ref files are provided) compute BLEU score.

  This function runs through multiple train+eval+bleu iterations.

  Args:
    estimator: tf.Estimator containing model to train.
261
    schedule_manager: A schedule.Manager object to guide the run loop.
262
263
    train_hooks: List of hooks to pass to the estimator during training.
    benchmark_logger: a BenchmarkLogger object that logs evaluation data
Katherine Wu's avatar
Katherine Wu committed
264
265
266
    bleu_source: File containing text to be translated for BLEU calculation.
    bleu_ref: File containing reference translations for BLEU calculation.
    bleu_threshold: minimum BLEU score before training is stopped.
267
    vocab_file_path: Path to vocabulary file used to subtokenize bleu_source.
Katherine Wu's avatar
Katherine Wu committed
268
269
270
  """

  evaluate_bleu = bleu_source is not None and bleu_ref is not None
271
272
273
274
  if evaluate_bleu and schedule_manager.use_tpu:
    raise ValueError("BLEU score can not be computed when training with a TPU, "
                     "as it requires estimator.predict which is not yet "
                     "supported.")
Katherine Wu's avatar
Katherine Wu committed
275

276
277
  # Print details of training schedule.
  tf.logging.info("Training schedule:")
278
279
  tf.logging.info(
      "\t1. Train for {}".format(schedule_manager.train_increment_str))
280
  tf.logging.info("\t2. Evaluate model.")
Katherine Wu's avatar
Katherine Wu committed
281
  if evaluate_bleu:
282
    tf.logging.info("\t3. Compute BLEU score.")
Katherine Wu's avatar
Katherine Wu committed
283
    if bleu_threshold is not None:
284
285
      tf.logging.info("Repeat above steps until the BLEU score reaches %f" %
                      bleu_threshold)
Katherine Wu's avatar
Katherine Wu committed
286
  if not evaluate_bleu or bleu_threshold is None:
287
288
    tf.logging.info("Repeat above steps %d times." %
                    schedule_manager.train_eval_iterations)
Katherine Wu's avatar
Katherine Wu committed
289
290

  if evaluate_bleu:
291
292
    # Create summary writer to log bleu score (values can be displayed in
    # Tensorboard).
Katherine Wu's avatar
Katherine Wu committed
293
294
295
296
    bleu_writer = tf.summary.FileWriter(
        os.path.join(estimator.model_dir, BLEU_DIR))
    if bleu_threshold is not None:
      # Change loop stopping condition if bleu_threshold is defined.
297
      schedule_manager.train_eval_iterations = INF
Katherine Wu's avatar
Katherine Wu committed
298
299

  # Loop training/evaluation/bleu cycles
300
  for i in xrange(schedule_manager.train_eval_iterations):
301
    tf.logging.info("Starting iteration %d" % (i + 1))
Katherine Wu's avatar
Katherine Wu committed
302
303
304

    # Train the model for single_iteration_train_steps or until the input fn
    # runs out of examples (if single_iteration_train_steps is None).
305
    estimator.train(
306
307
        dataset.train_input_fn,
        steps=schedule_manager.single_iteration_train_steps,
308
        hooks=train_hooks)
Katherine Wu's avatar
Katherine Wu committed
309

310
311
312
313
    eval_results = estimator.evaluate(
        input_fn=dataset.eval_input_fn,
        steps=schedule_manager.single_iteration_eval_steps)

314
    tf.logging.info("Evaluation results (iter %d/%d):" %
315
                    (i + 1, schedule_manager.train_eval_iterations))
316
317
318
319
320
321
322
323
    tf.logging.info(eval_results)
    benchmark_logger.log_evaluation_result(eval_results)

    # The results from estimator.evaluate() are measured on an approximate
    # translation, which utilize the target golden values provided. The actual
    # bleu score must be computed using the estimator.predict() path, which
    # outputs translations that are not based on golden values. The translations
    # are compared to reference file to get the actual bleu score.
Katherine Wu's avatar
Katherine Wu committed
324
    if evaluate_bleu:
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
      uncased_score, cased_score = evaluate_and_log_bleu(
          estimator, bleu_source, bleu_ref, vocab_file_path)

      # Write actual bleu scores using summary writer and benchmark logger
      global_step = get_global_step(estimator)
      summary = tf.Summary(value=[
          tf.Summary.Value(tag="bleu/uncased", simple_value=uncased_score),
          tf.Summary.Value(tag="bleu/cased", simple_value=cased_score),
      ])
      bleu_writer.add_summary(summary, global_step)
      bleu_writer.flush()
      benchmark_logger.log_metric(
          "bleu_uncased", uncased_score, global_step=global_step)
      benchmark_logger.log_metric(
          "bleu_cased", cased_score, global_step=global_step)

      # Stop training if bleu stopping threshold is met.
      if model_helpers.past_stop_threshold(bleu_threshold, uncased_score):
Katherine Wu's avatar
Katherine Wu committed
343
344
345
346
        bleu_writer.close()
        break


347
348
349
350
351
352
353
354
355
356
357
358
359
def define_transformer_flags():
  """Add flags and flag validators for running transformer_main."""
  # Add common flags (data_dir, model_dir, train_epochs, etc.).
  flags_core.define_base(multi_gpu=False, num_gpu=False, export_dir=False)
  flags_core.define_performance(
      num_parallel_calls=True,
      inter_op=False,
      intra_op=False,
      synthetic_data=False,
      max_train_steps=False,
      dtype=False
  )
  flags_core.define_benchmark()
360
  flags_core.define_device(tpu=True)
361
362
363
364
365
366
367
368
369

  # Set flags from the flags_core module as "key flags" so they're listed when
  # the '-h' flag is used. Without this line, the flags defined above are
  # only shown in the full `--helpful` help text.
  flags.adopt_module_key_flags(flags_core)

  # Add transformer-specific flags
  flags.DEFINE_enum(
      name="param_set", short_name="mp", default="big",
370
      enum_values=["base", "big", "tiny"],
371
372
373
374
375
376
377
      help=flags_core.help_wrap(
          "Parameter set to use when creating and training the model. The "
          "parameters define the input shape (batch size and max length), "
          "model configuration (size of embedding, # of hidden layers, etc.), "
          "and various other settings. The big parameter set increases the "
          "default batch size, embedding/hidden size, and filter size. For a "
          "complete list of parameters, please see model/model_params.py."))
Katherine Wu's avatar
Katherine Wu committed
378

379
380
381
382
383
384
385
386
387
388
  flags.DEFINE_bool(
      name="static_batch", default=False,
      help=flags_core.help_wrap(
          "Whether the batches in the dataset should have static shapes. In "
          "general, this setting should be False. Dynamic shapes allow the "
          "inputs to be grouped so that the number of padding tokens is "
          "minimized, and helps model training. In cases where the input shape "
          "must be static (e.g. running on TPU), this setting will be ignored "
          "and static batching will always be used."))

389
390
391
392
393
394
395
396
397
  # Flags for training with steps (may be used for debugging)
  flags.DEFINE_integer(
      name="train_steps", short_name="ts", default=None,
      help=flags_core.help_wrap("The number of steps used to train."))
  flags.DEFINE_integer(
      name="steps_between_evals", short_name="sbe", default=1000,
      help=flags_core.help_wrap(
          "The Number of training steps to run between evaluations. This is "
          "used if --train_steps is defined."))
Katherine Wu's avatar
Katherine Wu committed
398

399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
  # BLEU score computation
  flags.DEFINE_string(
      name="bleu_source", short_name="bls", default=None,
      help=flags_core.help_wrap(
          "Path to source file containing text translate when calculating the "
          "official BLEU score. --bleu_source, --bleu_ref, and --vocab_file "
          "must be set. Use the flag --stop_threshold to stop the script based "
          "on the uncased BLEU score."))
  flags.DEFINE_string(
      name="bleu_ref", short_name="blr", default=None,
      help=flags_core.help_wrap(
          "Path to source file containing text translate when calculating the "
          "official BLEU score. --bleu_source, --bleu_ref, and --vocab_file "
          "must be set. Use the flag --stop_threshold to stop the script based "
          "on the uncased BLEU score."))
  flags.DEFINE_string(
      name="vocab_file", short_name="vf", default=VOCAB_FILE,
      help=flags_core.help_wrap(
          "Name of vocabulary file containing subtokens for subtokenizing the "
          "bleu_source file. This file is expected to be in the directory "
          "defined by --data_dir."))

  flags_core.set_defaults(data_dir="/tmp/translate_ende",
                          model_dir="/tmp/transformer_model",
                          batch_size=None,
                          train_epochs=None)

  @flags.multi_flags_validator(
      ["train_epochs", "train_steps"],
      message="Both --train_steps and --train_epochs were set. Only one may be "
              "defined.")
  def _check_train_limits(flag_dict):
    return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None

  @flags.multi_flags_validator(
      ["data_dir", "bleu_source", "bleu_ref", "vocab_file"],
      message="--bleu_source, --bleu_ref, and/or --vocab_file don't exist. "
              "Please ensure that the file paths are correct.")
  def _check_bleu_files(flags_dict):
    """Validate files when bleu_source and bleu_ref are defined."""
    if flags_dict["bleu_source"] is None or flags_dict["bleu_ref"] is None:
      return True
    # Ensure that bleu_source, bleu_ref, and vocab files exist.
    vocab_file_path = os.path.join(
        flags_dict["data_dir"], flags_dict["vocab_file"])
    return all([
        tf.gfile.Exists(flags_dict["bleu_source"]),
        tf.gfile.Exists(flags_dict["bleu_ref"]),
        tf.gfile.Exists(vocab_file_path)])

449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
  flags_core.require_cloud_storage(["data_dir", "model_dir"])


def construct_estimator(flags_obj, params, schedule_manager):
  """Construct an estimator from either Estimator or TPUEstimator.

  Args:
    flags_obj: The FLAGS object parsed from command line.
    params: A dict of run specific parameters.
    schedule_manager: A schedule.Manager object containing the run schedule.

  Returns:
    An estimator object to be used for training and eval.
  """
  if not params["use_tpu"]:
    return tf.estimator.Estimator(
        model_fn=model_fn, model_dir=flags_obj.model_dir, params=params)

  tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
      tpu=flags_obj.tpu,
      zone=flags_obj.tpu_zone,
      project=flags_obj.tpu_gcp_project
  )

  tpu_config = tf.contrib.tpu.TPUConfig(
      iterations_per_loop=schedule_manager.single_iteration_train_steps,
      num_shards=flags_obj.num_tpu_shards)

  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=flags_obj.model_dir,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=True),
      tpu_config=tpu_config)

  return tf.contrib.tpu.TPUEstimator(
      model_fn=model_fn,
      use_tpu=params["use_tpu"] and flags_obj.tpu != tpu_util.LOCAL,
      train_batch_size=schedule_manager.batch_size,
      eval_batch_size=schedule_manager.batch_size,
      params={
          # TPUEstimator needs to populate batch_size itself due to sharding.
          key: value for key, value in params.items() if key != "batch_size"},
      config=run_config)

494
495
496
497
498
499
500

def run_transformer(flags_obj):
  """Create tf.Estimator to train and evaluate transformer model.

  Args:
    flags_obj: Object containing parsed flag values.
  """
Katherine Wu's avatar
Katherine Wu committed
501
  # Add flag-defined parameters to params object
502
  params = PARAMS_MAP[flags_obj.param_set]
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
  params["data_dir"] = flags_obj.data_dir
  params["model_dir"] = flags_obj.model_dir
  params["num_parallel_calls"] = flags_obj.num_parallel_calls

  params["tpu"] = flags_obj.tpu
  params["use_tpu"] = bool(flags_obj.tpu)  # was a tpu specified.
  params["batch_size"] = flags_obj.batch_size or (
      params["default_batch_size_tpu"] if params["use_tpu"]
      else params["default_batch_size"])
  params["static_batch"] = flags_obj.static_batch or params["use_tpu"]
  params["allow_ffn_pad"] = not params["use_tpu"]

  schedule_manager = schedule.Manager(
      train_steps=flags_obj.train_steps,
      steps_between_evals=flags_obj.steps_between_evals,
      train_epochs=flags_obj.train_epochs,
      epochs_between_evals=flags_obj.epochs_between_evals,
      default_train_epochs=DEFAULT_TRAIN_EPOCHS,
      batch_size=params["batch_size"],
      max_length=params["max_length"],
      use_tpu=params["use_tpu"],
      num_tpu_shards=flags_obj.num_tpu_shards
  )

  params["repeat_dataset"] = schedule_manager.repeat_dataset
528
529
530
531
532

  # Create hooks that log information about the training and metric values
  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      tensors_to_log=TENSORS_TO_LOG,  # used for logging hooks
533
534
      batch_size=schedule_manager.batch_size,  # for ExamplesPerSecondHook
      use_tpu=params["use_tpu"]  # Not all hooks can run with TPUs
535
  )
536
  benchmark_logger = logger.get_benchmark_logger()
537
538
539
  benchmark_logger.log_run_info(
      model_name="transformer",
      dataset_name="wmt_translate_ende",
540
      run_params=params,
541
      test_id=flags_obj.benchmark_test_id)
542
543

  # Train and evaluate transformer model
544
545
  estimator = construct_estimator(flags_obj, params, schedule_manager)
  run_loop(
546
547
      estimator=estimator,
      # Training arguments
548
      schedule_manager=schedule_manager,
549
550
551
552
553
554
555
      train_hooks=train_hooks,
      benchmark_logger=benchmark_logger,
      # BLEU calculation arguments
      bleu_source=flags_obj.bleu_source,
      bleu_ref=flags_obj.bleu_ref,
      bleu_threshold=flags_obj.stop_threshold,
      vocab_file_path=os.path.join(flags_obj.data_dir, flags_obj.vocab_file))
Katherine Wu's avatar
Katherine Wu committed
556
557


558
def main(_):
559
560
  with logger.benchmark_context(flags.FLAGS):
    run_transformer(flags.FLAGS)
Katherine Wu's avatar
Katherine Wu committed
561
562


563
564
565
566
if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)
  define_transformer_flags()
  absl_app.run(main)