Commit e7bbe6a3 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 352711257
parent bb35d42e
......@@ -18,27 +18,26 @@ from typing import Optional, Tuple
import dataclasses
from official.core import config_definitions as cfg
from official.modeling.hyperparams import base_config
from official.modeling import hyperparams
@dataclasses.dataclass
class TaskRoutine(base_config.Config):
class TaskRoutine(hyperparams.Config):
task_name: str = ""
task_config: cfg.TaskConfig = None
mixing_steps: int = 1
eval_steps: Optional[int] = None
task_weight: Optional[float] = None
task_weight: Optional[float] = 1.0
@dataclasses.dataclass
class MultiTaskConfig(base_config.Config):
class MultiTaskConfig(hyperparams.Config):
init_checkpoint: str = ""
model: base_config.Config = None
model: hyperparams.Config = None
task_routines: Tuple[TaskRoutine, ...] = ()
@dataclasses.dataclass
class MultiEvalExperimentConfig(base_config.Config):
class MultiEvalExperimentConfig(hyperparams.Config):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
......
......@@ -32,16 +32,16 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def __init__(self,
tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]],
task_mixing_steps: Optional[Dict[str, int]] = None,
task_weights: Optional[Dict[str, float]] = None,
task_weights: Optional[Dict[str, Union[float, int]]] = None,
task_eval_steps: Optional[Dict[str, int]] = None,
name: Optional[str] = None):
"""MultiTask initialization.
Args:
tasks: a list or a flat dict of Task.
task_mixing_steps: a dict of (task, mixing steps).
task_weights: a dict of (task, loss weight).
task_weights: a dict of (task, task weight), task weight can be applied
directly during loss summation in a joint backward step, or it can be
used to sample task among interleaved backward step.
task_eval_steps: a dict of (task, eval steps).
name: the instance name of a MultiTask object.
"""
......@@ -62,31 +62,24 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
self._task_eval_steps = dict([
(name, self._task_eval_steps.get(name, None)) for name in self.tasks
])
self._task_mixing_steps = task_mixing_steps or {}
self._task_mixing_steps = dict([
(name, self._task_mixing_steps.get(name, 1)) for name in self.tasks
])
self._task_weights = task_weights or {}
self._task_weights = dict([
(name, self._task_weights.get(name, None)) for name in self.tasks
(name, self._task_weights.get(name, 1.0)) for name in self.tasks
])
@classmethod
def from_config(cls, config: configs.MultiTaskConfig, logging_dir=None):
tasks = {}
task_eval_steps = {}
task_mixing_steps = {}
task_weights = {}
for task_routine in config.task_routines:
task_name = task_routine.task_name
tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir)
task_eval_steps[task_name] = task_routine.eval_steps
task_mixing_steps[task_name] = task_routine.mixing_steps
task_weights[task_name] = task_routine.task_weight
return cls(
tasks,
task_mixing_steps=task_mixing_steps,
task_eval_steps=task_eval_steps,
task_weights=task_weights)
......@@ -97,12 +90,13 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def task_eval_steps(self, task_name):
return self._task_eval_steps[task_name]
def task_mixing_steps(self, task_name):
return self._task_mixing_steps[task_name]
def task_weight(self, task_name):
return self._task_weights[task_name]
@property
def task_weights(self):
return self._task_weights
@classmethod
def create_optimizer(cls,
optimizer_config: OptimizationConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment