"examples/tensorflow/rgcn/model.py" did not exist on "b98dc92c59911cf72e27f87df05c74ef1894bc43"
Commit 2b68aa95 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 368535232
parent 8a16208b
...@@ -21,6 +21,7 @@ from official.core import base_task ...@@ -21,6 +21,7 @@ from official.core import base_task
from official.core import config_definitions from official.core import config_definitions
from official.core import task_factory from official.core import task_factory
from official.modeling import optimization from official.modeling import optimization
from official.modeling.multitask import base_model
from official.modeling.multitask import configs from official.modeling.multitask import configs
OptimizationConfig = optimization.OptimizationConfig OptimizationConfig = optimization.OptimizationConfig
...@@ -79,9 +80,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -79,9 +80,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_eval_steps[task_name] = task_routine.eval_steps task_eval_steps[task_name] = task_routine.eval_steps
task_weights[task_name] = task_routine.task_weight task_weights[task_name] = task_routine.task_weight
return cls( return cls(
tasks, tasks, task_eval_steps=task_eval_steps, task_weights=task_weights)
task_eval_steps=task_eval_steps,
task_weights=task_weights)
@property @property
def tasks(self): def tasks(self):
...@@ -104,15 +103,17 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -104,15 +103,17 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
return base_task.Task.create_optimizer( return base_task.Task.create_optimizer(
optimizer_config=optimizer_config, runtime_config=runtime_config) optimizer_config=optimizer_config, runtime_config=runtime_config)
def joint_train_step(self, task_inputs, multi_task_model, optimizer, def joint_train_step(self, task_inputs,
task_metrics): multi_task_model: base_model.MultiTaskBaseModel,
optimizer: tf.keras.optimizers.Optimizer, task_metrics):
"""The joint train step. """The joint train step.
Args: Args:
task_inputs: a dictionary of task names and per-task features. task_inputs: a dictionary of task names and per-task features.
multi_task_model: a MultiTaskModel instance. multi_task_model: a MultiTaskBaseModel instance.
optimizer: a tf.optimizers.Optimizer. optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics. task_metrics: a dictionary of task names and per-task metrics.
Returns: Returns:
A dictionary of losses, inculding per-task losses and their weighted sum. A dictionary of losses, inculding per-task losses and their weighted sum.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment