Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3635527d
"llm/ext_server/ext_server.cpp" did not exist on "0498f7ce56686bd44a8f92954daebe02352cdf82"
Commit
3635527d
authored
Nov 21, 2019
by
Chen Chen
Committed by
A. Unique TensorFlower
Nov 21, 2019
Browse files
Remove old tf2 BERT for squad
PiperOrigin-RevId: 281714406
parent
a2a1b66f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
39 deletions
+16
-39
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+8
-10
official/nlp/bert_models.py
official/nlp/bert_models.py
+8
-29
No files found.
official/nlp/bert/run_squad.py
View file @
3635527d
...
...
@@ -81,9 +81,7 @@ flags.DEFINE_integer(
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.'
)
flags
.
DEFINE_bool
(
'use_keras_bert_for_squad'
,
False
,
'Whether to use keras BERT for squad '
'task. Note that when the FLAG "hub_module_url" is specified, '
'"use_keras_bert_for_squad" cannot be True.'
)
'use_keras_bert_for_squad'
,
True
,
'Deprecated and will be removed soon.'
)
common_flags
.
define_common_bert_flags
()
...
...
@@ -173,8 +171,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
,
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
float_type
=
tf
.
float32
)
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
logging
.
info
(
'Restoring checkpoints from %s'
,
checkpoint_path
)
...
...
@@ -242,9 +239,7 @@ def train_squad(strategy,
bert_config
,
max_seq_length
,
float_type
=
tf
.
float16
if
use_float16
else
tf
.
float32
,
hub_module_url
=
FLAGS
.
hub_module_url
,
use_keras_bert
=
False
if
FLAGS
.
hub_module_url
else
FLAGS
.
use_keras_bert_for_squad
)
hub_module_url
=
FLAGS
.
hub_module_url
)
squad_model
.
optimizer
=
optimization
.
create_optimizer
(
FLAGS
.
learning_rate
,
steps_per_epoch
*
epochs
,
warmup_steps
)
if
use_float16
:
...
...
@@ -370,8 +365,7 @@ def export_squad(model_export_path, input_meta_data):
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
,
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
float_type
=
tf
.
float32
)
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
squad_model
,
checkpoint_dir
=
FLAGS
.
model_dir
)
...
...
@@ -380,6 +374,10 @@ def main(_):
# Users should always run this script under TF 2.x
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
if
not
FLAGS
.
use_keras_bert_for_squad
:
raise
ValueError
(
'Old tf2 BERT is no longer supported. Please use keras BERT.'
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
...
...
official/nlp/bert_models.py
View file @
3635527d
...
...
@@ -366,8 +366,7 @@ def squad_model(bert_config,
max_seq_length
,
float_type
,
initializer
=
None
,
hub_module_url
=
None
,
use_keras_bert
=
False
):
hub_module_url
=
None
):
"""Returns BERT Squad model along with core BERT model to import weights.
Args:
...
...
@@ -377,23 +376,15 @@ def squad_model(bert_config,
initializer: Initializer for the final dense layer in the span labeler.
Defaulted to TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module.
use_keras_bert: Whether to use keras BERT. Note that when the above
'hub_module_url' is specified, 'use_keras_bert' cannot be True.
Returns:
A tuple of (1) keras model that outputs start logits and end logits and
(2) the core BERT transformer encoder.
Raises:
ValueError: When 'hub_module_url' is specified and 'use_keras_bert' is True.
"""
if
hub_module_url
and
use_keras_bert
:
raise
ValueError
(
'Cannot use hub_module_url and keras BERT at the same time.'
)
if
initializer
is
None
:
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
)
if
use_keras_bert
:
if
not
hub_module_url
:
bert_encoder
=
_get_transformer_encoder
(
bert_config
,
max_seq_length
,
float_type
)
return
bert_span_labeler
.
BertSpanLabeler
(
...
...
@@ -405,24 +396,12 @@ def squad_model(bert_config,
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
if
hub_module_url
:
core_model
=
hub
.
KerasLayer
(
hub_module_url
,
trainable
=
True
)
_
,
sequence_output
=
core_model
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
# Sets the shape manually due to a bug in TF shape inference.
# TODO(hongkuny): remove this once shape inference is correct.
sequence_output
.
set_shape
((
None
,
max_seq_length
,
bert_config
.
hidden_size
))
else
:
core_model
=
modeling
.
get_bert_model
(
input_word_ids
,
input_mask
,
input_type_ids
,
config
=
bert_config
,
name
=
'bert_model'
,
float_type
=
float_type
)
# `BertSquadModel` only uses the sequnce_output which
# has dimensionality (batch_size, sequence_length, num_hidden).
sequence_output
=
core_model
.
outputs
[
1
]
squad_logits_layer
=
BertSquadLogitsLayer
(
initializer
=
initializer
,
float_type
=
float_type
,
name
=
'squad_logits'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment