unet3d_benchmark.py 5.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes benchmark testing for 3D Unet model."""
# pylint: disable=line-too-long
from __future__ import print_function

import functools
import os
import time
from typing import Optional
from absl import flags
import tensorflow as tf  # pylint: disable=g-bad-import-order

from official.benchmark import benchmark_wrappers
from official.benchmark import keras_benchmark
Jing Li's avatar
Jing Li committed
29
from official.benchmark import owner_utils
30
31
32
33
34
from official.vision.segmentation import unet_main as unet_training_lib
from official.vision.segmentation import unet_model as unet_model_lib

UNET3D_MIN_ACCURACY = 0.90
UNET3D_MAX_ACCURACY = 0.98
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
35
36
37
UNET_TRAINING_FILES = 'gs://mlcompass-data/unet3d/train_data/*'
UNET_EVAL_FILES = 'gs://mlcompass-data/unet3d/eval_data/*'
UNET_MODEL_CONFIG_FILE = 'gs://mlcompass-data/unet3d/config/unet_config.yaml'
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

FLAGS = flags.FLAGS


class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
  """Benchmark accuracy tests for UNet3D model in Keras."""

  def __init__(self,
               output_dir: Optional[str] = None,
               root_data_dir: Optional[str] = None,
               **kwargs):
    """A benchmark class.

    Args:
      output_dir: directory where to output e.g. log files
      root_data_dir: directory under which to look for dataset
      **kwargs: arbitrary named arguments. This is needed to make the
        constructor forward compatible in case PerfZero provides more named
        arguments before updating the constructor.
    """

    flag_methods = [unet_training_lib.define_unet3d_flags]

    # UNet3D model in Keras."""
    self.training_file_pattern = UNET_TRAINING_FILES
    self.eval_file_pattern = UNET_EVAL_FILES

    # TODO(hongjunchoi): Create and use shared config file instead.
    self.config_file = UNET_MODEL_CONFIG_FILE
    super(Unet3DAccuracyBenchmark, self).__init__(
        output_dir=output_dir, flag_methods=flag_methods)

  def _set_benchmark_parameters(self, experiment_name):
    """Overrides training parameters for benchmark tests."""
    FLAGS.model_dir = self._get_model_dir(experiment_name)
    FLAGS.mode = 'train'
    FLAGS.training_file_pattern = self.training_file_pattern
    FLAGS.eval_file_pattern = self.eval_file_pattern
    FLAGS.config_file = self.config_file
    FLAGS.lr_init_value = 0.00005
    FLAGS.lr_decay_rate = 0.5
    FLAGS.epochs = 3

  @benchmark_wrappers.enable_runtime_flags
  def _run_and_report_benchmark(self,
                                experiment_name: str,
                                min_accuracy: float = UNET3D_MIN_ACCURACY,
                                max_accuracy: float = UNET3D_MAX_ACCURACY,
                                distribution_strategy: str = 'tpu',
                                epochs: int = 10,
                                steps: int = 0,
                                epochs_between_evals: int = 1,
                                dtype: str = 'float32',
                                enable_xla: bool = False,
                                run_eagerly: bool = False):
    """Runs and reports the benchmark given the provided configuration."""
    params = unet_training_lib.extract_params(FLAGS)
    strategy = unet_training_lib.create_distribution_strategy(params)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
96
97
98
99
100

    input_dtype = params.dtype
    if input_dtype == 'float16' or input_dtype == 'bfloat16':
      policy = tf.keras.mixed_precision.experimental.Policy(
          'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
      tf.keras.mixed_precision.experimental.set_policy(policy)

    stats = {}
    start_time_sec = time.time()
    with strategy.scope():
      unet_model = unet_model_lib.build_unet_model(params)
      history = unet_training_lib.train(
          params, strategy, unet_model,
          functools.partial(unet_training_lib.get_train_dataset, params),
          functools.partial(unet_training_lib.get_eval_dataset, params))

      stats['accuracy_top_1'] = history.history['val_metric_accuracy'][-1]
      stats['training_accuracy_top_1'] = history.history['metric_accuracy'][-1]
    wall_time_sec = time.time() - start_time_sec

    super(Unet3DAccuracyBenchmark, self)._report_benchmark(
        stats,
        wall_time_sec,
        top_1_min=min_accuracy,
        top_1_max=max_accuracy,
        total_batch_size=params.train_batch_size)

  def _get_model_dir(self, folder_name):
    return os.path.join(self.output_dir, folder_name)

Jing Li's avatar
Jing Li committed
126
  @owner_utils.Owner('tf-model-garden')
127
128
129
130
131
132
133
134
135
136
  def benchmark_4x4_tpu_bf16(self):
    """Test Keras model with 4x4 TPU, fp16."""
    experiment_name = 'benchmark_4x4_tpu_fp16'
    self._setup()
    self._set_benchmark_parameters(experiment_name)
    self._run_and_report_benchmark(
        experiment_name=experiment_name,
        dtype='bfloat16',
        distribution_strategy='tpu')

137
  @owner_utils.Owner('tf-graph-compiler')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
138
139
140
141
142
143
144
145
146
147
148
  def benchmark_4x4_tpu_bf16_mlir(self):
    """Test Keras model with 4x4 TPU, fp16 and MLIR enabled."""
    experiment_name = 'benchmark_4x4_tpu_fp16_mlir'
    tf.config.experimental.enable_mlir_bridge()
    self._setup()
    self._set_benchmark_parameters(experiment_name)
    self._run_and_report_benchmark(
        experiment_name=experiment_name,
        dtype='bfloat16',
        distribution_strategy='tpu')

149
150
151

if __name__ == '__main__':
  tf.test.main()