Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3635527d
Commit
3635527d
authored
Nov 21, 2019
by
Chen Chen
Committed by
A. Unique TensorFlower
Nov 21, 2019
Browse files
Remove old tf2 BERT for squad
PiperOrigin-RevId: 281714406
parent
a2a1b66f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
39 deletions
+16
-39
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+8
-10
official/nlp/bert_models.py
official/nlp/bert_models.py
+8
-29
No files found.
official/nlp/bert/run_squad.py
View file @
3635527d
...
@@ -81,9 +81,7 @@ flags.DEFINE_integer(
...
@@ -81,9 +81,7 @@ flags.DEFINE_integer(
'The maximum length of an answer that can be generated. This is needed '
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.'
)
'because the start and end predictions are not conditioned on one another.'
)
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
'use_keras_bert_for_squad'
,
False
,
'Whether to use keras BERT for squad '
'use_keras_bert_for_squad'
,
True
,
'Deprecated and will be removed soon.'
)
'task. Note that when the FLAG "hub_module_url" is specified, '
'"use_keras_bert_for_squad" cannot be True.'
)
common_flags
.
define_common_bert_flags
()
common_flags
.
define_common_bert_flags
()
...
@@ -173,8 +171,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
...
@@ -173,8 +171,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
squad_model
,
_
=
bert_models
.
squad_model
(
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
bert_config
,
input_meta_data
[
'max_seq_length'
],
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
,
float_type
=
tf
.
float32
)
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
logging
.
info
(
'Restoring checkpoints from %s'
,
checkpoint_path
)
logging
.
info
(
'Restoring checkpoints from %s'
,
checkpoint_path
)
...
@@ -242,9 +239,7 @@ def train_squad(strategy,
...
@@ -242,9 +239,7 @@ def train_squad(strategy,
bert_config
,
bert_config
,
max_seq_length
,
max_seq_length
,
float_type
=
tf
.
float16
if
use_float16
else
tf
.
float32
,
float_type
=
tf
.
float16
if
use_float16
else
tf
.
float32
,
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_url
=
FLAGS
.
hub_module_url
)
use_keras_bert
=
False
if
FLAGS
.
hub_module_url
else
FLAGS
.
use_keras_bert_for_squad
)
squad_model
.
optimizer
=
optimization
.
create_optimizer
(
squad_model
.
optimizer
=
optimization
.
create_optimizer
(
FLAGS
.
learning_rate
,
steps_per_epoch
*
epochs
,
warmup_steps
)
FLAGS
.
learning_rate
,
steps_per_epoch
*
epochs
,
warmup_steps
)
if
use_float16
:
if
use_float16
:
...
@@ -370,8 +365,7 @@ def export_squad(model_export_path, input_meta_data):
...
@@ -370,8 +365,7 @@ def export_squad(model_export_path, input_meta_data):
squad_model
,
_
=
bert_models
.
squad_model
(
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
bert_config
,
input_meta_data
[
'max_seq_length'
],
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
,
float_type
=
tf
.
float32
)
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
squad_model
,
checkpoint_dir
=
FLAGS
.
model_dir
)
model_export_path
,
model
=
squad_model
,
checkpoint_dir
=
FLAGS
.
model_dir
)
...
@@ -380,6 +374,10 @@ def main(_):
...
@@ -380,6 +374,10 @@ def main(_):
# Users should always run this script under TF 2.x
# Users should always run this script under TF 2.x
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
if
not
FLAGS
.
use_keras_bert_for_squad
:
raise
ValueError
(
'Old tf2 BERT is no longer supported. Please use keras BERT.'
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
...
...
official/nlp/bert_models.py
View file @
3635527d
...
@@ -366,8 +366,7 @@ def squad_model(bert_config,
...
@@ -366,8 +366,7 @@ def squad_model(bert_config,
max_seq_length
,
max_seq_length
,
float_type
,
float_type
,
initializer
=
None
,
initializer
=
None
,
hub_module_url
=
None
,
hub_module_url
=
None
):
use_keras_bert
=
False
):
"""Returns BERT Squad model along with core BERT model to import weights.
"""Returns BERT Squad model along with core BERT model to import weights.
Args:
Args:
...
@@ -377,23 +376,15 @@ def squad_model(bert_config,
...
@@ -377,23 +376,15 @@ def squad_model(bert_config,
initializer: Initializer for the final dense layer in the span labeler.
initializer: Initializer for the final dense layer in the span labeler.
Defaulted to TruncatedNormal initializer.
Defaulted to TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module.
hub_module_url: TF-Hub path/url to Bert module.
use_keras_bert: Whether to use keras BERT. Note that when the above
'hub_module_url' is specified, 'use_keras_bert' cannot be True.
Returns:
Returns:
A tuple of (1) keras model that outputs start logits and end logits and
A tuple of (1) keras model that outputs start logits and end logits and
(2) the core BERT transformer encoder.
(2) the core BERT transformer encoder.
Raises:
ValueError: When 'hub_module_url' is specified and 'use_keras_bert' is True.
"""
"""
if
hub_module_url
and
use_keras_bert
:
raise
ValueError
(
'Cannot use hub_module_url and keras BERT at the same time.'
)
if
initializer
is
None
:
if
initializer
is
None
:
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
)
stddev
=
bert_config
.
initializer_range
)
if
use_keras_bert
:
if
not
hub_module_url
:
bert_encoder
=
_get_transformer_encoder
(
bert_config
,
max_seq_length
,
bert_encoder
=
_get_transformer_encoder
(
bert_config
,
max_seq_length
,
float_type
)
float_type
)
return
bert_span_labeler
.
BertSpanLabeler
(
return
bert_span_labeler
.
BertSpanLabeler
(
...
@@ -405,24 +396,12 @@ def squad_model(bert_config,
...
@@ -405,24 +396,12 @@ def squad_model(bert_config,
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
if
hub_module_url
:
core_model
=
hub
.
KerasLayer
(
hub_module_url
,
trainable
=
True
)
core_model
=
hub
.
KerasLayer
(
hub_module_url
,
trainable
=
True
)
_
,
sequence_output
=
core_model
(
_
,
sequence_output
=
core_model
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
[
input_word_ids
,
input_mask
,
input_type_ids
])
# Sets the shape manually due to a bug in TF shape inference.
# Sets the shape manually due to a bug in TF shape inference.
# TODO(hongkuny): remove this once shape inference is correct.
# TODO(hongkuny): remove this once shape inference is correct.
sequence_output
.
set_shape
((
None
,
max_seq_length
,
bert_config
.
hidden_size
))
sequence_output
.
set_shape
((
None
,
max_seq_length
,
bert_config
.
hidden_size
))
else
:
core_model
=
modeling
.
get_bert_model
(
input_word_ids
,
input_mask
,
input_type_ids
,
config
=
bert_config
,
name
=
'bert_model'
,
float_type
=
float_type
)
# `BertSquadModel` only uses the sequnce_output which
# has dimensionality (batch_size, sequence_length, num_hidden).
sequence_output
=
core_model
.
outputs
[
1
]
squad_logits_layer
=
BertSquadLogitsLayer
(
squad_logits_layer
=
BertSquadLogitsLayer
(
initializer
=
initializer
,
float_type
=
float_type
,
name
=
'squad_logits'
)
initializer
=
initializer
,
float_type
=
float_type
,
name
=
'squad_logits'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment