Commit 57c08e2f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Make function argument names consistent in core.base_task.Task

PiperOrigin-RevId: 316513485
parent ee3cc115
...@@ -114,18 +114,18 @@ class Task(tf.Module): ...@@ -114,18 +114,18 @@ class Task(tf.Module):
""" """
pass pass
def build_losses(self, features, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses. """Standard interface to compute losses.
Args: Args:
features: optional feature/labels tensors. labels: optional label tensors.
model_outputs: a nested structure of output tensors. model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns: Returns:
The total loss tensor. The total loss tensor.
""" """
del model_outputs, features del model_outputs, labels
if aux_losses is None: if aux_losses is None:
losses = [tf.constant(0.0, dtype=tf.float32)] losses = [tf.constant(0.0, dtype=tf.float32)]
...@@ -139,29 +139,29 @@ class Task(tf.Module): ...@@ -139,29 +139,29 @@ class Task(tf.Module):
del training del training
return [] return []
def process_metrics(self, metrics, labels, outputs): def process_metrics(self, metrics, labels, model_outputs):
"""Process and update metrics. Called when using custom training loop API. """Process and update metrics. Called when using custom training loop API.
Args: Args:
metrics: a nested structure of metrics objects. metrics: a nested structure of metrics objects.
The return of function self.build_metrics. The return of function self.build_metrics.
labels: a tensor or a nested structure of tensors. labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors. model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model. For example, output of the keras model built by self.build_model.
""" """
for metric in metrics: for metric in metrics:
metric.update_state(labels, outputs) metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
"""Process and update compiled_metrics. call when using compile/fit API. """Process and update compiled_metrics. call when using compile/fit API.
Args: Args:
compiled_metrics: the compiled metrics (model.compiled_metrics). compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors. labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors. model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model. For example, output of the keras model built by self.build_model.
""" """
compiled_metrics.update_state(labels, outputs) compiled_metrics.update_state(labels, model_outputs)
def train_step(self, def train_step(self,
inputs, inputs,
...@@ -187,7 +187,7 @@ class Task(tf.Module): ...@@ -187,7 +187,7 @@ class Task(tf.Module):
outputs = model(features, training=True) outputs = model(features, training=True)
# Computes per-replica loss. # Computes per-replica loss.
loss = self.build_losses( loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
# optimizer. # optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
...@@ -231,7 +231,7 @@ class Task(tf.Module): ...@@ -231,7 +231,7 @@ class Task(tf.Module):
features, labels = inputs, inputs features, labels = inputs, inputs
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
loss = self.build_losses( loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
......
...@@ -43,25 +43,25 @@ class MaskedLMTask(base_task.Task): ...@@ -43,25 +43,25 @@ class MaskedLMTask(base_task.Task):
return bert.instantiate_from_cfg(self.task_config.network) return bert.instantiate_from_cfg(self.task_config.network)
def build_losses(self, def build_losses(self,
features, labels,
model_outputs, model_outputs,
metrics, metrics,
aux_losses=None) -> tf.Tensor: aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics]) metrics = dict([(metric.name, metric) for metric in metrics])
lm_output = tf.nn.log_softmax(model_outputs['lm_output'], axis=-1) lm_output = tf.nn.log_softmax(model_outputs['lm_output'], axis=-1)
mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=features['masked_lm_ids'], labels=labels['masked_lm_ids'],
predictions=lm_output, predictions=lm_output,
weights=features['masked_lm_weights']) weights=labels['masked_lm_weights'])
metrics['lm_example_loss'].update_state(mlm_loss) metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in features: if 'next_sentence_labels' in labels:
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == 'mixed_bfloat16': # b/158514794: bf16 is not stable. if policy.name == 'mixed_bfloat16': # b/158514794: bf16 is not stable.
policy = tf.float32 policy = tf.float32
predictions = tf.keras.layers.Activation( predictions = tf.keras.layers.Activation(
tf.nn.log_softmax, dtype=policy)(model_outputs['next_sentence']) tf.nn.log_softmax, dtype=policy)(model_outputs['next_sentence'])
sentence_labels = features['next_sentence_labels'] sentence_labels = labels['next_sentence_labels']
sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, labels=sentence_labels,
predictions=predictions) predictions=predictions)
...@@ -112,15 +112,15 @@ class MaskedLMTask(base_task.Task): ...@@ -112,15 +112,15 @@ class MaskedLMTask(base_task.Task):
metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss')) metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss'))
return metrics return metrics
def process_metrics(self, metrics, inputs, outputs): def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics]) metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics: if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(inputs['masked_lm_ids'], metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
outputs['lm_output'], model_outputs['lm_output'],
inputs['masked_lm_weights']) labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics: if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state( metrics['next_sentence_accuracy'].update_state(
inputs['next_sentence_labels'], outputs['next_sentence']) labels['next_sentence_labels'], model_outputs['next_sentence'])
def train_step(self, inputs, model: tf.keras.Model, def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics): optimizer: tf.keras.optimizers.Optimizer, metrics):
...@@ -139,7 +139,7 @@ class MaskedLMTask(base_task.Task): ...@@ -139,7 +139,7 @@ class MaskedLMTask(base_task.Task):
outputs = model(inputs, training=True) outputs = model(inputs, training=True)
# Computes per-replica loss. # Computes per-replica loss.
loss = self.build_losses( loss = self.build_losses(
features=inputs, labels=inputs,
model_outputs=outputs, model_outputs=outputs,
metrics=metrics, metrics=metrics,
aux_losses=model.losses) aux_losses=model.losses)
...@@ -166,7 +166,7 @@ class MaskedLMTask(base_task.Task): ...@@ -166,7 +166,7 @@ class MaskedLMTask(base_task.Task):
""" """
outputs = self.inference_step(inputs, model) outputs = self.inference_step(inputs, model)
loss = self.build_losses( loss = self.build_losses(
features=inputs, labels=inputs,
model_outputs=outputs, model_outputs=outputs,
metrics=metrics, metrics=metrics,
aux_losses=model.losses) aux_losses=model.losses)
......
...@@ -79,8 +79,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -79,8 +79,7 @@ class SentencePredictionTask(base_task.Task):
else: else:
return bert.instantiate_from_cfg(self.task_config.network) return bert.instantiate_from_cfg(self.task_config.network)
def build_losses(self, features, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
labels = features
loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=labels, labels=labels,
predictions=tf.nn.log_softmax(model_outputs['sentence_prediction'], predictions=tf.nn.log_softmax(model_outputs['sentence_prediction'],
...@@ -118,12 +117,12 @@ class SentencePredictionTask(base_task.Task): ...@@ -118,12 +117,12 @@ class SentencePredictionTask(base_task.Task):
] ]
return metrics return metrics
def process_metrics(self, metrics, labels, outputs): def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics: for metric in metrics:
metric.update_state(labels, outputs['sentence_prediction']) metric.update_state(labels, model_outputs['sentence_prediction'])
def process_compiled_metrics(self, compiled_metrics, labels, outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, outputs['sentence_prediction']) compiled_metrics.update_state(labels, model_outputs['sentence_prediction'])
def initialize(self, model): def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0.""" """Load a pretrained checkpoint (if exists) and then train from iter 0."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment