Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
296d7494
Commit
296d7494
authored
Apr 20, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 369459880
parent
b51b9342
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
140 additions
and
21 deletions
+140
-21
official/nlp/modeling/layers/text_layers.py
official/nlp/modeling/layers/text_layers.py
+72
-1
official/nlp/modeling/layers/text_layers_test.py
official/nlp/modeling/layers/text_layers_test.py
+68
-20
No files found.
official/nlp/modeling/layers/text_layers.py
View file @
296d7494
...
@@ -33,6 +33,70 @@ def _check_if_tf_text_installed():
...
@@ -33,6 +33,70 @@ def _check_if_tf_text_installed():
"'tensorflow-text-nightly'."
)
"'tensorflow-text-nightly'."
)
def
_iterative_vectorized_fair_share
(
capacity
:
tf
.
Tensor
,
limit
:
Union
[
int
,
tf
.
Tensor
]):
"""Iterative algorithm for max min fairness algorithm.
Reference: https://en.wikipedia.org/wiki/Max-min_fairness
The idea is for each example with some number of segments and a limit of
total segment length allowed, we grant each segment a fair share of the
limit. For example, if every segment has the same length, no work to do.
If one segment has below average length, its share will be spilt to others
fairly. In this way, the longest segment will be the shortest among all
potential capacity assignments.
Args:
capacity: A rank-2 Tensor of #Segments x Batch.
limit: The largest permissible number of tokens in total across one example.
Returns:
A rank-2 Tensor with new segment capacity assignment such that
the total number of tokens in each example does not exceed the `limit`.
"""
# Firstly, we calculate the lower bound of the capacity assignment.
per_seg_limit
=
limit
//
capacity
.
shape
[
0
]
limit_mask
=
tf
.
ones
(
capacity
.
shape
,
dtype
=
tf
.
int64
)
*
per_seg_limit
lower_bound
=
tf
.
minimum
(
capacity
,
limit_mask
)
# This step makes up the capacity that already statisfy the capacity limit.
remaining_cap_sum
=
limit
-
tf
.
math
.
reduce_sum
(
lower_bound
,
axis
=
0
)
remaining_cap_mat
=
capacity
-
lower_bound
new_cap
=
lower_bound
+
remaining_cap_mat
*
tf
.
cast
(
tf
.
math
.
reduce_sum
(
remaining_cap_mat
,
axis
=
0
)
<=
remaining_cap_sum
,
tf
.
int64
)
# Process iteratively. This step is O(#segments), see analysis below.
while
True
:
remaining_limit
=
limit
-
tf
.
math
.
reduce_sum
(
new_cap
,
axis
=
0
)
remaining_cap
=
capacity
-
new_cap
masked_remaining_slots
=
tf
.
cast
(
remaining_cap
>
0
,
tf
.
int64
)
remaining_cap_col_slots
=
tf
.
reduce_sum
(
masked_remaining_slots
,
axis
=
0
)
masked_remaining_limit
=
tf
.
cast
(
remaining_cap_col_slots
>
0
,
tf
.
int64
)
*
remaining_limit
# Total remaining segment limit is different for each example.
per_seg_limit
=
masked_remaining_limit
//
(
tf
.
cast
(
remaining_cap_col_slots
<=
0
,
tf
.
int64
)
+
remaining_cap_col_slots
)
# +1 to make sure 0/0 = 0
# Note that for each step, there is at least one more segment being
# fulfilled or the loop is finished.
# The idea is, if remaining per example limit > smallest among segments,
# the smallest segment ask is fullfilled. Otherwise, all remaining segments
# are truncated, the assignment is finished.
if
tf
.
math
.
reduce_sum
(
per_seg_limit
)
>
0
:
remaining_slots_mat
=
tf
.
cast
(
remaining_cap
>
0
,
tf
.
int64
)
new_cap
=
new_cap
+
remaining_slots_mat
*
per_seg_limit
else
:
# Leftover assignment of limit that is smaller than #slots.
new_remained_assignment_mask
=
tf
.
cast
(
(
tf
.
cumsum
(
masked_remaining_slots
,
axis
=
0
)
<=
masked_remaining_limit
)
&
(
masked_remaining_slots
>
0
),
tf
.
int64
)
new_cap
=
new_cap
+
new_remained_assignment_mask
break
return
new_cap
def
round_robin_truncate_inputs
(
def
round_robin_truncate_inputs
(
inputs
:
Union
[
tf
.
RaggedTensor
,
List
[
tf
.
RaggedTensor
]],
inputs
:
Union
[
tf
.
RaggedTensor
,
List
[
tf
.
RaggedTensor
]],
limit
:
Union
[
int
,
tf
.
Tensor
],
limit
:
Union
[
int
,
tf
.
Tensor
],
...
@@ -74,7 +138,14 @@ def round_robin_truncate_inputs(
...
@@ -74,7 +138,14 @@ def round_robin_truncate_inputs(
return
[
_truncate_row_lengths
(
inputs
[
0
],
quota_a
),
return
[
_truncate_row_lengths
(
inputs
[
0
],
quota_a
),
_truncate_row_lengths
(
inputs
[
1
],
quota_b
)]
_truncate_row_lengths
(
inputs
[
1
],
quota_b
)]
else
:
else
:
raise
ValueError
(
"Must pass 1 or 2 inputs"
)
# Note that we don't merge with the 2 input case because the full algorithm
# is more expensive.
capacity
=
tf
.
stack
([
rt
.
row_lengths
()
for
rt
in
inputs
])
# #Segments x B
new_capacity
=
_iterative_vectorized_fair_share
(
capacity
,
limit
)
return
[
_truncate_row_lengths
(
inputs
[
i
],
new_capacity
[
i
])
for
i
in
range
(
capacity
.
shape
[
0
])
]
def
_truncate_row_lengths
(
ragged_tensor
:
tf
.
RaggedTensor
,
def
_truncate_row_lengths
(
ragged_tensor
:
tf
.
RaggedTensor
,
...
...
official/nlp/modeling/layers/text_layers_test.py
View file @
296d7494
...
@@ -26,15 +26,15 @@ from official.nlp.modeling.layers import text_layers
...
@@ -26,15 +26,15 @@ from official.nlp.modeling.layers import text_layers
class
RoundRobinTruncatorTest
(
tf
.
test
.
TestCase
):
class
RoundRobinTruncatorTest
(
tf
.
test
.
TestCase
):
def
test_correct_outputs
(
self
):
def
_test_input
(
self
,
start
,
lengths
):
return
tf
.
ragged
.
constant
([[
start
+
10
*
j
+
i
def
test_input
(
start
,
lengths
):
for
i
in
range
(
length
)]
return
tf
.
ragged
.
constant
([[
start
+
10
*
j
+
i
for
i
in
range
(
length
)]
for
j
,
length
in
enumerate
(
lengths
)],
for
j
,
length
in
enumerate
(
lengths
)],
dtype
=
tf
.
int32
)
dtype
=
tf
.
int32
)
def
test_single_segment
(
self
):
# Single segment.
# Single segment.
single_input
=
test_input
(
11
,
[
4
,
5
,
6
])
single_input
=
self
.
_
test_input
(
11
,
[
4
,
5
,
6
])
expected_single_output
=
tf
.
ragged
.
constant
(
expected_single_output
=
tf
.
ragged
.
constant
(
[[
11
,
12
,
13
,
14
],
[[
11
,
12
,
13
,
14
],
[
21
,
22
,
23
,
24
,
25
],
[
21
,
22
,
23
,
24
,
25
],
...
@@ -50,9 +50,9 @@ class RoundRobinTruncatorTest(tf.test.TestCase):
...
@@ -50,9 +50,9 @@ class RoundRobinTruncatorTest(tf.test.TestCase):
self
.
assertIsInstance
(
actual_single_list_output
,
list
)
self
.
assertIsInstance
(
actual_single_list_output
,
list
)
self
.
assertAllEqual
(
expected_single_output
,
actual_single_list_output
[
0
])
self
.
assertAllEqual
(
expected_single_output
,
actual_single_list_output
[
0
])
# T
wo
segments
.
def
test_t
wo
_
segments
(
self
):
input_a
=
test_input
(
111
,
[
1
,
2
,
2
,
3
,
4
,
5
])
input_a
=
self
.
_
test_input
(
111
,
[
1
,
2
,
2
,
3
,
4
,
5
])
input_b
=
test_input
(
211
,
[
1
,
3
,
4
,
2
,
2
,
5
])
input_b
=
self
.
_
test_input
(
211
,
[
1
,
3
,
4
,
2
,
2
,
5
])
expected_a
=
tf
.
ragged
.
constant
(
expected_a
=
tf
.
ragged
.
constant
(
[[
111
],
[[
111
],
[
121
,
122
],
[
121
,
122
],
...
@@ -74,6 +74,51 @@ class RoundRobinTruncatorTest(tf.test.TestCase):
...
@@ -74,6 +74,51 @@ class RoundRobinTruncatorTest(tf.test.TestCase):
self
.
assertAllEqual
(
expected_a
,
actual_a
)
self
.
assertAllEqual
(
expected_a
,
actual_a
)
self
.
assertAllEqual
(
expected_b
,
actual_b
)
self
.
assertAllEqual
(
expected_b
,
actual_b
)
def
test_three_segments
(
self
):
input_a
=
self
.
_test_input
(
111
,
[
1
,
2
,
2
,
3
,
4
,
5
,
1
])
input_b
=
self
.
_test_input
(
211
,
[
1
,
3
,
4
,
2
,
2
,
5
,
8
])
input_c
=
self
.
_test_input
(
311
,
[
1
,
3
,
4
,
2
,
2
,
5
,
10
])
seg_limit
=
8
expected_a
=
tf
.
ragged
.
constant
([
[
111
],
[
121
,
122
],
[
131
,
132
],
[
141
,
142
,
143
],
[
151
,
152
,
153
,
154
],
[
161
,
162
,
163
],
# Truncated
[
171
]
])
expected_b
=
tf
.
ragged
.
constant
([
[
211
],
[
221
,
222
,
223
],
[
231
,
232
,
233
],
# Truncated
[
241
,
242
],
[
251
,
252
],
[
261
,
262
,
263
],
# Truncated
[
271
,
272
,
273
,
274
]
# Truncated
])
expected_c
=
tf
.
ragged
.
constant
([
[
311
],
[
321
,
322
,
323
],
[
331
,
332
,
333
],
# Truncated
[
341
,
342
],
[
351
,
352
],
[
361
,
362
],
# Truncated
[
371
,
372
,
373
]
# Truncated
])
actual_a
,
actual_b
,
actual_c
=
text_layers
.
round_robin_truncate_inputs
(
[
input_a
,
input_b
,
input_c
],
limit
=
seg_limit
)
self
.
assertAllEqual
(
expected_a
,
actual_a
)
self
.
assertAllEqual
(
expected_b
,
actual_b
)
self
.
assertAllEqual
(
expected_c
,
actual_c
)
input_cap
=
tf
.
math
.
reduce_sum
(
tf
.
stack
([
rt
.
row_lengths
()
for
rt
in
[
input_a
,
input_b
,
input_c
]]),
axis
=
0
)
per_example_usage
=
tf
.
math
.
reduce_sum
(
tf
.
stack
([
rt
.
row_lengths
()
for
rt
in
[
actual_a
,
actual_b
,
actual_c
]]),
axis
=
0
)
self
.
assertTrue
(
all
(
per_example_usage
<=
tf
.
minimum
(
seg_limit
,
input_cap
)))
# This test covers the in-process behavior of a BertTokenizer layer.
# This test covers the in-process behavior of a BertTokenizer layer.
# For saving, restoring, and the restored behavior (incl. shape inference),
# For saving, restoring, and the restored behavior (incl. shape inference),
...
@@ -397,16 +442,19 @@ class BertPackInputsTest(tf.test.TestCase):
...
@@ -397,16 +442,19 @@ class BertPackInputsTest(tf.test.TestCase):
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
0
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
]]))
[
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
]]))
# Three inputs has not been supported for round_robin so far.
# Three inputs. rank 3.
with
self
.
assertRaisesRegex
(
ValueError
,
"Must pass 1 or 2 inputs"
):
bert_inputs
=
bpi
([
bert_inputs
=
bpi
([
tf
.
ragged
.
constant
([[[
111
],
[
112
,
113
]],
tf
.
ragged
.
constant
([[[
111
],
[
112
,
113
]],
[[
121
,
122
,
123
],
[
124
,
125
,
126
],
[
127
,
128
]]]),
[[
121
,
122
,
123
],
[
124
,
125
,
126
],
[
127
,
128
]]]),
tf
.
ragged
.
constant
([[[
211
,
212
],
[
213
]],
tf
.
ragged
.
constant
([[[
211
,
212
],
[
213
]],
[[
221
,
222
],
[
223
,
224
,
225
],
[
226
,
227
,
228
]]]),
[[
221
,
222
],
[
223
,
224
,
225
],
[
226
,
227
,
228
]]]),
tf
.
ragged
.
constant
([[[
311
,
312
],
[
313
]],
tf
.
ragged
.
constant
([[[
311
,
312
],
[
313
]],
[[
321
,
322
],
[
323
,
324
,
325
],
[
326
,
327
,
328
]]])
[[
321
,
322
],
[
323
,
324
,
325
],
[
326
,
327
,
328
]]])
])
])
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
],
tf
.
constant
([[
1001
,
111
,
112
,
1002
,
211
,
212
,
1002
,
311
,
312
,
1002
],
[
1001
,
121
,
122
,
1002
,
221
,
222
,
1002
,
321
,
322
,
1002
]]))
def
test_waterfall_correct_outputs
(
self
):
def
test_waterfall_correct_outputs
(
self
):
bpi
=
text_layers
.
BertPackInputs
(
bpi
=
text_layers
.
BertPackInputs
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment