common.py 2.4 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14
15

# Lint as: python3
Abdullah Rashwan's avatar
Abdullah Rashwan committed
16
17
"""Common configurations."""

18
from typing import Optional
Abdullah Rashwan's avatar
Abdullah Rashwan committed
19
# Import libraries
Fan Yang's avatar
Fan Yang committed
20

Abdullah Rashwan's avatar
Abdullah Rashwan committed
21
22
import dataclasses

Fan Yang's avatar
Fan Yang committed
23
from official.core import config_definitions as cfg
Abdullah Rashwan's avatar
Abdullah Rashwan committed
24
25
26
from official.modeling import hyperparams


27
28
29
30
31
32
33
@dataclasses.dataclass
class RandAugment(hyperparams.Config):
  """Configuration for RandAugment."""
  num_layers: int = 2
  magnitude: float = 10
  cutout_const: float = 40
  translate_const: float = 10
Fan Yang's avatar
Fan Yang committed
34
  prob_to_apply: Optional[float] = None
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58


@dataclasses.dataclass
class AutoAugment(hyperparams.Config):
  """Configuration for AutoAugment."""
  augmentation_name: str = 'v0'
  cutout_const: float = 100
  translate_const: float = 250


@dataclasses.dataclass
class Augmentation(hyperparams.OneOfConfig):
  """Configuration for input data augmentation.

  Attributes:
    type: 'str', type of augmentation be used, one of the fields below.
    randaug: RandAugment config.
    autoaug: AutoAugment config.
  """
  type: Optional[str] = None
  randaug: RandAugment = RandAugment()
  autoaug: AutoAugment = AutoAugment()


Abdullah Rashwan's avatar
Abdullah Rashwan committed
59
60
61
@dataclasses.dataclass
class NormActivation(hyperparams.Config):
  activation: str = 'relu'
Pengchong Jin's avatar
Pengchong Jin committed
62
  use_sync_bn: bool = True
Abdullah Rashwan's avatar
Abdullah Rashwan committed
63
64
  norm_momentum: float = 0.99
  norm_epsilon: float = 0.001
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
65
66
67


@dataclasses.dataclass
Fan Yang's avatar
Fan Yang committed
68
class PseudoLabelDataConfig(cfg.DataConfig):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
69
70
  """Psuedo Label input config for training."""
  input_path: str = ''
71
  data_ratio: float = 1.0  # Per-batch ratio of pseudo-labeled to labeled data.
Fan Yang's avatar
Fan Yang committed
72
73
74
75
  is_training: bool = True
  dtype: str = 'float32'
  shuffle_buffer_size: int = 10000
  cycle_length: int = 10
76
77
78
  aug_rand_hflip: bool = True
  aug_type: Optional[
      Augmentation] = None  # Choose from AutoAugment and RandAugment.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
79
  file_type: str = 'tfrecord'
Fan Yang's avatar
Fan Yang committed
80
81
82
83

  # Keep for backward compatibility.
  aug_policy: Optional[str] = None  # None, 'autoaug', or 'randaug'.
  randaug_magnitude: Optional[int] = 10