"vscode:/vscode.git/clone" did not exist on "69e5340ccdbb858a3c2d0ed0b307a64150a3da87"
Commit 1bdf9bf4 authored by Terry Huang's avatar Terry Huang Committed by A. Unique TensorFlower
Browse files

Add ops for extracting labels from sentences (get_*_labels() functions). This...

Add ops for extracting labels from sentences (get_*_labels() functions). This CL includes the ops for extracting labels for:
- BERT's next sentence prediction task
- ALBERT's sentence order prediction task

PiperOrigin-RevId: 326087418
parent 52188963
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Module for extracting segments from sentences in documents."""
import tensorflow as tf
# Get a random tensor like `positions` and make some decisions
def _get_random(positions, random_fn):
flat_random = random_fn(
shape=tf.shape(positions.flat_values),
minval=0,
maxval=1,
dtype=tf.float32)
return positions.with_flat_values(flat_random)
# For every position j in a row, sample a position preceeding j or
# a position which is [0, j-1]
def _random_int_up_to(maxval, random_fn):
# Need to cast because the int kernel for uniform doesn't support bcast.
# We add one because maxval is exclusive, and this will get rounded down
# when we cast back to int.
float_maxval = tf.cast(maxval, tf.float32)
return tf.cast(
random_fn(
shape=tf.shape(maxval),
minval=tf.zeros_like(float_maxval),
maxval=float_maxval),
dtype=maxval.dtype)
def _random_int_from_range(minval, maxval, random_fn):
# Need to cast because the int kernel for uniform doesn't support bcast.
# We add one because maxval is exclusive, and this will get rounded down
# when we cast back to int.
float_minval = tf.cast(minval, tf.float32)
float_maxval = tf.cast(maxval, tf.float32)
return tf.cast(
random_fn(tf.shape(maxval), minval=float_minval, maxval=float_maxval),
maxval.dtype)
def _sample_from_other_batch(sentences, random_fn):
"""Samples sentences from other batches."""
# other_batch: <int64>[num_sentences]: The batch to sample from for each
# sentence.
other_batch = random_fn(
shape=[tf.size(sentences)],
minval=0,
maxval=sentences.nrows() - 1,
dtype=tf.int64)
other_batch += tf.cast(other_batch >= sentences.value_rowids(), tf.int64)
# other_sentence: <int64>[num_sentences]: The sentence within each batch
# that we sampled.
other_sentence = _random_int_up_to(
tf.gather(sentences.row_lengths(), other_batch), random_fn)
return sentences.with_values(tf.stack([other_batch, other_sentence], axis=1))
def get_sentence_order_labels(sentences,
random_threshold=0.5,
random_next_threshold=0.5,
random_fn=tf.random.uniform):
"""Extract segments and labels for sentence order prediction (SOP) task.
Extracts the segment and labels for the sentence order prediction task
defined in "ALBERT: A Lite BERT for Self-Supervised Learning of Language
Representations" (https://arxiv.org/pdf/1909.11942.pdf)
Args:
sentences: a `RaggedTensor` of shape [batch, (num_sentences)] with string
dtype.
random_threshold: (optional) A float threshold between 0 and 1, used to
determine whether to extract a random, out-of-batch sentence or a
suceeding sentence. Higher value favors succeeding sentence.
random_next_threshold: (optional) A float threshold between 0 and 1, used to
determine whether to extract either a random, out-of-batch, or succeeding
sentence or a preceeding sentence. Higher value favors preceeding
sentences.
random_fn: (optional) An op used to generate random float values.
Returns:
a tuple of (preceeding_or_random_next, is_suceeding_or_random) where:
preceeding_or_random_next: a `RaggedTensor` of strings with the same shape
as `sentences` and contains either a preceeding, suceeding, or random
out-of-batch sentence respective to its counterpart in `sentences` and
dependent on its label in `is_preceeding_or_random_next`.
is_suceeding_or_random: a `RaggedTensor` of bool values with the
same shape as `sentences` and is True if it's corresponding sentence in
`preceeding_or_random_next` is a random or suceeding sentence, False
otherwise.
"""
# Create a RaggedTensor in the same shape as sentences ([doc, (sentences)])
# whose values are index positions.
positions = tf.ragged.range(sentences.row_lengths())
row_lengths_broadcasted = tf.expand_dims(positions.row_lengths(),
-1) + 0 * positions
row_lengths_broadcasted_flat = row_lengths_broadcasted.flat_values
# Generate indices for all preceeding, succeeding and random.
# For every position j in a row, sample a position preceeding j or
# a position which is [0, j-1]
all_preceding = tf.ragged.map_flat_values(_random_int_up_to, positions,
random_fn)
# For every position j, sample a position following j, or a position
# which is [j, row_max]
all_succeeding = positions.with_flat_values(
tf.ragged.map_flat_values(_random_int_from_range,
positions.flat_values + 1,
row_lengths_broadcasted_flat, random_fn))
# Convert to format that is convenient for `gather_nd`
rows_broadcasted = tf.expand_dims(tf.range(sentences.nrows()),
-1) + 0 * positions
all_preceding_nd = tf.stack([rows_broadcasted, all_preceding], -1)
all_succeeding_nd = tf.stack([rows_broadcasted, all_succeeding], -1)
all_random_nd = _sample_from_other_batch(positions, random_fn)
# There's a few spots where there is no "preceding" or "succeeding" item (e.g.
# first and last sentences in a document). Mark where these are and we will
# patch them up to grab a random sentence from another document later.
all_zeros = tf.zeros_like(positions)
all_ones = tf.ones_like(positions)
valid_preceding_mask = tf.cast(
tf.concat([all_zeros[:, :1], all_ones[:, 1:]], -1), tf.bool)
valid_succeeding_mask = tf.cast(
tf.concat([all_ones[:, :-1], all_zeros[:, -1:]], -1), tf.bool)
# Decide what to use for the segment: (1) random, out-of-batch, (2) preceeding
# item, or (3) succeeding.
# Should get out-of-batch instead of succeeding item
should_get_random = ((_get_random(positions, random_fn) > random_threshold)
| tf.logical_not(valid_succeeding_mask))
random_or_succeeding_nd = tf.compat.v1.where(should_get_random, all_random_nd,
all_succeeding_nd)
# Choose which items should get a random succeeding item. Force positions that
# don't have a valid preceeding items to get a random succeeding item.
should_get_random_or_succeeding = (
(_get_random(positions, random_fn) > random_next_threshold)
| tf.logical_not(valid_preceding_mask))
gather_indices = tf.compat.v1.where(should_get_random_or_succeeding,
random_or_succeeding_nd, all_preceding_nd)
return (tf.gather_nd(sentences,
gather_indices), should_get_random_or_succeeding)
def get_next_sentence_labels(sentences,
random_threshold=0.5,
random_fn=tf.random.uniform):
"""Extracts the next sentence label from sentences.
Args:
sentences: A `RaggedTensor` of strings w/ shape [batch, (num_sentences)].
random_threshold: (optional) A float threshold between 0 and 1, used to
determine whether to extract a random sentence or the immediate next
sentence. Higher value favors next sentence.
random_fn: (optional) An op used to generate random float values.
Returns:
A tuple of (next_sentence_or_random, is_next_sentence) where:
next_sentence_or_random: A `Tensor` with shape [num_sentences] that
contains either the subsequent sentence of `segment_a` or a randomly
injected sentence.
is_next_sentence: A `Tensor` of bool w/ shape [num_sentences]
that contains whether or not `next_sentence_or_random` is truly a
subsequent sentence or not.
"""
# shift everyone to get the next sentence predictions positions
positions = tf.ragged.range(sentences.row_lengths())
# Shift every position down to the right.
next_sentences_pos = (positions + 1) % tf.expand_dims(sentences.row_lengths(),
1)
rows_broadcasted = tf.expand_dims(tf.range(sentences.nrows()),
-1) + 0 * positions
next_sentences_pos_nd = tf.stack([rows_broadcasted, next_sentences_pos], -1)
all_random_nd = _sample_from_other_batch(positions, random_fn)
# Mark the items that don't have a next sentence (e.g. the last
# sentences in the document). We will patch these up and force them to grab a
# random sentence from a random document.
valid_next_sentences = tf.cast(
tf.concat([
tf.ones_like(positions)[:, :-1],
tf.zeros([positions.nrows(), 1], dtype=tf.int64)
], -1), tf.bool)
is_random = ((_get_random(positions, random_fn) > random_threshold)
| tf.logical_not(valid_next_sentences))
gather_indices = tf.compat.v1.where(is_random, all_random_nd,
next_sentences_pos_nd)
return tf.gather_nd(sentences, gather_indices), tf.logical_not(is_random)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# encoding=utf-8
# Lint as: python3
"""Tests for sentence prediction labels."""
import functools
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.modeling.ops import segment_extractor
class NextSentencePredictionTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters([
dict(
test_description="all random",
sentences=[[b"Hello there.", b"La la la.", b"Such is life."],
[b"Who let the dogs out?", b"Who?."]],
expected_segment=[[
b"Who let the dogs out?", b"Who?.", b"Who let the dogs out?"
], [b"Hello there.", b"Hello there."]],
expected_labels=[
[False, False, False],
[False, False],
],
random_threshold=0.0,
),
dict(
test_description="all next",
sentences=[[b"Hello there.", b"La la la.", b"Such is life."],
[b"Who let the dogs out?", b"Who?."]],
expected_segment=[
[b"La la la.", b"Such is life.", b"Who let the dogs out?"],
[b"Who?.", b"Hello there."],
],
expected_labels=[
[True, True, False],
[True, False],
],
random_threshold=1.0,
),
])
def testNextSentencePrediction(self,
sentences,
expected_segment,
expected_labels,
random_threshold=0.5,
test_description=""):
sentences = tf.ragged.constant(sentences)
# Set seed and rig the shuffle function to a deterministic reverse function
# instead. This is so that we have consistent and deterministic results.
extracted_segment, actual_labels = (
segment_extractor.get_next_sentence_labels(
sentences,
random_threshold,
random_fn=functools.partial(
tf.random.stateless_uniform, seed=(2, 3))))
self.assertAllEqual(expected_segment, extracted_segment)
self.assertAllEqual(expected_labels, actual_labels)
class SentenceOrderLabelsTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters([
dict(
test_description="all random",
sentences=[[b"Hello there.", b"La la la.", b"Such is life."],
[b"Who let the dogs out?", b"Who?."]],
expected_segment=[[
b"Who let the dogs out?", b"Who?.", b"Who let the dogs out?"
], [b"Hello there.", b"Hello there."]],
expected_labels=[[True, True, True], [True, True]],
random_threshold=0.0,
random_next_threshold=0.0,
),
dict(
test_description="all next",
sentences=[[b"Hello there.", b"La la la.", b"Such is life."],
[b"Who let the dogs out?", b"Who?."]],
expected_segment=[[
b"La la la.", b"Such is life.", b"Who let the dogs out?"
], [b"Who?.", b"Hello there."]],
expected_labels=[[True, True, True], [True, True]],
random_threshold=1.0,
random_next_threshold=0.0,
),
dict(
test_description="all preceeding",
sentences=[[b"Hello there.", b"La la la.", b"Such is life."],
[b"Who let the dogs out?", b"Who?."]],
expected_segment=[
[b"La la la.", b"Hello there.", b"Hello there."],
[b"Who?.", b"Who let the dogs out?"],
],
expected_labels=[
[True, False, False],
[True, False],
],
random_threshold=1.0,
random_next_threshold=1.0,
),
])
def testSentenceOrderPrediction(self,
sentences,
expected_segment,
expected_labels,
random_threshold=0.5,
random_next_threshold=0.5,
test_description=""):
sentences = tf.ragged.constant(sentences)
# Set seed and rig the shuffle function to a deterministic reverse function
# instead. This is so that we have consistent and deterministic results.
extracted_segment, actual_labels = (
segment_extractor.get_sentence_order_labels(
sentences,
random_threshold=random_threshold,
random_next_threshold=random_next_threshold,
random_fn=functools.partial(
tf.random.stateless_uniform, seed=(2, 3))))
self.assertAllEqual(expected_segment, extracted_segment)
self.assertAllEqual(expected_labels, actual_labels)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment