Commit 57c08e2f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Make function argument names consistent in core.base_task.Task

PiperOrigin-RevId: 316513485
parent ee3cc115
......@@ -114,18 +114,18 @@ class Task(tf.Module):
"""
pass
def build_losses(self, features, model_outputs, aux_losses=None) -> tf.Tensor:
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses.
Args:
features: optional feature/labels tensors.
labels: optional label tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
del model_outputs, features
del model_outputs, labels
if aux_losses is None:
losses = [tf.constant(0.0, dtype=tf.float32)]
......@@ -139,29 +139,29 @@ class Task(tf.Module):
del training
return []
def process_metrics(self, metrics, labels, outputs):
def process_metrics(self, metrics, labels, model_outputs):
"""Process and update metrics. Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects.
The return of function self.build_metrics.
labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
"""
for metric in metrics:
metric.update_state(labels, outputs)
metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, outputs):
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
"""Process and update compiled_metrics. call when using compile/fit API.
Args:
compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
"""
compiled_metrics.update_state(labels, outputs)
compiled_metrics.update_state(labels, model_outputs)
def train_step(self,
inputs,
......@@ -187,7 +187,7 @@ class Task(tf.Module):
outputs = model(features, training=True)
# Computes per-replica loss.
loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses)
labels=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
......@@ -231,7 +231,7 @@ class Task(tf.Module):
features, labels = inputs, inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses)
labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
......
......@@ -43,25 +43,25 @@ class MaskedLMTask(base_task.Task):
return bert.instantiate_from_cfg(self.task_config.network)
def build_losses(self,
features,
labels,
model_outputs,
metrics,
aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics])
lm_output = tf.nn.log_softmax(model_outputs['lm_output'], axis=-1)
mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=features['masked_lm_ids'],
labels=labels['masked_lm_ids'],
predictions=lm_output,
weights=features['masked_lm_weights'])
weights=labels['masked_lm_weights'])
metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in features:
if 'next_sentence_labels' in labels:
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == 'mixed_bfloat16': # b/158514794: bf16 is not stable.
policy = tf.float32
predictions = tf.keras.layers.Activation(
tf.nn.log_softmax, dtype=policy)(model_outputs['next_sentence'])
sentence_labels = features['next_sentence_labels']
sentence_labels = labels['next_sentence_labels']
sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels,
predictions=predictions)
......@@ -112,15 +112,15 @@ class MaskedLMTask(base_task.Task):
metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss'))
return metrics
def process_metrics(self, metrics, inputs, outputs):
def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(inputs['masked_lm_ids'],
outputs['lm_output'],
inputs['masked_lm_weights'])
metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
model_outputs['lm_output'],
labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state(
inputs['next_sentence_labels'], outputs['next_sentence'])
labels['next_sentence_labels'], model_outputs['next_sentence'])
def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics):
......@@ -139,7 +139,7 @@ class MaskedLMTask(base_task.Task):
outputs = model(inputs, training=True)
# Computes per-replica loss.
loss = self.build_losses(
features=inputs,
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
......@@ -166,7 +166,7 @@ class MaskedLMTask(base_task.Task):
"""
outputs = self.inference_step(inputs, model)
loss = self.build_losses(
features=inputs,
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
......
......@@ -79,8 +79,7 @@ class SentencePredictionTask(base_task.Task):
else:
return bert.instantiate_from_cfg(self.task_config.network)
def build_losses(self, features, model_outputs, aux_losses=None) -> tf.Tensor:
labels = features
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=labels,
predictions=tf.nn.log_softmax(model_outputs['sentence_prediction'],
......@@ -118,12 +117,12 @@ class SentencePredictionTask(base_task.Task):
]
return metrics
def process_metrics(self, metrics, labels, outputs):
def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics:
metric.update_state(labels, outputs['sentence_prediction'])
metric.update_state(labels, model_outputs['sentence_prediction'])
def process_compiled_metrics(self, compiled_metrics, labels, outputs):
compiled_metrics.update_state(labels, outputs['sentence_prediction'])
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs['sentence_prediction'])
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment