Commit 81d031d0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 281117886
parent c1ac2bfc
......@@ -41,16 +41,13 @@ def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
def _get_input_iterator(input_fn, strategy):
"""Returns distributed dataset iterator."""
# When training with TPU pods, datasets needs to be cloned across
# workers. Since Dataset instance cannot be cloned in eager mode, we instead
# pass callable that returns a dataset.
input_data = input_fn()
if callable(input_data):
iterator = iter(
strategy.experimental_distribute_datasets_from_function(input_data))
else:
iterator = iter(strategy.experimental_distribute_dataset(input_data))
if not callable(input_fn):
raise ValueError('`input_fn` should be a closure that returns a dataset.')
iterator = iter(
strategy.experimental_distribute_datasets_from_function(input_fn))
return iterator
......
......@@ -66,12 +66,15 @@ def create_fake_data_input_fn(batch_size, features_shape, num_classes):
An input function that is usable in the executor.
"""
def _input_fn():
def _dataset_fn(input_context=None):
"""An input function for generating fake data."""
local_batch_size = input_context.get_per_replica_batch_size(batch_size)
features = np.random.rand(64, *features_shape)
labels = np.random.randint(2, size=[64, num_classes])
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
def _assign_dtype(features, labels):
features = tf.cast(features, tf.float32)
......@@ -81,11 +84,11 @@ def create_fake_data_input_fn(batch_size, features_shape, num_classes):
# Shuffle, repeat, and batch the examples.
dataset = dataset.map(_assign_dtype)
dataset = dataset.shuffle(64).repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.batch(local_batch_size, drop_remainder=True)
dataset = dataset.prefetch(buffer_size=64)
return dataset
return _input_fn
return _dataset_fn
def create_model_fn(input_shape, num_classes, use_float16=False):
......@@ -134,21 +137,21 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(ModelTrainingUtilsTest, self).setUp()
self._input_fn = create_fake_data_input_fn(
batch_size=8, features_shape=[128], num_classes=3)
self._model_fn = create_model_fn(input_shape=[128], num_classes=3)
def run_training(self, distribution, model_dir, steps_per_loop, run_eagerly):
def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly):
input_fn = create_fake_data_input_fn(
batch_size=8, features_shape=[128], num_classes=3)
model_training_utils.run_customized_training_loop(
strategy=distribution,
strategy=strategy,
model_fn=self._model_fn,
loss_fn=tf.keras.losses.categorical_crossentropy,
model_dir=model_dir,
steps_per_epoch=20,
steps_per_loop=steps_per_loop,
epochs=2,
train_input_fn=self._input_fn,
eval_input_fn=self._input_fn,
train_input_fn=input_fn,
eval_input_fn=input_fn,
eval_steps=10,
init_checkpoint=None,
metric_fn=metric_fn,
......
......@@ -36,27 +36,22 @@ def decode_record(record, name_to_features):
return example
def file_based_input_fn_builder(input_file, name_to_features):
"""Creates an `input_fn` closure to be passed for BERT custom training."""
def input_fn():
"""Returns dataset for training/evaluation."""
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
d = d.map(lambda record: decode_record(record, name_to_features))
# When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that
# same input file is sent to all workers.
if isinstance(input_file, str) or len(input_file) == 1:
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
d = d.with_options(options)
return d
return input_fn
def single_file_dataset(input_file, name_to_features):
"""Creates a single-file dataset to be passed for BERT custom training."""
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
d = d.map(lambda record: decode_record(record, name_to_features))
# When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that
# same input file is sent to all workers.
if isinstance(input_file, str) or len(input_file) == 1:
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
d = d.with_options(options)
return d
def create_pretrain_dataset(input_patterns,
......@@ -142,7 +137,7 @@ def create_classifier_dataset(file_path,
seq_length,
batch_size,
is_training=True,
drop_remainder=True):
input_pipeline_context=None):
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
......@@ -151,8 +146,13 @@ def create_classifier_dataset(file_path,
'label_ids': tf.io.FixedLenFeature([], tf.int64),
'is_real_example': tf.io.FixedLenFeature([], tf.int64),
}
input_fn = file_based_input_fn_builder(file_path, name_to_features)
dataset = input_fn()
dataset = single_file_dataset(file_path, name_to_features)
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
def _select_data_from_record(record):
x = {
......@@ -169,12 +169,16 @@ def create_classifier_dataset(file_path,
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.prefetch(1024)
return dataset
def create_squad_dataset(file_path, seq_length, batch_size, is_training=True):
def create_squad_dataset(file_path,
seq_length,
batch_size,
is_training=True,
input_pipeline_context=None):
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
......@@ -187,8 +191,13 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training=True):
else:
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
input_fn = file_based_input_fn_builder(file_path, name_to_features)
dataset = input_fn()
dataset = single_file_dataset(file_path, name_to_features)
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
def _select_data_from_record(record):
"""Dispatches record to features and labels."""
......
......@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import math
import os
......@@ -80,6 +79,25 @@ def get_loss_fn(num_classes, loss_factor=1.0):
return classification_loss_fn
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_classifier_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def run_bert_classifier(strategy,
bert_config,
input_meta_data,
......@@ -264,7 +282,10 @@ def export_classifier(model_export_path, input_meta_data,
restore_model_using_load_weights=restore_model_using_load_weights)
def run_bert(strategy, input_meta_data, train_input_fn, eval_input_fn):
def run_bert(strategy,
input_meta_data,
train_input_fn=None,
eval_input_fn=None):
"""Run BERT training."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.mode == 'export_only':
......@@ -340,18 +361,17 @@ def main(_):
FLAGS.strategy_type)
max_seq_length = input_meta_data['max_seq_length']
train_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.train_batch_size)
eval_input_fn = functools.partial(
input_pipeline.create_classifier_dataset,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
eval_input_fn = get_dataset_fn(
FLAGS.eval_data_path,
seq_length=max_seq_length,
batch_size=FLAGS.eval_batch_size,
is_training=False,
drop_remainder=False)
max_seq_length,
FLAGS.eval_batch_size,
is_training=False)
run_bert(strategy, input_meta_data, train_input_fn, eval_input_fn)
......
......@@ -13,13 +13,10 @@
# limitations under the License.
# ==============================================================================
"""Run masked LM/next sentence masked_lm pre-training for BERT in tf2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl import app
from absl import flags
from absl import logging
......@@ -56,31 +53,17 @@ common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
def get_pretrain_input_data(input_file_pattern, seq_length,
max_predictions_per_seq, batch_size, strategy):
def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, global_batch_size):
"""Returns input dataset from input file string."""
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
'Batch size must be divisible by number of replicas : {}'.format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
input_patterns = input_file_pattern.split(',')
input_files = []
for input_pattern in input_file_pattern.split(','):
input_files.extend(tf.io.gfile.glob(input_pattern))
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
train_dataset = input_pipeline.create_pretrain_dataset(
input_patterns,
input_files,
seq_length,
max_predictions_per_seq,
batch_size,
......@@ -88,7 +71,7 @@ def get_pretrain_input_data(input_file_pattern, seq_length,
input_pipeline_context=ctx)
return train_dataset
return _dataset_fn if use_dataset_fn else _dataset_fn()
return _dataset_fn
def get_loss_fn(loss_factor=1.0):
......@@ -114,9 +97,9 @@ def run_customized_training(strategy,
train_batch_size):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = functools.partial(get_pretrain_input_data, input_files,
max_seq_length, max_predictions_per_seq,
train_batch_size, strategy)
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
max_predictions_per_seq,
train_batch_size)
def _get_pretrain_model():
"""Gets a pretraining model."""
......
......@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import os
......@@ -136,22 +135,44 @@ def get_raw_results(predictions):
end_logits=values[2].tolist())
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset.."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_squad_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset = input_pipeline.create_squad_dataset(
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_dataset(predict_dataset))
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32,
bert_config,
input_meta_data['max_seq_length'],
float_type=tf.float32,
use_keras_bert=FLAGS.use_keras_bert_for_squad)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
......@@ -208,8 +229,7 @@ def train_squad(strategy,
max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = functools.partial(
input_pipeline.create_squad_dataset,
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
......@@ -347,7 +367,9 @@ def export_squad(model_export_path, input_meta_data):
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32,
bert_config,
input_meta_data['max_seq_length'],
float_type=tf.float32,
use_keras_bert=FLAGS.use_keras_bert_for_squad)
model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment