Unverified Commit 7ebfc3dd authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #1 from tensorflow/master

updated
parents 867f0c47 6f0e3a0b
......@@ -193,6 +193,7 @@ class Transformer(tf.keras.layers.Layer):
base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@tf.function(experimental_compile=True)
def call(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
......@@ -204,19 +205,21 @@ class Transformer(tf.keras.layers.Layer):
if attention_mask is not None:
attention_inputs.append(attention_mask)
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor +
attention_output)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. cast layer_output to fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output
with tf.name_scope(self.name):
attention_output = self._attention_layer(attention_inputs)
attention_output = self._attention_output_dense(attention_output)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor +
attention_output)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output
......@@ -298,13 +298,16 @@ class EpochHelper(object):
self._epoch_steps = epoch_steps
self._global_step = global_step
self._current_epoch = None
self._epoch_start_step = None
self._in_epoch = False
def epoch_begin(self):
"""Returns whether a new epoch should begin."""
if self._in_epoch:
return False
self._current_epoch = self._global_step.numpy() / self._epoch_steps
current_step = self._global_step.numpy()
self._epoch_start_step = current_step
self._current_epoch = current_step // self._epoch_steps
self._in_epoch = True
return True
......@@ -313,13 +316,18 @@ class EpochHelper(object):
if not self._in_epoch:
raise ValueError("`epoch_end` can only be called inside an epoch")
current_step = self._global_step.numpy()
epoch = current_step / self._epoch_steps
epoch = current_step // self._epoch_steps
if epoch > self._current_epoch:
self._in_epoch = False
return True
return False
@property
def batch_index(self):
"""Index of the next batch within the current epoch."""
return self._global_step.numpy() - self._epoch_start_step
@property
def current_epoch(self):
return self._current_epoch
......@@ -26,7 +26,7 @@ from absl import logging
import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import tf2
from tensorflow.python.profiler import profiler_v2 as profiler
from tensorflow.python.eager import profiler
class BatchTimestamp(object):
......@@ -44,17 +44,28 @@ class BatchTimestamp(object):
class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models."""
def __init__(self, batch_size, log_steps):
def __init__(self, batch_size, log_steps, logdir=None):
"""Callback for logging performance.
Args:
batch_size: Total batch size.
log_steps: Interval of steps between logging of batch level stats.
logdir: Optional directory to write TensorBoard summaries.
"""
# TODO(wcromar): remove this parameter and rely on `logs` parameter of
# on_train_batch_end()
self.batch_size = batch_size
super(TimeHistory, self).__init__()
self.log_steps = log_steps
self.global_steps = 0
self.last_log_step = 0
self.steps_before_epoch = 0
self.steps_in_epoch = 0
self.start_time = None
if logdir:
self.summary_writer = tf.summary.create_file_writer(logdir)
else:
self.summary_writer = None
# Logs start of step 1 then end of each step based on log_steps interval.
self.timestamp_log = []
......@@ -62,38 +73,70 @@ class TimeHistory(tf.keras.callbacks.Callback):
# Records the time each epoch takes to run from start to finish of epoch.
self.epoch_runtime_log = []
@property
def global_steps(self):
"""The current 1-indexed global step."""
return self.steps_before_epoch + self.steps_in_epoch
@property
def average_steps_per_second(self):
"""The average training steps per second across all epochs."""
return self.global_steps / sum(self.epoch_runtime_log)
@property
def average_examples_per_second(self):
"""The average number of training examples per second across all epochs."""
return self.average_steps_per_second * self.batch_size
def on_train_end(self, logs=None):
self.train_finish_time = time.time()
if self.summary_writer:
self.summary_writer.flush()
def on_epoch_begin(self, epoch, logs=None):
self.epoch_start = time.time()
def on_batch_begin(self, batch, logs=None):
self.global_steps += 1
if self.global_steps == 1:
if not self.start_time:
self.start_time = time.time()
# Record the timestamp of the first global step
if not self.timestamp_log:
self.timestamp_log.append(BatchTimestamp(self.global_steps,
self.start_time))
def on_batch_end(self, batch, logs=None):
"""Records elapse time of the batch and calculates examples per second."""
if self.global_steps % self.log_steps == 0:
timestamp = time.time()
elapsed_time = timestamp - self.start_time
examples_per_second = (self.batch_size * self.log_steps) / elapsed_time
self.timestamp_log.append(BatchTimestamp(self.global_steps, timestamp))
self.steps_in_epoch = batch + 1
steps_since_last_log = self.global_steps - self.last_log_step
if steps_since_last_log >= self.log_steps:
now = time.time()
elapsed_time = now - self.start_time
steps_per_second = steps_since_last_log / elapsed_time
examples_per_second = steps_per_second * self.batch_size
self.timestamp_log.append(BatchTimestamp(self.global_steps, now))
logging.info(
"BenchmarkMetric: {'global step':%d, 'time_taken': %f,"
"'examples_per_second': %f}",
self.global_steps, elapsed_time, examples_per_second)
self.start_time = timestamp
"TimeHistory: %.2f examples/second between steps %d and %d",
examples_per_second, self.last_log_step, self.global_steps)
if self.summary_writer:
with self.summary_writer.as_default():
tf.summary.scalar('global_step/sec', steps_per_second,
self.global_steps)
tf.summary.scalar('examples/sec', examples_per_second,
self.global_steps)
self.last_log_step = self.global_steps
self.start_time = None
def on_epoch_end(self, epoch, logs=None):
epoch_run_time = time.time() - self.epoch_start
self.epoch_runtime_log.append(epoch_run_time)
logging.info(
"BenchmarkMetric: {'epoch':%d, 'time_taken': %f}",
epoch, epoch_run_time)
self.steps_before_epoch += self.steps_in_epoch
self.steps_in_epoch = 0
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
......@@ -145,15 +188,17 @@ class ProfilerCallback(tf.keras.callbacks.Callback):
def on_batch_begin(self, batch, logs=None):
if batch == self.start_step_in_epoch and self.should_start:
self.should_start = False
profiler.start(self.log_dir)
profiler.start()
logging.info('Profiler started at Step %s', self.start_step)
def on_batch_end(self, batch, logs=None):
if batch == self.stop_step_in_epoch and self.should_stop:
self.should_stop = False
profiler.stop()
logging.info('Profiler saved profiles for steps between %s and %s to %s',
self.start_step, self.stop_step, self.log_dir)
results = profiler.stop()
profiler.save(self.log_dir, results)
logging.info(
'Profiler saved profiles for steps between %s and %s to %s',
self.start_step, self.stop_step, self.log_dir)
def set_session_config(enable_eager=False,
......
......@@ -14,8 +14,9 @@
# ==============================================================================
"""Config template to train Mask R-CNN."""
from official.vision.detection.configs import base_config
from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import base_config
# pylint: disable=line-too-long
MASKRCNN_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
......@@ -23,6 +24,7 @@ MASKRCNN_CFG.override({
'type': 'mask_rcnn',
'eval': {
'type': 'box_and_mask',
'num_images_to_visualize': 0,
},
'architecture': {
'parser': 'maskrcnn_parser',
......
......@@ -23,9 +23,8 @@
# need to be fine-tuned for the detection task.
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET50_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/'
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
# pylint: disable=line-too-long
RETINANET_CFG = {
......@@ -54,10 +53,11 @@ RETINANET_CFG = {
'path': '',
'prefix': '',
},
'frozen_variable_prefix': RESNET50_FROZEN_VAR_PREFIX,
'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX,
'train_file_pattern': '',
# TODO(b/142174042): Support transpose_input option.
'transpose_input': False,
'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
'l2_weight_decay': 0.0001,
'input_sharding': False,
},
......
......@@ -18,11 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import functools
import re
import six
from absl import logging
import tensorflow.compat.v2 as tf
......@@ -53,7 +51,7 @@ class OptimizerFactory(object):
self._optimizer = tf.keras.optimizers.Adagrad
elif params.type == 'rmsprop':
self._optimizer = functools.partial(
tf.keras.optimizers.RMSProp, momentum=params.momentum)
tf.keras.optimizers.RMSprop, momentum=params.momentum)
else:
raise ValueError('Unsupported optimizer type %s.' % self._optimizer)
......@@ -104,6 +102,7 @@ class Model(object):
params.train.learning_rate)
self._frozen_variable_prefix = params.train.frozen_variable_prefix
self._regularization_var_regex = params.train.regularization_variable_regex
self._l2_weight_decay = params.train.l2_weight_decay
# Checkpoint restoration.
......@@ -146,12 +145,17 @@ class Model(object):
"""
return _make_filter_trainable_variables_fn(self._frozen_variable_prefix)
def weight_decay_loss(self, l2_weight_decay, trainable_variables):
return l2_weight_decay * tf.add_n([
tf.nn.l2_loss(v)
for v in trainable_variables
if 'batch_normalization' not in v.name and 'bias' not in v.name
])
def weight_decay_loss(self, trainable_variables):
reg_variables = [
v for v in trainable_variables
if self._regularization_var_regex is None
or re.match(self._regularization_var_regex, v.name)
]
logging.info('Regularization Variables: %s',
[v.name for v in reg_variables])
return self._l2_weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in reg_variables])
def make_restore_checkpoint_fn(self):
"""Returns scaffold function to restore parameters from v1 checkpoint."""
......
......@@ -106,8 +106,7 @@ class RetinanetModel(base_model.Model):
labels['box_targets'],
labels['num_positives'])
model_loss = cls_loss + self._box_loss_weight * box_loss
l2_regularization_loss = self.weight_decay_loss(self._l2_weight_decay,
trainable_variables)
l2_regularization_loss = self.weight_decay_loss(trainable_variables)
total_loss = model_loss + l2_regularization_loss
return {
'total_loss': total_loss,
......
......@@ -188,7 +188,10 @@ def get_callbacks(
enable_checkpoint_and_export=False,
model_dir=None):
"""Returns common callbacks."""
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
time_callback = keras_utils.TimeHistory(
FLAGS.batch_size,
FLAGS.log_steps,
logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks = [time_callback]
if not FLAGS.use_tensor_lr and learning_rate_schedule_fn:
......@@ -265,11 +268,9 @@ def build_stats(history, eval_output, callbacks):
timestamp_log = callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
callback.batch_size * callback.log_steps *
(len(callback.timestamp_log)-1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
if callback.epoch_runtime_log:
stats['avg_exp_per_second'] = callback.average_examples_per_second
return stats
......
......@@ -64,15 +64,8 @@ def build_stats(runnable, time_callback):
timestamp_log = time_callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = time_callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
avg_exp_per_second = tf.reduce_mean(
runnable.examples_per_second_history).numpy(),
stats['avg_exp_per_second'] = avg_exp_per_second
if time_callback.epoch_runtime_log:
stats['avg_exp_per_second'] = time_callback.average_examples_per_second
return stats
......@@ -154,8 +147,10 @@ def run(flags_obj):
'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
train_epochs * per_epoch_steps, eval_steps)
time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
flags_obj.log_steps)
time_callback = keras_utils.TimeHistory(
flags_obj.batch_size,
flags_obj.log_steps,
logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
with distribution_utils.get_strategy_scope(strategy):
runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
per_epoch_steps)
......
......@@ -114,7 +114,6 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
# Handling epochs.
self.epoch_steps = epoch_steps
self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step)
self.examples_per_second_history = []
def build_train_dataset(self):
"""See base class."""
......@@ -147,8 +146,8 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
self.train_loss.reset_states()
self.train_accuracy.reset_states()
self.time_callback.on_batch_begin(self.global_step)
self._epoch_begin()
self.time_callback.on_batch_begin(self.epoch_helper.batch_index)
def train_step(self, iterator):
"""See base class."""
......@@ -194,12 +193,13 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
def train_loop_end(self):
"""See base class."""
self.time_callback.on_batch_end(self.global_step)
self._epoch_end()
return {
metrics = {
'train_loss': self.train_loss.result(),
'train_accuracy': self.train_accuracy.result(),
}
self.time_callback.on_batch_end(self.epoch_helper.batch_index - 1)
self._epoch_end()
return metrics
def eval_begin(self):
"""See base class."""
......@@ -234,10 +234,3 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
def _epoch_end(self):
if self.epoch_helper.epoch_end():
self.time_callback.on_epoch_end(self.epoch_helper.current_epoch)
epoch_time = self.time_callback.epoch_runtime_log[-1]
steps_per_second = self.epoch_steps / epoch_time
examples_per_second = steps_per_second * self.flags_obj.batch_size
self.examples_per_second_history.append(examples_per_second)
tf.summary.scalar('global_step/sec', steps_per_second)
tf.summary.scalar('examples/sec', examples_per_second)
......@@ -53,6 +53,34 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample rate = 16000\n"
]
}
],
"source": [
"# Read in the audio.\n",
"# You can get this example waveform via:\n",
"# curl -O https://storage.googleapis.com/audioset/speech_whistling2.wav\n",
"\n",
"wav_file_name = 'speech_whistling2.wav'\n",
"\n",
"wav_data, sr = sf.read(wav_file_name, dtype=np.int16)\n",
"waveform = wav_data / 32768.0\n",
"# The graph is designed for a sampling rate of 16 kHz, but higher rates \n",
"# should work too.\n",
"params.SAMPLE_RATE = sr\n",
"print(\"Sample rate =\", params.SAMPLE_RATE)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
......@@ -74,21 +102,6 @@
" yamnet.load_weights('yamnet.h5')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Read in the audio.\n",
"# You can get this example waveform via:\n",
"# curl -O https://storage.googleapis.com/audioset/speech_whistling2.wav\n",
"wav_data, sr = sf.read('speech_whistling2.wav', dtype=np.int16)\n",
"waveform = wav_data / 32768.0\n",
"# Sampling rate should be 16000 Hz.\n",
"assert sr == 16000"
]
},
{
"cell_type": "code",
"execution_count": 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment