Commit 3635527d authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Remove old tf2 BERT for squad

PiperOrigin-RevId: 281714406
parent a2a1b66f
...@@ -81,9 +81,7 @@ flags.DEFINE_integer( ...@@ -81,9 +81,7 @@ flags.DEFINE_integer(
'The maximum length of an answer that can be generated. This is needed ' 'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.') 'because the start and end predictions are not conditioned on one another.')
flags.DEFINE_bool( flags.DEFINE_bool(
'use_keras_bert_for_squad', False, 'Whether to use keras BERT for squad ' 'use_keras_bert_for_squad', True, 'Deprecated and will be removed soon.')
'task. Note that when the FLAG "hub_module_url" is specified, '
'"use_keras_bert_for_squad" cannot be True.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
...@@ -173,8 +171,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -173,8 +171,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
squad_model, _ = bert_models.squad_model( squad_model, _ = bert_models.squad_model(
bert_config, bert_config,
input_meta_data['max_seq_length'], input_meta_data['max_seq_length'],
float_type=tf.float32, float_type=tf.float32)
use_keras_bert=FLAGS.use_keras_bert_for_squad)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path) logging.info('Restoring checkpoints from %s', checkpoint_path)
...@@ -242,9 +239,7 @@ def train_squad(strategy, ...@@ -242,9 +239,7 @@ def train_squad(strategy,
bert_config, bert_config,
max_seq_length, max_seq_length,
float_type=tf.float16 if use_float16 else tf.float32, float_type=tf.float16 if use_float16 else tf.float32,
hub_module_url=FLAGS.hub_module_url, hub_module_url=FLAGS.hub_module_url)
use_keras_bert=False
if FLAGS.hub_module_url else FLAGS.use_keras_bert_for_squad)
squad_model.optimizer = optimization.create_optimizer( squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16: if use_float16:
...@@ -370,8 +365,7 @@ def export_squad(model_export_path, input_meta_data): ...@@ -370,8 +365,7 @@ def export_squad(model_export_path, input_meta_data):
squad_model, _ = bert_models.squad_model( squad_model, _ = bert_models.squad_model(
bert_config, bert_config,
input_meta_data['max_seq_length'], input_meta_data['max_seq_length'],
float_type=tf.float32, float_type=tf.float32)
use_keras_bert=FLAGS.use_keras_bert_for_squad)
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir) model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
...@@ -380,6 +374,10 @@ def main(_): ...@@ -380,6 +374,10 @@ def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
if not FLAGS.use_keras_bert_for_squad:
raise ValueError(
'Old tf2 BERT is no longer supported. Please use keras BERT.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
......
...@@ -366,8 +366,7 @@ def squad_model(bert_config, ...@@ -366,8 +366,7 @@ def squad_model(bert_config,
max_seq_length, max_seq_length,
float_type, float_type,
initializer=None, initializer=None,
hub_module_url=None, hub_module_url=None):
use_keras_bert=False):
"""Returns BERT Squad model along with core BERT model to import weights. """Returns BERT Squad model along with core BERT model to import weights.
Args: Args:
...@@ -377,23 +376,15 @@ def squad_model(bert_config, ...@@ -377,23 +376,15 @@ def squad_model(bert_config,
initializer: Initializer for the final dense layer in the span labeler. initializer: Initializer for the final dense layer in the span labeler.
Defaulted to TruncatedNormal initializer. Defaulted to TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module. hub_module_url: TF-Hub path/url to Bert module.
use_keras_bert: Whether to use keras BERT. Note that when the above
'hub_module_url' is specified, 'use_keras_bert' cannot be True.
Returns: Returns:
A tuple of (1) keras model that outputs start logits and end logits and A tuple of (1) keras model that outputs start logits and end logits and
(2) the core BERT transformer encoder. (2) the core BERT transformer encoder.
Raises:
ValueError: When 'hub_module_url' is specified and 'use_keras_bert' is True.
""" """
if hub_module_url and use_keras_bert:
raise ValueError(
'Cannot use hub_module_url and keras BERT at the same time.')
if initializer is None: if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal( initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range) stddev=bert_config.initializer_range)
if use_keras_bert: if not hub_module_url:
bert_encoder = _get_transformer_encoder(bert_config, max_seq_length, bert_encoder = _get_transformer_encoder(bert_config, max_seq_length,
float_type) float_type)
return bert_span_labeler.BertSpanLabeler( return bert_span_labeler.BertSpanLabeler(
...@@ -405,24 +396,12 @@ def squad_model(bert_config, ...@@ -405,24 +396,12 @@ def squad_model(bert_config,
shape=(max_seq_length,), dtype=tf.int32, name='input_mask') shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
if hub_module_url: core_model = hub.KerasLayer(hub_module_url, trainable=True)
core_model = hub.KerasLayer(hub_module_url, trainable=True) _, sequence_output = core_model(
_, sequence_output = core_model( [input_word_ids, input_mask, input_type_ids])
[input_word_ids, input_mask, input_type_ids]) # Sets the shape manually due to a bug in TF shape inference.
# Sets the shape manually due to a bug in TF shape inference. # TODO(hongkuny): remove this once shape inference is correct.
# TODO(hongkuny): remove this once shape inference is correct. sequence_output.set_shape((None, max_seq_length, bert_config.hidden_size))
sequence_output.set_shape((None, max_seq_length, bert_config.hidden_size))
else:
core_model = modeling.get_bert_model(
input_word_ids,
input_mask,
input_type_ids,
config=bert_config,
name='bert_model',
float_type=float_type)
# `BertSquadModel` only uses the sequnce_output which
# has dimensionality (batch_size, sequence_length, num_hidden).
sequence_output = core_model.outputs[1]
squad_logits_layer = BertSquadLogitsLayer( squad_logits_layer = BertSquadLogitsLayer(
initializer=initializer, float_type=float_type, name='squad_logits') initializer=initializer, float_type=float_type, name='squad_logits')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment