params.py 6.12 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Datastructures for all the configurations for MobileBERT-EdgeTPU training."""
import dataclasses
from typing import Optional

from official.modeling import optimization
from official.modeling.hyperparams import base_config
from official.nlp.configs import bert
from official.nlp.data import pretrain_dataloader

DatasetParams = pretrain_dataloader.BertPretrainDataConfig
PretrainerModelParams = bert.PretrainerConfig


@dataclasses.dataclass
class OrbitParams(base_config.Config):
  """Parameters that setup Orbit training/evaluation pipeline.

  Attributes:
    mode: Orbit controller mode, can be 'train', 'train_and_evaluate', or
      'evaluate'.
    steps_per_loop: The number of steps to run in each inner loop of training.
    total_steps: The global step count to train up to.
    eval_steps: The number of steps to run during an evaluation. If -1, this
      method will evaluate over the entire evaluation dataset.
    eval_interval: The number of training steps to run between evaluations. If
      set, training will always stop every `eval_interval` steps, even if this
      results in a shorter inner loop than specified by `steps_per_loop`
      setting. If None, evaluation will only be performed after training is
      complete.
  """
  mode: str = 'train'
  steps_per_loop: int = 1000
  total_steps: int = 1000000
  eval_steps: int = -1
  eval_interval: Optional[int] = None


@dataclasses.dataclass
class OptimizerParams(optimization.OptimizationConfig):
  """Optimizer parameters for MobileBERT-EdgeTPU."""
  optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig(
      type='adamw',
      adamw=optimization.AdamWeightDecayConfig(
          weight_decay_rate=0.01,
          exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias']))
  learning_rate: optimization.LrConfig = optimization.LrConfig(
      type='polynomial',
      polynomial=optimization.PolynomialLrConfig(
          initial_learning_rate=1e-4,
          decay_steps=1000000,
          end_learning_rate=0.0))
  warmup: optimization.WarmupConfig = optimization.WarmupConfig(
      type='polynomial',
      polynomial=optimization.PolynomialWarmupConfig(warmup_steps=10000))


@dataclasses.dataclass
class RuntimeParams(base_config.Config):
  """Parameters that set up the training runtime.

  TODO(longy): Can reuse the Runtime Config in:
  official/core/config_definitions.py

  Attributes
    distribution_strategy: Keras distribution strategy
    use_gpu: Whether to use GPU
    use_tpu: Whether to use TPU
    num_gpus: Number of gpus to use for training
    num_workers: Number of parallel workers
    tpu_address: The bns address of the TPU to use.
  """
  distribution_strategy: str = 'off'
  num_gpus: Optional[int] = 0
  all_reduce_alg: Optional[str] = None
  num_workers: int = 1
  tpu_address: str = ''
  use_gpu: Optional[bool] = None
  use_tpu: Optional[bool] = None


@dataclasses.dataclass
class LayerWiseDistillationParams(base_config.Config):
  """Define the behavior of layer-wise distillation.

  Layer-wise distillation is an optional step where the knowledge is transferred
  layerwisely for all the transformer layers. The end-to-end distillation is
  performed after layer-wise distillation if layer-wise distillation steps is
  not zero.
  """
  num_steps: int = 10000
  warmup_steps: int = 10000
  initial_learning_rate: float = 1.5e-3
  end_learning_rate: float = 1.5e-3
  decay_steps: int = 10000
  hidden_distill_factor: float = 100.0
  beta_distill_factor: float = 5000.0
  gamma_distill_factor: float = 5.0
  attention_distill_factor: float = 1.0


@dataclasses.dataclass
class EndToEndDistillationParams(base_config.Config):
  """Define the behavior of end2end pretrainer distillation."""
  num_steps: int = 580000
  warmup_steps: int = 20000
  initial_learning_rate: float = 1.5e-3
  end_learning_rate: float = 1.5e-7
  decay_steps: int = 580000
  distill_ground_truth_ratio: float = 0.5


@dataclasses.dataclass
class EdgeTPUBERTCustomParams(base_config.Config):
  """EdgeTPU-BERT custom params.

  Attributes:
    train_dataset: An instance of the DatasetParams.
    eval_dataset: An instance of the DatasetParams.
    teacher_model: An instance of the PretrainerModelParams. If None, then the
      student model is trained independently without distillation.
    student_model: An instance of the PretrainerModelParams
    teacher_model_init_checkpoint: Path for the teacher model init checkpoint.
    student_model_init_checkpoint: Path for the student model init checkpoint.
    layer_wise_distillation: Distillation config for the layer-wise step.
    end_to_end_distillation: Distillation config for the end2end step.
    optimizer: An instance of the OptimizerParams.
    runtime: An instance of the RuntimeParams.
    learning_rate: An instance of the LearningRateParams.
    orbit_config: An instance of the OrbitParams.
    distill_ground_truth_ratio: A float number representing the ratio between
      distillation output and ground truth.
  """
  train_datasest: DatasetParams = DatasetParams()
  eval_dataset: DatasetParams = DatasetParams()
  teacher_model: Optional[PretrainerModelParams] = PretrainerModelParams()
  student_model: PretrainerModelParams = PretrainerModelParams()
  teacher_model_init_checkpoint: str = ''
  student_model_init_checkpoint: str = ''
  layer_wise_distillation: LayerWiseDistillationParams = (
      LayerWiseDistillationParams())
  end_to_end_distillation: EndToEndDistillationParams = (
      EndToEndDistillationParams())
  optimizer: OptimizerParams = OptimizerParams()
  runtime: RuntimeParams = RuntimeParams()
  orbit_config: OrbitParams = OrbitParams()