Commit f7fd59b8 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 284792715
parent 558bab5d
...@@ -65,6 +65,10 @@ def define_common_bert_flags(): ...@@ -65,6 +65,10 @@ def define_common_bert_flags():
flags.DEFINE_string( flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. ' 'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.') 'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_enum(
'model_type', 'bert', ['bert', 'albert'],
'Specifies the type of the model. '
'If "bert", will use canonical BERT; if "albert", will use ALBERT model.')
# Adds flags for mixed precision training. # Adds flags for mixed precision training.
flags_core.define_performance( flags_core.define_performance(
......
...@@ -287,7 +287,11 @@ def run_bert(strategy, ...@@ -287,7 +287,11 @@ def run_bert(strategy,
train_input_fn=None, train_input_fn=None,
eval_input_fn=None): eval_input_fn=None):
"""Run BERT training.""" """Run BERT training."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.model_type == 'bert':
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
else:
assert FLAGS.model_type == 'albert'
bert_config = modeling.AlbertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.mode == 'export_only': if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API # As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must # internally uses model.save_weights() to save checkpoints, we must
......
...@@ -22,6 +22,7 @@ import tensorflow as tf ...@@ -22,6 +22,7 @@ import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp import bert_modeling
from official.nlp.modeling import losses from official.nlp.modeling import losses
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier from official.nlp.modeling.networks import bert_classifier
...@@ -139,14 +140,14 @@ def _get_transformer_encoder(bert_config, ...@@ -139,14 +140,14 @@ def _get_transformer_encoder(bert_config,
"""Gets a 'TransformerEncoder' object. """Gets a 'TransformerEncoder' object.
Args: Args:
bert_config: A 'modeling.BertConfig' object. bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: Maximum sequence length of the training data. sequence_length: Maximum sequence length of the training data.
float_dtype: tf.dtype, tf.float32 or tf.float16. float_dtype: tf.dtype, tf.float32 or tf.float16.
Returns: Returns:
A networks.TransformerEncoder object. A networks.TransformerEncoder object.
""" """
return networks.TransformerEncoder( kwargs = dict(
vocab_size=bert_config.vocab_size, vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size, hidden_size=bert_config.hidden_size,
num_layers=bert_config.num_hidden_layers, num_layers=bert_config.num_hidden_layers,
...@@ -161,6 +162,12 @@ def _get_transformer_encoder(bert_config, ...@@ -161,6 +162,12 @@ def _get_transformer_encoder(bert_config,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range), stddev=bert_config.initializer_range),
float_dtype=float_dtype.name) float_dtype=float_dtype.name)
if isinstance(bert_config, bert_modeling.AlbertConfig):
kwargs['embedding_width'] = bert_config.embedding_size
return networks.AlbertTransformerEncoder(**kwargs)
else:
assert isinstance(bert_config, bert_modeling.BertConfig)
return networks.TransformerEncoder(**kwargs)
def pretrain_model(bert_config, def pretrain_model(bert_config,
...@@ -332,7 +339,8 @@ def classifier_model(bert_config, ...@@ -332,7 +339,8 @@ def classifier_model(bert_config,
maximum sequence length `max_seq_length`. maximum sequence length `max_seq_length`.
Args: Args:
bert_config: BertConfig, the config defines the core BERT model. bert_config: BertConfig or AlbertConfig, the config defines the core
BERT or ALBERT model.
float_type: dtype, tf.float32 or tf.bfloat16. float_type: dtype, tf.float32 or tf.bfloat16.
num_labels: integer, the number of classes. num_labels: integer, the number of classes.
max_seq_length: integer, the maximum input sequence length. max_seq_length: integer, the maximum input sequence length.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment