Commit a76237da authored by Rajagopal Ananthanarayanan's avatar Rajagopal Ananthanarayanan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 280019807
parent 95dc9045
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import json import json
import math import math
import os import os
...@@ -31,6 +32,7 @@ import tensorflow as tf ...@@ -31,6 +32,7 @@ import tensorflow as tf
from official.benchmark import bert_benchmark_utils as benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp import bert_modeling as modeling from official.nlp import bert_modeling as modeling
from official.nlp.bert import input_pipeline
from official.nlp.bert import run_classifier from official.nlp.bert import run_classifier
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
...@@ -76,6 +78,19 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -76,6 +78,19 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
steps_per_loop = 1 steps_per_loop = 1
max_seq_length = input_meta_data['max_seq_length']
train_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
FLAGS.train_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.train_batch_size)
eval_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
FLAGS.eval_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.eval_batch_size,
is_training=False,
drop_remainder=False)
run_classifier.run_bert_classifier( run_classifier.run_bert_classifier(
strategy, strategy,
bert_config, bert_config,
...@@ -88,6 +103,8 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -88,6 +103,8 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
warmup_steps, warmup_steps,
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.init_checkpoint, FLAGS.init_checkpoint,
train_input_fn,
eval_input_fn,
custom_callbacks=callbacks) custom_callbacks=callbacks)
......
...@@ -91,6 +91,8 @@ def run_bert_classifier(strategy, ...@@ -91,6 +91,8 @@ def run_bert_classifier(strategy,
warmup_steps, warmup_steps,
initial_lr, initial_lr,
init_checkpoint, init_checkpoint,
train_input_fn,
eval_input_fn,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=False, run_eagerly=False,
use_keras_compile_fit=False): use_keras_compile_fit=False):
...@@ -98,19 +100,6 @@ def run_bert_classifier(strategy, ...@@ -98,19 +100,6 @@ def run_bert_classifier(strategy,
max_seq_length = input_meta_data['max_seq_length'] max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels'] num_classes = input_meta_data['num_labels']
train_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
FLAGS.train_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.train_batch_size)
eval_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
FLAGS.eval_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.eval_batch_size,
is_training=False,
drop_remainder=False)
def _get_classifier_model(): def _get_classifier_model():
"""Gets a classifier model.""" """Gets a classifier model."""
classifier_model, core_model = ( classifier_model, core_model = (
...@@ -153,7 +142,7 @@ def run_bert_classifier(strategy, ...@@ -153,7 +142,7 @@ def run_bert_classifier(strategy,
if use_keras_compile_fit: if use_keras_compile_fit:
# Start training using Keras compile/fit API. # Start training using Keras compile/fit API.
logging.info('Training using TF 2.0 Keras compile/fit API with ' logging.info('Training using TF 2.0 Keras compile/fit API with '
'distrubuted strategy.') 'distribution strategy.')
return run_keras_compile_fit( return run_keras_compile_fit(
model_dir, model_dir,
strategy, strategy,
...@@ -170,7 +159,7 @@ def run_bert_classifier(strategy, ...@@ -170,7 +159,7 @@ def run_bert_classifier(strategy,
# Use user-defined loop to start training. # Use user-defined loop to start training.
logging.info('Training using customized training loop TF 2.0 with ' logging.info('Training using customized training loop TF 2.0 with '
'distrubuted strategy.') 'distribution strategy.')
return model_training_utils.run_customized_training_loop( return model_training_utils.run_customized_training_loop(
strategy=strategy, strategy=strategy,
model_fn=_get_classifier_model, model_fn=_get_classifier_model,
...@@ -237,7 +226,8 @@ def run_keras_compile_fit(model_dir, ...@@ -237,7 +226,8 @@ def run_keras_compile_fit(model_dir,
def export_classifier(model_export_path, input_meta_data, def export_classifier(model_export_path, input_meta_data,
restore_model_using_load_weights): restore_model_using_load_weights,
bert_config, model_dir):
"""Exports a trained model as a `SavedModel` for inference. """Exports a trained model as a `SavedModel` for inference.
Args: Args:
...@@ -249,15 +239,19 @@ def export_classifier(model_export_path, input_meta_data, ...@@ -249,15 +239,19 @@ def export_classifier(model_export_path, input_meta_data,
tf.train.Checkpoint and another is using Keras model.save_weights(). tf.train.Checkpoint and another is using Keras model.save_weights().
Custom training loop implementation uses tf.train.Checkpoint API Custom training loop implementation uses tf.train.Checkpoint API
and Keras ModelCheckpoint callback internally uses model.save_weights() and Keras ModelCheckpoint callback internally uses model.save_weights()
API. Since these two API's cannot be used toghether, model loading logic API. Since these two API's cannot be used together, model loading logic
must be take into account how model checkpoint was saved. must be take into account how model checkpoint was saved.
bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation
summaries are stored.
Raises: Raises:
Export path is not specified, got an empty string or None. Export path is not specified, got an empty string or None.
""" """
if not model_export_path: if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path) raise ValueError('Export path is not specified: %s' % model_export_path)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if not model_dir:
raise ValueError('Export path is not specified: %s' % model_dir)
classifier_model = bert_models.classifier_model( classifier_model = bert_models.classifier_model(
bert_config, tf.float32, input_meta_data['num_labels'], bert_config, tf.float32, input_meta_data['num_labels'],
...@@ -266,18 +260,20 @@ def export_classifier(model_export_path, input_meta_data, ...@@ -266,18 +260,20 @@ def export_classifier(model_export_path, input_meta_data,
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
model_export_path, model_export_path,
model=classifier_model, model=classifier_model,
checkpoint_dir=FLAGS.model_dir, checkpoint_dir=model_dir,
restore_model_using_load_weights=restore_model_using_load_weights) restore_model_using_load_weights=restore_model_using_load_weights)
def run_bert(strategy, input_meta_data): def run_bert(strategy, input_meta_data, train_input_fn, eval_input_fn):
"""Run BERT training.""" """Run BERT training."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.mode == 'export_only': if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API # As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must # internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used. # use model.load_weights() when Keras compile/fit() is used.
export_classifier(FLAGS.model_export_path, input_meta_data, export_classifier(FLAGS.model_export_path, input_meta_data,
FLAGS.use_keras_compile_fit) FLAGS.use_keras_compile_fit,
bert_config, FLAGS.model_dir)
return return
if FLAGS.mode != 'train_and_eval': if FLAGS.mode != 'train_and_eval':
...@@ -285,7 +281,6 @@ def run_bert(strategy, input_meta_data): ...@@ -285,7 +281,6 @@ def run_bert(strategy, input_meta_data):
# Enables XLA in Session Config. Should not be set for TPU. # Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla) keras_utils.set_config_v2(FLAGS.enable_xla)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs epochs = FLAGS.num_train_epochs
train_data_size = input_meta_data['train_data_size'] train_data_size = input_meta_data['train_data_size']
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
...@@ -308,6 +303,8 @@ def run_bert(strategy, input_meta_data): ...@@ -308,6 +303,8 @@ def run_bert(strategy, input_meta_data):
warmup_steps, warmup_steps,
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.init_checkpoint, FLAGS.init_checkpoint,
train_input_fn,
eval_input_fn,
run_eagerly=FLAGS.run_eagerly, run_eagerly=FLAGS.run_eagerly,
use_keras_compile_fit=FLAGS.use_keras_compile_fit) use_keras_compile_fit=FLAGS.use_keras_compile_fit)
...@@ -341,7 +338,21 @@ def main(_): ...@@ -341,7 +338,21 @@ def main(_):
else: else:
raise ValueError('The distribution strategy type is not supported: %s' % raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type) FLAGS.strategy_type)
run_bert(strategy, input_meta_data)
max_seq_length = input_meta_data['max_seq_length']
train_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
FLAGS.train_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.train_batch_size)
eval_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
FLAGS.eval_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.eval_batch_size,
is_training=False,
drop_remainder=False)
run_bert(strategy, input_meta_data, train_input_fn, eval_input_fn)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment