Unverified Commit 1e2ceffd authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #4 from tensorflow/master

Updating 
parents 51e60bab c7adbbe4
...@@ -44,6 +44,11 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback): ...@@ -44,6 +44,11 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
self.batch_start_times[batch] = time.time() self.batch_start_times[batch] = time.time()
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
# If there are multiple steps_per_loop, the end batch index will not be the
# same as the starting index. Use the last starting index instead.
if batch not in self.batch_start_times:
batch = max(self.batch_start_times.keys())
self.batch_stop_times[batch] = time.time() self.batch_stop_times[batch] = time.time()
def get_examples_per_sec(self, batch_size, num_batches_to_skip=1): def get_examples_per_sec(self, batch_size, num_batches_to_skip=1):
......
...@@ -1305,7 +1305,6 @@ class Resnet50KerasPruningAccuracy(KerasPruningAccuracyBase): ...@@ -1305,7 +1305,6 @@ class Resnet50KerasPruningAccuracy(KerasPruningAccuracyBase):
'model': 'resnet50_v1.5', 'model': 'resnet50_v1.5',
'optimizer': 'mobilenet_default', 'optimizer': 'mobilenet_default',
'initial_learning_rate_per_sample': 0.0000039, 'initial_learning_rate_per_sample': 0.0000039,
'use_tf_keras_layers': True,
'pretrained_filepath': tf.train.latest_checkpoint( 'pretrained_filepath': tf.train.latest_checkpoint(
os.path.join(root_data_dir, 'resnet50')), os.path.join(root_data_dir, 'resnet50')),
'pruning_begin_step': 0, 'pruning_begin_step': 0,
...@@ -1369,7 +1368,6 @@ class Resnet50KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase): ...@@ -1369,7 +1368,6 @@ class Resnet50KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase):
default_flags = { default_flags = {
'model': 'resnet50_v1.5', 'model': 'resnet50_v1.5',
'optimizer': 'mobilenet_default', 'optimizer': 'mobilenet_default',
'use_tf_keras_layers': True,
} }
super(Resnet50KerasPruningBenchmarkReal, self).__init__( super(Resnet50KerasPruningBenchmarkReal, self).__init__(
default_flags=default_flags, **kwargs) default_flags=default_flags, **kwargs)
......
...@@ -18,9 +18,9 @@ from __future__ import absolute_import ...@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np from absl import app
from absl import app as absl_app
from absl import flags from absl import flags
import numpy as np
import tensorflow as tf import tensorflow as tf
from official.benchmark.models import resnet_cifar_model from official.benchmark.models import resnet_cifar_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -174,7 +174,6 @@ def run(flags_obj): ...@@ -174,7 +174,6 @@ def run(flags_obj):
is_training=True, is_training=True,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=cifar_preprocessing.parse_record, parse_record_fn=cifar_preprocessing.parse_record,
datasets_num_private_threads=flags_obj.datasets_num_private_threads, datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype, dtype=dtype,
...@@ -189,7 +188,6 @@ def run(flags_obj): ...@@ -189,7 +188,6 @@ def run(flags_obj):
is_training=False, is_training=False,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=cifar_preprocessing.parse_record) parse_record_fn=cifar_preprocessing.parse_record)
steps_per_epoch = ( steps_per_epoch = (
...@@ -284,4 +282,4 @@ def main(_): ...@@ -284,4 +282,4 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
define_cifar_flags() define_cifar_flags()
absl_app.run(main) app.run(main)
...@@ -281,6 +281,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -281,6 +281,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager')
FLAGS.batch_size = 128 FLAGS.batch_size = 128
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.single_l2_loss_op = True FLAGS.single_l2_loss_op = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -294,6 +295,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -294,6 +295,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 250 FLAGS.batch_size = 250
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.single_l2_loss_op = True FLAGS.single_l2_loss_op = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -324,6 +326,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -324,6 +326,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager')
FLAGS.batch_size = 128 FLAGS.batch_size = 128
...@@ -336,6 +339,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -336,6 +339,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager_fp16') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager_fp16')
FLAGS.batch_size = 128 FLAGS.batch_size = 128
...@@ -392,8 +396,7 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase): ...@@ -392,8 +396,7 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {} def_flags = {}
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['data_dir'] = ('/readahead/200M/placer/prod/home/distbelief/' def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
'imagenet-tensorflow/imagenet-2012-tfrecord')
def_flags['train_steps'] = 110 def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20 def_flags['steps_per_loop'] = 20
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
......
...@@ -24,7 +24,6 @@ import copy ...@@ -24,7 +24,6 @@ import copy
import functools import functools
from typing import Any, List, Mapping, Optional, Type from typing import Any, List, Mapping, Optional, Type
from absl import logging
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
import yaml import yaml
...@@ -74,8 +73,8 @@ class Config(params_dict.ParamsDict): ...@@ -74,8 +73,8 @@ class Config(params_dict.ParamsDict):
"""Returns v with dicts converted to Configs, recursively.""" """Returns v with dicts converted to Configs, recursively."""
if not issubclass(subconfig_type, params_dict.ParamsDict): if not issubclass(subconfig_type, params_dict.ParamsDict):
raise TypeError( raise TypeError(
'Subconfig_type should be subclass of ParamsDict, found %r', 'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
subconfig_type) subconfig_type))
if isinstance(v, cls.IMMUTABLE_TYPES): if isinstance(v, cls.IMMUTABLE_TYPES):
return v return v
elif isinstance(v, cls.SEQUENCE_TYPES): elif isinstance(v, cls.SEQUENCE_TYPES):
...@@ -95,7 +94,7 @@ class Config(params_dict.ParamsDict): ...@@ -95,7 +94,7 @@ class Config(params_dict.ParamsDict):
elif isinstance(v, dict): elif isinstance(v, dict):
return subconfig_type(v) return subconfig_type(v)
else: else:
raise TypeError('Unknown type: %r' % type(v)) raise TypeError('Unknown type: {!r}'.format(type(v)))
@classmethod @classmethod
def _export_config(cls, v): def _export_config(cls, v):
...@@ -162,7 +161,9 @@ class Config(params_dict.ParamsDict): ...@@ -162,7 +161,9 @@ class Config(params_dict.ParamsDict):
""" """
subconfig_type = self._get_subconfig_type(k) subconfig_type = self._get_subconfig_type(k)
if isinstance(v, dict): if isinstance(v, dict):
if k not in self.__dict__: if k not in self.__dict__ or not self.__dict__[k]:
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self.__dict__[k] = subconfig_type(v) self.__dict__[k] = subconfig_type(v)
else: else:
self.__dict__[k].override(v) self.__dict__[k].override(v)
...@@ -193,15 +194,16 @@ class Config(params_dict.ParamsDict): ...@@ -193,15 +194,16 @@ class Config(params_dict.ParamsDict):
'Can not be overridden.'.format(k)) 'Can not be overridden.'.format(k))
if k not in self.__dict__: if k not in self.__dict__:
if is_strict: if is_strict:
raise KeyError('The key {!r} does not exist. ' raise KeyError('The key {!r} does not exist in {!r}. '
'To extend the existing keys, use ' 'To extend the existing keys, use '
'`override` with `is_strict` = False.'.format(k)) '`override` with `is_strict` = False.'.format(
k, type(self)))
else: else:
self._set(k, v) self._set(k, v)
else: else:
if isinstance(v, dict): if isinstance(v, dict) and self.__dict__[k]:
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
elif isinstance(v, params_dict.ParamsDict): elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
else: else:
self._set(k, v) self._set(k, v)
...@@ -268,6 +270,8 @@ class RuntimeConfig(Config): ...@@ -268,6 +270,8 @@ class RuntimeConfig(Config):
multi-worker models with DistributionStrategy. multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker. task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce. all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
""" """
distribution_strategy: str = 'mirrored' distribution_strategy: str = 'mirrored'
enable_eager: bool = False enable_eager: bool = False
...@@ -281,6 +285,7 @@ class RuntimeConfig(Config): ...@@ -281,6 +285,7 @@ class RuntimeConfig(Config):
worker_hosts: Optional[str] = None worker_hosts: Optional[str] = None
task_index: int = -1 task_index: int = -1
all_reduce_alg: Optional[str] = None all_reduce_alg: Optional[str] = None
num_packs: int = 1
@dataclasses.dataclass @dataclasses.dataclass
...@@ -311,4 +316,3 @@ class CallbacksConfig(Config): ...@@ -311,4 +316,3 @@ class CallbacksConfig(Config):
""" """
enable_checkpoint_and_export: bool = True enable_checkpoint_and_export: bool = True
enable_tensorboard: bool = True enable_tensorboard: bool = True
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import json import json
import os import os
import tempfile
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -30,12 +31,29 @@ _SUMMARY_TXT = 'training_summary.txt' ...@@ -30,12 +31,29 @@ _SUMMARY_TXT = 'training_summary.txt'
_MIN_SUMMARY_STEPS = 10 _MIN_SUMMARY_STEPS = 10
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix): def _should_export_checkpoint(strategy):
return (not strategy) or strategy.extended.should_checkpoint
def _should_export_summary(strategy):
return (not strategy) or strategy.extended.should_save_summary
def _save_checkpoint(strategy, checkpoint, model_dir, checkpoint_prefix):
"""Saves model to with provided checkpoint prefix.""" """Saves model to with provided checkpoint prefix."""
if _should_export_checkpoint(strategy):
checkpoint_path = os.path.join(model_dir, checkpoint_prefix) checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path) saved_path = checkpoint.save(checkpoint_path)
logging.info('Saving model as TF checkpoint: %s', saved_path) logging.info('Saving model as TF checkpoint: %s', saved_path)
else:
# In multi worker training we need every worker to save checkpoint, because
# variables can trigger synchronization on read and synchronization needs
# all workers to participate. To avoid workers overriding each other we save
# to a temporary directory on non-chief workers.
tmp_dir = tempfile.mkdtemp()
checkpoint.save(os.path.join(tmp_dir, 'ckpt'))
tf.io.gfile.rmtree(tmp_dir)
return return
...@@ -242,7 +260,13 @@ def run_customized_training_loop( ...@@ -242,7 +260,13 @@ def run_customized_training_loop(
] ]
# Create summary writers # Create summary writers
if _should_export_summary(strategy):
summary_dir = os.path.join(model_dir, 'summaries') summary_dir = os.path.join(model_dir, 'summaries')
else:
# In multi worker training we need every worker to write summary, because
# variables can trigger synchronization on read and synchronization needs
# all workers to participate.
summary_dir = tempfile.mkdtemp()
eval_summary_writer = tf.summary.create_file_writer( eval_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'eval')) os.path.join(summary_dir, 'eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS: if steps_per_loop >= _MIN_SUMMARY_STEPS:
...@@ -395,8 +419,8 @@ def run_customized_training_loop( ...@@ -395,8 +419,8 @@ def run_customized_training_loop(
train_steps(train_iterator, train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32)) tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric) train_loss = _float_metric_value(train_loss_metric)
_run_callbacks_on_batch_end(current_step, {'loss': train_loss})
current_step += steps current_step += steps
_run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging. # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
...@@ -418,11 +442,11 @@ def run_customized_training_loop( ...@@ -418,11 +442,11 @@ def run_customized_training_loop(
# To avoid repeated model saving, we do not save after the last # To avoid repeated model saving, we do not save after the last
# step of training. # step of training.
if current_step < total_training_steps: if current_step < total_training_steps:
_save_checkpoint(checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
if sub_model_export_name: if sub_model_export_name:
_save_checkpoint( _save_checkpoint(
sub_model_checkpoint, model_dir, strategy, sub_model_checkpoint, model_dir,
'%s_step_%d.ckpt' % (sub_model_export_name, current_step)) '%s_step_%d.ckpt' % (sub_model_export_name, current_step))
if eval_input_fn: if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step) logging.info('Running evaluation after step: %s.', current_step)
...@@ -432,10 +456,10 @@ def run_customized_training_loop( ...@@ -432,10 +456,10 @@ def run_customized_training_loop(
for metric in eval_metrics + model.metrics: for metric in eval_metrics + model.metrics:
metric.reset_states() metric.reset_states()
_save_checkpoint(checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
if sub_model_export_name: if sub_model_export_name:
_save_checkpoint(sub_model_checkpoint, model_dir, _save_checkpoint(strategy, sub_model_checkpoint, model_dir,
'%s.ckpt' % sub_model_export_name) '%s.ckpt' % sub_model_export_name)
if eval_input_fn: if eval_input_fn:
...@@ -455,4 +479,7 @@ def run_customized_training_loop( ...@@ -455,4 +479,7 @@ def run_customized_training_loop(
write_txt_summary(training_summary, summary_dir) write_txt_summary(training_summary, summary_dir)
if not _should_export_summary(strategy):
tf.io.gfile.rmtree(summary_dir)
return model return model
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
from absl.testing import parameterized from absl.testing import parameterized
from absl.testing.absltest import mock
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -208,6 +209,27 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -208,6 +209,27 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword('mean_input', check_eventfile_for_keyword('mean_input',
os.path.join(model_dir, 'summaries/eval'))) os.path.join(model_dir, 'summaries/eval')))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
))
def test_train_check_artifacts_non_chief(self, distribution):
# We shouldn't export artifacts on non-chief workers. Since there's no easy
# way to test with real MultiWorkerMirroredStrategy, we patch the strategy
# to make it as if it's MultiWorkerMirroredStrategy on non-chief workers.
extended = distribution.extended
with mock.patch.object(extended.__class__, 'should_checkpoint',
new_callable=mock.PropertyMock, return_value=False), \
mock.patch.object(extended.__class__, 'should_save_summary',
new_callable=mock.PropertyMock, return_value=False):
model_dir = self.get_temp_dir()
self.run_training(
distribution, model_dir, steps_per_loop=10, run_eagerly=False)
self.assertEmpty(tf.io.gfile.listdir(model_dir))
if __name__ == '__main__': if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
......
...@@ -78,7 +78,6 @@ def export_albert_tfhub(albert_config: configs.AlbertConfig, ...@@ -78,7 +78,6 @@ def export_albert_tfhub(albert_config: configs.AlbertConfig,
def main(_): def main(_):
assert tf.version.VERSION.startswith('2.')
albert_config = configs.AlbertConfig.from_json_file( albert_config = configs.AlbertConfig.from_json_file(
FLAGS.albert_config_file) FLAGS.albert_config_file)
export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path, export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path,
......
...@@ -86,5 +86,4 @@ class ExportAlbertTfhubTest(tf.test.TestCase): ...@@ -86,5 +86,4 @@ class ExportAlbertTfhubTest(tf.test.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -33,7 +33,6 @@ FLAGS = flags.FLAGS ...@@ -33,7 +33,6 @@ FLAGS = flags.FLAGS
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
......
...@@ -80,7 +80,6 @@ def export_squad(model_export_path, input_meta_data): ...@@ -80,7 +80,6 @@ def export_squad(model_export_path, input_meta_data):
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
......
...@@ -122,7 +122,6 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint): ...@@ -122,7 +122,6 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def main(_): def main(_):
assert tf.version.VERSION.startswith('2.')
output_path = FLAGS.converted_checkpoint_path output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert v1_checkpoint = FLAGS.checkpoint_to_convert
albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file) albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
......
...@@ -77,6 +77,8 @@ def define_common_bert_flags(): ...@@ -77,6 +77,8 @@ def define_common_bert_flags():
flags.DEFINE_bool('hub_module_trainable', True, flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.') 'True to make keras layers in the hub module trainable.')
flags_core.define_log_steps()
# Adds flags for mixed precision and multi-worker training. # Adds flags for mixed precision and multi-worker training.
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=False, num_parallel_calls=False,
......
...@@ -77,7 +77,6 @@ def export_bert_tfhub(bert_config: configs.BertConfig, ...@@ -77,7 +77,6 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
def main(_): def main(_):
assert tf.version.VERSION.startswith('2.')
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path, export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file) FLAGS.vocab_file)
......
...@@ -84,5 +84,4 @@ class ExportTfhubTest(tf.test.TestCase): ...@@ -84,5 +84,4 @@ class ExportTfhubTest(tf.test.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -169,7 +169,7 @@ def run_bert_classifier(strategy, ...@@ -169,7 +169,7 @@ def run_bert_classifier(strategy,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
eval_steps, eval_steps,
custom_callbacks=None) custom_callbacks=custom_callbacks)
# Use user-defined loop to start training. # Use user-defined loop to start training.
logging.info('Training using customized training loop TF 2.0 with ' logging.info('Training using customized training loop TF 2.0 with '
...@@ -363,6 +363,15 @@ def run_bert(strategy, ...@@ -363,6 +363,15 @@ def run_bert(strategy,
if not strategy: if not strategy:
raise ValueError('Distribution strategy has not been specified.') raise ValueError('Distribution strategy has not been specified.')
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
trained_model = run_bert_classifier( trained_model = run_bert_classifier(
strategy, strategy,
model_config, model_config,
...@@ -378,7 +387,8 @@ def run_bert(strategy, ...@@ -378,7 +387,8 @@ def run_bert(strategy,
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
run_eagerly=FLAGS.run_eagerly, run_eagerly=FLAGS.run_eagerly,
use_keras_compile_fit=FLAGS.use_keras_compile_fit) use_keras_compile_fit=FLAGS.use_keras_compile_fit,
custom_callbacks=custom_callbacks)
if FLAGS.model_export_path: if FLAGS.model_export_path:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API # As Keras ModelCheckpoint callback used with Keras compile/fit() API
...@@ -393,7 +403,6 @@ def run_bert(strategy, ...@@ -393,7 +403,6 @@ def run_bert(strategy,
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
......
...@@ -159,7 +159,6 @@ def run_bert_pretrain(strategy): ...@@ -159,7 +159,6 @@ def run_bert_pretrain(strategy):
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
if not FLAGS.model_dir: if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/' FLAGS.model_dir = '/tmp/bert20/'
......
...@@ -29,6 +29,7 @@ from official.nlp.bert import run_squad_helper ...@@ -29,6 +29,7 @@ from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_string('vocab_file', None, flags.DEFINE_string('vocab_file', None,
...@@ -75,7 +76,6 @@ def export_squad(model_export_path, input_meta_data): ...@@ -75,7 +76,6 @@ def export_squad(model_export_path, input_meta_data):
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
...@@ -94,7 +94,21 @@ def main(_): ...@@ -94,7 +94,21 @@ def main(_):
all_reduce_alg=FLAGS.all_reduce_alg, all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'): if FLAGS.mode in ('train', 'train_and_predict'):
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly) if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
train_squad(
strategy,
input_meta_data,
custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly,
)
if FLAGS.mode in ('predict', 'train_and_predict'): if FLAGS.mode in ('predict', 'train_and_predict'):
predict_squad(strategy, input_meta_data) predict_squad(strategy, input_meta_data)
......
...@@ -98,7 +98,7 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint): ...@@ -98,7 +98,7 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def main(_): def main(_):
assert tf.version.VERSION.startswith('2.') tf.enable_v2_behavior()
output_path = FLAGS.converted_checkpoint_path output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert v1_checkpoint = FLAGS.checkpoint_to_convert
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
......
...@@ -6,7 +6,10 @@ assemble new layers, networks, or models. ...@@ -6,7 +6,10 @@ assemble new layers, networks, or models.
logic required to generate the einsum expression for the given initialization logic required to generate the einsum expression for the given initialization
parameters. parameters.
* [Attention](attention.py) implements an optionally masked attention between two tensors, from_tensor and to_tensor, as described in ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If `from_tensor` and `to_tensor` are the same, then this is self-attention. * [MultiHeadAttention](attention.py) implements an optionally masked attention
between two tensors, from_tensor and to_tensor, as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
If `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache used * [CachedAttention](attention.py) implements an attention layer with cache used
for auto-agressive decoding. for auto-agressive decoding.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment