Commit 81d031d0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 281117886
parent c1ac2bfc
...@@ -41,16 +41,13 @@ def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix): ...@@ -41,16 +41,13 @@ def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
def _get_input_iterator(input_fn, strategy): def _get_input_iterator(input_fn, strategy):
"""Returns distributed dataset iterator.""" """Returns distributed dataset iterator."""
# When training with TPU pods, datasets needs to be cloned across # When training with TPU pods, datasets needs to be cloned across
# workers. Since Dataset instance cannot be cloned in eager mode, we instead # workers. Since Dataset instance cannot be cloned in eager mode, we instead
# pass callable that returns a dataset. # pass callable that returns a dataset.
input_data = input_fn() if not callable(input_fn):
if callable(input_data): raise ValueError('`input_fn` should be a closure that returns a dataset.')
iterator = iter( iterator = iter(
strategy.experimental_distribute_datasets_from_function(input_data)) strategy.experimental_distribute_datasets_from_function(input_fn))
else:
iterator = iter(strategy.experimental_distribute_dataset(input_data))
return iterator return iterator
......
...@@ -66,12 +66,15 @@ def create_fake_data_input_fn(batch_size, features_shape, num_classes): ...@@ -66,12 +66,15 @@ def create_fake_data_input_fn(batch_size, features_shape, num_classes):
An input function that is usable in the executor. An input function that is usable in the executor.
""" """
def _input_fn(): def _dataset_fn(input_context=None):
"""An input function for generating fake data.""" """An input function for generating fake data."""
local_batch_size = input_context.get_per_replica_batch_size(batch_size)
features = np.random.rand(64, *features_shape) features = np.random.rand(64, *features_shape)
labels = np.random.randint(2, size=[64, num_classes]) labels = np.random.randint(2, size=[64, num_classes])
# Convert the inputs to a Dataset. # Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)) dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
def _assign_dtype(features, labels): def _assign_dtype(features, labels):
features = tf.cast(features, tf.float32) features = tf.cast(features, tf.float32)
...@@ -81,11 +84,11 @@ def create_fake_data_input_fn(batch_size, features_shape, num_classes): ...@@ -81,11 +84,11 @@ def create_fake_data_input_fn(batch_size, features_shape, num_classes):
# Shuffle, repeat, and batch the examples. # Shuffle, repeat, and batch the examples.
dataset = dataset.map(_assign_dtype) dataset = dataset.map(_assign_dtype)
dataset = dataset.shuffle(64).repeat() dataset = dataset.shuffle(64).repeat()
dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.batch(local_batch_size, drop_remainder=True)
dataset = dataset.prefetch(buffer_size=64) dataset = dataset.prefetch(buffer_size=64)
return dataset return dataset
return _input_fn return _dataset_fn
def create_model_fn(input_shape, num_classes, use_float16=False): def create_model_fn(input_shape, num_classes, use_float16=False):
...@@ -134,21 +137,21 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -134,21 +137,21 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(ModelTrainingUtilsTest, self).setUp() super(ModelTrainingUtilsTest, self).setUp()
self._input_fn = create_fake_data_input_fn(
batch_size=8, features_shape=[128], num_classes=3)
self._model_fn = create_model_fn(input_shape=[128], num_classes=3) self._model_fn = create_model_fn(input_shape=[128], num_classes=3)
def run_training(self, distribution, model_dir, steps_per_loop, run_eagerly): def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly):
input_fn = create_fake_data_input_fn(
batch_size=8, features_shape=[128], num_classes=3)
model_training_utils.run_customized_training_loop( model_training_utils.run_customized_training_loop(
strategy=distribution, strategy=strategy,
model_fn=self._model_fn, model_fn=self._model_fn,
loss_fn=tf.keras.losses.categorical_crossentropy, loss_fn=tf.keras.losses.categorical_crossentropy,
model_dir=model_dir, model_dir=model_dir,
steps_per_epoch=20, steps_per_epoch=20,
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
epochs=2, epochs=2,
train_input_fn=self._input_fn, train_input_fn=input_fn,
eval_input_fn=self._input_fn, eval_input_fn=input_fn,
eval_steps=10, eval_steps=10,
init_checkpoint=None, init_checkpoint=None,
metric_fn=metric_fn, metric_fn=metric_fn,
......
...@@ -36,27 +36,22 @@ def decode_record(record, name_to_features): ...@@ -36,27 +36,22 @@ def decode_record(record, name_to_features):
return example return example
def file_based_input_fn_builder(input_file, name_to_features): def single_file_dataset(input_file, name_to_features):
"""Creates an `input_fn` closure to be passed for BERT custom training.""" """Creates a single-file dataset to be passed for BERT custom training."""
# For training, we want a lot of parallel reading and shuffling.
def input_fn(): # For eval, we want no shuffling and parallel reading doesn't matter.
"""Returns dataset for training/evaluation.""" d = tf.data.TFRecordDataset(input_file)
# For training, we want a lot of parallel reading and shuffling. d = d.map(lambda record: decode_record(record, name_to_features))
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file) # When `input_file` is a path to a single file or a list
d = d.map(lambda record: decode_record(record, name_to_features)) # containing a single path, disable auto sharding so that
# same input file is sent to all workers.
# When `input_file` is a path to a single file or a list if isinstance(input_file, str) or len(input_file) == 1:
# containing a single path, disable auto sharding so that options = tf.data.Options()
# same input file is sent to all workers. options.experimental_distribute.auto_shard_policy = (
if isinstance(input_file, str) or len(input_file) == 1: tf.data.experimental.AutoShardPolicy.OFF)
options = tf.data.Options() d = d.with_options(options)
options.experimental_distribute.auto_shard_policy = ( return d
tf.data.experimental.AutoShardPolicy.OFF)
d = d.with_options(options)
return d
return input_fn
def create_pretrain_dataset(input_patterns, def create_pretrain_dataset(input_patterns,
...@@ -142,7 +137,7 @@ def create_classifier_dataset(file_path, ...@@ -142,7 +137,7 @@ def create_classifier_dataset(file_path,
seq_length, seq_length,
batch_size, batch_size,
is_training=True, is_training=True,
drop_remainder=True): input_pipeline_context=None):
"""Creates input dataset from (tf)records files for train/eval.""" """Creates input dataset from (tf)records files for train/eval."""
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
...@@ -151,8 +146,13 @@ def create_classifier_dataset(file_path, ...@@ -151,8 +146,13 @@ def create_classifier_dataset(file_path,
'label_ids': tf.io.FixedLenFeature([], tf.int64), 'label_ids': tf.io.FixedLenFeature([], tf.int64),
'is_real_example': tf.io.FixedLenFeature([], tf.int64), 'is_real_example': tf.io.FixedLenFeature([], tf.int64),
} }
input_fn = file_based_input_fn_builder(file_path, name_to_features) dataset = single_file_dataset(file_path, name_to_features)
dataset = input_fn()
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
def _select_data_from_record(record): def _select_data_from_record(record):
x = { x = {
...@@ -169,12 +169,16 @@ def create_classifier_dataset(file_path, ...@@ -169,12 +169,16 @@ def create_classifier_dataset(file_path,
dataset = dataset.shuffle(100) dataset = dataset.shuffle(100)
dataset = dataset.repeat() dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.prefetch(1024) dataset = dataset.prefetch(1024)
return dataset return dataset
def create_squad_dataset(file_path, seq_length, batch_size, is_training=True): def create_squad_dataset(file_path,
seq_length,
batch_size,
is_training=True,
input_pipeline_context=None):
"""Creates input dataset from (tf)records files for train/eval.""" """Creates input dataset from (tf)records files for train/eval."""
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
...@@ -187,8 +191,13 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training=True): ...@@ -187,8 +191,13 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training=True):
else: else:
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
input_fn = file_based_input_fn_builder(file_path, name_to_features) dataset = single_file_dataset(file_path, name_to_features)
dataset = input_fn()
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
def _select_data_from_record(record): def _select_data_from_record(record):
"""Dispatches record to features and labels.""" """Dispatches record to features and labels."""
......
...@@ -18,7 +18,6 @@ from __future__ import absolute_import ...@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import json import json
import math import math
import os import os
...@@ -80,6 +79,25 @@ def get_loss_fn(num_classes, loss_factor=1.0): ...@@ -80,6 +79,25 @@ def get_loss_fn(num_classes, loss_factor=1.0):
return classification_loss_fn return classification_loss_fn
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_classifier_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def run_bert_classifier(strategy, def run_bert_classifier(strategy,
bert_config, bert_config,
input_meta_data, input_meta_data,
...@@ -264,7 +282,10 @@ def export_classifier(model_export_path, input_meta_data, ...@@ -264,7 +282,10 @@ def export_classifier(model_export_path, input_meta_data,
restore_model_using_load_weights=restore_model_using_load_weights) restore_model_using_load_weights=restore_model_using_load_weights)
def run_bert(strategy, input_meta_data, train_input_fn, eval_input_fn): def run_bert(strategy,
input_meta_data,
train_input_fn=None,
eval_input_fn=None):
"""Run BERT training.""" """Run BERT training."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.mode == 'export_only': if FLAGS.mode == 'export_only':
...@@ -340,18 +361,17 @@ def main(_): ...@@ -340,18 +361,17 @@ def main(_):
FLAGS.strategy_type) FLAGS.strategy_type)
max_seq_length = input_meta_data['max_seq_length'] max_seq_length = input_meta_data['max_seq_length']
train_input_fn = functools.partial( train_input_fn = get_dataset_fn(
input_pipeline.create_classifier_dataset,
FLAGS.train_data_path, FLAGS.train_data_path,
seq_length=max_seq_length, max_seq_length,
batch_size=FLAGS.train_batch_size) FLAGS.train_batch_size,
eval_input_fn = functools.partial( is_training=True)
input_pipeline.create_classifier_dataset, eval_input_fn = get_dataset_fn(
FLAGS.eval_data_path, FLAGS.eval_data_path,
seq_length=max_seq_length, max_seq_length,
batch_size=FLAGS.eval_batch_size, FLAGS.eval_batch_size,
is_training=False, is_training=False)
drop_remainder=False)
run_bert(strategy, input_meta_data, train_input_fn, eval_input_fn) run_bert(strategy, input_meta_data, train_input_fn, eval_input_fn)
......
...@@ -13,13 +13,10 @@ ...@@ -13,13 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Run masked LM/next sentence masked_lm pre-training for BERT in tf2.0.""" """Run masked LM/next sentence masked_lm pre-training for BERT in tf2.0."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
...@@ -56,31 +53,17 @@ common_flags.define_common_bert_flags() ...@@ -56,31 +53,17 @@ common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def get_pretrain_input_data(input_file_pattern, seq_length, def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, batch_size, strategy): max_predictions_per_seq, global_batch_size):
"""Returns input dataset from input file string.""" """Returns input dataset from input file string."""
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
'Batch size must be divisible by number of replicas : {}'.format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
def _dataset_fn(ctx=None): def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining.""" """Returns tf.data.Dataset for distributed BERT pretraining."""
input_patterns = input_file_pattern.split(',') input_files = []
for input_pattern in input_file_pattern.split(','):
input_files.extend(tf.io.gfile.glob(input_pattern))
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
train_dataset = input_pipeline.create_pretrain_dataset( train_dataset = input_pipeline.create_pretrain_dataset(
input_patterns, input_files,
seq_length, seq_length,
max_predictions_per_seq, max_predictions_per_seq,
batch_size, batch_size,
...@@ -88,7 +71,7 @@ def get_pretrain_input_data(input_file_pattern, seq_length, ...@@ -88,7 +71,7 @@ def get_pretrain_input_data(input_file_pattern, seq_length,
input_pipeline_context=ctx) input_pipeline_context=ctx)
return train_dataset return train_dataset
return _dataset_fn if use_dataset_fn else _dataset_fn() return _dataset_fn
def get_loss_fn(loss_factor=1.0): def get_loss_fn(loss_factor=1.0):
...@@ -114,9 +97,9 @@ def run_customized_training(strategy, ...@@ -114,9 +97,9 @@ def run_customized_training(strategy,
train_batch_size): train_batch_size):
"""Run BERT pretrain model training using low-level API.""" """Run BERT pretrain model training using low-level API."""
train_input_fn = functools.partial(get_pretrain_input_data, input_files, train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
max_seq_length, max_predictions_per_seq, max_predictions_per_seq,
train_batch_size, strategy) train_batch_size)
def _get_pretrain_model(): def _get_pretrain_model():
"""Gets a pretraining model.""" """Gets a pretraining model."""
......
...@@ -18,7 +18,6 @@ from __future__ import absolute_import ...@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import json import json
import os import os
...@@ -136,22 +135,44 @@ def get_raw_results(predictions): ...@@ -136,22 +135,44 @@ def get_raw_results(predictions):
end_logits=values[2].tolist()) end_logits=values[2].tolist())
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset.."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_squad_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def predict_squad_customized(strategy, input_meta_data, bert_config, def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps): predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model.""" """Make predictions using a Bert-based squad model."""
predict_dataset = input_pipeline.create_squad_dataset( predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path, predict_tfrecord_path,
input_meta_data['max_seq_length'], input_meta_data['max_seq_length'],
FLAGS.predict_batch_size, FLAGS.predict_batch_size,
is_training=False) is_training=False)
predict_iterator = iter( predict_iterator = iter(
strategy.experimental_distribute_dataset(predict_dataset)) strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope(): with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision. # Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model( squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32, bert_config,
input_meta_data['max_seq_length'],
float_type=tf.float32,
use_keras_bert=FLAGS.use_keras_bert_for_squad) use_keras_bert=FLAGS.use_keras_bert_for_squad)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
...@@ -208,8 +229,7 @@ def train_squad(strategy, ...@@ -208,8 +229,7 @@ def train_squad(strategy,
max_seq_length = input_meta_data['max_seq_length'] max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size) steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size) warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = functools.partial( train_input_fn = get_dataset_fn(
input_pipeline.create_squad_dataset,
FLAGS.train_data_path, FLAGS.train_data_path,
max_seq_length, max_seq_length,
FLAGS.train_batch_size, FLAGS.train_batch_size,
...@@ -347,7 +367,9 @@ def export_squad(model_export_path, input_meta_data): ...@@ -347,7 +367,9 @@ def export_squad(model_export_path, input_meta_data):
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
squad_model, _ = bert_models.squad_model( squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32, bert_config,
input_meta_data['max_seq_length'],
float_type=tf.float32,
use_keras_bert=FLAGS.use_keras_bert_for_squad) use_keras_bert=FLAGS.use_keras_bert_for_squad)
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir) model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment