Commit e30aa7d8 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 318935091
parent cc5a7980
...@@ -125,25 +125,26 @@ class SentencePredictionTask(base_task.Task): ...@@ -125,25 +125,26 @@ class SentencePredictionTask(base_task.Task):
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
loss = self.build_losses( loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss}
if self.metric_type == 'matthews_corrcoef': if self.metric_type == 'matthews_corrcoef':
return { logs.update({
self.loss:
loss,
'sentence_prediction': 'sentence_prediction':
tf.expand_dims( tf.expand_dims(
tf.math.argmax(outputs['sentence_prediction'], axis=1), tf.math.argmax(outputs['sentence_prediction'], axis=1),
axis=0), axis=0),
'labels': 'labels':
labels, labels,
} })
if self.metric_type == 'pearson_spearman_corr': if self.metric_type == 'pearson_spearman_corr':
return { logs.update({
self.loss: loss,
'sentence_prediction': outputs['sentence_prediction'], 'sentence_prediction': outputs['sentence_prediction'],
'labels': labels, 'labels': labels,
} })
return logs
def aggregate_logs(self, state=None, step_outputs=None): def aggregate_logs(self, state=None, step_outputs=None):
if self.metric_type == 'accuracy':
return None
if state is None: if state is None:
state = {'sentence_prediction': [], 'labels': []} state = {'sentence_prediction': [], 'labels': []}
state['sentence_prediction'].append( state['sentence_prediction'].append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment