bert_pretrain_benchmark.py 20 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes benchmark testing for bert pretraining."""
# pylint: disable=line-too-long
import json
import os
import time
from typing import Optional

from absl import flags
from absl import logging
import tensorflow as tf

from official.benchmark import benchmark_wrappers
from official.benchmark import bert_benchmark_utils
from official.benchmark import owner_utils
from official.common import distribute_utils
from official.legacy.bert import run_pretraining
from official.utils.flags import core as flags_core

# Pretrain masked lanauge modeling accuracy range:
MIN_MLM_ACCURACY = 0.635
MAX_MLM_ACCURACY = 0.645

# Pretrain next sentence prediction accuracy range:
MIN_NSP_ACCURACY = 0.94
MAX_NSP_ACCURACY = 0.96


# Pretrain masked lanauge modeling accuracy range:
MIN_MLM_ACCURACY_GPU = 0.378
MAX_MLM_ACCURACY_GPU = 0.388

# Pretrain next sentence prediction accuracy range:
MIN_NSP_ACCURACY_GPU = 0.82
MAX_NSP_ACCURACY_GPU = 0.84


BERT_PRETRAIN_FILES_SEQ128 = 'gs://mlcompass-data/bert/pretraining_data/seq_128/wikipedia.tfrecord*,gs://mlcompass-data/bert/pretraining_data/seq_128/books.tfrecord*'
BERT_BASE_CONFIG_FILE = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_config.json'

FLAGS = flags.FLAGS


class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
  """Benchmark accuracy tests for BERT Pretraining."""

  def __init__(self,
               output_dir: Optional[str] = None,
               tpu: Optional[str] = None,
               **kwargs):
    """Inits BertPretrainAccuracyBenchmark class.

    Args:
      output_dir: Directory where to output e.g. log files
      tpu: TPU name to use in a TPU benchmark.
      **kwargs: Additional keyword arguments.
    """
    super(BertPretrainAccuracyBenchmark, self).__init__(
        output_dir=output_dir, tpu=tpu, **kwargs)

  def _get_distribution_strategy(self, ds_type='mirrored'):
    """Gets the distribution strategy.

    Args:
      ds_type: String, the distribution strategy type to be used. Can be
        'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.

    Returns:
      A `tf.distribute.DistibutionStrategy` object.
    """
    if self.tpu or ds_type == 'tpu':
      return distribute_utils.get_distribution_strategy(
          distribution_strategy='tpu', tpu_address=self.tpu)
    elif ds_type == 'multi_worker_mirrored':
      # Configures cluster spec for multi-worker distribution strategy.
      _ = distribute_utils.configure_cluster(FLAGS.worker_hosts,
                                             FLAGS.task_index)
    return distribute_utils.get_distribution_strategy(
        distribution_strategy=ds_type,
        num_gpus=FLAGS.num_gpus,
        all_reduce_alg=FLAGS.all_reduce_alg)

  @benchmark_wrappers.enable_runtime_flags
  def _run_and_report_benchmark(self, summary_path: str, report_accuracy: bool,
                                ds_type: str):
    """Runs and reports the benchmark given the provided configuration."""
    distribution = self._get_distribution_strategy(ds_type=ds_type)
    logging.info('Flags: %s', flags_core.get_nondefault_flags_as_str())
    start_time_sec = time.time()
    run_pretraining.run_bert_pretrain(
        strategy=distribution, custom_callbacks=self.timer_callback)
    wall_time_sec = time.time() - start_time_sec

    # For GPU multi-worker, the summary text file is only generated on chief
    # (metrics aggregated), so only chief has to report the result.
    if tf.io.gfile.exists(summary_path):
      with tf.io.gfile.GFile(summary_path, 'rb') as reader:
        summary = json.loads(reader.read().decode('utf-8'))
      self._report_benchmark(summary, start_time_sec, wall_time_sec,
                             report_accuracy, ds_type)

  def _report_benchmark(self, summary, start_time_sec, wall_time_sec,
                        report_accuracy, ds_type):
    metrics = [{
        'name': 'train_loss',
        'value': summary['train_loss'],
    }, {
        'name':
            'exp_per_second',
        'value':
            self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
                                                     FLAGS.steps_per_loop)
    }, {
        'name': 'startup_time',
        'value': self.timer_callback.get_startup_time(start_time_sec)
    }]
    if report_accuracy:
      if ds_type == 'tpu':
        min_mlm_acc = MIN_MLM_ACCURACY
        max_mlm_acc = MAX_MLM_ACCURACY
        min_nsp_acc = MIN_NSP_ACCURACY
        max_nsp_acc = MAX_NSP_ACCURACY
      else:
        min_mlm_acc = MIN_MLM_ACCURACY_GPU
        max_mlm_acc = MAX_MLM_ACCURACY_GPU
        min_nsp_acc = MIN_NSP_ACCURACY_GPU
        max_nsp_acc = MAX_NSP_ACCURACY_GPU
      metrics.extend([{
          'name': 'masked_lm_accuracy',
          'value': summary['masked_lm_accuracy'],
          'min_value': min_mlm_acc,
          'max_value': max_mlm_acc,
      }, {
          'name': 'next_sentence_accuracy',
          'value': summary['next_sentence_accuracy'],
          'min_value': min_nsp_acc,
          'max_value': max_nsp_acc,
      }])
    self.report_benchmark(
        iters=summary['total_training_steps'],
        wall_time=wall_time_sec,
        metrics=metrics,
        extras={'flags': flags_core.get_nondefault_flags_as_str()})

  def _specify_common_flags(self):
    FLAGS.bert_config_file = BERT_BASE_CONFIG_FILE
    FLAGS.learning_rate = 1e-4
    FLAGS.warmup_steps = 10000
    FLAGS.steps_per_loop = 10000
    FLAGS.input_files = BERT_PRETRAIN_FILES_SEQ128
    FLAGS.max_seq_length = 128
    FLAGS.max_predictions_per_seq = 20

  def _specify_tpu_common_flags(self):
    FLAGS.distribution_strategy = 'tpu'
    FLAGS.dtype = 'bf16'

  def _specify_gpu_common_flags(self):
    FLAGS.distribution_strategy = 'mirrored'
    FLAGS.dtype = 'fp16'
    FLAGS.loss_scale = 'dynamic'

  @owner_utils.Owner('tf-model-garden')
  def benchmark_accuracy_8x8_tpu_bf16_seq128_500k_steps(self):
    """Test bert pretraining with 8x8 TPU for 500k steps."""
    # This is used for accuracy test.
    self._setup()
    self._specify_common_flags()
    self._specify_tpu_common_flags()
    FLAGS.train_batch_size = 512
    FLAGS.num_steps_per_epoch = 500000
    FLAGS.num_train_epochs = 1
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_accuracy_8x8_tpu_bf16_seq128_500k_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Set train_summary_interval to -1 to disable training summary, because
    # writing summary to gcs may fail and summaries are not needed for this
    # accuracy benchmark test.
    FLAGS.train_summary_interval = -1
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=True,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-model-garden')
  def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps(self):
    """Test bert pretraining with 2x2 TPU for 10000 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_tpu_common_flags()
    FLAGS.num_steps_per_epoch = 5000
    FLAGS.num_train_epochs = 2
    FLAGS.train_batch_size = 128
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-model-garden')
  def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir(self):
    """Test bert pretraining with 2x2 TPU with MLIR for 10000 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_tpu_common_flags()
    FLAGS.num_steps_per_epoch = 5000
    FLAGS.num_train_epochs = 2
    FLAGS.train_batch_size = 128
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    tf.config.experimental.enable_mlir_bridge()
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-model-garden')
  def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self):
    """Test bert pretraining with 4x4 TPU for 10000 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_tpu_common_flags()
    FLAGS.train_batch_size = 512
    FLAGS.num_steps_per_epoch = 5000
    FLAGS.num_train_epochs = 2
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_4x4_tpu_bf16_seq128_10k_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-model-garden')
  def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir(self):
    """Test bert pretraining with 4x4 TPU with MLIR for 10000 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_tpu_common_flags()
    FLAGS.train_batch_size = 512
    FLAGS.num_steps_per_epoch = 5000
    FLAGS.num_train_epochs = 2
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    tf.config.experimental.enable_mlir_bridge()
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-model-garden')
  def benchmark_perf_4x4_tpu_bf16_seq128_1k_steps(self):
    """Test bert pretraining with 4x4 TPU for 1000 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_tpu_common_flags()
    FLAGS.train_batch_size = 512
    FLAGS.warmup_steps = 0
    FLAGS.num_steps_per_epoch = 1000
    FLAGS.num_train_epochs = 1
    FLAGS.steps_per_loop = 500
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_4x4_tpu_bf16_seq128_1k_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-model-garden')
  def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self):
    """Test bert pretraining with 8x8 TPU for 10000 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_tpu_common_flags()
    FLAGS.train_batch_size = 512
    FLAGS.num_steps_per_epoch = 5000
    FLAGS.num_train_epochs = 2
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_8x8_tpu_bf16_seq128_10k_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-model-garden')
  def benchmark_perf_8x16_tpu_bf16_seq128_1k_steps(self):
    """Test bert pretraining with 8x16 TPU for 1000 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_tpu_common_flags()
    FLAGS.train_batch_size = 4096
    FLAGS.warmup_steps = 0
    FLAGS.num_steps_per_epoch = 1000
    FLAGS.num_train_epochs = 1
    FLAGS.steps_per_loop = 500
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_8x16_tpu_bf16_seq128_1k_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-dist-strat')
  def benchmark_accuracy_1x8_gpu_fp16_seq128_15k_steps(self):
    """Test bert pretraining with 8 GPU for 15k steps."""
    # This is used for accuracy test.
    self._setup()
    self._specify_common_flags()
    self._specify_gpu_common_flags()
    FLAGS.num_gpus = 8
    FLAGS.train_batch_size = 96
    FLAGS.num_steps_per_epoch = 5000
    FLAGS.num_train_epochs = 3
    FLAGS.steps_per_loop = 5000
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_accuracy_1x8_gpu_fp16_seq128_15k_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Set train_summary_interval to -1 to disable training summary, because
    # writing summary to gcs may fail and summaries are not needed for this
    # accuracy benchmark test.
    FLAGS.train_summary_interval = -1
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=True,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-dist-strat')
  def benchmark_perf_1x1_gpu_fp16_seq128_200_steps(self):
    """Test bert pretraining with 1 GPU for 200 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_gpu_common_flags()
    FLAGS.num_steps_per_epoch = 200
    FLAGS.num_train_epochs = 1
    FLAGS.num_gpus = 1
    FLAGS.train_batch_size = 12
    FLAGS.steps_per_loop = 100
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_1x1_gpu_fp16_seq128_200_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-dist-strat')
  def benchmark_perf_1x8_gpu_fp16_seq128_200_steps(self):
    """Test bert pretraining with 8 GPU for 200 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_gpu_common_flags()
    FLAGS.num_steps_per_epoch = 200
    FLAGS.num_train_epochs = 1
    FLAGS.num_gpus = 8
    FLAGS.train_batch_size = 96
    FLAGS.steps_per_loop = 100
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_1x8_gpu_fp16_seq128_200_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)


class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark):
  """Bert pretrain distributed benchmark tests with multiple workers."""

  def __init__(self, output_dir=None, tpu=None, **kwargs):
    super(BertPretrainMultiWorkerBenchmark, self).__init__(
        output_dir=output_dir, tpu=tpu, **kwargs)

  def _specify_gpu_mwms_flags(self):
    FLAGS.distribution_strategy = 'multi_worker_mirrored'
    FLAGS.all_reduce_alg = 'nccl'
    FLAGS.dtype = 'fp16'
    FLAGS.loss_scale = 'dynamic'
    FLAGS.num_gpus = 8

  @owner_utils.Owner('tf-dist-strat')
  def benchmark_accuracy_mwms_1x8_gpu_fp16_seq128_15k_steps(self):
    """Test bert pretraining with 8 GPU for 15k steps."""
    # This is used for accuracy test.
    self._setup()
    self._specify_common_flags()
    self._specify_gpu_mwms_flags()
    FLAGS.train_batch_size = 96
    FLAGS.num_steps_per_epoch = 5000
    FLAGS.num_train_epochs = 3
    FLAGS.steps_per_loop = 5000
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_accuracy_mwms_1x8_gpu_fp16_seq128_15k_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Set train_summary_interval to -1 to disable training summary, because
    # writing summary to gcs may fail and summaries are not needed for this
    # accuracy benchmark test.
    FLAGS.train_summary_interval = -1
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=True,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-dist-strat')
  def benchmark_accuracy_mwms_2x8_gpu_fp16_seq128_15k_steps(self):
    """Test bert pretraining with 2x8 GPU for 15k steps."""
    # This is used for accuracy test.
    self._setup()
    self._specify_common_flags()
    self._specify_gpu_mwms_flags()
    # ues the same global batch size as accuracy_mwms_1x8 benchmark.
    FLAGS.train_batch_size = 96
    FLAGS.num_steps_per_epoch = 5000
    FLAGS.num_train_epochs = 3
    FLAGS.steps_per_loop = 5000
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_accuracy_mwms_2x8_gpu_fp16_seq128_15k_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Set train_summary_interval to -1 to disable training summary, because
    # writing summary to gcs may fail and summaries are not needed for this
    # accuracy benchmark test.
    FLAGS.train_summary_interval = -1
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=True,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-dist-strat')
  def benchmark_perf_mwms_1x8_gpu_fp16_seq128_200_steps(self):
    """Test bert pretraining with 1x8 GPU for 200 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_gpu_mwms_flags()
    FLAGS.num_steps_per_epoch = 200
    FLAGS.num_train_epochs = 1
    FLAGS.train_batch_size = 96 * 1
    FLAGS.steps_per_loop = 100
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_mwms_1x8_gpu_fp16_seq128_200_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-dist-strat')
  def benchmark_perf_mwms_2x8_gpu_fp16_seq128_200_steps(self):
    """Test bert pretraining with 2x8 GPU for 200 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_gpu_mwms_flags()
    FLAGS.num_steps_per_epoch = 200
    FLAGS.num_train_epochs = 1
    FLAGS.train_batch_size = 96 * 2
    FLAGS.steps_per_loop = 100
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_mwms_2x8_gpu_fp16_seq128_200_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)

  @owner_utils.Owner('tf-dist-strat')
  def benchmark_perf_mwms_8x8_gpu_fp16_seq128_200_steps(self):
    """Test bert pretraining with 8x8 GPU for 200 steps."""
    self._setup()
    self._specify_common_flags()
    self._specify_gpu_mwms_flags()
    FLAGS.num_steps_per_epoch = 200
    FLAGS.num_train_epochs = 1
    FLAGS.train_batch_size = 96*8
    FLAGS.steps_per_loop = 100
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_perf_mwms_8x8_gpu_fp16_seq128_200_steps')
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
    # Disable accuracy check.
    self._run_and_report_benchmark(
        summary_path=summary_path,
        report_accuracy=False,
        ds_type=FLAGS.distribution_strategy)


if __name__ == '__main__':
  tf.test.main()