Commit 7359586f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into detr-push-3

parents c594cecf a78b05b9
...@@ -299,20 +299,21 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -299,20 +299,21 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark):
return os.path.join(self.output_dir, folder_name) return os.path.join(self.output_dir, folder_name)
class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Resnet50 (classifier_trainer) benchmarks.""" """Classifier Trainer benchmarks."""
def __init__(self, output_dir=None, default_flags=None, def __init__(self, model, output_dir=None, default_flags=None,
tpu=None, dataset_builder='records', train_epochs=1, tpu=None, dataset_builder='records', train_epochs=1,
train_steps=110, data_dir=None): train_steps=110, data_dir=None):
flag_methods = [classifier_trainer.define_classifier_flags] flag_methods = [classifier_trainer.define_classifier_flags]
self.model = model
self.dataset_builder = dataset_builder self.dataset_builder = dataset_builder
self.train_epochs = train_epochs self.train_epochs = train_epochs
self.train_steps = train_steps self.train_steps = train_steps
self.data_dir = data_dir self.data_dir = data_dir
super(Resnet50KerasClassifierBenchmarkBase, self).__init__( super(KerasClassifierBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
flag_methods=flag_methods, flag_methods=flag_methods,
default_flags=default_flags, default_flags=default_flags,
...@@ -337,7 +338,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -337,7 +338,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
dataset_num_private_threads: Optional[int] = None, dataset_num_private_threads: Optional[int] = None,
loss_scale: Optional[str] = None): loss_scale: Optional[str] = None):
"""Runs and reports the benchmark given the provided configuration.""" """Runs and reports the benchmark given the provided configuration."""
FLAGS.model_type = 'resnet' FLAGS.model_type = self.model
FLAGS.dataset = 'imagenet' FLAGS.dataset = 'imagenet'
FLAGS.mode = 'train_and_eval' FLAGS.mode = 'train_and_eval'
FLAGS.data_dir = self.data_dir FLAGS.data_dir = self.data_dir
...@@ -372,7 +373,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -372,7 +373,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
# input skip_steps. # input skip_steps.
warmup = (skip_steps or (self.train_steps - 100)) // FLAGS.log_steps warmup = (skip_steps or (self.train_steps - 100)) // FLAGS.log_steps
super(Resnet50KerasClassifierBenchmarkBase, self)._report_benchmark( super(KerasClassifierBenchmarkBase, self)._report_benchmark(
stats, stats,
wall_time_sec, wall_time_sec,
total_batch_size=total_batch_size, total_batch_size=total_batch_size,
...@@ -599,8 +600,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -599,8 +600,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='mirrored', distribution_strategy='mirrored',
per_replica_batch_size=256, per_replica_batch_size=256,
gpu_thread_mode='gpu_private', gpu_thread_mode='gpu_private',
dataset_num_private_threads=48, dataset_num_private_threads=48)
steps=310)
def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self): def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
"""Tests Keras model with config tuning, XLA, 8 GPUs and dynamic fp16.""" """Tests Keras model with config tuning, XLA, 8 GPUs and dynamic fp16."""
...@@ -636,6 +636,28 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -636,6 +636,28 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='tpu', distribution_strategy='tpu',
per_replica_batch_size=128) per_replica_batch_size=128)
def benchmark_2x2_tpu_bf16_mlir(self):
"""Test Keras model with 2x2 TPU, bf16."""
self._setup()
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
experiment_name='benchmark_2x2_tpu_bf16_mlir',
dtype='bfloat16',
num_tpus=8,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_4x4_tpu_bf16_mlir(self):
"""Test Keras model with 4x4 TPU, bf16."""
self._setup()
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
experiment_name='benchmark_4x4_tpu_bf16_mlir',
dtype='bfloat16',
num_tpus=32,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_8x8_tpu_bf16(self): def benchmark_8x8_tpu_bf16(self):
"""Test Keras model with 8x8 TPU, bf16.""" """Test Keras model with 8x8 TPU, bf16."""
self._setup() self._setup()
...@@ -647,7 +669,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -647,7 +669,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
per_replica_batch_size=64) per_replica_batch_size=64)
def fill_report_object(self, stats): def fill_report_object(self, stats):
super(Resnet50KerasClassifierBenchmarkBase, self).fill_report_object( super(KerasClassifierBenchmarkBase, self).fill_report_object(
stats, stats,
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps) log_steps=FLAGS.log_steps)
...@@ -1086,7 +1108,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -1086,7 +1108,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
log_steps=FLAGS.log_steps) log_steps=FLAGS.log_steps)
class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase): class Resnet50KerasBenchmarkSynth(KerasClassifierBenchmarkBase):
"""Resnet50 synthetic benchmark tests.""" """Resnet50 synthetic benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
...@@ -1094,11 +1116,11 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase): ...@@ -1094,11 +1116,11 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase):
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkSynth, self).__init__( super(Resnet50KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags, tpu=tpu, model='resnet', output_dir=output_dir, default_flags=def_flags, tpu=tpu,
dataset_builder='synthetic', train_epochs=1, train_steps=110) dataset_builder='synthetic', train_epochs=1, train_steps=110)
class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase): class Resnet50KerasBenchmarkReal(KerasClassifierBenchmarkBase):
"""Resnet50 real data benchmark tests.""" """Resnet50 real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
...@@ -1107,11 +1129,25 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase): ...@@ -1107,11 +1129,25 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase):
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkReal, self).__init__( super(Resnet50KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags, tpu=tpu, model='resnet', output_dir=output_dir, default_flags=def_flags, tpu=tpu,
dataset_builder='records', train_epochs=1, train_steps=110, dataset_builder='records', train_epochs=1, train_steps=110,
data_dir=data_dir) data_dir=data_dir)
class EfficientNetKerasBenchmarkReal(KerasClassifierBenchmarkBase):
"""EfficientNet real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
data_dir = os.path.join(root_data_dir, 'imagenet')
def_flags = {}
def_flags['log_steps'] = 10
super(EfficientNetKerasBenchmarkReal, self).__init__(
model='efficientnet', output_dir=output_dir, default_flags=def_flags,
tpu=tpu, dataset_builder='records', train_epochs=1, train_steps=110,
data_dir=data_dir)
class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase): class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
"""Resnet50 real data (stored in remote storage) benchmark tests.""" """Resnet50 real data (stored in remote storage) benchmark tests."""
......
...@@ -44,11 +44,11 @@ RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoi ...@@ -44,11 +44,11 @@ RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoi
# pylint: enable=line-too-long # pylint: enable=line-too-long
class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark): class BenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
"""Base class to hold methods common to test classes.""" """Base class to hold methods common to test classes."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DetectionBenchmarkBase, self).__init__(**kwargs) super(BenchmarkBase, self).__init__(**kwargs)
self.timer_callback = None self.timer_callback = None
def _report_benchmark(self, stats, start_time_sec, wall_time_sec, min_ap, def _report_benchmark(self, stats, start_time_sec, wall_time_sec, min_ap,
...@@ -99,7 +99,7 @@ class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark): ...@@ -99,7 +99,7 @@ class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
extras={'flags': flags_str}) extras={'flags': flags_str})
class RetinanetBenchmarkBase(DetectionBenchmarkBase): class DetectionBenchmarkBase(BenchmarkBase):
"""Base class to hold methods common to test classes in the module.""" """Base class to hold methods common to test classes in the module."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -107,7 +107,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase): ...@@ -107,7 +107,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
self.eval_data_path = COCO_EVAL_DATA self.eval_data_path = COCO_EVAL_DATA
self.eval_json_path = COCO_EVAL_JSON self.eval_json_path = COCO_EVAL_JSON
self.resnet_checkpoint_path = RESNET_CHECKPOINT_PATH self.resnet_checkpoint_path = RESNET_CHECKPOINT_PATH
super(RetinanetBenchmarkBase, self).__init__(**kwargs) super(DetectionBenchmarkBase, self).__init__(**kwargs)
def _run_detection_main(self): def _run_detection_main(self):
"""Starts detection job.""" """Starts detection job."""
...@@ -118,7 +118,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase): ...@@ -118,7 +118,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
return detection.run() return detection.run()
class RetinanetAccuracy(RetinanetBenchmarkBase): class DetectionAccuracy(DetectionBenchmarkBase):
"""Accuracy test for RetinaNet model. """Accuracy test for RetinaNet model.
Tests RetinaNet detection task model accuracy. The naming Tests RetinaNet detection task model accuracy. The naming
...@@ -126,6 +126,10 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -126,6 +126,10 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
`benchmark_(number of gpus)_gpu_(dataset type)` format. `benchmark_(number of gpus)_gpu_(dataset type)` format.
""" """
def __init__(self, model, **kwargs):
self.model = model
super(DetectionAccuracy, self).__init__(**kwargs)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
params, params,
...@@ -133,7 +137,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -133,7 +137,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
max_ap=0.35, max_ap=0.35,
do_eval=True, do_eval=True,
warmup=1): warmup=1):
"""Starts RetinaNet accuracy benchmark test.""" """Starts Detection accuracy benchmark test."""
FLAGS.params_override = json.dumps(params) FLAGS.params_override = json.dumps(params)
# Need timer callback to measure performance # Need timer callback to measure performance
self.timer_callback = keras_utils.TimeHistory( self.timer_callback = keras_utils.TimeHistory(
...@@ -156,8 +160,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -156,8 +160,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
max_ap, warmup) max_ap, warmup)
def _setup(self): def _setup(self):
super(RetinanetAccuracy, self)._setup() super(DetectionAccuracy, self)._setup()
FLAGS.model = 'retinanet' FLAGS.model = self.model
def _params(self): def _params(self):
return { return {
...@@ -195,22 +199,22 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -195,22 +199,22 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
self._run_and_report_benchmark(params) self._run_and_report_benchmark(params)
class RetinanetBenchmarkReal(RetinanetAccuracy): class DetectionBenchmarkReal(DetectionAccuracy):
"""Short benchmark performance tests for RetinaNet model. """Short benchmark performance tests for a detection model.
Tests RetinaNet performance in different GPU configurations. Tests detection performance in different accelerator configurations.
The naming convention of below test cases follow The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format. `benchmark_(number of gpus)_gpu` format.
""" """
def _setup(self): def _setup(self):
super(RetinanetBenchmarkReal, self)._setup() super(DetectionBenchmarkReal, self)._setup()
# Use negative value to avoid saving checkpoints. # Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1 FLAGS.save_checkpoint_freq = -1
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_8_gpu_coco(self): def benchmark_8_gpu_coco(self):
"""Run RetinaNet model accuracy test with 8 GPUs.""" """Run detection model accuracy test with 8 GPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False params['architecture']['use_bfloat16'] = False
...@@ -230,7 +234,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -230,7 +234,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_1_gpu_coco(self): def benchmark_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU.""" """Run detection model accuracy test with 1 GPU."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False params['architecture']['use_bfloat16'] = False
...@@ -245,7 +249,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -245,7 +249,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_xla_1_gpu_coco(self): def benchmark_xla_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU and XLA enabled.""" """Run detection model accuracy test with 1 GPU and XLA enabled."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False params['architecture']['use_bfloat16'] = False
...@@ -261,7 +265,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -261,7 +265,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_2x2_tpu_coco(self): def benchmark_2x2_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs.""" """Run detection model accuracy test with 4 TPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['train']['batch_size'] = 64 params['train']['batch_size'] = 64
...@@ -273,7 +277,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -273,7 +277,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_4x4_tpu_coco(self): def benchmark_4x4_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs.""" """Run detection model accuracy test with 4 TPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['train']['batch_size'] = 256 params['train']['batch_size'] = 256
...@@ -285,7 +289,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -285,7 +289,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_2x2_tpu_coco_mlir(self): def benchmark_2x2_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs.""" """Run detection model accuracy test with 4 TPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['train']['batch_size'] = 64 params['train']['batch_size'] = 64
...@@ -311,7 +315,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -311,7 +315,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_2x2_tpu_spinenet_coco(self): def benchmark_2x2_tpu_spinenet_coco(self):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs.""" """Run detection model with SpineNet backbone accuracy test with 4 TPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['backbone'] = 'spinenet' params['architecture']['backbone'] = 'spinenet'
...@@ -327,5 +331,32 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -327,5 +331,32 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
self._run_and_report_benchmark(params, do_eval=False, warmup=0) self._run_and_report_benchmark(params, do_eval=False, warmup=0)
class RetinanetBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Retinanet model."""
def __init__(self, **kwargs):
super(RetinanetBenchmarkReal, self).__init__(
model='retinanet',
**kwargs)
class MaskRCNNBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Mask RCNN model."""
def __init__(self, **kwargs):
super(MaskRCNNBenchmarkReal, self).__init__(
model='mask_rcnn',
**kwargs)
class ShapeMaskBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for ShapeMask model."""
def __init__(self, **kwargs):
super(ShapeMaskBenchmarkReal, self).__init__(
model='shapemask',
**kwargs)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -32,8 +32,9 @@ class InputReader: ...@@ -32,8 +32,9 @@ class InputReader:
dataset_fn=tf.data.TFRecordDataset, dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None, decoder_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None, parser_fn: Optional[Callable[..., Any]] = None,
dataset_transform_fn: Optional[Callable[[tf.data.Dataset], transform_and_batch_fn: Optional[Callable[
tf.data.Dataset]] = None, [tf.data.Dataset, Optional[tf.distribute.InputContext]],
tf.data.Dataset]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None): postprocess_fn: Optional[Callable[..., Any]] = None):
"""Initializes an InputReader instance. """Initializes an InputReader instance.
...@@ -48,9 +49,12 @@ class InputReader: ...@@ -48,9 +49,12 @@ class InputReader:
parser_fn: An optional `callable` that takes the decoded raw tensors dict parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn. model. It will be executed after decoder_fn.
dataset_transform_fn: An optional `callable` that takes a transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and returns a `tf.data.Dataset`. It will be `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
executed after parser_fn. input, and returns a `tf.data.Dataset` object. It will be
executed after `parser_fn` to transform and batch the dataset; if None,
after `parser_fn` is executed, the dataset will be batched into
per-replica batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching. will be executed after batching.
""" """
...@@ -101,7 +105,7 @@ class InputReader: ...@@ -101,7 +105,7 @@ class InputReader:
self._dataset_fn = dataset_fn self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn self._decoder_fn = decoder_fn
self._parser_fn = parser_fn self._parser_fn = parser_fn
self._dataset_transform_fn = dataset_transform_fn self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn self._postprocess_fn = postprocess_fn
def _read_sharded_files( def _read_sharded_files(
...@@ -214,13 +218,13 @@ class InputReader: ...@@ -214,13 +218,13 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._decoder_fn) dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = maybe_map_fn(dataset, self._parser_fn) dataset = maybe_map_fn(dataset, self._parser_fn)
if self._dataset_transform_fn is not None: if self._transform_and_batch_fn is not None:
dataset = self._dataset_transform_fn(dataset) dataset = self._transform_and_batch_fn(dataset, input_context)
else:
per_replica_batch_size = input_context.get_per_replica_batch_size( per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size self._global_batch_size) if input_context else self._global_batch_size
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
dataset = maybe_map_fn(dataset, self._postprocess_fn) dataset = maybe_map_fn(dataset, self._postprocess_fn)
return dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset.prefetch(tf.data.experimental.AUTOTUNE)
...@@ -34,7 +34,7 @@ class MaskedLM(tf.keras.layers.Layer): ...@@ -34,7 +34,7 @@ class MaskedLM(tf.keras.layers.Layer):
Arguments: Arguments:
embedding_table: The embedding table of the targets. embedding_table: The embedding table of the targets.
activation: The activation, if any, for the dense layer. activation: The activation, if any, for the dense layer.
initializer: The intializer for the dense layer. Defaults to a Glorot initializer: The initializer for the dense layer. Defaults to a Glorot
uniform initializer. uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
'predictions'. 'predictions'.
......
...@@ -37,8 +37,8 @@ class Classification(tf.keras.Model): ...@@ -37,8 +37,8 @@ class Classification(tf.keras.Model):
num_classes: The number of classes that this network should classify to. If num_classes: The number of classes that this network should classify to. If
equal to 1, a regression problem is assumed. equal to 1, a regression problem is assumed.
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to initializer: The initializer for the dense layer in this network. Defaults
a Glorot uniform initializer. to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
'predictions'. 'predictions'.
""" """
......
...@@ -33,8 +33,8 @@ class SpanLabeling(tf.keras.Model): ...@@ -33,8 +33,8 @@ class SpanLabeling(tf.keras.Model):
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to initializer: The initializer for the dense layer in this network. Defaults
a Glorot uniform initializer. to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
'predictions'. 'predictions'.
""" """
......
...@@ -34,8 +34,8 @@ class TokenClassification(tf.keras.Model): ...@@ -34,8 +34,8 @@ class TokenClassification(tf.keras.Model):
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to. num_classes: The number of classes that this network should classify to.
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to initializer: The initializer for the dense layer in this network. Defaults
a Glorot uniform initializer. to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
'predictions'. 'predictions'.
""" """
......
...@@ -274,8 +274,21 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -274,8 +274,21 @@ class QuestionAnsweringTask(base_task.Task):
if self.task_config.validation_data.version_2_with_negative: if self.task_config.validation_data.version_2_with_negative:
eval_metrics = squad_evaluate_v2_0.evaluate( eval_metrics = squad_evaluate_v2_0.evaluate(
pred_dataset, all_predictions, scores_diff) pred_dataset, all_predictions, scores_diff)
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
eval_metrics = {
'exact_match': eval_metrics['final_exact'],
'exact_match_threshold': eval_metrics['final_exact_thresh'],
'final_f1': eval_metrics['final_f1'] / 100.0, # scale back to [0, 1].
'f1_threshold': eval_metrics['final_f1_thresh'],
'has_answer_exact_match': eval_metrics['HasAns_exact'],
'has_answer_f1': eval_metrics['HasAns_f1']}
else: else:
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions) eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
eval_metrics = {'exact_match': eval_metrics['exact_match'],
'final_f1': eval_metrics['final_f1']}
return eval_metrics return eval_metrics
def initialize(self, model): def initialize(self, model):
......
...@@ -55,7 +55,7 @@ def rel_shift(x, klen=-1): ...@@ -55,7 +55,7 @@ def rel_shift(x, klen=-1):
def _get_initializer(flags): def _get_initializer(flags):
"""Get variable intializer.""" """Get variable initializer."""
if flags.init_method == 'uniform': if flags.init_method == 'uniform':
initializer = tf.keras.initializers.RandomUniform( initializer = tf.keras.initializers.RandomUniform(
minval=-flags.init_range, maxval=flags.init_range) minval=-flags.init_range, maxval=flags.init_range)
......
...@@ -488,19 +488,19 @@ def run_ncf_custom_training(params, ...@@ -488,19 +488,19 @@ def run_ncf_custom_training(params,
c.on_batch_end(current_step) c.on_batch_end(current_step)
train_loss /= num_train_steps train_loss /= num_train_steps
logging.info("Done training epoch %s, epoch loss=%s.", epoch + 1, logging.info("Done training epoch %s, epoch loss=%.3f", epoch + 1,
train_loss) train_loss)
eval_input_iterator = iter( eval_input_iterator = iter(eval_input_dataset)
strategy.experimental_distribute_dataset(eval_input_dataset))
hr_sum = 0 hr_sum = 0.0
hr_count = 0 hr_count = 0.0
for _ in range(num_eval_steps): for _ in range(num_eval_steps):
step_hr_sum, step_hr_count = eval_step(eval_input_iterator) step_hr_sum, step_hr_count = eval_step(eval_input_iterator)
hr_sum += step_hr_sum hr_sum += step_hr_sum
hr_count += step_hr_count hr_count += step_hr_count
logging.info("Done eval epoch %s, hit_rate=%s.", epoch + 1, logging.info("Done eval epoch %s, hit_rate=%.3f", epoch + 1,
hr_sum / hr_count) hr_sum / hr_count)
if eval_summary_writer: if eval_summary_writer:
with eval_summary_writer.as_default(): with eval_summary_writer.as_default():
......
...@@ -17,11 +17,33 @@ ...@@ -17,11 +17,33 @@
import abc import abc
from typing import Any, Dict, Optional, Text from typing import Any, Dict, Optional, Text
import dataclasses
from orbit import runner from orbit import runner
from orbit import utils from orbit import utils
import tensorflow as tf import tensorflow as tf
@dataclasses.dataclass(frozen=True)
class TrainerOverrides:
"""Advanced overrides for Orbit trainers.
Attributes:
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
"""
use_tf_while_loop: bool = True
use_tf_function: bool = True
use_tpu_summary_optimization: bool = False
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractTrainer APIs.""" """Implements the standard functionality of AbstractTrainer APIs."""
...@@ -102,6 +124,12 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): ...@@ -102,6 +124,12 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
context" for generality, to allow e.g. multiple iterator dequeues and calls context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.run`. to `strategy.run`.
Note that if `use_tf_function=True`, all the code inside `train_step` should
be tf.function compatible, as they will be traced with tf.function. This
means you cannot put arbitrary python code in this function. If users have
any numpy operations, they should be put in `train_loop_begin` or
`train_loop_end` functions.
Args: Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or iterator: A tf.nest-compatible structure of tf.data Iterator or
DistributedIterator. DistributedIterator.
...@@ -139,6 +167,17 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): ...@@ -139,6 +167,17 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
self._train_iter = None self._train_iter = None
@dataclasses.dataclass(frozen=True)
class EvaluatorOverrides:
"""Advanced overrides for Orbit evaluators.
Attributes:
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
"""
use_tf_function: bool = True
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs.""" """Implements the standard functionality of AbstractEvaluator APIs."""
...@@ -195,6 +234,12 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): ...@@ -195,6 +234,12 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
context" for generality, to allow e.g. multiple iterator dequeues and calls context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.run`. to `strategy.run`.
Note that if `use_tf_function=True`, all the code inside `eval_step` should
be tf.function compatible, as they will be traced with tf.function. This
means you cannot put arbitrary python code in this function. If users have
any numpy operations, they should be put in `eval_begin`, `eval_end` or
`eval_reduce` functions.
Args: Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or iterator: A tf.nest-compatible structure of tf.data Iterator or
DistributedIterator. DistributedIterator.
......
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools to convert a quantized deeplab model to tflite."""
from absl import app
from absl import flags
import numpy as np
from PIL import Image
import tensorflow as tf
flags.DEFINE_string('quantized_graph_def_path', None,
'Path to quantized graphdef.')
flags.DEFINE_string('output_tflite_path', None, 'Output TFlite model path.')
flags.DEFINE_string(
'input_tensor_name', None,
'Input tensor to TFlite model. This usually should be the input tensor to '
'model backbone.'
)
flags.DEFINE_string(
'output_tensor_name', 'ArgMax:0',
'Output tensor name of TFlite model. By default we output the raw semantic '
'label predictions.'
)
flags.DEFINE_string(
'test_image_path', None,
'Path to an image to test the consistency between input graphdef / '
'converted tflite model.'
)
FLAGS = flags.FLAGS
def convert_to_tflite(quantized_graphdef,
backbone_input_tensor,
output_tensor):
"""Helper method to convert quantized deeplab model to TFlite."""
with tf.Graph().as_default() as graph:
tf.graph_util.import_graph_def(quantized_graphdef, name='')
sess = tf.compat.v1.Session()
tflite_input = graph.get_tensor_by_name(backbone_input_tensor)
tflite_output = graph.get_tensor_by_name(output_tensor)
converter = tf.compat.v1.lite.TFLiteConverter.from_session(
sess, [tflite_input], [tflite_output])
converter.inference_type = tf.compat.v1.lite.constants.QUANTIZED_UINT8
input_arrays = converter.get_input_arrays()
converter.quantized_input_stats = {input_arrays[0]: (127.5, 127.5)}
return converter.convert()
def check_tflite_consistency(graph_def, tflite_model, image_path):
"""Runs tflite and frozen graph on same input, check their outputs match."""
# Load tflite model and check input size.
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
height, width = input_details[0]['shape'][1:3]
# Prepare input image data.
with tf.io.gfile.GFile(image_path, 'rb') as f:
image = Image.open(f)
image = np.asarray(image.convert('RGB').resize((width, height)))
image = np.expand_dims(image, 0)
# Output from tflite model.
interpreter.set_tensor(input_details[0]['index'], image)
interpreter.invoke()
output_tflite = interpreter.get_tensor(output_details[0]['index'])
with tf.Graph().as_default():
tf.graph_util.import_graph_def(graph_def, name='')
with tf.compat.v1.Session() as sess:
# Note here the graph will include preprocessing part of the graph
# (e.g. resize, pad, normalize). Given the input image size is at the
# crop size (backbone input size), resize / pad should be an identity op.
output_graph = sess.run(
FLAGS.output_tensor_name, feed_dict={'ImageTensor:0': image})
print('%.2f%% pixels have matched semantic labels.' % (
100 * np.mean(output_graph == output_tflite)))
def main(unused_argv):
with tf.io.gfile.GFile(FLAGS.quantized_graph_def_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef.FromString(f.read())
tflite_model = convert_to_tflite(
graph_def, FLAGS.input_tensor_name, FLAGS.output_tensor_name)
if FLAGS.output_tflite_path:
with tf.io.gfile.GFile(FLAGS.output_tflite_path, 'wb') as f:
f.write(tflite_model)
if FLAGS.test_image_path:
check_tflite_consistency(graph_def, tflite_model, FLAGS.test_image_path)
if __name__ == '__main__':
app.run(main)
...@@ -42,7 +42,6 @@ python deeplab/train.py \ ...@@ -42,7 +42,6 @@ python deeplab/train.py \
--train_batch_size=8 \ --train_batch_size=8 \
--base_learning_rate=3e-5 \ --base_learning_rate=3e-5 \
--dataset="pascal_voc_seg" \ --dataset="pascal_voc_seg" \
--initialize_last_layer \
--quantize_delay_step=0 \ --quantize_delay_step=0 \
--tf_initial_checkpoint=${PATH_TO_TRAINED_FLOAT_MODEL} \ --tf_initial_checkpoint=${PATH_TO_TRAINED_FLOAT_MODEL} \
--train_logdir=${PATH_TO_TRAIN_DIR} \ --train_logdir=${PATH_TO_TRAIN_DIR} \
...@@ -65,18 +64,12 @@ python deeplab/export_model.py \ ...@@ -65,18 +64,12 @@ python deeplab/export_model.py \
Commandline below shows how to convert exported graphdef to TFlite model. Commandline below shows how to convert exported graphdef to TFlite model.
``` ```
tflite_convert \ # From tensorflow/models/research/
--graph_def_file=${OUTPUT_DIR}/frozen_inference_graph.pb \ python deeplab/convert_to_tflite.py \
--output_file=${OUTPUT_DIR}/frozen_inference_graph.tflite \ --quantized_graph_def_path=${OUTPUT_DIR}/frozen_inference_graph.pb \
--output_format=TFLITE \ --input_tensor_name=MobilenetV2/MobilenetV2/input:0 \
--input_shape=1,513,513,3 \ --output_tflite_path=${OUTPUT_DIR}/frozen_inference_graph.tflite \
--input_arrays="MobilenetV2/MobilenetV2/input" \ --test_image_path=${PATH_TO_TEST_IMAGE}
--inference_type=QUANTIZED_UINT8 \
--inference_input_type=QUANTIZED_UINT8 \
--std_dev_values=128 \
--mean_values=128 \
--change_concat_input_ranges=true \
--output_arrays="ArgMax"
``` ```
**[Important]** Note that converted model expects 513x513 RGB input and doesn't **[Important]** Note that converted model expects 513x513 RGB input and doesn't
......
...@@ -1671,6 +1671,8 @@ class CenterNetMaskTargetAssigner(object): ...@@ -1671,6 +1671,8 @@ class CenterNetMaskTargetAssigner(object):
# Shape: [h, w, num_classes]. # Shape: [h, w, num_classes].
segmentations_for_image = tf.reduce_max( segmentations_for_image = tf.reduce_max(
gt_masks * gt_classes_reshaped, axis=0) gt_masks * gt_classes_reshaped, axis=0)
# Avoid the case where max of an empty array is -inf.
segmentations_for_image = tf.maximum(segmentations_for_image, 0.0)
segmentation_targets_list.append(segmentations_for_image) segmentation_targets_list.append(segmentations_for_image)
segmentation_target = tf.stack(segmentation_targets_list, axis=0) segmentation_target = tf.stack(segmentation_targets_list, axis=0)
......
...@@ -1905,6 +1905,22 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase): ...@@ -1905,6 +1905,22 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
expected_seg_target, segmentation_target) expected_seg_target, segmentation_target)
def test_assign_segmentation_targets_no_objects(self):
def graph_fn():
gt_masks_list = [tf.zeros((0, 5, 5))]
gt_classes_list = [tf.zeros((0, 10))]
cn_assigner = targetassigner.CenterNetMaskTargetAssigner(stride=1)
segmentation_target = cn_assigner.assign_segmentation_targets(
gt_masks_list=gt_masks_list,
gt_classes_list=gt_classes_list,
mask_resize_method=targetassigner.ResizeMethod.NEAREST_NEIGHBOR)
return segmentation_target
segmentation_target = self.execute(graph_fn, [])
expected_seg_target = np.zeros((1, 5, 5, 10))
np.testing.assert_array_almost_equal(
expected_seg_target, segmentation_target)
class CenterNetDensePoseTargetAssignerTest(test_case.TestCase): class CenterNetDensePoseTargetAssignerTest(test_case.TestCase):
......
...@@ -23,9 +23,9 @@ model configs in this [directory](../configs/tf2) (also in the linked ...@@ -23,9 +23,9 @@ model configs in this [directory](../configs/tf2) (also in the linked
Model name | Speed (ms) | COCO mAP | Outputs Model name | Speed (ms) | COCO mAP | Outputs
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------: | :----------: | :-----: --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------: | :----------: | :-----:
[CenterNet HourGlass104 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_hg104_512x512_coco17_tpu-8.tar.gz) | 70 | 41.6 | Boxes [CenterNet HourGlass104 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200713/centernet_hg104_512x512_coco17_tpu-8.tar.gz) | 70 | 41.9 | Boxes
[CenterNet HourGlass104 Keypoints 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_hg104_512x512_kpts_coco17_tpu-32.tar.gz) | 76 | 40.0/61.4 | Boxes/Keypoints [CenterNet HourGlass104 Keypoints 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_hg104_512x512_kpts_coco17_tpu-32.tar.gz) | 76 | 40.0/61.4 | Boxes/Keypoints
[CenterNet HourGlass104 1024x1024](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_hg104_1024x1024_coco17_tpu-32.tar.gz) | 197 | 43.5 | Boxes [CenterNet HourGlass104 1024x1024](http://download.tensorflow.org/models/object_detection/tf2/20200713/centernet_hg104_1024x1024_coco17_tpu-32.tar.gz) | 197 | 44.5 | Boxes
[CenterNet HourGlass104 Keypoints 1024x1024](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_hg104_1024x1024_kpts_coco17_tpu-32.tar.gz) | 211 | 42.8/64.5 | Boxes/Keypoints [CenterNet HourGlass104 Keypoints 1024x1024](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_hg104_1024x1024_kpts_coco17_tpu-32.tar.gz) | 211 | 42.8/64.5 | Boxes/Keypoints
[CenterNet Resnet50 V1 FPN 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_resnet50_v1_fpn_512x512_coco17_tpu-8.tar.gz) | 27 | 31.2 | Boxes [CenterNet Resnet50 V1 FPN 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_resnet50_v1_fpn_512x512_coco17_tpu-8.tar.gz) | 27 | 31.2 | Boxes
[CenterNet Resnet50 V1 FPN Keypoints 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_resnet50_v1_fpn_512x512_kpts_coco17_tpu-8.tar.gz) | 30 | 29.3/50.7 | Boxes/Keypoints [CenterNet Resnet50 V1 FPN Keypoints 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_resnet50_v1_fpn_512x512_kpts_coco17_tpu-8.tar.gz) | 30 | 29.3/50.7 | Boxes/Keypoints
......
...@@ -118,6 +118,19 @@ class CenterNetFeatureExtractor(tf.keras.Model): ...@@ -118,6 +118,19 @@ class CenterNetFeatureExtractor(tf.keras.Model):
"""Ther number of feature outputs returned by the feature extractor.""" """Ther number of feature outputs returned by the feature extractor."""
pass pass
@abc.abstractmethod
def get_sub_model(self, sub_model_type):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256, def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
bias_fill=None): bias_fill=None):
...@@ -2762,20 +2775,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2762,20 +2775,8 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint). A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
""" """
if fine_tune_checkpoint_type == 'classification': sub_model = self._feature_extractor.get_sub_model(fine_tune_checkpoint_type)
return {'feature_extractor': self._feature_extractor.get_base_model()} return {'feature_extractor': sub_model}
elif fine_tune_checkpoint_type == 'detection':
return {'feature_extractor': self._feature_extractor.get_model()}
elif fine_tune_checkpoint_type == 'fine_tune':
feature_extractor_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor)
return {'model': feature_extractor_model}
else:
raise ValueError('Not supported fine tune checkpoint type - {}'.format(
fine_tune_checkpoint_type))
def updates(self): def updates(self):
raise RuntimeError('This model is intended to be used with model_lib_v2 ' raise RuntimeError('This model is intended to be used with model_lib_v2 '
......
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
from __future__ import division from __future__ import division
import functools import functools
import re
import unittest import unittest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -1788,6 +1790,15 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase): ...@@ -1788,6 +1790,15 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self.assertIsInstance(restore_from_objects_map['feature_extractor'], self.assertIsInstance(restore_from_objects_map['feature_extractor'],
tf.keras.Model) tf.keras.Model)
def test_retore_map_error(self):
"""Test that restoring unsupported checkpoint type raises an error."""
model = build_center_net_meta_arch(build_resnet=True)
msg = ("Sub model detection is not defined for ResNet."
"Supported types are ['classification'].")
with self.assertRaisesRegex(ValueError, re.escape(msg)):
model.restore_from_objects('detection')
class DummyFeatureExtractor(cnma.CenterNetFeatureExtractor): class DummyFeatureExtractor(cnma.CenterNetFeatureExtractor):
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library functions for Context R-CNN."""
import tensorflow as tf
from object_detection.core import freezable_batch_norm
# The negative value used in padding the invalid weights.
_NEGATIVE_PADDING_VALUE = -100000
class ContextProjection(tf.keras.layers.Layer):
"""Custom layer to do batch normalization and projection."""
def __init__(self, projection_dimension, **kwargs):
self.batch_norm = freezable_batch_norm.FreezableBatchNorm(
epsilon=0.001,
center=True,
scale=True,
momentum=0.97,
trainable=True)
self.projection = tf.keras.layers.Dense(units=projection_dimension,
activation=tf.nn.relu6,
use_bias=True)
super(ContextProjection, self).__init__(**kwargs)
def build(self, input_shape):
self.batch_norm.build(input_shape)
self.projection.build(input_shape)
def call(self, input_features, is_training=False):
return self.projection(self.batch_norm(input_features, is_training))
class AttentionBlock(tf.keras.layers.Layer):
"""Custom layer to perform all attention."""
def __init__(self, bottleneck_dimension, attention_temperature,
output_dimension=None, is_training=False,
name='AttentionBlock', **kwargs):
"""Constructs an attention block.
Args:
bottleneck_dimension: A int32 Tensor representing the bottleneck dimension
for intermediate projections.
attention_temperature: A float Tensor. It controls the temperature of the
softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature))
output_dimension: A int32 Tensor representing the last dimension of the
output feature.
is_training: A boolean Tensor (affecting batch normalization).
name: A string describing what to name the variables in this block.
**kwargs: Additional keyword arguments.
"""
self._key_proj = ContextProjection(bottleneck_dimension)
self._val_proj = ContextProjection(bottleneck_dimension)
self._query_proj = ContextProjection(bottleneck_dimension)
self._feature_proj = None
self._attention_temperature = attention_temperature
self._bottleneck_dimension = bottleneck_dimension
self._is_training = is_training
self._output_dimension = output_dimension
if self._output_dimension:
self._feature_proj = ContextProjection(self._output_dimension)
super(AttentionBlock, self).__init__(name=name, **kwargs)
def build(self, input_shapes):
"""Finishes building the attention block.
Args:
input_shapes: the shape of the primary input box features.
"""
if not self._feature_proj:
self._output_dimension = input_shapes[-1]
self._feature_proj = ContextProjection(self._output_dimension)
def call(self, box_features, context_features, valid_context_size):
"""Handles a call by performing attention.
Args:
box_features: A float Tensor of shape [batch_size, input_size,
num_input_features].
context_features: A float Tensor of shape [batch_size, context_size,
num_context_features].
valid_context_size: A int32 Tensor of shape [batch_size].
Returns:
A float Tensor with shape [batch_size, input_size, num_input_features]
containing output features after attention with context features.
"""
_, context_size, _ = context_features.shape
valid_mask = compute_valid_mask(valid_context_size, context_size)
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
box_features = tf.reduce_mean(box_features, [2, 3])
queries = project_features(
box_features, self._bottleneck_dimension, self._is_training,
self._query_proj, normalize=True)
keys = project_features(
context_features, self._bottleneck_dimension, self._is_training,
self._key_proj, normalize=True)
values = project_features(
context_features, self._bottleneck_dimension, self._is_training,
self._val_proj, normalize=True)
weights = tf.matmul(queries, keys, transpose_b=True)
weights, values = filter_weight_value(weights, values, valid_mask)
weights = tf.nn.softmax(weights / self._attention_temperature)
features = tf.matmul(weights, values)
output_features = project_features(
features, self._output_dimension, self._is_training,
self._feature_proj, normalize=False)
output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
return output_features
def filter_weight_value(weights, values, valid_mask):
"""Filters weights and values based on valid_mask.
_NEGATIVE_PADDING_VALUE will be added to invalid elements in the weights to
avoid their contribution in softmax. 0 will be set for the invalid elements in
the values.
Args:
weights: A float Tensor of shape [batch_size, input_size, context_size].
values: A float Tensor of shape [batch_size, context_size,
projected_dimension].
valid_mask: A boolean Tensor of shape [batch_size, context_size]. True means
valid and False means invalid.
Returns:
weights: A float Tensor of shape [batch_size, input_size, context_size].
values: A float Tensor of shape [batch_size, context_size,
projected_dimension].
Raises:
ValueError: If shape of doesn't match.
"""
w_batch_size, _, w_context_size = weights.shape
v_batch_size, v_context_size, _ = values.shape
m_batch_size, m_context_size = valid_mask.shape
if w_batch_size != v_batch_size or v_batch_size != m_batch_size:
raise ValueError('Please make sure the first dimension of the input'
' tensors are the same.')
if w_context_size != v_context_size:
raise ValueError('Please make sure the third dimension of weights matches'
' the second dimension of values.')
if w_context_size != m_context_size:
raise ValueError('Please make sure the third dimension of the weights'
' matches the second dimension of the valid_mask.')
valid_mask = valid_mask[..., tf.newaxis]
# Force the invalid weights to be very negative so it won't contribute to
# the softmax.
weights += tf.transpose(
tf.cast(tf.math.logical_not(valid_mask), weights.dtype) *
_NEGATIVE_PADDING_VALUE,
perm=[0, 2, 1])
# Force the invalid values to be 0.
values *= tf.cast(valid_mask, values.dtype)
return weights, values
def project_features(features, bottleneck_dimension, is_training,
layer, normalize=True):
"""Projects features to another feature space.
Args:
features: A float Tensor of shape [batch_size, features_size,
num_features].
bottleneck_dimension: A int32 Tensor.
is_training: A boolean Tensor (affecting batch normalization).
layer: Contains a custom layer specific to the particular operation
being performed (key, value, query, features)
normalize: A boolean Tensor. If true, the output features will be l2
normalized on the last dimension.
Returns:
A float Tensor of shape [batch, features_size, projection_dimension].
"""
shape_arr = features.shape
batch_size, _, num_features = shape_arr
features = tf.reshape(features, [-1, num_features])
projected_features = layer(features, is_training)
projected_features = tf.reshape(projected_features,
[batch_size, -1, bottleneck_dimension])
if normalize:
projected_features = tf.keras.backend.l2_normalize(projected_features,
axis=-1)
return projected_features
def compute_valid_mask(num_valid_elements, num_elements):
"""Computes mask of valid entries within padded context feature.
Args:
num_valid_elements: A int32 Tensor of shape [batch_size].
num_elements: An int32 Tensor.
Returns:
A boolean Tensor of the shape [batch_size, num_elements]. True means
valid and False means invalid.
"""
batch_size = num_valid_elements.shape[0]
element_idxs = tf.range(num_elements, dtype=tf.int32)
batch_element_idxs = tf.tile(element_idxs[tf.newaxis, ...], [batch_size, 1])
num_valid_elements = num_valid_elements[..., tf.newaxis]
valid_mask = tf.less(batch_element_idxs, num_valid_elements)
return valid_mask
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for context_rcnn_lib."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from absl.testing import parameterized
import tensorflow.compat.v1 as tf
from object_detection.meta_architectures import context_rcnn_lib_tf2 as context_rcnn_lib
from object_detection.utils import test_case
from object_detection.utils import tf_version
_NEGATIVE_PADDING_VALUE = -100000
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
"""Tests for the functions in context_rcnn_lib."""
def test_compute_valid_mask(self):
num_elements = tf.constant(3, tf.int32)
num_valid_elementss = tf.constant((1, 2), tf.int32)
valid_mask = context_rcnn_lib.compute_valid_mask(num_valid_elementss,
num_elements)
expected_valid_mask = tf.constant([[1, 0, 0], [1, 1, 0]], tf.float32)
self.assertAllEqual(valid_mask, expected_valid_mask)
def test_filter_weight_value(self):
weights = tf.ones((2, 3, 2), tf.float32) * 4
values = tf.ones((2, 2, 4), tf.float32)
valid_mask = tf.constant([[True, True], [True, False]], tf.bool)
filtered_weights, filtered_values = context_rcnn_lib.filter_weight_value(
weights, values, valid_mask)
expected_weights = tf.constant([[[4, 4], [4, 4], [4, 4]],
[[4, _NEGATIVE_PADDING_VALUE + 4],
[4, _NEGATIVE_PADDING_VALUE + 4],
[4, _NEGATIVE_PADDING_VALUE + 4]]])
expected_values = tf.constant([[[1, 1, 1, 1], [1, 1, 1, 1]],
[[1, 1, 1, 1], [0, 0, 0, 0]]])
self.assertAllEqual(filtered_weights, expected_weights)
self.assertAllEqual(filtered_values, expected_values)
# Changes the valid_mask so the results will be different.
valid_mask = tf.constant([[True, True], [False, False]], tf.bool)
filtered_weights, filtered_values = context_rcnn_lib.filter_weight_value(
weights, values, valid_mask)
expected_weights = tf.constant(
[[[4, 4], [4, 4], [4, 4]],
[[_NEGATIVE_PADDING_VALUE + 4, _NEGATIVE_PADDING_VALUE + 4],
[_NEGATIVE_PADDING_VALUE + 4, _NEGATIVE_PADDING_VALUE + 4],
[_NEGATIVE_PADDING_VALUE + 4, _NEGATIVE_PADDING_VALUE + 4]]])
expected_values = tf.constant([[[1, 1, 1, 1], [1, 1, 1, 1]],
[[0, 0, 0, 0], [0, 0, 0, 0]]])
self.assertAllEqual(filtered_weights, expected_weights)
self.assertAllEqual(filtered_values, expected_values)
@parameterized.parameters((2, True, True), (2, False, True),
(10, True, False), (10, False, False))
def test_project_features(self, projection_dimension, is_training, normalize):
features = tf.ones([2, 3, 4], tf.float32)
projected_features = context_rcnn_lib.project_features(
features,
projection_dimension,
is_training,
context_rcnn_lib.ContextProjection(projection_dimension),
normalize=normalize)
# Makes sure the shape is correct.
self.assertAllEqual(projected_features.shape, [2, 3, projection_dimension])
@parameterized.parameters(
(2, 10, 1),
(3, 10, 2),
(4, None, 3),
(5, 20, 4),
(7, None, 5),
)
def test_attention_block(self, bottleneck_dimension, output_dimension,
attention_temperature):
input_features = tf.ones([2, 8, 3, 3, 3], tf.float32)
context_features = tf.ones([2, 20, 10], tf.float32)
attention_block = context_rcnn_lib.AttentionBlock(
bottleneck_dimension,
attention_temperature,
output_dimension=output_dimension,
is_training=False)
valid_context_size = tf.random_uniform((2,),
minval=0,
maxval=10,
dtype=tf.int32)
output_features = attention_block(input_features, context_features,
valid_context_size)
# Makes sure the shape is correct.
self.assertAllEqual(output_features.shape,
[2, 8, 1, 1, (output_dimension or 3)])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment