Commit 296d7494 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 369459880
parent b51b9342
......@@ -33,6 +33,70 @@ def _check_if_tf_text_installed():
"'tensorflow-text-nightly'.")
def _iterative_vectorized_fair_share(capacity: tf.Tensor,
limit: Union[int, tf.Tensor]):
"""Iterative algorithm for max min fairness algorithm.
Reference: https://en.wikipedia.org/wiki/Max-min_fairness
The idea is for each example with some number of segments and a limit of
total segment length allowed, we grant each segment a fair share of the
limit. For example, if every segment has the same length, no work to do.
If one segment has below average length, its share will be spilt to others
fairly. In this way, the longest segment will be the shortest among all
potential capacity assignments.
Args:
capacity: A rank-2 Tensor of #Segments x Batch.
limit: The largest permissible number of tokens in total across one example.
Returns:
A rank-2 Tensor with new segment capacity assignment such that
the total number of tokens in each example does not exceed the `limit`.
"""
# Firstly, we calculate the lower bound of the capacity assignment.
per_seg_limit = limit // capacity.shape[0]
limit_mask = tf.ones(capacity.shape, dtype=tf.int64) * per_seg_limit
lower_bound = tf.minimum(capacity, limit_mask)
# This step makes up the capacity that already statisfy the capacity limit.
remaining_cap_sum = limit - tf.math.reduce_sum(lower_bound, axis=0)
remaining_cap_mat = capacity - lower_bound
new_cap = lower_bound + remaining_cap_mat * tf.cast(
tf.math.reduce_sum(remaining_cap_mat, axis=0) <= remaining_cap_sum,
tf.int64)
# Process iteratively. This step is O(#segments), see analysis below.
while True:
remaining_limit = limit - tf.math.reduce_sum(new_cap, axis=0)
remaining_cap = capacity - new_cap
masked_remaining_slots = tf.cast(remaining_cap > 0, tf.int64)
remaining_cap_col_slots = tf.reduce_sum(masked_remaining_slots, axis=0)
masked_remaining_limit = tf.cast(remaining_cap_col_slots > 0,
tf.int64) * remaining_limit
# Total remaining segment limit is different for each example.
per_seg_limit = masked_remaining_limit // (
tf.cast(remaining_cap_col_slots <= 0, tf.int64) +
remaining_cap_col_slots) # +1 to make sure 0/0 = 0
# Note that for each step, there is at least one more segment being
# fulfilled or the loop is finished.
# The idea is, if remaining per example limit > smallest among segments,
# the smallest segment ask is fullfilled. Otherwise, all remaining segments
# are truncated, the assignment is finished.
if tf.math.reduce_sum(per_seg_limit) > 0:
remaining_slots_mat = tf.cast(remaining_cap > 0, tf.int64)
new_cap = new_cap + remaining_slots_mat * per_seg_limit
else:
# Leftover assignment of limit that is smaller than #slots.
new_remained_assignment_mask = tf.cast(
(tf.cumsum(masked_remaining_slots, axis=0) <= masked_remaining_limit)
& (masked_remaining_slots > 0), tf.int64)
new_cap = new_cap + new_remained_assignment_mask
break
return new_cap
def round_robin_truncate_inputs(
inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]],
limit: Union[int, tf.Tensor],
......@@ -74,7 +138,14 @@ def round_robin_truncate_inputs(
return [_truncate_row_lengths(inputs[0], quota_a),
_truncate_row_lengths(inputs[1], quota_b)]
else:
raise ValueError("Must pass 1 or 2 inputs")
# Note that we don't merge with the 2 input case because the full algorithm
# is more expensive.
capacity = tf.stack([rt.row_lengths() for rt in inputs]) # #Segments x B
new_capacity = _iterative_vectorized_fair_share(capacity, limit)
return [
_truncate_row_lengths(inputs[i], new_capacity[i])
for i in range(capacity.shape[0])
]
def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
......
......@@ -26,15 +26,15 @@ from official.nlp.modeling.layers import text_layers
class RoundRobinTruncatorTest(tf.test.TestCase):
def test_correct_outputs(self):
def test_input(start, lengths):
return tf.ragged.constant([[start + 10*j + i for i in range(length)]
def _test_input(self, start, lengths):
return tf.ragged.constant([[start + 10 * j + i
for i in range(length)]
for j, length in enumerate(lengths)],
dtype=tf.int32)
def test_single_segment(self):
# Single segment.
single_input = test_input(11, [4, 5, 6])
single_input = self._test_input(11, [4, 5, 6])
expected_single_output = tf.ragged.constant(
[[11, 12, 13, 14],
[21, 22, 23, 24, 25],
......@@ -50,9 +50,9 @@ class RoundRobinTruncatorTest(tf.test.TestCase):
self.assertIsInstance(actual_single_list_output, list)
self.assertAllEqual(expected_single_output, actual_single_list_output[0])
# Two segments.
input_a = test_input(111, [1, 2, 2, 3, 4, 5])
input_b = test_input(211, [1, 3, 4, 2, 2, 5])
def test_two_segments(self):
input_a = self._test_input(111, [1, 2, 2, 3, 4, 5])
input_b = self._test_input(211, [1, 3, 4, 2, 2, 5])
expected_a = tf.ragged.constant(
[[111],
[121, 122],
......@@ -74,6 +74,51 @@ class RoundRobinTruncatorTest(tf.test.TestCase):
self.assertAllEqual(expected_a, actual_a)
self.assertAllEqual(expected_b, actual_b)
def test_three_segments(self):
input_a = self._test_input(111, [1, 2, 2, 3, 4, 5, 1])
input_b = self._test_input(211, [1, 3, 4, 2, 2, 5, 8])
input_c = self._test_input(311, [1, 3, 4, 2, 2, 5, 10])
seg_limit = 8
expected_a = tf.ragged.constant([
[111],
[121, 122],
[131, 132],
[141, 142, 143],
[151, 152, 153, 154],
[161, 162, 163], # Truncated
[171]
])
expected_b = tf.ragged.constant([
[211],
[221, 222, 223],
[231, 232, 233], # Truncated
[241, 242],
[251, 252],
[261, 262, 263], # Truncated
[271, 272, 273, 274] # Truncated
])
expected_c = tf.ragged.constant([
[311],
[321, 322, 323],
[331, 332, 333], # Truncated
[341, 342],
[351, 352],
[361, 362], # Truncated
[371, 372, 373] # Truncated
])
actual_a, actual_b, actual_c = text_layers.round_robin_truncate_inputs(
[input_a, input_b, input_c], limit=seg_limit)
self.assertAllEqual(expected_a, actual_a)
self.assertAllEqual(expected_b, actual_b)
self.assertAllEqual(expected_c, actual_c)
input_cap = tf.math.reduce_sum(
tf.stack([rt.row_lengths() for rt in [input_a, input_b, input_c]]),
axis=0)
per_example_usage = tf.math.reduce_sum(
tf.stack([rt.row_lengths() for rt in [actual_a, actual_b, actual_c]]),
axis=0)
self.assertTrue(all(per_example_usage <= tf.minimum(seg_limit, input_cap)))
# This test covers the in-process behavior of a BertTokenizer layer.
# For saving, restoring, and the restored behavior (incl. shape inference),
......@@ -397,8 +442,7 @@ class BertPackInputsTest(tf.test.TestCase):
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]))
# Three inputs has not been supported for round_robin so far.
with self.assertRaisesRegex(ValueError, "Must pass 1 or 2 inputs"):
# Three inputs. rank 3.
bert_inputs = bpi([
tf.ragged.constant([[[111], [112, 113]],
[[121, 122, 123], [124, 125, 126], [127, 128]]]),
......@@ -407,6 +451,10 @@ class BertPackInputsTest(tf.test.TestCase):
tf.ragged.constant([[[311, 312], [313]],
[[321, 322], [323, 324, 325], [326, 327, 328]]])
])
self.assertAllEqual(
bert_inputs["input_word_ids"],
tf.constant([[1001, 111, 112, 1002, 211, 212, 1002, 311, 312, 1002],
[1001, 121, 122, 1002, 221, 222, 1002, 321, 322, 1002]]))
def test_waterfall_correct_outputs(self):
bpi = text_layers.BertPackInputs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment