Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
aa870ff4
Commit
aa870ff4
authored
Feb 27, 2022
by
Yuexin Wu
Committed by
A. Unique TensorFlower
Feb 27, 2022
Browse files
Fix export_tfhub module with BertV2.
PiperOrigin-RevId: 431236080
parent
34a93745
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
303 additions
and
199 deletions
+303
-199
official/nlp/tools/export_tfhub_lib.py
official/nlp/tools/export_tfhub_lib.py
+47
-30
official/nlp/tools/export_tfhub_lib_test.py
official/nlp/tools/export_tfhub_lib_test.py
+256
-169
No files found.
official/nlp/tools/export_tfhub_lib.py
View file @
aa870ff4
...
...
@@ -84,13 +84,13 @@ def _create_model(
"""Creates the model to export and the model to restore the checkpoint.
Args:
bert_config: A legacy `BertConfig` to create a `BertEncoder` object.
Exactly
one of encoder_config and bert_config must be set.
bert_config: A legacy `BertConfig` to create a `BertEncoder` object.
Exactly
one of encoder_config and bert_config must be set.
encoder_config: An `EncoderConfig` to create an encoder of the configured
type (`BertEncoder` or other).
with_mlm: A bool to control the second component of the result.
If True,
will create a `BertPretrainerV2` object; otherwise, will
create a
`BertEncoder` object.
with_mlm: A bool to control the second component of the result.
If True,
will create a `BertPretrainerV2` object; otherwise, will
create a
`BertEncoder` object.
Returns:
A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2`
...
...
@@ -110,7 +110,11 @@ def _create_model(
# Convert from list of named inputs to dict of inputs keyed by name.
# Only the latter accepts a dict of inputs after restoring from SavedModel.
encoder_inputs_dict
=
{
x
.
name
:
x
for
x
in
encoder
.
inputs
}
if
isinstance
(
encoder
.
inputs
,
list
)
or
isinstance
(
encoder
.
inputs
,
tuple
):
encoder_inputs_dict
=
{
x
.
name
:
x
for
x
in
encoder
.
inputs
}
else
:
# encoder.inputs by default is dict for BertEncoderV2.
encoder_inputs_dict
=
encoder
.
inputs
encoder_output_dict
=
encoder
(
encoder_inputs_dict
)
# For interchangeability with other text representations,
# add "default" as an alias for BERT's whole-input reptesentations.
...
...
@@ -206,26 +210,28 @@ def export_model(export_path: Text,
encoder_config: An optional `encoders.EncoderConfig` object.
model_checkpoint_path: The path to the checkpoint.
with_mlm: Whether to export the additional mlm sub-object.
copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer
used
in the next sentence prediction task to the encoder.
copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer
used
in the next sentence prediction task to the encoder.
vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None.
Exactly
one of vocab_file and sp_model_file must be set.
sp_model_file: The path to the sentencepiece model file, or None.
Exactly
one of vocab_file and sp_model_file must be set.
do_lower_case: Whether to lower-case text before tokenization.
"""
if
with_mlm
:
core_model
,
pretrainer
=
_create_model
(
bert_config
=
bert_config
,
encoder_config
=
encoder_config
,
with_mlm
=
with_mlm
)
core_model
,
pretrainer
=
_create_model
(
bert_config
=
bert_config
,
encoder_config
=
encoder_config
,
with_mlm
=
with_mlm
)
encoder
=
pretrainer
.
encoder_network
# It supports both the new pretrainer checkpoint produced by TF-NLP and
# the checkpoint converted from TF1 (original BERT, SmallBERTs).
checkpoint_items
=
pretrainer
.
checkpoint_items
checkpoint
=
tf
.
train
.
Checkpoint
(
**
checkpoint_items
)
else
:
core_model
,
encoder
=
_create_model
(
bert_config
=
bert_config
,
encoder_config
=
encoder_config
,
with_mlm
=
with_mlm
)
core_model
,
encoder
=
_create_model
(
bert_config
=
bert_config
,
encoder_config
=
encoder_config
,
with_mlm
=
with_mlm
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
,
# Legacy checkpoints.
encoder
=
encoder
)
...
...
@@ -279,21 +285,26 @@ class BertPackInputsSavedModelWrapper(tf.train.Checkpoint):
# overridable. Having this dynamically determined default argument
# requires self.__call__ to be defined in this indirect way.
default_seq_length
=
bert_pack_inputs
.
seq_length
@
tf
.
function
(
autograph
=
False
)
def
call
(
inputs
,
seq_length
=
default_seq_length
):
return
layers
.
BertPackInputs
.
bert_pack_inputs
(
inputs
,
seq_length
=
seq_length
,
inputs
,
seq_length
=
seq_length
,
start_of_sequence_id
=
bert_pack_inputs
.
start_of_sequence_id
,
end_of_segment_id
=
bert_pack_inputs
.
end_of_segment_id
,
padding_id
=
bert_pack_inputs
.
padding_id
)
self
.
__call__
=
call
for
ragged_rank
in
range
(
1
,
3
):
for
num_segments
in
range
(
1
,
3
):
_
=
self
.
__call__
.
get_concrete_function
(
[
tf
.
RaggedTensorSpec
([
None
]
*
(
ragged_rank
+
1
),
dtype
=
tf
.
int32
)
for
_
in
range
(
num_segments
)],
seq_length
=
tf
.
TensorSpec
([],
tf
.
int32
))
_
=
self
.
__call__
.
get_concrete_function
([
tf
.
RaggedTensorSpec
([
None
]
*
(
ragged_rank
+
1
),
dtype
=
tf
.
int32
)
for
_
in
range
(
num_segments
)
],
seq_length
=
tf
.
TensorSpec
(
[],
tf
.
int32
))
def
create_preprocessing
(
*
,
...
...
@@ -311,14 +322,14 @@ def create_preprocessing(*,
Args:
vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None.
Exactly
one of vocab_file and sp_model_file must be set.
This determines the type
of tokenzer that is used.
sp_model_file: The path to the sentencepiece model file, or None.
Exactly
one of vocab_file and sp_model_file must be set.
This determines the type
of tokenzer that is used.
do_lower_case: Whether to do lower case.
tokenize_with_offsets: Whether to include the .tokenize_with_offsets
subobject.
default_seq_length: The sequence length of preprocessing results from
root
callable. This is also the default sequence length for the
default_seq_length: The sequence length of preprocessing results from
root
callable. This is also the default sequence length for the
bert_pack_inputs subobject.
Returns:
...
...
@@ -378,7 +389,8 @@ def create_preprocessing(*,
def
_move_to_tmpdir
(
file_path
:
Optional
[
Text
],
tmpdir
:
Text
)
->
Optional
[
Text
]:
"""Returns new path with same basename and hash of original path."""
if
file_path
is
None
:
return
None
if
file_path
is
None
:
return
None
olddir
,
filename
=
os
.
path
.
split
(
file_path
)
hasher
=
hashlib
.
sha1
()
hasher
.
update
(
olddir
.
encode
(
"utf-8"
))
...
...
@@ -460,12 +472,17 @@ def _check_no_assert(saved_model_path):
assert_nodes
=
[]
graph_def
=
saved_model
.
meta_graphs
[
0
].
graph_def
assert_nodes
+=
[
"node '{}' in global graph"
.
format
(
n
.
name
)
for
n
in
graph_def
.
node
if
n
.
op
==
"Assert"
]
assert_nodes
+=
[
"node '{}' in global graph"
.
format
(
n
.
name
)
for
n
in
graph_def
.
node
if
n
.
op
==
"Assert"
]
for
fdef
in
graph_def
.
library
.
function
:
assert_nodes
+=
[
"node '{}' in function '{}'"
.
format
(
n
.
name
,
fdef
.
signature
.
name
)
for
n
in
fdef
.
node_def
if
n
.
op
==
"Assert"
]
for
n
in
fdef
.
node_def
if
n
.
op
==
"Assert"
]
if
assert_nodes
:
raise
AssertionError
(
"Internal tool error: "
...
...
official/nlp/tools/export_tfhub_lib_test.py
View file @
aa870ff4
...
...
@@ -32,9 +32,26 @@ from official.nlp.modeling import models
from
official.nlp.tools
import
export_tfhub_lib
def
_get_bert_config_or_encoder_config
(
use_bert_config
,
hidden_size
,
num_hidden_layers
,
vocab_size
=
100
):
"""Returns config args for export_tfhub_lib._create_model()."""
def
_get_bert_config_or_encoder_config
(
use_bert_config
,
hidden_size
,
num_hidden_layers
,
encoder_type
=
"albert"
,
vocab_size
=
100
):
"""Generates config args for export_tfhub_lib._create_model().
Args:
use_bert_config: bool. If True, returns legacy BertConfig.
hidden_size: int.
num_hidden_layers: int.
encoder_type: str. Can be ['albert', 'bert', 'bert_v2']. If use_bert_config
== True, then model_type is not used.
vocab_size: int.
Returns:
bert_config, encoder_config. Only one is not None. If
`use_bert_config` == True, the first config is valid. Otherwise
`bert_config` == None.
"""
if
use_bert_config
:
bert_config
=
configs
.
BertConfig
(
vocab_size
=
vocab_size
,
...
...
@@ -46,17 +63,31 @@ def _get_bert_config_or_encoder_config(use_bert_config, hidden_size,
encoder_config
=
None
else
:
bert_config
=
None
encoder_config
=
encoders
.
EncoderConfig
(
type
=
"albert"
,
albert
=
encoders
.
AlbertEncoderConfig
(
vocab_size
=
vocab_size
,
embedding_width
=
16
,
hidden_size
=
hidden_size
,
intermediate_size
=
32
,
max_position_embeddings
=
128
,
num_attention_heads
=
2
,
num_layers
=
num_hidden_layers
,
dropout_rate
=
0.1
))
if
encoder_type
==
"albert"
:
encoder_config
=
encoders
.
EncoderConfig
(
type
=
"albert"
,
albert
=
encoders
.
AlbertEncoderConfig
(
vocab_size
=
vocab_size
,
embedding_width
=
16
,
hidden_size
=
hidden_size
,
intermediate_size
=
32
,
max_position_embeddings
=
128
,
num_attention_heads
=
2
,
num_layers
=
num_hidden_layers
,
dropout_rate
=
0.1
))
else
:
# encoder_type can be 'bert' or 'bert_v2'.
model_config
=
encoders
.
BertEncoderConfig
(
vocab_size
=
vocab_size
,
embedding_size
=
16
,
hidden_size
=
hidden_size
,
intermediate_size
=
32
,
max_position_embeddings
=
128
,
num_attention_heads
=
2
,
num_layers
=
num_hidden_layers
,
dropout_rate
=
0.1
)
kwargs
=
{
"type"
:
encoder_type
,
encoder_type
:
model_config
}
encoder_config
=
encoders
.
EncoderConfig
(
**
kwargs
)
return
bert_config
,
encoder_config
...
...
@@ -105,13 +136,18 @@ class ExportModelTest(tf.test.TestCase, parameterized.TestCase):
alternative to BertTokenizer).
"""
@
parameterized
.
named_parameters
((
"Bert"
,
True
),
(
"Albert"
,
False
))
def
test_export_model
(
self
,
use_bert
):
@
parameterized
.
named_parameters
(
(
"Bert_Legacy"
,
True
,
None
),
(
"Albert"
,
False
,
"albert"
),
(
"BertEncoder"
,
False
,
"bert"
),
(
"BertEncoderV2"
,
False
,
"bert_v2"
))
def
test_export_model
(
self
,
use_bert
,
encoder_type
):
# Create the encoder and export it.
hidden_size
=
16
num_hidden_layers
=
1
bert_config
,
encoder_config
=
_get_bert_config_or_encoder_config
(
use_bert
,
hidden_size
,
num_hidden_layers
)
use_bert
,
hidden_size
=
hidden_size
,
num_hidden_layers
=
num_hidden_layers
,
encoder_type
=
encoder_type
)
bert_model
,
encoder
=
export_tfhub_lib
.
_create_model
(
bert_config
=
bert_config
,
encoder_config
=
encoder_config
,
with_mlm
=
False
)
self
.
assertEmpty
(
...
...
@@ -151,8 +187,8 @@ class ExportModelTest(tf.test.TestCase, parameterized.TestCase):
_read_asset
(
hub_layer
.
resolved_object
.
sp_model_file
))
# Check restored weights.
self
.
assertEqual
(
len
(
bert_model
.
trainable_weights
),
len
(
hub_layer
.
trainable_weights
))
self
.
assertEqual
(
len
(
bert_model
.
trainable_weights
),
len
(
hub_layer
.
trainable_weights
))
for
source_weight
,
hub_weight
in
zip
(
bert_model
.
trainable_weights
,
hub_layer
.
trainable_weights
):
self
.
assertAllClose
(
source_weight
.
numpy
(),
hub_weight
.
numpy
())
...
...
@@ -334,8 +370,8 @@ class ExportModelWithMLMTest(tf.test.TestCase, parameterized.TestCase):
# Note that we set `_auto_track_sub_layers` to False when exporting the
# SavedModel, so hub_layer has the same number of weights as bert_model;
# otherwise, hub_layer will have extra weights from its `mlm` subobject.
self
.
assertEqual
(
len
(
bert_model
.
trainable_weights
),
len
(
hub_layer
.
trainable_weights
))
self
.
assertEqual
(
len
(
bert_model
.
trainable_weights
),
len
(
hub_layer
.
trainable_weights
))
for
source_weight
,
hub_weight
in
zip
(
bert_model
.
trainable_weights
,
hub_layer
.
trainable_weights
):
self
.
assertAllClose
(
source_weight
,
hub_weight
)
...
...
@@ -473,10 +509,11 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
The absolute filename of the created vocab file.
"""
full_vocab
=
[
"[PAD]"
,
"[UNK]"
,
"[CLS]"
,
"[SEP]"
]
+
[
"[MASK]"
]
*
add_mask_token
+
vocab
]
+
[
"[MASK]"
]
*
add_mask_token
+
vocab
path
=
os
.
path
.
join
(
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
(),
# New subdir each time.
prefix
=
_STRING_NOT_TO_LEAK
),
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
(),
# New subdir each time.
prefix
=
_STRING_NOT_TO_LEAK
),
filename
)
with
tf
.
io
.
gfile
.
GFile
(
path
,
"w"
)
as
f
:
f
.
write
(
"
\n
"
.
join
(
full_vocab
+
[
""
]))
...
...
@@ -522,22 +559,30 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
model_prefix
=
model_prefix
,
model_type
=
"word"
,
input
=
input_file
,
pad_id
=
0
,
unk_id
=
1
,
control_symbols
=
control_symbols
,
pad_id
=
0
,
unk_id
=
1
,
control_symbols
=
control_symbols
,
vocab_size
=
full_vocab_size
,
bos_id
=
full_vocab_size
-
2
,
eos_id
=
full_vocab_size
-
1
)
SentencePieceTrainer
.
Train
(
" "
.
join
([
"--{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
flags
.
items
()]))
bos_id
=
full_vocab_size
-
2
,
eos_id
=
full_vocab_size
-
1
)
SentencePieceTrainer
.
Train
(
" "
.
join
(
[
"--{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
flags
.
items
()]))
return
model_prefix
+
".model"
def
_do_export
(
self
,
vocab
,
do_lower_case
,
default_seq_length
=
128
,
tokenize_with_offsets
=
True
,
use_sp_model
=
False
,
experimental_disable_assert
=
False
,
add_mask_token
=
False
):
def
_do_export
(
self
,
vocab
,
do_lower_case
,
default_seq_length
=
128
,
tokenize_with_offsets
=
True
,
use_sp_model
=
False
,
experimental_disable_assert
=
False
,
add_mask_token
=
False
):
"""Runs SavedModel export and returns the export_path."""
export_path
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
vocab_file
=
sp_model_file
=
None
if
use_sp_model
:
sp_model_file
=
self
.
_make_sp_model_file
(
vocab
,
add_mask_token
=
add_mask_token
)
sp_model_file
=
self
.
_make_sp_model_file
(
vocab
,
add_mask_token
=
add_mask_token
)
else
:
vocab_file
=
self
.
_make_vocab_file
(
vocab
,
add_mask_token
=
add_mask_token
)
export_tfhub_lib
.
export_preprocessing
(
...
...
@@ -554,19 +599,24 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def
test_no_leaks
(
self
):
"""Tests not leaking the path to the original vocab file."""
path
=
self
.
_do_export
(
[
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
use_sp_model
=
False
)
path
=
self
.
_do_export
([
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
use_sp_model
=
False
)
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
path
,
"saved_model.pb"
),
"rb"
)
as
f
:
self
.
assertFalse
(
# pylint: disable=g-generic-assert
_STRING_NOT_TO_LEAK
.
encode
(
"ascii"
)
in
f
.
read
())
@
parameterized
.
named_parameters
((
"Bert"
,
False
),
(
"Sentencepiece"
,
True
))
def
test_exported_callables
(
self
,
use_sp_model
):
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
[
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
tokenize_with_offsets
=
not
use_sp_model
,
# TODO(b/181866850): drop this.
experimental_disable_assert
=
True
,
# TODO(b/175369555): drop this.
use_sp_model
=
use_sp_model
))
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
[
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
# TODO(b/181866850): drop this.
tokenize_with_offsets
=
not
use_sp_model
,
# TODO(b/175369555): drop this.
experimental_disable_assert
=
True
,
use_sp_model
=
use_sp_model
))
def
fold_dim
(
rt
):
"""Removes the word/subword distinction of BertTokenizer."""
...
...
@@ -575,18 +625,20 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
# .tokenize()
inputs
=
tf
.
constant
([
"abc d ef"
,
"ABC D EF d"
])
token_ids
=
preprocess
.
tokenize
(
inputs
)
self
.
assertAllEqual
(
fold_dim
(
token_ids
),
tf
.
ragged
.
constant
([[
6
,
4
,
5
],
[
6
,
4
,
5
,
4
]]))
self
.
assertAllEqual
(
fold_dim
(
token_ids
),
tf
.
ragged
.
constant
([[
6
,
4
,
5
],
[
6
,
4
,
5
,
4
]]))
special_tokens_dict
=
{
k
:
v
.
numpy
().
item
()
# Expecting eager Tensor, converting to Python.
for
k
,
v
in
preprocess
.
tokenize
.
get_special_tokens_dict
().
items
()}
self
.
assertDictEqual
(
special_tokens_dict
,
dict
(
padding_id
=
0
,
start_of_sequence_id
=
2
,
end_of_segment_id
=
3
,
vocab_size
=
4
+
6
if
use_sp_model
else
4
+
4
))
for
k
,
v
in
preprocess
.
tokenize
.
get_special_tokens_dict
().
items
()
}
self
.
assertDictEqual
(
special_tokens_dict
,
dict
(
padding_id
=
0
,
start_of_sequence_id
=
2
,
end_of_segment_id
=
3
,
vocab_size
=
4
+
6
if
use_sp_model
else
4
+
4
))
# .tokenize_with_offsets()
if
use_sp_model
:
...
...
@@ -595,92 +647,104 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
else
:
token_ids
,
start_offsets
,
limit_offsets
=
(
preprocess
.
tokenize_with_offsets
(
inputs
))
self
.
assertAllEqual
(
fold_dim
(
token_ids
),
tf
.
ragged
.
constant
([[
6
,
4
,
5
],
[
6
,
4
,
5
,
4
]]))
self
.
assertAllEqual
(
fold_dim
(
start_offsets
),
tf
.
ragged
.
constant
([[
0
,
4
,
6
],
[
0
,
4
,
6
,
9
]]))
self
.
assertAllEqual
(
fold_dim
(
limit_offsets
),
tf
.
ragged
.
constant
([[
3
,
5
,
8
],
[
3
,
5
,
8
,
10
]]))
self
.
assertAllEqual
(
fold_dim
(
token_ids
),
tf
.
ragged
.
constant
([[
6
,
4
,
5
],
[
6
,
4
,
5
,
4
]]))
self
.
assertAllEqual
(
fold_dim
(
start_offsets
),
tf
.
ragged
.
constant
([[
0
,
4
,
6
],
[
0
,
4
,
6
,
9
]]))
self
.
assertAllEqual
(
fold_dim
(
limit_offsets
),
tf
.
ragged
.
constant
([[
3
,
5
,
8
],
[
3
,
5
,
8
,
10
]]))
self
.
assertIs
(
preprocess
.
tokenize
.
get_special_tokens_dict
,
preprocess
.
tokenize_with_offsets
.
get_special_tokens_dict
)
# Root callable.
bert_inputs
=
preprocess
(
inputs
)
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
].
shape
.
as_list
(),
[
2
,
128
])
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
][:,
:
10
],
tf
.
constant
([[
2
,
6
,
4
,
5
,
3
,
0
,
0
,
0
,
0
,
0
],
[
2
,
6
,
4
,
5
,
4
,
3
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
][:,
:
10
],
tf
.
constant
([[
2
,
6
,
4
,
5
,
3
,
0
,
0
,
0
,
0
,
0
],
[
2
,
6
,
4
,
5
,
4
,
3
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
].
shape
.
as_list
(),
[
2
,
128
])
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
][:,
:
10
],
tf
.
constant
([[
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
][:,
:
10
],
tf
.
constant
([[
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
].
shape
.
as_list
(),
[
2
,
128
])
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
][:,
:
10
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
][:,
:
10
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]]))
# .bert_pack_inputs()
inputs_2
=
tf
.
constant
([
"d xy"
,
"xy abc"
])
token_ids_2
=
preprocess
.
tokenize
(
inputs_2
)
bert_inputs
=
preprocess
.
bert_pack_inputs
(
[
token_ids
,
token_ids_2
],
seq_length
=
256
)
bert_inputs
=
preprocess
.
bert_pack_inputs
(
[
token_ids
,
token_ids_2
],
seq_length
=
256
)
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
].
shape
.
as_list
(),
[
2
,
256
])
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
][:,
:
10
],
tf
.
constant
([[
2
,
6
,
4
,
5
,
3
,
4
,
7
,
3
,
0
,
0
],
[
2
,
6
,
4
,
5
,
4
,
3
,
7
,
6
,
3
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
][:,
:
10
],
tf
.
constant
([[
2
,
6
,
4
,
5
,
3
,
4
,
7
,
3
,
0
,
0
],
[
2
,
6
,
4
,
5
,
4
,
3
,
7
,
6
,
3
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
].
shape
.
as_list
(),
[
2
,
256
])
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
][:,
:
10
],
tf
.
constant
([[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
][:,
:
10
],
tf
.
constant
([[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
].
shape
.
as_list
(),
[
2
,
256
])
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
][:,
:
10
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
][:,
:
10
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
0
]]))
# For BertTokenizer only: repeat relevant parts for do_lower_case=False,
# default_seq_length=10, experimental_disable_assert=False,
# tokenize_with_offsets=False, and without folding the word/subword dimension.
def
test_cased_length10
(
self
):
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
[
"d"
,
"##ef"
,
"abc"
,
"ABC"
],
do_lower_case
=
False
,
default_seq_length
=
10
,
tokenize_with_offsets
=
False
,
use_sp_model
=
False
,
experimental_disable_assert
=
False
))
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
([
"d"
,
"##ef"
,
"abc"
,
"ABC"
],
do_lower_case
=
False
,
default_seq_length
=
10
,
tokenize_with_offsets
=
False
,
use_sp_model
=
False
,
experimental_disable_assert
=
False
))
inputs
=
tf
.
constant
([
"abc def"
,
"ABC DEF"
])
token_ids
=
preprocess
.
tokenize
(
inputs
)
self
.
assertAllEqual
(
token_ids
,
tf
.
ragged
.
constant
([[[
6
],
[
4
,
5
]],
[[
7
],
[
1
]]]))
self
.
assertAllEqual
(
token_ids
,
tf
.
ragged
.
constant
([[[
6
],
[
4
,
5
]],
[[
7
],
[
1
]]]))
self
.
assertFalse
(
hasattr
(
preprocess
,
"tokenize_with_offsets"
))
bert_inputs
=
preprocess
(
inputs
)
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
],
tf
.
constant
([[
2
,
6
,
4
,
5
,
3
,
0
,
0
,
0
,
0
,
0
],
[
2
,
7
,
1
,
3
,
0
,
0
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
],
tf
.
constant
([[
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
],
tf
.
constant
([[
2
,
6
,
4
,
5
,
3
,
0
,
0
,
0
,
0
,
0
],
[
2
,
7
,
1
,
3
,
0
,
0
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
],
tf
.
constant
([[
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]]))
inputs_2
=
tf
.
constant
([
"d ABC"
,
"ABC abc"
])
token_ids_2
=
preprocess
.
tokenize
(
inputs_2
)
bert_inputs
=
preprocess
.
bert_pack_inputs
([
token_ids
,
token_ids_2
])
# Test default seq_length=10.
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
],
tf
.
constant
([[
2
,
6
,
4
,
5
,
3
,
4
,
7
,
3
,
0
,
0
],
[
2
,
7
,
1
,
3
,
7
,
6
,
3
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
],
tf
.
constant
([[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
],
tf
.
constant
([[
2
,
6
,
4
,
5
,
3
,
4
,
7
,
3
,
0
,
0
],
[
2
,
7
,
1
,
3
,
7
,
6
,
3
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
],
tf
.
constant
([[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
]]))
# XLA requires fixed shapes for tensors found in graph mode.
# Statically known shapes in Python are a particularly firm way to
...
...
@@ -689,16 +753,21 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
# inference when applied to fully or partially known input shapes.
@
parameterized
.
named_parameters
((
"Bert"
,
False
),
(
"Sentencepiece"
,
True
))
def
test_shapes
(
self
,
use_sp_model
):
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
[
"abc"
,
"def"
],
do_lower_case
=
True
,
tokenize_with_offsets
=
not
use_sp_model
,
# TODO(b/181866850): drop this.
experimental_disable_assert
=
True
,
# TODO(b/175369555): drop this.
use_sp_model
=
use_sp_model
))
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
[
"abc"
,
"def"
],
do_lower_case
=
True
,
# TODO(b/181866850): drop this.
tokenize_with_offsets
=
not
use_sp_model
,
# TODO(b/175369555): drop this.
experimental_disable_assert
=
True
,
use_sp_model
=
use_sp_model
))
def
expected_bert_input_shapes
(
batch_size
,
seq_length
):
return
dict
(
input_word_ids
=
[
batch_size
,
seq_length
],
input_mask
=
[
batch_size
,
seq_length
],
input_type_ids
=
[
batch_size
,
seq_length
])
return
dict
(
input_word_ids
=
[
batch_size
,
seq_length
],
input_mask
=
[
batch_size
,
seq_length
],
input_type_ids
=
[
batch_size
,
seq_length
])
for
batch_size
in
[
7
,
None
]:
if
use_sp_model
:
...
...
@@ -706,11 +775,9 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
else
:
token_out_shape
=
[
batch_size
,
None
,
None
]
self
.
assertEqual
(
_result_shapes_in_tf_function
(
preprocess
.
tokenize
,
tf
.
TensorSpec
([
batch_size
],
tf
.
string
)),
token_out_shape
,
"with batch_size=%s"
%
batch_size
)
_result_shapes_in_tf_function
(
preprocess
.
tokenize
,
tf
.
TensorSpec
([
batch_size
],
tf
.
string
)),
token_out_shape
,
"with batch_size=%s"
%
batch_size
)
# TODO(b/181866850): Enable tokenize_with_offsets when it works and test.
if
use_sp_model
:
self
.
assertFalse
(
hasattr
(
preprocess
,
"tokenize_with_offsets"
))
...
...
@@ -718,8 +785,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
(
_result_shapes_in_tf_function
(
preprocess
.
tokenize_with_offsets
,
tf
.
TensorSpec
([
batch_size
],
tf
.
string
)),
[
token_out_shape
]
*
3
,
tf
.
TensorSpec
([
batch_size
],
tf
.
string
)),
[
token_out_shape
]
*
3
,
"with batch_size=%s"
%
batch_size
)
self
.
assertEqual
(
_result_shapes_in_tf_function
(
...
...
@@ -737,7 +803,9 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def
test_reexport
(
self
,
use_sp_model
):
"""Test that preprocess keeps working after another save/load cycle."""
path1
=
self
.
_do_export
(
[
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
default_seq_length
=
10
,
[
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
default_seq_length
=
10
,
tokenize_with_offsets
=
False
,
experimental_disable_assert
=
True
,
# TODO(b/175369555): drop this.
use_sp_model
=
use_sp_model
)
...
...
@@ -752,35 +820,46 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
inputs
=
tf
.
constant
([
"abc d ef"
,
"ABC D EF d"
])
bert_inputs
=
model2
(
inputs
)
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
],
tf
.
constant
([[
2
,
6
,
4
,
5
,
3
,
0
,
0
,
0
,
0
,
0
],
[
2
,
6
,
4
,
5
,
4
,
3
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
],
tf
.
constant
([[
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_word_ids"
],
tf
.
constant
([[
2
,
6
,
4
,
5
,
3
,
0
,
0
,
0
,
0
,
0
],
[
2
,
6
,
4
,
5
,
4
,
3
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_mask"
],
tf
.
constant
([[
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
]]))
self
.
assertAllEqual
(
bert_inputs
[
"input_type_ids"
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]]))
@
parameterized
.
named_parameters
((
"Bert"
,
True
),
(
"Albert"
,
False
))
def
test_preprocessing_for_mlm
(
self
,
use_bert
):
"""Combines both SavedModel types and TF.text helpers for MLM."""
# Create the preprocessing SavedModel with a [MASK] token.
non_special_tokens
=
[
"hello"
,
"world"
,
"nice"
,
"movie"
,
"great"
,
"actors"
,
"quick"
,
"fox"
,
"lazy"
,
"dog"
]
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
non_special_tokens
,
do_lower_case
=
True
,
tokenize_with_offsets
=
use_bert
,
# TODO(b/181866850): drop this.
experimental_disable_assert
=
True
,
# TODO(b/175369555): drop this.
add_mask_token
=
True
,
use_sp_model
=
not
use_bert
))
non_special_tokens
=
[
"hello"
,
"world"
,
"nice"
,
"movie"
,
"great"
,
"actors"
,
"quick"
,
"fox"
,
"lazy"
,
"dog"
]
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
non_special_tokens
,
do_lower_case
=
True
,
tokenize_with_offsets
=
use_bert
,
# TODO(b/181866850): drop this.
experimental_disable_assert
=
True
,
# TODO(b/175369555): drop this.
add_mask_token
=
True
,
use_sp_model
=
not
use_bert
))
vocab_size
=
len
(
non_special_tokens
)
+
(
5
if
use_bert
else
7
)
# Create the encoder SavedModel with an .mlm subobject.
hidden_size
=
16
num_hidden_layers
=
2
bert_config
,
encoder_config
=
_get_bert_config_or_encoder_config
(
use_bert
,
hidden_size
,
num_hidden_layers
,
vocab_size
)
use_bert_config
=
use_bert
,
hidden_size
=
hidden_size
,
num_hidden_layers
=
num_hidden_layers
,
vocab_size
=
vocab_size
)
_
,
pretrainer
=
export_tfhub_lib
.
_create_model
(
bert_config
=
bert_config
,
encoder_config
=
encoder_config
,
with_mlm
=
True
)
model_checkpoint_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"checkpoint"
)
...
...
@@ -814,8 +893,10 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
(
mask_id
,
4
)
# A batch of 3 segment pairs.
raw_segments
=
[
tf
.
constant
([
"hello"
,
"nice movie"
,
"quick fox"
]),
tf
.
constant
([
"world"
,
"great actors"
,
"lazy dog"
])]
raw_segments
=
[
tf
.
constant
([
"hello"
,
"nice movie"
,
"quick fox"
]),
tf
.
constant
([
"world"
,
"great actors"
,
"lazy dog"
])
]
batch_size
=
3
# Misc hyperparameters.
...
...
@@ -842,18 +923,18 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
selection_rate
=
0.5
,
# Adjusted for the short test examples.
unselectable_ids
=
[
start_of_sequence_id
,
end_of_segment_id
]),
mask_values_chooser
=
text
.
MaskValuesChooser
(
vocab_size
=
vocab_size
,
mask_token
=
mask_id
,
vocab_size
=
vocab_size
,
mask_token
=
mask_id
,
# Always put [MASK] to have a predictable result.
mask_token_rate
=
1.0
,
random_token_rate
=
0.0
))
mask_token_rate
=
1.0
,
random_token_rate
=
0.0
))
# Pad to fixed-length Transformer encoder inputs.
input_word_ids
,
_
=
text
.
pad_model_inputs
(
masked_input_ids
,
seq_length
,
pad_value
=
padding_id
)
input_type_ids
,
input_mask
=
text
.
pad_model_inputs
(
segment_ids
,
seq_length
,
pad_value
=
0
)
masked_lm_positions
,
_
=
text
.
pad_model_inputs
(
masked_lm_positions
,
max_selections_per_seq
,
pad_value
=
0
)
input_word_ids
,
_
=
text
.
pad_model_inputs
(
masked_input_ids
,
seq_length
,
pad_value
=
padding_id
)
input_type_ids
,
input_mask
=
text
.
pad_model_inputs
(
segment_ids
,
seq_length
,
pad_value
=
0
)
masked_lm_positions
,
_
=
text
.
pad_model_inputs
(
masked_lm_positions
,
max_selections_per_seq
,
pad_value
=
0
)
masked_lm_positions
=
tf
.
cast
(
masked_lm_positions
,
tf
.
int32
)
num_predictions
=
int
(
tf
.
shape
(
masked_lm_positions
)[
1
])
...
...
@@ -865,7 +946,8 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
# [CLS] nice movie [SEP] great actors [SEP]
[
2
,
7
,
8
,
3
,
9
,
10
,
3
,
0
,
0
,
0
],
# [CLS] brown fox [SEP] lazy dog [SEP]
[
2
,
11
,
12
,
3
,
13
,
14
,
3
,
0
,
0
,
0
]])
[
2
,
11
,
12
,
3
,
13
,
14
,
3
,
0
,
0
,
0
]
])
for
i
in
range
(
batch_size
):
for
j
in
range
(
num_predictions
):
k
=
int
(
masked_lm_positions
[
i
,
j
])
...
...
@@ -896,15 +978,17 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
@
parameterized
.
named_parameters
((
"Bert"
,
False
),
(
"Sentencepiece"
,
True
))
def
test_special_tokens_in_estimator
(
self
,
use_sp_model
):
"""Tests getting special tokens without an Eager init context."""
preprocess_export_path
=
self
.
_do_export
(
[
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
use_sp_model
=
use_sp_model
,
tokenize_with_offsets
=
False
)
preprocess_export_path
=
self
.
_do_export
([
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
use_sp_model
=
use_sp_model
,
tokenize_with_offsets
=
False
)
def
_get_special_tokens_dict
(
obj
):
"""Returns special tokens of restored tokenizer as Python values."""
if
tf
.
executing_eagerly
():
special_tokens_numpy
=
{
k
:
v
.
numpy
()
for
k
,
v
in
obj
.
get_special_tokens_dict
()}
special_tokens_numpy
=
{
k
:
v
.
numpy
()
for
k
,
v
in
obj
.
get_special_tokens_dict
()
}
else
:
with
tf
.
Graph
().
as_default
():
# This code expects `get_special_tokens_dict()` to be a tf.function
...
...
@@ -913,8 +997,10 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
special_tokens_tensors
=
obj
.
get_special_tokens_dict
()
with
tf
.
compat
.
v1
.
Session
()
as
sess
:
special_tokens_numpy
=
sess
.
run
(
special_tokens_tensors
)
return
{
k
:
v
.
item
()
# Numpy to Python.
for
k
,
v
in
special_tokens_numpy
.
items
()}
return
{
k
:
v
.
item
()
# Numpy to Python.
for
k
,
v
in
special_tokens_numpy
.
items
()
}
def
input_fn
():
self
.
assertFalse
(
tf
.
executing_eagerly
())
...
...
@@ -927,7 +1013,8 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertIsInstance
(
v
,
int
,
"Unexpected type for {}"
.
format
(
k
))
tokens
=
tokenize
(
sentences
)
packed_inputs
=
layers
.
BertPackInputs
(
4
,
special_tokens_dict
=
special_tokens_dict
)(
tokens
)
4
,
special_tokens_dict
=
special_tokens_dict
)(
tokens
)
preprocessing
=
tf
.
keras
.
Model
(
sentences
,
packed_inputs
)
# Map the dataset.
ds
=
tf
.
data
.
Dataset
.
from_tensors
(
...
...
@@ -937,22 +1024,22 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def
model_fn
(
features
,
labels
,
mode
):
del
labels
# Unused.
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
features
[
"input_word_ids"
])
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
features
[
"input_word_ids"
])
estimator
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_fn
)
outputs
=
list
(
estimator
.
predict
(
input_fn
))
self
.
assertAllEqual
(
outputs
,
np
.
array
([[
2
,
6
,
3
,
0
],
[
2
,
4
,
5
,
3
]]))
self
.
assertAllEqual
(
outputs
,
np
.
array
([[
2
,
6
,
3
,
0
],
[
2
,
4
,
5
,
3
]]))
# TODO(b/175369555): Remove that code and its test.
@
parameterized
.
named_parameters
((
"Bert"
,
False
),
(
"Sentencepiece"
,
True
))
def
test_check_no_assert
(
self
,
use_sp_model
):
"""Tests the self-check during export without assertions."""
preprocess_export_path
=
self
.
_do_export
(
[
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
use_sp_model
=
use_sp_model
,
tokenize_with_offsets
=
False
,
experimental_disable_assert
=
False
)
preprocess_export_path
=
self
.
_do_export
([
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
use_sp_model
=
use_sp_model
,
tokenize_with_offsets
=
False
,
experimental_disable_assert
=
False
)
with
self
.
assertRaisesRegex
(
AssertionError
,
r
"failed to suppress \d+ Assert ops"
):
export_tfhub_lib
.
_check_no_assert
(
preprocess_export_path
)
...
...
@@ -963,8 +1050,8 @@ def _result_shapes_in_tf_function(fn, *args, **kwargs):
Args:
fn: A callable.
*args: TensorSpecs for Tensor-valued arguments and actual values
for
Python-valued arguments to fn.
*args: TensorSpecs for Tensor-valued arguments and actual values
for
Python-valued arguments to fn.
**kwargs: Same for keyword arguments.
Returns:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment