transformer_main.py 19.1 KB
Newer Older
Katherine Wu's avatar
Katherine Wu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
16
17
18
19
"""Train and evaluate the Transformer model.

See README for description of setting the training schedule and evaluating the
BLEU score.
"""
Katherine Wu's avatar
Katherine Wu committed
20
21
22
23
24
25
26
27
28
29

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tempfile

# pylint: disable=g-bad-import-order
from six.moves import xrange  # pylint: disable=redefined-builtin
30
31
from absl import app as absl_app
from absl import flags
Katherine Wu's avatar
Katherine Wu committed
32
33
34
35
36
37
38
39
40
41
42
import tensorflow as tf
# pylint: enable=g-bad-import-order

from official.transformer import compute_bleu
from official.transformer import translate
from official.transformer.data_download import VOCAB_FILE
from official.transformer.model import model_params
from official.transformer.model import transformer
from official.transformer.utils import dataset
from official.transformer.utils import metrics
from official.transformer.utils import tokenizer
43
44
45
46
47
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import model_helpers

Katherine Wu's avatar
Katherine Wu committed
48

49
50
51
52
PARAMS_MAP = {
    "base": model_params.TransformerBaseParams,
    "big": model_params.TransformerBigParams,
}
Katherine Wu's avatar
Katherine Wu committed
53
54
55
56
DEFAULT_TRAIN_EPOCHS = 10
BLEU_DIR = "bleu"
INF = int(1e9)

57
58
59
60
61
62
# Dictionary containing tensors that are logged by the logging hooks. Each item
# maps a string to the tensor name.
TENSORS_TO_LOG = {
    "learning_rate": "model/get_train_op/learning_rate/learning_rate",
    "cross_entropy_loss": "model/cross_entropy"}

Katherine Wu's avatar
Katherine Wu committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

def model_fn(features, labels, mode, params):
  """Defines how to train, evaluate and predict from the transformer model."""
  with tf.variable_scope("model"):
    inputs, targets = features, labels

    # Create model and get output logits.
    model = transformer.Transformer(params, mode == tf.estimator.ModeKeys.TRAIN)

    output = model(inputs, targets)

    # When in prediction mode, the labels/targets is None. The model output
    # is the prediction
    if mode == tf.estimator.ModeKeys.PREDICT:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.PREDICT,
          predictions=output)

    logits = output

    # Calculate model loss.
Katherine Wu's avatar
Katherine Wu committed
84
85
    # xentropy contains the cross entropy loss of every nonpadding token in the
    # targets.
Katherine Wu's avatar
Katherine Wu committed
86
87
    xentropy, weights = metrics.padded_cross_entropy_loss(
        logits, targets, params.label_smoothing, params.vocab_size)
Katherine Wu's avatar
Katherine Wu committed
88
89
    # Compute the weighted mean of the cross entropy losses
    loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
Katherine Wu's avatar
Katherine Wu committed
90

91
92
93
    # Save loss as named tensor that will be logged with the logging hook.
    tf.identity(loss, "cross_entropy")

Katherine Wu's avatar
Katherine Wu committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    if mode == tf.estimator.ModeKeys.EVAL:
      return tf.estimator.EstimatorSpec(
          mode=mode, loss=loss, predictions={"predictions": logits},
          eval_metric_ops=metrics.get_eval_metrics(logits, labels, params))
    else:
      train_op = get_train_op(loss, params)
      return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)


def get_learning_rate(learning_rate, hidden_size, learning_rate_warmup_steps):
  """Calculate learning rate with linear warmup and rsqrt decay."""
  with tf.name_scope("learning_rate"):
    warmup_steps = tf.to_float(learning_rate_warmup_steps)
    step = tf.to_float(tf.train.get_or_create_global_step())

    learning_rate *= (hidden_size ** -0.5)
    # Apply linear warmup
    learning_rate *= tf.minimum(1.0, step / warmup_steps)
    # Apply rsqrt decay
    learning_rate *= tf.rsqrt(tf.maximum(step, warmup_steps))

115
116
117
118
    # Create a named tensor that will be logged using the logging hook.
    # The full name includes variable and names scope. In this case, the name
    # is model/get_train_op/learning_rate/learning_rate
    tf.identity(learning_rate, "learning_rate")
Katherine Wu's avatar
Katherine Wu committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    # Save learning rate value to TensorBoard summary.
    tf.summary.scalar("learning_rate", learning_rate)

    return learning_rate


def get_train_op(loss, params):
  """Generate training operation that updates variables based on loss."""
  with tf.variable_scope("get_train_op"):
    learning_rate = get_learning_rate(
        params.learning_rate, params.hidden_size,
        params.learning_rate_warmup_steps)

    # Create optimizer. Use LazyAdamOptimizer from TF contrib, which is faster
    # than the TF core Adam optimizer.
    optimizer = tf.contrib.opt.LazyAdamOptimizer(
        learning_rate,
        beta1=params.optimizer_adam_beta1,
        beta2=params.optimizer_adam_beta2,
        epsilon=params.optimizer_adam_epsilon)

    # Calculate and apply gradients using LazyAdamOptimizer.
    global_step = tf.train.get_global_step()
    tvars = tf.trainable_variables()
    gradients = optimizer.compute_gradients(
        loss, tvars, colocate_gradients_with_ops=True)
    train_op = optimizer.apply_gradients(
        gradients, global_step=global_step, name="train")

    # Save gradient norm to Tensorboard
    tf.summary.scalar("global_norm/gradient_norm",
                      tf.global_norm(list(zip(*gradients))[0]))

    return train_op


def translate_and_compute_bleu(estimator, subtokenizer, bleu_source, bleu_ref):
  """Translate file and report the cased and uncased bleu scores."""
  # Create temporary file to store translation.
  tmp = tempfile.NamedTemporaryFile(delete=False)
  tmp_filename = tmp.name

  translate.translate_file(
      estimator, subtokenizer, bleu_source, output_file=tmp_filename,
      print_all_translations=False)

  # Compute uncased and cased bleu scores.
  uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False)
  cased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, True)
  os.remove(tmp_filename)
  return uncased_score, cased_score


def get_global_step(estimator):
  """Return estimator's last checkpoint."""
  return int(estimator.latest_checkpoint().split("-")[-1])


177
def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file_path):
Katherine Wu's avatar
Katherine Wu committed
178
  """Calculate and record the BLEU score."""
179
  subtokenizer = tokenizer.Subtokenizer(vocab_file_path)
Katherine Wu's avatar
Katherine Wu committed
180
181
182
183

  uncased_score, cased_score = translate_and_compute_bleu(
      estimator, subtokenizer, bleu_source, bleu_ref)

184
185
  tf.logging.info("Bleu score (uncased):", uncased_score)
  tf.logging.info("Bleu score (cased):", cased_score)
Katherine Wu's avatar
Katherine Wu committed
186
187
188
189
190
  return uncased_score, cased_score


def train_schedule(
    estimator, train_eval_iterations, single_iteration_train_steps=None,
191
192
    single_iteration_train_epochs=None, train_hooks=None, benchmark_logger=None,
    bleu_source=None, bleu_ref=None, bleu_threshold=None, vocab_file_path=None):
Katherine Wu's avatar
Katherine Wu committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
  """Train and evaluate model, and optionally compute model's BLEU score.

  **Step vs. Epoch vs. Iteration**

  Steps and epochs are canonical terms used in TensorFlow and general machine
  learning. They are used to describe running a single process (train/eval):
    - Step refers to running the process through a single or batch of examples.
    - Epoch refers to running the process through an entire dataset.

  E.g. training a dataset with 100 examples. The dataset is
  divided into 20 batches with 5 examples per batch. A single training step
  trains the model on one batch. After 20 training steps, the model will have
  trained on every batch in the dataset, or, in other words, one epoch.

  Meanwhile, iteration is used in this implementation to describe running
  multiple processes (training and eval).
    - A single iteration:
      1. trains the model for a specific number of steps or epochs.
      2. evaluates the model.
      3. (if source and ref files are provided) compute BLEU score.

  This function runs through multiple train+eval+bleu iterations.

  Args:
    estimator: tf.Estimator containing model to train.
    train_eval_iterations: Number of times to repeat the train+eval iteration.
    single_iteration_train_steps: Number of steps to train in one iteration.
    single_iteration_train_epochs: Number of epochs to train in one iteration.
221
222
    train_hooks: List of hooks to pass to the estimator during training.
    benchmark_logger: a BenchmarkLogger object that logs evaluation data
Katherine Wu's avatar
Katherine Wu committed
223
224
225
    bleu_source: File containing text to be translated for BLEU calculation.
    bleu_ref: File containing reference translations for BLEU calculation.
    bleu_threshold: minimum BLEU score before training is stopped.
226
    vocab_file_path: Path to vocabulary file used to subtokenize bleu_source.
Katherine Wu's avatar
Katherine Wu committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

  Raises:
    ValueError: if both or none of single_iteration_train_steps and
      single_iteration_train_epochs were defined.
  """
  # Ensure that exactly one of single_iteration_train_steps and
  # single_iteration_train_epochs is defined.
  if single_iteration_train_steps is None:
    if single_iteration_train_epochs is None:
      raise ValueError(
          "Exactly one of single_iteration_train_steps or "
          "single_iteration_train_epochs must be defined. Both were none.")
  else:
    if single_iteration_train_epochs is not None:
      raise ValueError(
          "Exactly one of single_iteration_train_steps or "
          "single_iteration_train_epochs must be defined. Both were defined.")

  evaluate_bleu = bleu_source is not None and bleu_ref is not None

247
248
  # Print details of training schedule.
  tf.logging.info("Training schedule:")
Katherine Wu's avatar
Katherine Wu committed
249
  if single_iteration_train_epochs is not None:
250
    tf.logging.info("\t1. Train for %d epochs." % single_iteration_train_epochs)
Katherine Wu's avatar
Katherine Wu committed
251
  else:
252
253
    tf.logging.info("\t1. Train for %d steps." % single_iteration_train_steps)
  tf.logging.info("\t2. Evaluate model.")
Katherine Wu's avatar
Katherine Wu committed
254
  if evaluate_bleu:
255
    tf.logging.info("\t3. Compute BLEU score.")
Katherine Wu's avatar
Katherine Wu committed
256
    if bleu_threshold is not None:
257
258
      tf.logging.info("Repeat above steps until the BLEU score reaches %f" %
                      bleu_threshold)
Katherine Wu's avatar
Katherine Wu committed
259
  if not evaluate_bleu or bleu_threshold is None:
260
    tf.logging.info("Repeat above steps %d times." % train_eval_iterations)
Katherine Wu's avatar
Katherine Wu committed
261
262

  if evaluate_bleu:
263
264
    # Create summary writer to log bleu score (values can be displayed in
    # Tensorboard).
Katherine Wu's avatar
Katherine Wu committed
265
266
267
268
269
270
271
272
    bleu_writer = tf.summary.FileWriter(
        os.path.join(estimator.model_dir, BLEU_DIR))
    if bleu_threshold is not None:
      # Change loop stopping condition if bleu_threshold is defined.
      train_eval_iterations = INF

  # Loop training/evaluation/bleu cycles
  for i in xrange(train_eval_iterations):
273
    tf.logging.info("Starting iteration %d" % (i + 1))
Katherine Wu's avatar
Katherine Wu committed
274
275
276

    # Train the model for single_iteration_train_steps or until the input fn
    # runs out of examples (if single_iteration_train_steps is None).
277
278
279
    estimator.train(
        dataset.train_input_fn, steps=single_iteration_train_steps,
        hooks=train_hooks)
Katherine Wu's avatar
Katherine Wu committed
280
281

    eval_results = estimator.evaluate(dataset.eval_input_fn)
282
283
284
285
286
287
288
289
290
291
    tf.logging.info("Evaluation results (iter %d/%d):" %
                    (i + 1, train_eval_iterations))
    tf.logging.info(eval_results)
    benchmark_logger.log_evaluation_result(eval_results)

    # The results from estimator.evaluate() are measured on an approximate
    # translation, which utilize the target golden values provided. The actual
    # bleu score must be computed using the estimator.predict() path, which
    # outputs translations that are not based on golden values. The translations
    # are compared to reference file to get the actual bleu score.
Katherine Wu's avatar
Katherine Wu committed
292
    if evaluate_bleu:
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
      uncased_score, cased_score = evaluate_and_log_bleu(
          estimator, bleu_source, bleu_ref, vocab_file_path)

      # Write actual bleu scores using summary writer and benchmark logger
      global_step = get_global_step(estimator)
      summary = tf.Summary(value=[
          tf.Summary.Value(tag="bleu/uncased", simple_value=uncased_score),
          tf.Summary.Value(tag="bleu/cased", simple_value=cased_score),
      ])
      bleu_writer.add_summary(summary, global_step)
      bleu_writer.flush()
      benchmark_logger.log_metric(
          "bleu_uncased", uncased_score, global_step=global_step)
      benchmark_logger.log_metric(
          "bleu_cased", cased_score, global_step=global_step)

      # Stop training if bleu stopping threshold is met.
      if model_helpers.past_stop_threshold(bleu_threshold, uncased_score):
Katherine Wu's avatar
Katherine Wu committed
311
312
313
314
        bleu_writer.close()
        break


315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def define_transformer_flags():
  """Add flags and flag validators for running transformer_main."""
  # Add common flags (data_dir, model_dir, train_epochs, etc.).
  flags_core.define_base(multi_gpu=False, num_gpu=False, export_dir=False)
  flags_core.define_performance(
      num_parallel_calls=True,
      inter_op=False,
      intra_op=False,
      synthetic_data=False,
      max_train_steps=False,
      dtype=False
  )
  flags_core.define_benchmark()

  # Set flags from the flags_core module as "key flags" so they're listed when
  # the '-h' flag is used. Without this line, the flags defined above are
  # only shown in the full `--helpful` help text.
  flags.adopt_module_key_flags(flags_core)

  # Add transformer-specific flags
  flags.DEFINE_enum(
      name="param_set", short_name="mp", default="big",
      enum_values=["base", "big"],
      help=flags_core.help_wrap(
          "Parameter set to use when creating and training the model. The "
          "parameters define the input shape (batch size and max length), "
          "model configuration (size of embedding, # of hidden layers, etc.), "
          "and various other settings. The big parameter set increases the "
          "default batch size, embedding/hidden size, and filter size. For a "
          "complete list of parameters, please see model/model_params.py."))
Katherine Wu's avatar
Katherine Wu committed
345

346
347
348
349
350
351
352
353
354
  # Flags for training with steps (may be used for debugging)
  flags.DEFINE_integer(
      name="train_steps", short_name="ts", default=None,
      help=flags_core.help_wrap("The number of steps used to train."))
  flags.DEFINE_integer(
      name="steps_between_evals", short_name="sbe", default=1000,
      help=flags_core.help_wrap(
          "The Number of training steps to run between evaluations. This is "
          "used if --train_steps is defined."))
Katherine Wu's avatar
Katherine Wu committed
355

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
  # BLEU score computation
  flags.DEFINE_string(
      name="bleu_source", short_name="bls", default=None,
      help=flags_core.help_wrap(
          "Path to source file containing text translate when calculating the "
          "official BLEU score. --bleu_source, --bleu_ref, and --vocab_file "
          "must be set. Use the flag --stop_threshold to stop the script based "
          "on the uncased BLEU score."))
  flags.DEFINE_string(
      name="bleu_ref", short_name="blr", default=None,
      help=flags_core.help_wrap(
          "Path to source file containing text translate when calculating the "
          "official BLEU score. --bleu_source, --bleu_ref, and --vocab_file "
          "must be set. Use the flag --stop_threshold to stop the script based "
          "on the uncased BLEU score."))
  flags.DEFINE_string(
      name="vocab_file", short_name="vf", default=VOCAB_FILE,
      help=flags_core.help_wrap(
          "Name of vocabulary file containing subtokens for subtokenizing the "
          "bleu_source file. This file is expected to be in the directory "
          "defined by --data_dir."))

  flags_core.set_defaults(data_dir="/tmp/translate_ende",
                          model_dir="/tmp/transformer_model",
                          batch_size=None,
                          train_epochs=None)

  @flags.multi_flags_validator(
      ["train_epochs", "train_steps"],
      message="Both --train_steps and --train_epochs were set. Only one may be "
              "defined.")
  def _check_train_limits(flag_dict):
    return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None

  @flags.multi_flags_validator(
      ["data_dir", "bleu_source", "bleu_ref", "vocab_file"],
      message="--bleu_source, --bleu_ref, and/or --vocab_file don't exist. "
              "Please ensure that the file paths are correct.")
  def _check_bleu_files(flags_dict):
    """Validate files when bleu_source and bleu_ref are defined."""
    if flags_dict["bleu_source"] is None or flags_dict["bleu_ref"] is None:
      return True
    # Ensure that bleu_source, bleu_ref, and vocab files exist.
    vocab_file_path = os.path.join(
        flags_dict["data_dir"], flags_dict["vocab_file"])
    return all([
        tf.gfile.Exists(flags_dict["bleu_source"]),
        tf.gfile.Exists(flags_dict["bleu_ref"]),
        tf.gfile.Exists(vocab_file_path)])


def run_transformer(flags_obj):
  """Create tf.Estimator to train and evaluate transformer model.

  Args:
    flags_obj: Object containing parsed flag values.
  """
Katherine Wu's avatar
Katherine Wu committed
413
  # Determine training schedule based on flags.
414
415
416
417
  if flags_obj.train_steps is not None:
    train_eval_iterations = (
        flags_obj.train_steps // flags_obj.steps_between_evals)
    single_iteration_train_steps = flags_obj.steps_between_evals
Katherine Wu's avatar
Katherine Wu committed
418
419
    single_iteration_train_epochs = None
  else:
420
421
    train_epochs = flags_obj.train_epochs or DEFAULT_TRAIN_EPOCHS
    train_eval_iterations = train_epochs // flags_obj.epochs_between_evals
Katherine Wu's avatar
Katherine Wu committed
422
    single_iteration_train_steps = None
423
    single_iteration_train_epochs = flags_obj.epochs_between_evals
Katherine Wu's avatar
Katherine Wu committed
424
425

  # Add flag-defined parameters to params object
426
427
428
429
  params = PARAMS_MAP[flags_obj.param_set]
  params.data_dir = flags_obj.data_dir
  params.num_parallel_calls = flags_obj.num_parallel_calls
  params.epochs_between_evals = flags_obj.epochs_between_evals
Katherine Wu's avatar
Katherine Wu committed
430
  params.repeat_dataset = single_iteration_train_epochs
431
432
433
434
435
436
437
438
  params.batch_size = flags_obj.batch_size or params.batch_size

  # Create hooks that log information about the training and metric values
  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      tensors_to_log=TENSORS_TO_LOG,  # used for logging hooks
      batch_size=params.batch_size  # for ExamplesPerSecondHook
  )
439
  benchmark_logger = logger.get_benchmark_logger()
440
441
442
  benchmark_logger.log_run_info(
      model_name="transformer",
      dataset_name="wmt_translate_ende",
443
444
      run_params=params.__dict__,
      test_id=flags_obj.benchmark_test_id)
445
446

  # Train and evaluate transformer model
Katherine Wu's avatar
Katherine Wu committed
447
  estimator = tf.estimator.Estimator(
448
      model_fn=model_fn, model_dir=flags_obj.model_dir, params=params)
449

Katherine Wu's avatar
Katherine Wu committed
450
  train_schedule(
451
452
453
454
455
456
457
458
459
460
461
462
      estimator=estimator,
      # Training arguments
      train_eval_iterations=train_eval_iterations,
      single_iteration_train_steps=single_iteration_train_steps,
      single_iteration_train_epochs=single_iteration_train_epochs,
      train_hooks=train_hooks,
      benchmark_logger=benchmark_logger,
      # BLEU calculation arguments
      bleu_source=flags_obj.bleu_source,
      bleu_ref=flags_obj.bleu_ref,
      bleu_threshold=flags_obj.stop_threshold,
      vocab_file_path=os.path.join(flags_obj.data_dir, flags_obj.vocab_file))
Katherine Wu's avatar
Katherine Wu committed
463
464


465
def main(_):
466
467
  with logger.benchmark_context(flags.FLAGS):
    run_transformer(flags.FLAGS)
Katherine Wu's avatar
Katherine Wu committed
468
469


470
471
472
473
if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)
  define_transformer_flags()
  absl_app.run(main)