Commit 2926ba69 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 295757618
parent 9737810f
...@@ -85,8 +85,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -85,8 +85,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return final_loss return final_loss
def get_transformer_encoder(bert_config, def get_transformer_encoder(bert_config, sequence_length):
sequence_length):
"""Gets a 'TransformerEncoder' object. """Gets a 'TransformerEncoder' object.
Args: Args:
...@@ -189,7 +188,8 @@ def pretrain_model(bert_config, ...@@ -189,7 +188,8 @@ def pretrain_model(bert_config,
def squad_model(bert_config, def squad_model(bert_config,
max_seq_length, max_seq_length,
initializer=None, initializer=None,
hub_module_url=None): hub_module_url=None,
hub_module_trainable=True):
"""Returns BERT Squad model along with core BERT model to import weights. """Returns BERT Squad model along with core BERT model to import weights.
Args: Args:
...@@ -198,6 +198,7 @@ def squad_model(bert_config, ...@@ -198,6 +198,7 @@ def squad_model(bert_config,
initializer: Initializer for the final dense layer in the span labeler. initializer: Initializer for the final dense layer in the span labeler.
Defaulted to TruncatedNormal initializer. Defaulted to TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module. hub_module_url: TF-Hub path/url to Bert module.
hub_module_trainable: True to finetune layers in the hub module.
Returns: Returns:
A tuple of (1) keras model that outputs start logits and end logits and A tuple of (1) keras model that outputs start logits and end logits and
...@@ -217,7 +218,7 @@ def squad_model(bert_config, ...@@ -217,7 +218,7 @@ def squad_model(bert_config,
shape=(max_seq_length,), dtype=tf.int32, name='input_mask') shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
core_model = hub.KerasLayer(hub_module_url, trainable=True) core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
pooled_output, sequence_output = core_model( pooled_output, sequence_output = core_model(
[input_word_ids, input_mask, input_type_ids]) [input_word_ids, input_mask, input_type_ids])
bert_encoder = tf.keras.Model( bert_encoder = tf.keras.Model(
...@@ -236,20 +237,22 @@ def classifier_model(bert_config, ...@@ -236,20 +237,22 @@ def classifier_model(bert_config,
num_labels, num_labels,
max_seq_length, max_seq_length,
final_layer_initializer=None, final_layer_initializer=None,
hub_module_url=None): hub_module_url=None,
hub_module_trainable=True):
"""BERT classifier model in functional API style. """BERT classifier model in functional API style.
Construct a Keras model for predicting `num_labels` outputs from an input with Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`. maximum sequence length `max_seq_length`.
Args: Args:
bert_config: BertConfig or AlbertConfig, the config defines the core bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
BERT or ALBERT model. ALBERT model.
num_labels: integer, the number of classes. num_labels: integer, the number of classes.
max_seq_length: integer, the maximum input sequence length. max_seq_length: integer, the maximum input sequence length.
final_layer_initializer: Initializer for final dense layer. Defaulted final_layer_initializer: Initializer for final dense layer. Defaulted
TruncatedNormal initializer. TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module. hub_module_url: TF-Hub path/url to Bert module.
hub_module_trainable: True to finetune layers in the hub module.
Returns: Returns:
Combined prediction model (words, mask, type) -> (one-hot labels) Combined prediction model (words, mask, type) -> (one-hot labels)
...@@ -275,15 +278,14 @@ def classifier_model(bert_config, ...@@ -275,15 +278,14 @@ def classifier_model(bert_config,
shape=(max_seq_length,), dtype=tf.int32, name='input_mask') shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_model = hub.KerasLayer(hub_module_url, trainable=True) bert_model = hub.KerasLayer(
hub_module_url, trainable=hub_module_trainable)
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids]) pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)( output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
pooled_output) pooled_output)
output = tf.keras.layers.Dense( output = tf.keras.layers.Dense(
num_labels, num_labels, kernel_initializer=initializer, name='output')(
kernel_initializer=initializer,
name='output')(
output) output)
return tf.keras.Model( return tf.keras.Model(
inputs={ inputs={
......
...@@ -70,6 +70,8 @@ def define_common_bert_flags(): ...@@ -70,6 +70,8 @@ def define_common_bert_flags():
'model_type', 'bert', ['bert', 'albert'], 'model_type', 'bert', ['bert', 'albert'],
'Specifies the type of the model. ' 'Specifies the type of the model. '
'If "bert", will use canonical BERT; if "albert", will use ALBERT model.') 'If "bert", will use canonical BERT; if "albert", will use ALBERT model.')
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
# Adds flags for mixed precision and multi-worker training. # Adds flags for mixed precision and multi-worker training.
flags_core.define_performance( flags_core.define_performance(
......
...@@ -124,7 +124,8 @@ def run_bert_classifier(strategy, ...@@ -124,7 +124,8 @@ def run_bert_classifier(strategy,
bert_config, bert_config,
num_classes, num_classes,
max_seq_length, max_seq_length,
hub_module_url=FLAGS.hub_module_url)) hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable))
classifier_model.optimizer = optimization.create_optimizer( classifier_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite': if FLAGS.fp16_implementation == 'graph_rewrite':
......
...@@ -253,7 +253,8 @@ def train_squad(strategy, ...@@ -253,7 +253,8 @@ def train_squad(strategy,
squad_model, core_model = bert_models.squad_model( squad_model, core_model = bert_models.squad_model(
bert_config, bert_config,
max_seq_length, max_seq_length,
hub_module_url=FLAGS.hub_module_url) hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)
squad_model.optimizer = optimization.create_optimizer( squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16: if use_float16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment