config_definitions.py 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
17

18
from typing import Optional, Sequence, Union
19
20
21
22

import dataclasses

from official.modeling.hyperparams import base_config
23
from official.modeling.optimization.configs import optimization_config
24

25
OptimizationConfig = optimization_config.OptimizationConfig
26
27
28
29
30
31
32


@dataclasses.dataclass
class DataConfig(base_config.Config):
  """The base configuration for building datasets.

  Attributes:
33
34
35
36
37
38
    input_path: The path to the input. It can be either (1) a str indicating
      a file path/pattern, or (2) a str indicating multiple file paths/patterns
      separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or
      (3) a list of str, each of which is a file path/pattern or multiple file
      paths/patterns separated by comma.
      It should not be specified when the following `tfds_name` is specified.
39
40
41
42
    tfds_name: The name of the tensorflow dataset (TFDS). It should not be
      specified when the above `input_path` is specified.
    tfds_split: A str indicating which split of the data to load from TFDS. It
      is required when above `tfds_name` is specified.
43
44
45
46
47
48
49
50
51
    global_batch_size: The global batch size across all replicas.
    is_training: Whether this data is used for training or not.
    drop_remainder: Whether the last batch should be dropped in the case it has
      fewer than `global_batch_size` elements.
    shuffle_buffer_size: The buffer size used for shuffling training data.
    cache: Whether to cache dataset examples. Can be used to avoid re-reading
      from disk on the second epoch. Requires significant memory overhead.
    cycle_length: The number of files that will be processed concurrently when
      interleaving files.
52
53
    block_length: The number of consecutive elements to produce from each input
      element before cycling to another input element when interleaving files.
Ruoxin Sang's avatar
Ruoxin Sang committed
54
    deterministic: A boolean controlling whether determinism should be enforced.
55
    sharding: Whether sharding is used in the input pipeline.
Ruoxin Sang's avatar
Ruoxin Sang committed
56
57
58
59
60
61
62
63
64
65
    enable_tf_data_service: A boolean indicating whether to enable tf.data
      service for the input pipeline.
    tf_data_service_address: The URI of a tf.data service to offload
      preprocessing onto during training. The URI should be in the format
      "protocol://address", e.g. "grpc://tf-data-service:5050". It can be
      overridden by `FLAGS.tf_data_service` flag in the binary.
    tf_data_service_job_name: The name of the tf.data service job. This
      argument makes it possible for multiple datasets to share the same job.
      The default behavior is that the dataset creates anonymous, exclusively
      owned jobs.
66
67
    tfds_data_dir: A str specifying the directory to read/write TFDS data.
    tfds_download: A bool to indicate whether to download data using TFDS.
Hongkun Yu's avatar
Hongkun Yu committed
68
69
70
71
72
73
74
75
    tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
      returned tf.data.Dataset will have a 2-tuple structure (input, label)
      according to builder.info.supervised_keys; if False, the default, the
      returned tf.data.Dataset will have a dictionary with all the features.
    tfds_skip_decoding_feature: A str to indicate which features are skipped for
      decoding when loading dataset from TFDS. Use comma to separate multiple
      features. The main use case is to skip the image/video decoding for better
      performance.
76
  """
77
  input_path: Union[Sequence[str], str] = ""
78
79
  tfds_name: str = ""
  tfds_split: str = ""
80
81
82
83
84
  global_batch_size: int = 0
  is_training: bool = None
  drop_remainder: bool = True
  shuffle_buffer_size: int = 100
  cache: bool = False
Ruoxin Sang's avatar
Ruoxin Sang committed
85
  cycle_length: Optional[int] = None
86
  block_length: int = 1
Ruoxin Sang's avatar
Ruoxin Sang committed
87
  deterministic: Optional[bool] = None
88
  sharding: bool = True
Ruoxin Sang's avatar
Ruoxin Sang committed
89
90
91
  enable_tf_data_service: bool = False
  tf_data_service_address: Optional[str] = None
  tf_data_service_job_name: Optional[str] = None
92
93
94
95
  tfds_data_dir: str = ""
  tfds_download: bool = False
  tfds_as_supervised: bool = False
  tfds_skip_decoding_feature: str = ""
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119


@dataclasses.dataclass
class RuntimeConfig(base_config.Config):
  """High-level configurations for Runtime.

  These include parameters that are not directly related to the experiment,
  e.g. directories, accelerator type, etc.

  Attributes:
    distribution_strategy: e.g. 'mirrored', 'tpu', etc.
    enable_xla: Whether or not to enable XLA.
    per_gpu_thread_count: thread count per GPU.
    gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
    dataset_num_private_threads: Number of threads for a private threadpool
      created for all datasets computation.
    tpu: The address of the TPU to use, if any.
    num_gpus: The number of GPUs to use, if any.
    worker_hosts: comma-separated list of worker ip:port pairs for running
      multi-worker models with DistributionStrategy.
    task_index: If multi-worker training, the task index of this worker.
    all_reduce_alg: Defines the algorithm for performing all-reduce.
    num_packs: Sets `num_packs` in the cross device ops used in
      MirroredStrategy.  For details, see tf.distribute.NcclAllReduce.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
120
121
122
123
    mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
      'float16', or 'bfloat16'.
    loss_scale: The type of loss scale, or 'float' value. This is used when
      setting the mixed precision policy.
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    run_eagerly: Whether or not to run the experiment eagerly.
    batchnorm_spatial_persistent: Whether or not to enable the spatial
      persistent mode for CuDNN batch norm kernel for improved GPU performance.
  """
  distribution_strategy: str = "mirrored"
  enable_xla: bool = False
  gpu_thread_mode: Optional[str] = None
  dataset_num_private_threads: Optional[int] = None
  per_gpu_thread_count: int = 0
  tpu: Optional[str] = None
  num_gpus: int = 0
  worker_hosts: Optional[str] = None
  task_index: int = -1
  all_reduce_alg: Optional[str] = None
  num_packs: int = 1
Abdullah Rashwan's avatar
Abdullah Rashwan committed
139
  mixed_precision_dtype: Optional[str] = None
140
  loss_scale: Optional[Union[str, float]] = None
141
142
143
  run_eagerly: bool = False
  batchnorm_spatial_persistent: bool = False

Hongkun Yu's avatar
Hongkun Yu committed
144
145
146
147
148
149
150
151
152
  # Global model parallelism configurations.
  num_cores_per_replica: int = 1
  default_shard_dim: int = -1

  def model_parallelism(self):
    return dict(
        num_cores_per_replica=self.num_cores_per_replica,
        default_shard_dim=self.default_shard_dim)

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

@dataclasses.dataclass
class TensorboardConfig(base_config.Config):
  """Configuration for Tensorboard.

  Attributes:
    track_lr: Whether or not to track the learning rate in Tensorboard. Defaults
      to True.
    write_model_weights: Whether or not to write the model weights as images in
      Tensorboard. Defaults to False.
  """
  track_lr: bool = True
  write_model_weights: bool = False


@dataclasses.dataclass
class CallbacksConfig(base_config.Config):
  """Configuration for Callbacks.

  Attributes:
    enable_checkpoint_and_export: Whether or not to enable checkpoints as a
      Callback. Defaults to True.
175
176
    enable_backup_and_restore: Whether or not to add BackupAndRestore
      callback. Defaults to True.
177
178
179
180
181
182
    enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
      Defaults to True.
    enable_time_history: Whether or not to enable TimeHistory Callbacks.
      Defaults to True.
  """
  enable_checkpoint_and_export: bool = True
183
  enable_backup_and_restore: bool = False
184
185
186
187
188
189
  enable_tensorboard: bool = True
  enable_time_history: bool = True


@dataclasses.dataclass
class TrainerConfig(base_config.Config):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
190
191
192
193
194
195
196
197
  """Configuration for trainer.

  Attributes:
    optimizer_config: optimizer config, it includes optimizer, learning rate,
      and warmup schedule configs.
    train_tf_while_loop: whether or not to use tf while loop.
    train_tf_function: whether or not to use tf_function for training loop.
    eval_tf_function: whether or not to use tf_function for eval.
198
199
    allow_tpu_summary: Whether to allow summary happen inside the XLA program
      runs on TPU through automatic outside compilation.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
200
201
    steps_per_loop: number of steps per loop.
    summary_interval: number of steps between each summary.
202
    checkpoint_interval: number of steps between checkpoints.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
203
204
    max_to_keep: max checkpoints to keep.
    continuous_eval_timeout: maximum number of seconds to wait between
Hongkun Yu's avatar
Hongkun Yu committed
205
      checkpoints, if set to None, continuous eval will wait indefinitely. This
Hongkun Yu's avatar
Hongkun Yu committed
206
      is only used continuous_train_and_eval and continuous_eval modes. Default
Hongkun Yu's avatar
Hongkun Yu committed
207
      value is 1 hrs.
208
209
210
211
    train_steps: number of train steps.
    validation_steps: number of eval steps. If `None`, the entire eval dataset
      is used.
    validation_interval: number of training steps to run between evaluations.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
212
213
214
215
216
217
218
219
220
221
222
    best_checkpoint_export_subdir: if set, the trainer will keep track of the
      best evaluation metric, and export the corresponding best checkpoint under
      `model_dir/best_checkpoint_export_subdir`. Note that this only works if
      mode contains eval (such as `train_and_eval`, `continuous_eval`, and
      `continuous_train_and_eval`).
    best_checkpoint_eval_metric: for exporting the best checkpoint, which
      evaluation metric the trainer should monitor. This can be any evaluation
      metric appears on tensorboard.
    best_checkpoint_metric_comp: for exporting the best checkpoint, how the
      trainer should compare the evaluation metrics. This can be either `higher`
      (higher the better) or `lower` (lower the better).
Abdullah Rashwan's avatar
Abdullah Rashwan committed
223
  """
224
  optimizer_config: OptimizationConfig = OptimizationConfig()
Hongkun Yu's avatar
Hongkun Yu committed
225
  # Orbit settings.
226
227
228
  train_tf_while_loop: bool = True
  train_tf_function: bool = True
  eval_tf_function: bool = True
Hongkun Yu's avatar
Hongkun Yu committed
229
230
  allow_tpu_summary: bool = False
  # Trainer intervals.
231
232
233
  steps_per_loop: int = 1000
  summary_interval: int = 1000
  checkpoint_interval: int = 1000
Hongkun Yu's avatar
Hongkun Yu committed
234
  # Checkpoint manager.
235
  max_to_keep: int = 5
Hongkun Yu's avatar
Hongkun Yu committed
236
  continuous_eval_timeout: int = 60 * 60
Hongkun Yu's avatar
Hongkun Yu committed
237
  # Train/Eval routines.
238
239
240
  train_steps: int = 0
  validation_steps: Optional[int] = None
  validation_interval: int = 1000
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
241
242
243
244
  # Best checkpoint export.
  best_checkpoint_export_subdir: str = ""
  best_checkpoint_eval_metric: str = ""
  best_checkpoint_metric_comp: str = "higher"
245
246
247
248


@dataclasses.dataclass
class TaskConfig(base_config.Config):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
249
  init_checkpoint: str = ""
Pengchong Jin's avatar
Pengchong Jin committed
250
  model: base_config.Config = None
251
252
253
254
255
256
257
258
259
260
  train_data: DataConfig = DataConfig()
  validation_data: DataConfig = DataConfig()


@dataclasses.dataclass
class ExperimentConfig(base_config.Config):
  """Top-level configuration."""
  task: TaskConfig = TaskConfig()
  trainer: TrainerConfig = TrainerConfig()
  runtime: RuntimeConfig = RuntimeConfig()