Commit b2e422b0 authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Adjust run_classification to support fine-tuning regression tasks.

PiperOrigin-RevId: 314607393
parent 4bb13e61
......@@ -154,13 +154,14 @@ def create_classifier_dataset(file_path,
seq_length,
batch_size,
is_training=True,
input_pipeline_context=None):
input_pipeline_context=None,
label_type=tf.int64):
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64),
'label_ids': tf.io.FixedLenFeature([], label_type),
}
dataset = single_file_dataset(file_path, name_to_features)
......
......@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT classification finetuning runner in TF 2.x."""
"""BERT classification or regression finetuning runner in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import math
import os
......@@ -60,6 +61,8 @@ common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
def get_loss_fn(num_classes):
"""Gets the classification loss function."""
......@@ -77,8 +80,20 @@ def get_loss_fn(num_classes):
return classification_loss_fn
def get_regression_loss_fn():
"""Gets the regression loss function."""
def regression_loss_fn(labels, logits):
"""Regression loss."""
labels = tf.cast(labels, dtype=tf.float32)
per_example_loss = tf.math.squared_difference(labels, logits)
return tf.reduce_mean(per_example_loss)
return regression_loss_fn
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
is_training, label_type=tf.int64):
"""Gets a closure to create a dataset."""
def _dataset_fn(ctx=None):
......@@ -90,7 +105,8 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
input_pipeline_context=ctx,
label_type=label_type)
return dataset
return _dataset_fn
......@@ -113,7 +129,8 @@ def run_bert_classifier(strategy,
custom_callbacks=None):
"""Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels']
num_classes = input_meta_data.get('num_labels', 1)
is_regression = num_classes == 1
def _get_classifier_model():
"""Gets a classifier model."""
......@@ -134,13 +151,17 @@ def run_bert_classifier(strategy,
use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model
loss_fn = get_loss_fn(num_classes)
loss_fn = (get_regression_loss_fn() if is_regression
else get_loss_fn(num_classes))
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
def metric_fn():
return tf.keras.metrics.SparseCategoricalAccuracy(
'accuracy', dtype=tf.float32)
if is_regression:
metric_fn = functools.partial(tf.keras.metrics.MeanSquaredError,
'mean_squared_error', dtype=tf.float32)
else:
metric_fn = functools.partial(tf.keras.metrics.SparseCategoricalAccuracy,
'accuracy', dtype=tf.float32)
# Start training using Keras compile/fit API.
logging.info('Training using TF 2.x Keras compile/fit API with '
......@@ -310,7 +331,7 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
classifier_model = bert_models.classifier_model(
bert_config, input_meta_data['num_labels'])[0]
bert_config, input_meta_data.get('num_labels', 1))[0]
model_saving_utils.export_bert_model(
model_export_path, model=classifier_model, checkpoint_dir=model_dir)
......@@ -371,7 +392,7 @@ def run_bert(strategy,
def custom_main(custom_callbacks=None):
"""Run classification.
"""Run classification or regression.
Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
......@@ -380,6 +401,7 @@ def custom_main(custom_callbacks=None):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
......@@ -399,7 +421,8 @@ def custom_main(custom_callbacks=None):
FLAGS.eval_data_path,
input_meta_data['max_seq_length'],
FLAGS.eval_batch_size,
is_training=False)
is_training=False,
label_type=label_type)
if FLAGS.mode == 'predict':
with strategy.scope():
......@@ -432,7 +455,8 @@ def custom_main(custom_callbacks=None):
FLAGS.train_data_path,
input_meta_data['max_seq_length'],
FLAGS.train_batch_size,
is_training=True)
is_training=True,
label_type=label_type)
run_bert(
strategy,
input_meta_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment