"src/vscode:/vscode.git/clone" did not exist on "63a0c9e5f7a56d49dc142e643b7237fc9082ff59"
Commit d495e481 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Exclude the label_id used for padding in the label space for tagging task.

PiperOrigin-RevId: 318149287
parent aca51294
......@@ -35,16 +35,37 @@ class TaggingConfig(cfg.TaskConfig):
hub_module_url: str = ''
model: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
# The number of real labels. Note that a word may be tokenized into
# multiple word_pieces tokens, and we asssume the real label id (non-negative)
# is assigned to the first token of the word, and a negative label id is
# assigned to the remaining tokens. The negative label id will not contribute
# to loss and metrics.
num_classes: int = 0
# The ignored label id will not contribute to loss.
# A word may be tokenized into multiple word_pieces tokens, and we usually
# assign the real label id for the first token of the word, and
# `ignore_label_id` for the remaining tokens.
ignore_label_id: int = 0
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
def _masked_labels_and_weights(y_true):
"""Masks negative values from token level labels.
Args:
y_true: Token labels, typically shape (batch_size, seq_len), where tokens
with negative labels should be ignored during loss/accuracy calculation.
Returns:
(masked_y_true, masked_weights) where `masked_y_true` is the input
with each negative label replaced with zero and `masked_weights` is 0.0
where negative labels were replaced and 1.0 for original labels.
"""
# Ignore the classes of tokens with negative values.
mask = tf.greater_equal(y_true, 0)
# Replace negative labels, which are out of bounds for some loss functions,
# with zero.
masked_y_true = tf.where(mask, y_true, 0)
return masked_y_true, tf.cast(mask, tf.float32)
@base_task.register_task_cls(TaggingConfig)
class TaggingTask(base_task.Task):
"""Task object for tagging (e.g., NER or POS)."""
......@@ -79,14 +100,11 @@ class TaggingTask(base_task.Task):
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
model_outputs = tf.cast(model_outputs, tf.float32)
masked_labels, masked_weights = _masked_labels_and_weights(labels)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True)
# `ignore_label_id` will not contribute to loss.
label_weights = tf.cast(
tf.not_equal(labels, self.task_config.ignore_label_id),
dtype=tf.float32)
numerator_loss = tf.reduce_sum(loss * label_weights)
denominator_loss = tf.reduce_sum(label_weights)
masked_labels, model_outputs, from_logits=True)
numerator_loss = tf.reduce_sum(loss * masked_weights)
denominator_loss = tf.reduce_sum(masked_weights)
loss = tf.math.divide_no_nan(numerator_loss, denominator_loss)
return loss
......@@ -100,7 +118,13 @@ class TaggingTask(base_task.Task):
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
y = tf.ones((1, params.seq_length), dtype=tf.int32)
# Include some label_id as -1, which will be ignored in loss/metrics.
y = tf.random.uniform(
shape=(1, params.seq_length),
minval=-1,
maxval=self.task_config.num_classes,
dtype=tf.dtypes.int32)
return (x, y)
dataset = tf.data.Dataset.range(1)
......@@ -118,19 +142,13 @@ class TaggingTask(base_task.Task):
return [tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
def process_metrics(self, metrics, labels, model_outputs):
# `ignore_label_id` will not contribute to metrics.
sample_weight = tf.cast(
tf.not_equal(labels, self.task_config.ignore_label_id),
dtype=tf.float32)
masked_labels, masked_weights = _masked_labels_and_weights(labels)
for metric in metrics:
metric.update_state(labels, model_outputs, sample_weight)
metric.update_state(masked_labels, model_outputs, masked_weights)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
# `ignore_label_id` will not contribute to metrics.
sample_weight = tf.cast(
tf.not_equal(labels, self.task_config.ignore_label_id),
dtype=tf.float32)
compiled_metrics.update_state(labels, model_outputs, sample_weight)
masked_labels, masked_weights = _masked_labels_and_weights(labels)
compiled_metrics.update_state(masked_labels, model_outputs, masked_weights)
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment