configs.py 2.62 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Hongkun Yu's avatar
Hongkun Yu committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Configuration definitions for multi-task training."""
from typing import Optional, Tuple

import dataclasses

from official.core import config_definitions as cfg
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
21
from official.modeling import hyperparams
Frederick Liu's avatar
Frederick Liu committed
22
from official.modeling.privacy import configs as dp_configs
Hongkun Yu's avatar
Hongkun Yu committed
23
24
25


@dataclasses.dataclass
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
26
class TaskRoutine(hyperparams.Config):
27
  # TODO(hongkuny): deprecate the task_name once we migrated client code.
Hongkun Yu's avatar
Hongkun Yu committed
28
29
30
  task_name: str = ""
  task_config: cfg.TaskConfig = None
  eval_steps: Optional[int] = None
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
31
  task_weight: Optional[float] = 1.0
Hongkun Yu's avatar
Hongkun Yu committed
32
33
34


@dataclasses.dataclass
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
35
class MultiTaskConfig(hyperparams.Config):
Hongkun Yu's avatar
Hongkun Yu committed
36
  init_checkpoint: str = ""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
37
  model: hyperparams.Config = None
Hongkun Yu's avatar
Hongkun Yu committed
38
  task_routines: Tuple[TaskRoutine, ...] = ()
Frederick Liu's avatar
Frederick Liu committed
39
40
  differential_privacy_config: Optional[
      dp_configs.DifferentialPrivacyConfig] = None
Hongkun Yu's avatar
Hongkun Yu committed
41
42


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@dataclasses.dataclass
class ProportionalSampleConfig(hyperparams.Config):
  alpha: float = 1.0


@dataclasses.dataclass
class AnnealingSampleConfig(hyperparams.Config):
  steps_per_epoch: int = 5
  total_steps: int = 20


@dataclasses.dataclass
class TaskSamplingConfig(hyperparams.OneOfConfig):
  type: str = ""
  uniform: hyperparams.Config = hyperparams.Config()
  proportional: ProportionalSampleConfig = ProportionalSampleConfig()
  annealing: AnnealingSampleConfig = AnnealingSampleConfig()


@dataclasses.dataclass
class MultiTaskTrainerConfig(cfg.TrainerConfig):
  trainer_type: str = "interleaving"
  task_sampler: TaskSamplingConfig = TaskSamplingConfig(type="proportional")


@dataclasses.dataclass
class MultiTaskExperimentConfig(hyperparams.Config):
  """An experiment config for multi-task training and multi-task evaluation."""
  task: MultiTaskConfig = MultiTaskConfig()
  trainer: MultiTaskTrainerConfig = MultiTaskTrainerConfig()
  runtime: cfg.RuntimeConfig = cfg.RuntimeConfig()


Hongkun Yu's avatar
Hongkun Yu committed
76
@dataclasses.dataclass
77
class MultiEvalExperimentConfig(cfg.ExperimentConfig):
Hongkun Yu's avatar
Hongkun Yu committed
78
79
80
81
82
  """An experiment config for single-task training and multi-task evaluation.

  Attributes:
    eval_tasks: individual evaluation tasks.
  """
83
  eval_tasks: Tuple[TaskRoutine, ...] = ()