Commit 2dfd1e63 authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Support using sample weights for finetuning regression.

Replace as well the regression loss with Keras loss object.

PiperOrigin-RevId: 316062800
parent 5a4c4e18
...@@ -155,7 +155,8 @@ def create_classifier_dataset(file_path, ...@@ -155,7 +155,8 @@ def create_classifier_dataset(file_path,
batch_size, batch_size,
is_training=True, is_training=True,
input_pipeline_context=None, input_pipeline_context=None,
label_type=tf.int64): label_type=tf.int64,
include_sample_weights=False):
"""Creates input dataset from (tf)records files for train/eval.""" """Creates input dataset from (tf)records files for train/eval."""
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
...@@ -163,6 +164,8 @@ def create_classifier_dataset(file_path, ...@@ -163,6 +164,8 @@ def create_classifier_dataset(file_path,
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], label_type), 'label_ids': tf.io.FixedLenFeature([], label_type),
} }
if include_sample_weights:
name_to_features['weight'] = tf.io.FixedLenFeature([], tf.float32)
dataset = single_file_dataset(file_path, name_to_features) dataset = single_file_dataset(file_path, name_to_features)
# The dataset is always sharded by number of hosts. # The dataset is always sharded by number of hosts.
...@@ -178,6 +181,9 @@ def create_classifier_dataset(file_path, ...@@ -178,6 +181,9 @@ def create_classifier_dataset(file_path,
'input_type_ids': record['segment_ids'] 'input_type_ids': record['segment_ids']
} }
y = record['label_ids'] y = record['label_ids']
if include_sample_weights:
w = record['weight']
return (x, y, w)
return (x, y) return (x, y)
if is_training: if is_training:
......
...@@ -86,23 +86,12 @@ def get_loss_fn(num_classes): ...@@ -86,23 +86,12 @@ def get_loss_fn(num_classes):
return classification_loss_fn return classification_loss_fn
def get_regression_loss_fn():
"""Gets the regression loss function."""
def regression_loss_fn(labels, logits):
"""Regression loss."""
labels = tf.cast(labels, dtype=tf.float32)
per_example_loss = tf.math.squared_difference(labels, logits)
return tf.reduce_mean(per_example_loss)
return regression_loss_fn
def get_dataset_fn(input_file_pattern, def get_dataset_fn(input_file_pattern,
max_seq_length, max_seq_length,
global_batch_size, global_batch_size,
is_training, is_training,
label_type=tf.int64): label_type=tf.int64,
include_sample_weights=False):
"""Gets a closure to create a dataset.""" """Gets a closure to create a dataset."""
def _dataset_fn(ctx=None): def _dataset_fn(ctx=None):
...@@ -115,7 +104,8 @@ def get_dataset_fn(input_file_pattern, ...@@ -115,7 +104,8 @@ def get_dataset_fn(input_file_pattern,
batch_size, batch_size,
is_training=is_training, is_training=is_training,
input_pipeline_context=ctx, input_pipeline_context=ctx,
label_type=label_type) label_type=label_type,
include_sample_weights=include_sample_weights)
return dataset return dataset
return _dataset_fn return _dataset_fn
...@@ -160,8 +150,12 @@ def run_bert_classifier(strategy, ...@@ -160,8 +150,12 @@ def run_bert_classifier(strategy,
use_graph_rewrite=common_flags.use_graph_rewrite()) use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model return classifier_model, core_model
loss_fn = ( # tf.keras.losses objects accept optional sample_weight arguments (eg. coming
get_regression_loss_fn() if is_regression else get_loss_fn(num_classes)) # from the dataset) to compute weighted loss, as used for the regression
# tasks. The classification tasks, using the custom get_loss_fn don't accept
# sample weights though.
loss_fn = (tf.keras.losses.MeanSquaredError() if is_regression
else get_loss_fn(num_classes))
# Defines evaluation metrics function, which will create metrics in the # Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope. # correct device and strategy scope.
...@@ -416,6 +410,7 @@ def custom_main(custom_callbacks=None): ...@@ -416,6 +410,7 @@ def custom_main(custom_callbacks=None):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')] label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
include_sample_weights = input_meta_data.get('has_sample_weights', False)
if not FLAGS.model_dir: if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/' FLAGS.model_dir = '/tmp/bert20/'
...@@ -436,7 +431,8 @@ def custom_main(custom_callbacks=None): ...@@ -436,7 +431,8 @@ def custom_main(custom_callbacks=None):
input_meta_data['max_seq_length'], input_meta_data['max_seq_length'],
FLAGS.eval_batch_size, FLAGS.eval_batch_size,
is_training=False, is_training=False,
label_type=label_type) label_type=label_type,
include_sample_weights=include_sample_weights)
if FLAGS.mode == 'predict': if FLAGS.mode == 'predict':
with strategy.scope(): with strategy.scope():
...@@ -470,7 +466,8 @@ def custom_main(custom_callbacks=None): ...@@ -470,7 +466,8 @@ def custom_main(custom_callbacks=None):
input_meta_data['max_seq_length'], input_meta_data['max_seq_length'],
FLAGS.train_batch_size, FLAGS.train_batch_size,
is_training=True, is_training=True,
label_type=label_type) label_type=label_type,
include_sample_weights=include_sample_weights)
run_bert( run_bert(
strategy, strategy,
input_meta_data, input_meta_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment