Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d4dd827f
Commit
d4dd827f
authored
Feb 26, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 359773626
parent
52531231
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
9 deletions
+28
-9
official/nlp/tools/export_tfhub_lib_test.py
official/nlp/tools/export_tfhub_lib_test.py
+28
-9
No files found.
official/nlp/tools/export_tfhub_lib_test.py
View file @
d4dd827f
...
...
@@ -766,12 +766,15 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def
test_preprocessing_for_mlm
(
self
,
use_bert
):
"""Combines both SavedModel types and TF.text helpers for MLM."""
# Create the preprocessing SavedModel with a [MASK] token.
non_special_tokens
=
[
"hello"
,
"world"
,
"nice"
,
"movie"
,
"great"
,
"actors"
,
"quick"
,
"fox"
,
"lazy"
,
"dog"
]
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
[
"d"
,
"ef"
,
"abc"
,
"xy"
]
,
do_lower_case
=
True
,
non_special_tokens
,
do_lower_case
=
True
,
tokenize_with_offsets
=
use_bert
,
# TODO(b/149576200): drop this.
experimental_disable_assert
=
True
,
# TODO(b/175369555): drop this.
add_mask_token
=
True
,
use_sp_model
=
not
use_bert
))
vocab_size
=
4
+
5
if
use_bert
else
4
+
7
vocab_size
=
len
(
non_special_tokens
)
+
(
5
if
use_bert
else
7
)
# Create the encoder SavedModel with an .mlm subobject.
hidden_size
=
16
...
...
@@ -811,12 +814,12 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
(
mask_id
,
4
)
# A batch of 3 segment pairs.
raw_segments
=
[
tf
.
constant
([
"hello"
,
"nice movie"
,
"quick
brown
fox"
]),
raw_segments
=
[
tf
.
constant
([
"hello"
,
"nice movie"
,
"quick fox"
]),
tf
.
constant
([
"world"
,
"great actors"
,
"lazy dog"
])]
batch_size
=
3
# Misc hyperparameters.
seq_length
=
1
2
seq_length
=
1
0
max_selections_per_seq
=
2
# Tokenize inputs.
...
...
@@ -836,12 +839,12 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
input_ids
=
input_ids
,
item_selector
=
text
.
RandomItemSelector
(
max_selections_per_seq
,
selection_rate
=
0.
1
5
,
selection_rate
=
0.5
,
# Adjusted for the short test examples.
unselectable_ids
=
[
start_of_sequence_id
,
end_of_segment_id
]),
mask_values_chooser
=
text
.
MaskValuesChooser
(
vocab_size
=
vocab_size
,
mask_token
=
mask_id
,
mask_token_rate
=
0.8
,
random_token_rate
=
0.
1
))
mask_values_chooser
=
text
.
MaskValuesChooser
(
vocab_size
=
vocab_size
,
mask_token
=
mask_id
,
# Always put [MASK] to have a predictable result.
mask_token_rate
=
1.0
,
random_token_rate
=
0.
0
))
# Pad to fixed-length Transformer encoder inputs.
input_word_ids
,
_
=
text
.
pad_model_inputs
(
masked_input_ids
,
seq_length
,
...
...
@@ -854,6 +857,22 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
masked_lm_positions
=
tf
.
cast
(
masked_lm_positions
,
tf
.
int32
)
num_predictions
=
int
(
tf
.
shape
(
masked_lm_positions
)[
1
])
# Test transformer inputs.
self
.
assertEqual
(
num_predictions
,
max_selections_per_seq
)
expected_word_ids
=
np
.
array
([
# [CLS] hello [SEP] world [SEP]
[
2
,
5
,
3
,
6
,
3
,
0
,
0
,
0
,
0
,
0
],
# [CLS] nice movie [SEP] great actors [SEP]
[
2
,
7
,
8
,
3
,
9
,
10
,
3
,
0
,
0
,
0
],
# [CLS] brown fox [SEP] lazy dog [SEP]
[
2
,
11
,
12
,
3
,
13
,
14
,
3
,
0
,
0
,
0
]])
for
i
in
range
(
batch_size
):
for
j
in
range
(
num_predictions
):
k
=
int
(
masked_lm_positions
[
i
,
j
])
if
k
!=
0
:
expected_word_ids
[
i
,
k
]
=
4
# [MASK]
self
.
assertAllEqual
(
input_word_ids
,
expected_word_ids
)
# Call the MLM head of the Transformer encoder.
mlm_inputs
=
dict
(
input_word_ids
=
input_word_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment