Commit 8e00ce42 authored by Terry Huang's avatar Terry Huang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 405693354
parent 0d3a13bf
...@@ -166,7 +166,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -166,7 +166,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
del training del training
return [] return []
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs, **kwargs):
"""Process and update metrics. """Process and update metrics.
Called when using custom training loop API. Called when using custom training loop API.
...@@ -177,6 +177,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -177,6 +177,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
labels: a tensor or a nested structure of tensors. labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example, model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model. output of the keras model built by self.build_model.
**kwargs: other args.
""" """
for metric in metrics: for metric in metrics:
metric.update_state(labels, model_outputs) metric.update_state(labels, model_outputs)
......
...@@ -99,7 +99,8 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -99,7 +99,8 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def joint_train_step(self, task_inputs, def joint_train_step(self, task_inputs,
multi_task_model: base_model.MultiTaskBaseModel, multi_task_model: base_model.MultiTaskBaseModel,
optimizer: tf.keras.optimizers.Optimizer, task_metrics): optimizer: tf.keras.optimizers.Optimizer, task_metrics,
**kwargs):
"""The joint train step. """The joint train step.
Args: Args:
...@@ -107,6 +108,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -107,6 +108,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
multi_task_model: a MultiTaskBaseModel instance. multi_task_model: a MultiTaskBaseModel instance.
optimizer: a tf.optimizers.Optimizer. optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics. task_metrics: a dictionary of task names and per-task metrics.
**kwargs: other arguments to pass through.
Returns: Returns:
A dictionary of losses, inculding per-task losses and their weighted sum. A dictionary of losses, inculding per-task losses and their weighted sum.
...@@ -129,7 +131,8 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -129,7 +131,8 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_weight = self.task_weight(name) task_weight = self.task_weight(name)
total_loss += task_weight * task_loss total_loss += task_weight * task_loss
losses[name] = task_loss losses[name] = task_loss
self.tasks[name].process_metrics(task_metrics[name], labels, outputs) self.tasks[name].process_metrics(task_metrics[name], labels, outputs,
**kwargs)
# Scales loss as the default gradients allreduce performs sum inside # Scales loss as the default gradients allreduce performs sum inside
# the optimizer. # the optimizer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment