Commit b2e422b0 authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Adjust run_classification to support fine-tuning regression tasks.

PiperOrigin-RevId: 314607393
parent 4bb13e61
...@@ -154,13 +154,14 @@ def create_classifier_dataset(file_path, ...@@ -154,13 +154,14 @@ def create_classifier_dataset(file_path,
seq_length, seq_length,
batch_size, batch_size,
is_training=True, is_training=True,
input_pipeline_context=None): input_pipeline_context=None,
label_type=tf.int64):
"""Creates input dataset from (tf)records files for train/eval.""" """Creates input dataset from (tf)records files for train/eval."""
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64), 'label_ids': tf.io.FixedLenFeature([], label_type),
} }
dataset = single_file_dataset(file_path, name_to_features) dataset = single_file_dataset(file_path, name_to_features)
......
...@@ -12,11 +12,12 @@ ...@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""BERT classification finetuning runner in TF 2.x.""" """BERT classification or regression finetuning runner in TF 2.x."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import json import json
import math import math
import os import os
...@@ -60,6 +61,8 @@ common_flags.define_common_bert_flags() ...@@ -60,6 +61,8 @@ common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
def get_loss_fn(num_classes): def get_loss_fn(num_classes):
"""Gets the classification loss function.""" """Gets the classification loss function."""
...@@ -77,8 +80,20 @@ def get_loss_fn(num_classes): ...@@ -77,8 +80,20 @@ def get_loss_fn(num_classes):
return classification_loss_fn return classification_loss_fn
def get_regression_loss_fn():
"""Gets the regression loss function."""
def regression_loss_fn(labels, logits):
"""Regression loss."""
labels = tf.cast(labels, dtype=tf.float32)
per_example_loss = tf.math.squared_difference(labels, logits)
return tf.reduce_mean(per_example_loss)
return regression_loss_fn
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size, def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training): is_training, label_type=tf.int64):
"""Gets a closure to create a dataset.""" """Gets a closure to create a dataset."""
def _dataset_fn(ctx=None): def _dataset_fn(ctx=None):
...@@ -90,7 +105,8 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size, ...@@ -90,7 +105,8 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
max_seq_length, max_seq_length,
batch_size, batch_size,
is_training=is_training, is_training=is_training,
input_pipeline_context=ctx) input_pipeline_context=ctx,
label_type=label_type)
return dataset return dataset
return _dataset_fn return _dataset_fn
...@@ -113,7 +129,8 @@ def run_bert_classifier(strategy, ...@@ -113,7 +129,8 @@ def run_bert_classifier(strategy,
custom_callbacks=None): custom_callbacks=None):
"""Run BERT classifier training using low-level API.""" """Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length'] max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels'] num_classes = input_meta_data.get('num_labels', 1)
is_regression = num_classes == 1
def _get_classifier_model(): def _get_classifier_model():
"""Gets a classifier model.""" """Gets a classifier model."""
...@@ -134,13 +151,17 @@ def run_bert_classifier(strategy, ...@@ -134,13 +151,17 @@ def run_bert_classifier(strategy,
use_graph_rewrite=common_flags.use_graph_rewrite()) use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model return classifier_model, core_model
loss_fn = get_loss_fn(num_classes) loss_fn = (get_regression_loss_fn() if is_regression
else get_loss_fn(num_classes))
# Defines evaluation metrics function, which will create metrics in the # Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope. # correct device and strategy scope.
def metric_fn(): if is_regression:
return tf.keras.metrics.SparseCategoricalAccuracy( metric_fn = functools.partial(tf.keras.metrics.MeanSquaredError,
'accuracy', dtype=tf.float32) 'mean_squared_error', dtype=tf.float32)
else:
metric_fn = functools.partial(tf.keras.metrics.SparseCategoricalAccuracy,
'accuracy', dtype=tf.float32)
# Start training using Keras compile/fit API. # Start training using Keras compile/fit API.
logging.info('Training using TF 2.x Keras compile/fit API with ' logging.info('Training using TF 2.x Keras compile/fit API with '
...@@ -310,7 +331,7 @@ def export_classifier(model_export_path, input_meta_data, bert_config, ...@@ -310,7 +331,7 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
# Export uses float32 for now, even if training uses mixed precision. # Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
classifier_model = bert_models.classifier_model( classifier_model = bert_models.classifier_model(
bert_config, input_meta_data['num_labels'])[0] bert_config, input_meta_data.get('num_labels', 1))[0]
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
model_export_path, model=classifier_model, checkpoint_dir=model_dir) model_export_path, model=classifier_model, checkpoint_dir=model_dir)
...@@ -371,7 +392,7 @@ def run_bert(strategy, ...@@ -371,7 +392,7 @@ def run_bert(strategy,
def custom_main(custom_callbacks=None): def custom_main(custom_callbacks=None):
"""Run classification. """Run classification or regression.
Args: Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop. custom_callbacks: list of tf.keras.Callbacks passed to training loop.
...@@ -380,6 +401,7 @@ def custom_main(custom_callbacks=None): ...@@ -380,6 +401,7 @@ def custom_main(custom_callbacks=None):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
if not FLAGS.model_dir: if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/' FLAGS.model_dir = '/tmp/bert20/'
...@@ -399,7 +421,8 @@ def custom_main(custom_callbacks=None): ...@@ -399,7 +421,8 @@ def custom_main(custom_callbacks=None):
FLAGS.eval_data_path, FLAGS.eval_data_path,
input_meta_data['max_seq_length'], input_meta_data['max_seq_length'],
FLAGS.eval_batch_size, FLAGS.eval_batch_size,
is_training=False) is_training=False,
label_type=label_type)
if FLAGS.mode == 'predict': if FLAGS.mode == 'predict':
with strategy.scope(): with strategy.scope():
...@@ -432,7 +455,8 @@ def custom_main(custom_callbacks=None): ...@@ -432,7 +455,8 @@ def custom_main(custom_callbacks=None):
FLAGS.train_data_path, FLAGS.train_data_path,
input_meta_data['max_seq_length'], input_meta_data['max_seq_length'],
FLAGS.train_batch_size, FLAGS.train_batch_size,
is_training=True) is_training=True,
label_type=label_type)
run_bert( run_bert(
strategy, strategy,
input_meta_data, input_meta_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment