Commit 2b68aa95 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 368535232
parent 8a16208b
......@@ -21,6 +21,7 @@ from official.core import base_task
from official.core import config_definitions
from official.core import task_factory
from official.modeling import optimization
from official.modeling.multitask import base_model
from official.modeling.multitask import configs
OptimizationConfig = optimization.OptimizationConfig
......@@ -79,9 +80,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_eval_steps[task_name] = task_routine.eval_steps
task_weights[task_name] = task_routine.task_weight
return cls(
tasks,
task_eval_steps=task_eval_steps,
task_weights=task_weights)
tasks, task_eval_steps=task_eval_steps, task_weights=task_weights)
@property
def tasks(self):
......@@ -104,15 +103,17 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
return base_task.Task.create_optimizer(
optimizer_config=optimizer_config, runtime_config=runtime_config)
def joint_train_step(self, task_inputs, multi_task_model, optimizer,
task_metrics):
def joint_train_step(self, task_inputs,
multi_task_model: base_model.MultiTaskBaseModel,
optimizer: tf.keras.optimizers.Optimizer, task_metrics):
"""The joint train step.
Args:
task_inputs: a dictionary of task names and per-task features.
multi_task_model: a MultiTaskModel instance.
multi_task_model: a MultiTaskBaseModel instance.
optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics.
Returns:
A dictionary of losses, inculding per-task losses and their weighted sum.
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment