Commit 2926ba69 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 295757618
parent 9737810f
......@@ -85,8 +85,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return final_loss
def get_transformer_encoder(bert_config,
sequence_length):
def get_transformer_encoder(bert_config, sequence_length):
"""Gets a 'TransformerEncoder' object.
Args:
......@@ -189,7 +188,8 @@ def pretrain_model(bert_config,
def squad_model(bert_config,
max_seq_length,
initializer=None,
hub_module_url=None):
hub_module_url=None,
hub_module_trainable=True):
"""Returns BERT Squad model along with core BERT model to import weights.
Args:
......@@ -198,6 +198,7 @@ def squad_model(bert_config,
initializer: Initializer for the final dense layer in the span labeler.
Defaulted to TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module.
hub_module_trainable: True to finetune layers in the hub module.
Returns:
A tuple of (1) keras model that outputs start logits and end logits and
......@@ -217,7 +218,7 @@ def squad_model(bert_config,
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
core_model = hub.KerasLayer(hub_module_url, trainable=True)
core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
pooled_output, sequence_output = core_model(
[input_word_ids, input_mask, input_type_ids])
bert_encoder = tf.keras.Model(
......@@ -236,20 +237,22 @@ def classifier_model(bert_config,
num_labels,
max_seq_length,
final_layer_initializer=None,
hub_module_url=None):
hub_module_url=None,
hub_module_trainable=True):
"""BERT classifier model in functional API style.
Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`.
Args:
bert_config: BertConfig or AlbertConfig, the config defines the core
BERT or ALBERT model.
bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
ALBERT model.
num_labels: integer, the number of classes.
max_seq_length: integer, the maximum input sequence length.
final_layer_initializer: Initializer for final dense layer. Defaulted
TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module.
hub_module_trainable: True to finetune layers in the hub module.
Returns:
Combined prediction model (words, mask, type) -> (one-hot labels)
......@@ -275,15 +278,14 @@ def classifier_model(bert_config,
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_model = hub.KerasLayer(hub_module_url, trainable=True)
bert_model = hub.KerasLayer(
hub_module_url, trainable=hub_module_trainable)
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
pooled_output)
output = tf.keras.layers.Dense(
num_labels,
kernel_initializer=initializer,
name='output')(
num_labels, kernel_initializer=initializer, name='output')(
output)
return tf.keras.Model(
inputs={
......
......@@ -70,6 +70,8 @@ def define_common_bert_flags():
'model_type', 'bert', ['bert', 'albert'],
'Specifies the type of the model. '
'If "bert", will use canonical BERT; if "albert", will use ALBERT model.')
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
# Adds flags for mixed precision and multi-worker training.
flags_core.define_performance(
......
......@@ -124,7 +124,8 @@ def run_bert_classifier(strategy,
bert_config,
num_classes,
max_seq_length,
hub_module_url=FLAGS.hub_module_url))
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable))
classifier_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite':
......
......@@ -253,7 +253,8 @@ def train_squad(strategy,
squad_model, core_model = bert_models.squad_model(
bert_config,
max_seq_length,
hub_module_url=FLAGS.hub_module_url)
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment