Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b3377b09
Commit
b3377b09
authored
Apr 20, 2020
by
A. Unique TensorFlower
Browse files
In BERT's export to TF Hub, fix shape propagation for seq_length.
PiperOrigin-RevId: 307425903
parent
0b0ca66b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
10 deletions
+21
-10
official/nlp/bert/export_tfhub_test.py
official/nlp/bert/export_tfhub_test.py
+16
-4
official/nlp/modeling/layers/on_device_embedding.py
official/nlp/modeling/layers/on_device_embedding.py
+5
-6
No files found.
official/nlp/bert/export_tfhub_test.py
View file @
b3377b09
...
@@ -32,9 +32,10 @@ class ExportTfhubTest(tf.test.TestCase):
...
@@ -32,9 +32,10 @@ class ExportTfhubTest(tf.test.TestCase):
def
test_export_tfhub
(
self
):
def
test_export_tfhub
(
self
):
# Exports a savedmodel for TF-Hub
# Exports a savedmodel for TF-Hub
hidden_size
=
16
bert_config
=
configs
.
BertConfig
(
bert_config
=
configs
.
BertConfig
(
vocab_size
=
100
,
vocab_size
=
100
,
hidden_size
=
16
,
hidden_size
=
hidden_size
,
intermediate_size
=
32
,
intermediate_size
=
32
,
max_position_embeddings
=
128
,
max_position_embeddings
=
128
,
num_attention_heads
=
2
,
num_attention_heads
=
2
,
...
@@ -67,7 +68,8 @@ class ExportTfhubTest(tf.test.TestCase):
...
@@ -67,7 +68,8 @@ class ExportTfhubTest(tf.test.TestCase):
hub_layer
.
trainable_weights
):
hub_layer
.
trainable_weights
):
self
.
assertAllClose
(
source_weight
.
numpy
(),
hub_weight
.
numpy
())
self
.
assertAllClose
(
source_weight
.
numpy
(),
hub_weight
.
numpy
())
dummy_ids
=
np
.
zeros
((
2
,
10
),
dtype
=
np
.
int32
)
seq_length
=
10
dummy_ids
=
np
.
zeros
((
2
,
seq_length
),
dtype
=
np
.
int32
)
hub_outputs
=
hub_layer
([
dummy_ids
,
dummy_ids
,
dummy_ids
])
hub_outputs
=
hub_layer
([
dummy_ids
,
dummy_ids
,
dummy_ids
])
source_outputs
=
bert_model
([
dummy_ids
,
dummy_ids
,
dummy_ids
])
source_outputs
=
bert_model
([
dummy_ids
,
dummy_ids
,
dummy_ids
])
...
@@ -75,13 +77,23 @@ class ExportTfhubTest(tf.test.TestCase):
...
@@ -75,13 +77,23 @@ class ExportTfhubTest(tf.test.TestCase):
# while the outputs of encoder is in reversed order, i.e.,
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
# "sequence_output" and "pooled_output".
encoder_outputs
=
reversed
(
encoder
([
dummy_ids
,
dummy_ids
,
dummy_ids
]))
encoder_outputs
=
reversed
(
encoder
([
dummy_ids
,
dummy_ids
,
dummy_ids
]))
self
.
assertEqual
(
hub_outputs
[
0
].
shape
,
(
2
,
16
))
self
.
assertEqual
(
hub_outputs
[
0
].
shape
,
(
2
,
hidden_size
))
self
.
assertEqual
(
hub_outputs
[
1
].
shape
,
(
2
,
10
,
16
))
self
.
assertEqual
(
hub_outputs
[
1
].
shape
,
(
2
,
seq_length
,
hidden_size
))
for
source_output
,
hub_output
,
encoder_output
in
zip
(
for
source_output
,
hub_output
,
encoder_output
in
zip
(
source_outputs
,
hub_outputs
,
encoder_outputs
):
source_outputs
,
hub_outputs
,
encoder_outputs
):
self
.
assertAllClose
(
source_output
.
numpy
(),
hub_output
.
numpy
())
self
.
assertAllClose
(
source_output
.
numpy
(),
hub_output
.
numpy
())
self
.
assertAllClose
(
source_output
.
numpy
(),
encoder_output
.
numpy
())
self
.
assertAllClose
(
source_output
.
numpy
(),
encoder_output
.
numpy
())
# Test propagation of seq_length in shape inference.
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
)
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
)
pooled_output
,
sequence_output
=
hub_layer
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
self
.
assertEqual
(
pooled_output
.
shape
.
as_list
(),
[
None
,
hidden_size
])
self
.
assertEqual
(
sequence_output
.
shape
.
as_list
(),
[
None
,
seq_length
,
hidden_size
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/nlp/modeling/layers/on_device_embedding.py
View file @
b3377b09
...
@@ -21,8 +21,6 @@ from __future__ import print_function
...
@@ -21,8 +21,6 @@ from __future__ import print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
OnDeviceEmbedding
(
tf
.
keras
.
layers
.
Layer
):
class
OnDeviceEmbedding
(
tf
.
keras
.
layers
.
Layer
):
...
@@ -78,8 +76,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
...
@@ -78,8 +76,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
super
(
OnDeviceEmbedding
,
self
).
build
(
input_shape
)
super
(
OnDeviceEmbedding
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
input_shape
=
tf_utils
.
get_shape_list
(
inputs
,
expected_rank
=
2
)
input_shape
.
append
(
self
.
_embedding_width
)
flat_inputs
=
tf
.
reshape
(
inputs
,
[
-
1
])
flat_inputs
=
tf
.
reshape
(
inputs
,
[
-
1
])
if
self
.
_use_one_hot
:
if
self
.
_use_one_hot
:
one_hot_data
=
tf
.
one_hot
(
one_hot_data
=
tf
.
one_hot
(
...
@@ -87,6 +83,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
...
@@ -87,6 +83,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embeddings
=
tf
.
matmul
(
one_hot_data
,
self
.
embeddings
)
embeddings
=
tf
.
matmul
(
one_hot_data
,
self
.
embeddings
)
else
:
else
:
embeddings
=
tf
.
gather
(
self
.
embeddings
,
flat_inputs
)
embeddings
=
tf
.
gather
(
self
.
embeddings
,
flat_inputs
)
embeddings
=
tf
.
reshape
(
embeddings
,
input_shape
)
embeddings
=
tf
.
reshape
(
embeddings
,
# Work around b/142213824: prefer concat to shape over a Python list.
tf
.
concat
([
tf
.
shape
(
inputs
),
[
self
.
_embedding_width
]],
axis
=
0
))
embeddings
.
set_shape
(
inputs
.
shape
.
as_list
()
+
[
self
.
_embedding_width
])
return
embeddings
return
embeddings
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment