Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
553a4f41
Commit
553a4f41
authored
Dec 20, 2019
by
Chen Chen
Committed by
A. Unique TensorFlower
Dec 20, 2019
Browse files
Internal change
PiperOrigin-RevId: 286634090
parent
5f7bdb11
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
23 deletions
+32
-23
official/nlp/bert/export_tfhub.py
official/nlp/bert/export_tfhub.py
+15
-13
official/nlp/bert/export_tfhub_test.py
official/nlp/bert/export_tfhub_test.py
+10
-3
official/nlp/bert_models.py
official/nlp/bert_models.py
+7
-7
No files found.
official/nlp/bert/export_tfhub.py
View file @
553a4f41
...
...
@@ -24,6 +24,7 @@ import tensorflow as tf
from
typing
import
Text
from
official.nlp
import
bert_modeling
from
official.nlp
import
bert_models
FLAGS
=
flags
.
FLAGS
...
...
@@ -31,8 +32,7 @@ flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers."
)
flags
.
DEFINE_string
(
"model_checkpoint_path"
,
None
,
"File path to TF model checkpoint."
)
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path."
)
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path."
)
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
"The vocabulary file that the BERT model was trained on."
)
...
...
@@ -53,21 +53,23 @@ def create_bert_model(bert_config: bert_modeling.BertConfig):
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
"input_mask"
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
)
return
bert_modeling
.
get_bert_model
(
input_word_ids
,
input_mask
,
input_type_ids
,
config
=
bert_config
,
name
=
"bert_model"
,
float_type
=
tf
.
float32
)
transformer_encoder
=
bert_models
.
get_transformer_encoder
(
bert_config
,
sequence_length
=
None
,
float_dtype
=
tf
.
float32
)
sequence_output
,
pooled_output
=
transformer_encoder
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
# To keep consistent with legacy hub modules, the outputs are
# "pooled_output" and "sequence_output".
return
tf
.
keras
.
Model
(
inputs
=
[
input_word_ids
,
input_mask
,
input_type_ids
],
outputs
=
[
pooled_output
,
sequence_output
]),
transformer_encoder
def
export_bert_tfhub
(
bert_config
:
bert_modeling
.
BertConfig
,
model_checkpoint_path
:
Text
,
hub_destination
:
Text
,
vocab_file
:
Text
):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model
=
create_bert_model
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
core_m
ode
l
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
enc
ode
r
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_consumed
()
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
do_lower_case
=
tf
.
Variable
(
...
...
@@ -79,8 +81,8 @@ def main(_):
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
bert_config
=
bert_modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
export_bert_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
FLAGS
.
vocab_file
)
export_bert_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
FLAGS
.
vocab_file
)
if
__name__
==
"__main__"
:
...
...
official/nlp/bert/export_tfhub_test.py
View file @
553a4f41
...
...
@@ -39,9 +39,9 @@ class ExportTfhubTest(tf.test.TestCase):
max_position_embeddings
=
128
,
num_attention_heads
=
2
,
num_hidden_layers
=
1
)
bert_model
=
export_tfhub
.
create_bert_model
(
bert_config
)
bert_model
,
encoder
=
export_tfhub
.
create_bert_model
(
bert_config
)
model_checkpoint_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"checkpoint"
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
bert_m
ode
l
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
enc
ode
r
)
checkpoint
.
save
(
os
.
path
.
join
(
model_checkpoint_dir
,
"test"
))
model_checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
model_checkpoint_dir
)
...
...
@@ -70,10 +70,17 @@ class ExportTfhubTest(tf.test.TestCase):
dummy_ids
=
np
.
zeros
((
2
,
10
),
dtype
=
np
.
int32
)
hub_outputs
=
hub_layer
([
dummy_ids
,
dummy_ids
,
dummy_ids
])
source_outputs
=
bert_model
([
dummy_ids
,
dummy_ids
,
dummy_ids
])
# The outputs of hub module are "pooled_output" and "sequence_output",
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs
=
reversed
(
encoder
([
dummy_ids
,
dummy_ids
,
dummy_ids
]))
self
.
assertEqual
(
hub_outputs
[
0
].
shape
,
(
2
,
16
))
self
.
assertEqual
(
hub_outputs
[
1
].
shape
,
(
2
,
10
,
16
))
for
source_output
,
hub_output
in
zip
(
source_outputs
,
hub_outputs
):
for
source_output
,
hub_output
,
encoder_output
in
zip
(
source_outputs
,
hub_outputs
,
encoder_outputs
):
self
.
assertAllClose
(
source_output
.
numpy
(),
hub_output
.
numpy
())
self
.
assertAllClose
(
source_output
.
numpy
(),
encoder_output
.
numpy
())
if
__name__
==
"__main__"
:
...
...
official/nlp/bert_models.py
View file @
553a4f41
...
...
@@ -134,9 +134,9 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return
final_loss
def
_
get_transformer_encoder
(
bert_config
,
sequence_length
,
float_dtype
=
tf
.
float32
):
def
get_transformer_encoder
(
bert_config
,
sequence_length
,
float_dtype
=
tf
.
float32
):
"""Gets a 'TransformerEncoder' object.
Args:
...
...
@@ -206,7 +206,7 @@ def pretrain_model(bert_config,
next_sentence_labels
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
1
,),
name
=
'next_sentence_labels'
,
dtype
=
tf
.
int32
)
transformer_encoder
=
_
get_transformer_encoder
(
bert_config
,
seq_length
)
transformer_encoder
=
get_transformer_encoder
(
bert_config
,
seq_length
)
if
initializer
is
None
:
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
)
...
...
@@ -294,8 +294,8 @@ def squad_model(bert_config,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
)
if
not
hub_module_url
:
bert_encoder
=
_
get_transformer_encoder
(
bert_config
,
max_seq_length
,
float_type
)
bert_encoder
=
get_transformer_encoder
(
bert_config
,
max_seq_length
,
float_type
)
return
bert_span_labeler
.
BertSpanLabeler
(
network
=
bert_encoder
,
initializer
=
initializer
),
bert_encoder
...
...
@@ -359,7 +359,7 @@ def classifier_model(bert_config,
stddev
=
bert_config
.
initializer_range
)
if
not
hub_module_url
:
bert_encoder
=
_
get_transformer_encoder
(
bert_config
,
max_seq_length
)
bert_encoder
=
get_transformer_encoder
(
bert_config
,
max_seq_length
)
return
bert_classifier
.
BertClassifier
(
bert_encoder
,
num_classes
=
num_labels
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment