Commit 9725a407 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Introducing SpineNet backbone to TF2.

PiperOrigin-RevId: 318945656
parent 91e2171b
...@@ -19,9 +19,10 @@ In the near future, we will add: ...@@ -19,9 +19,10 @@ In the near future, we will add:
* State-of-the-art language understanding models: * State-of-the-art language understanding models:
More members in Transformer family More members in Transformer family
* Start-of-the-art image classification models: * State-of-the-art image classification models:
EfficientNet, MnasNet, and variants EfficientNet, MnasNet, and variants
* A set of excellent objection detection models. * State-of-the-art objection detection and instance segmentation models:
RetinaNet, Mask R-CNN, SpineNet, and variants
## Table of Contents ## Table of Contents
...@@ -52,6 +53,7 @@ In the near future, we will add: ...@@ -52,6 +53,7 @@ In the near future, we will add:
| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) | | [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | | [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
| [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) | | [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
| [SpineNet](vision/detection) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
### Natural Language Processing ### Natural Language Processing
......
...@@ -271,6 +271,23 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -271,6 +271,23 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
FLAGS.strategy_type = 'tpu' FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0) self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_spinenet_coco(self):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['architecture']['backbone'] = 'spinenet'
params['architecture']['multilevel_features'] = 'identity'
params['architecture']['use_bfloat16'] = False
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
params['train']['checkpoint']['path'] = ''
FLAGS.model_dir = self._get_model_dir(
'real_benchmark_2x2_tpu_spinenet_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -48,6 +48,22 @@ so the checkpoints are not compatible. ...@@ -48,6 +48,22 @@ so the checkpoints are not compatible.
We will unify the implementation soon. We will unify the implementation soon.
### Train a SpineNet-49 based RetinaNet.
```bash
TPU_NAME="<your GCP TPU name>"
MODEL_DIR="<path to the directory to store model files>"
TRAIN_FILE_PATTERN="<path to the TFRecord training data>"
EVAL_FILE_PATTERN="<path to the TFRecord validation data>"
VAL_JSON_FILE="<path to the validation annotation JSON file>"
python3 ~/models/official/vision/detection/main.py \
--strategy_type=tpu \
--tpu="${TPU_NAME?}" \
--model_dir="${MODEL_DIR?}" \
--mode=train \
--params_override="{ type: retinanet, architecture: {backbone: spinenet, multilevel_features: identity}, spinenet: {model_id: 49}, train_file_pattern: ${TRAIN_FILE_PATTERN?} }, eval: { val_json_file: ${VAL_JSON_FILE?}, eval_file_pattern: ${EVAL_FILE_PATTERN?} } }"
```
### Train a custom RetinaNet using the config file. ### Train a custom RetinaNet using the config file.
...@@ -163,6 +179,24 @@ so the checkpoints are not compatible. ...@@ -163,6 +179,24 @@ so the checkpoints are not compatible.
We will unify the implementation soon. We will unify the implementation soon.
### Train a SpineNet-49 based Mask R-CNN.
```bash
TPU_NAME="<your GCP TPU name>"
MODEL_DIR="<path to the directory to store model files>"
TRAIN_FILE_PATTERN="<path to the TFRecord training data>"
EVAL_FILE_PATTERN="<path to the TFRecord validation data>"
VAL_JSON_FILE="<path to the validation annotation JSON file>"
python3 ~/models/official/vision/detection/main.py \
--strategy_type=tpu \
--tpu="${TPU_NAME?}" \
--model_dir="${MODEL_DIR?}" \
--mode=train \
--model=mask_rcnn \
--params_override="{architecture: {backbone: spinenet, multilevel_features: identity}, spinenet: {model_id: 49}, train_file_pattern: ${TRAIN_FILE_PATTERN?} }, eval: { val_json_file: ${VAL_JSON_FILE?}, eval_file_pattern: ${EVAL_FILE_PATTERN?} } }"
```
### Train a custom Mask R-CNN using the config file. ### Train a custom Mask R-CNN using the config file.
First, create a YAML config file, e.g. *my_maskrcnn.yaml*. First, create a YAML config file, e.g. *my_maskrcnn.yaml*.
......
...@@ -17,10 +17,12 @@ ...@@ -17,10 +17,12 @@
BACKBONES = [ BACKBONES = [
'resnet', 'resnet',
'spinenet',
] ]
MULTILEVEL_FEATURES = [ MULTILEVEL_FEATURES = [
'fpn', 'fpn',
'identity',
] ]
# pylint: disable=line-too-long # pylint: disable=line-too-long
...@@ -118,6 +120,9 @@ BASE_CFG = { ...@@ -118,6 +120,9 @@ BASE_CFG = {
'resnet': { 'resnet': {
'resnet_depth': 50, 'resnet_depth': 50,
}, },
'spinenet': {
'model_id': '49',
},
'fpn': { 'fpn': {
'fpn_feat_dims': 256, 'fpn_feat_dims': 256,
'use_separable_conv': False, 'use_separable_conv': False,
......
...@@ -23,6 +23,7 @@ from official.vision.detection.modeling.architecture import heads ...@@ -23,6 +23,7 @@ from official.vision.detection.modeling.architecture import heads
from official.vision.detection.modeling.architecture import identity from official.vision.detection.modeling.architecture import identity
from official.vision.detection.modeling.architecture import nn_ops from official.vision.detection.modeling.architecture import nn_ops
from official.vision.detection.modeling.architecture import resnet from official.vision.detection.modeling.architecture import resnet
from official.vision.detection.modeling.architecture import spinenet
def norm_activation_generator(params): def norm_activation_generator(params):
...@@ -42,6 +43,9 @@ def backbone_generator(params): ...@@ -42,6 +43,9 @@ def backbone_generator(params):
activation=params.norm_activation.activation, activation=params.norm_activation.activation,
norm_activation=norm_activation_generator( norm_activation=norm_activation_generator(
params.norm_activation)) params.norm_activation))
elif params.architecture.backbone == 'spinenet':
spinenet_params = params.spinenet
backbone_fn = spinenet.SpineNetBuilder(model_id=spinenet_params.model_id)
else: else:
raise ValueError('Backbone model `{}` is not supported.' raise ValueError('Backbone model `{}` is not supported.'
.format(params.architecture.backbone)) .format(params.architecture.backbone))
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common building blocks for neural networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Vision')
class ResidualBlock(tf.keras.layers.Layer):
"""A residual block."""
def __init__(self,
filters,
strides,
use_projection=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A residual block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(ResidualBlock, self).__init__(**kwargs)
self._filters = filters
self._strides = strides
self._use_projection = use_projection
self._use_sync_bn = use_sync_bn
self._activation = activation
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(ResidualBlock, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ResidualBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
return self._activation_fn(x + shortcut)
@tf.keras.utils.register_keras_serializable(package='Vision')
class BottleneckBlock(tf.keras.layers.Layer):
"""A standard bottleneck block."""
def __init__(self,
filters,
strides,
use_projection=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A standard bottleneck block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(BottleneckBlock, self).__init__(**kwargs)
self._filters = filters
self._strides = strides
self._use_projection = use_projection
self._use_sync_bn = use_sync_bn
self._activation = activation
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv3 = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm3 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(BottleneckBlock, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(BottleneckBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
x = self._activation_fn(x)
x = self._conv3(x)
x = self._norm3(x)
return self._activation_fn(x + shortcut)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of SpineNet model.
X. Du, T-Y. Lin, P. Jin, G. Ghiasi, M. Tan, Y. Cui, Q. V. Le, X. Song
SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization
https://arxiv.org/abs/1912.05027
"""
import math
from absl import logging
import tensorflow as tf
from tensorflow.python.keras import backend
from official.modeling import tf_utils
from official.vision.detection.modeling.architecture import nn_blocks
layers = tf.keras.layers
FILTER_SIZE_MAP = {
1: 32,
2: 64,
3: 128,
4: 256,
5: 256,
6: 256,
7: 256,
}
# The fixed SpineNet architecture discovered by NAS.
# Each element represents a specification of a building block:
# (block_level, block_fn, (input_offset0, input_offset1), is_output).
SPINENET_BLOCK_SPECS = [
(2, 'bottleneck', (0, 1), False),
(4, 'residual', (0, 1), False),
(3, 'bottleneck', (2, 3), False),
(4, 'bottleneck', (2, 4), False),
(6, 'residual', (3, 5), False),
(4, 'bottleneck', (3, 5), False),
(5, 'residual', (6, 7), False),
(7, 'residual', (6, 8), False),
(5, 'bottleneck', (8, 9), False),
(5, 'bottleneck', (8, 10), False),
(4, 'bottleneck', (5, 10), True),
(3, 'bottleneck', (4, 10), True),
(5, 'bottleneck', (7, 12), True),
(7, 'bottleneck', (5, 14), True),
(6, 'bottleneck', (12, 14), True),
]
SCALING_MAP = {
'49S': {
'endpoints_num_filters': 128,
'filter_size_scale': 0.65,
'resample_alpha': 0.5,
'block_repeats': 1,
},
'49': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 0.5,
'block_repeats': 1,
},
'96': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 0.5,
'block_repeats': 2,
},
'143': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 1.0,
'block_repeats': 3,
},
'190': {
'endpoints_num_filters': 512,
'filter_size_scale': 1.3,
'resample_alpha': 1.0,
'block_repeats': 4,
},
}
class BlockSpec(object):
"""A container class that specifies the block configuration for SpineNet."""
def __init__(self, level, block_fn, input_offsets, is_output):
self.level = level
self.block_fn = block_fn
self.input_offsets = input_offsets
self.is_output = is_output
def build_block_specs(block_specs=None):
"""Builds the list of BlockSpec objects for SpineNet."""
if not block_specs:
block_specs = SPINENET_BLOCK_SPECS
logging.info('Building SpineNet block specs: %s', block_specs)
return [BlockSpec(*b) for b in block_specs]
@tf.keras.utils.register_keras_serializable(package='Vision')
class SpineNet(tf.keras.Model):
"""Class to build SpineNet models."""
def __init__(self,
input_specs=tf.keras.layers.InputSpec(shape=[None, 640, 640, 3]),
min_level=3,
max_level=7,
block_specs=build_block_specs(),
endpoints_num_filters=256,
resample_alpha=0.5,
block_repeats=1,
filter_size_scale=1.0,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""SpineNet model."""
self._min_level = min_level
self._max_level = max_level
self._block_specs = block_specs
self._endpoints_num_filters = endpoints_num_filters
self._resample_alpha = resample_alpha
self._block_repeats = block_repeats
self._filter_size_scale = filter_size_scale
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if activation == 'relu':
self._activation = tf.nn.relu
elif activation == 'swish':
self._activation = tf.nn.swish
else:
raise ValueError('Activation {} not implemented.'.format(activation))
self._init_block_fn = 'bottleneck'
self._num_init_blocks = 2
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
# Build SpineNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:])
net = self._build_stem(inputs=inputs)
net = self._build_scale_permuted_network(
net=net, input_width=input_specs.shape[1])
net = self._build_endpoints(net=net)
super(SpineNet, self).__init__(inputs=inputs, outputs=net)
def _block_group(self,
inputs,
filters,
strides,
block_fn_cand,
block_repeats=1,
name='block_group'):
"""Creates one group of blocks for the SpineNet model."""
block_fn_candidates = {
'bottleneck': nn_blocks.BottleneckBlock,
'residual': nn_blocks.ResidualBlock,
}
block_fn = block_fn_candidates[block_fn_cand]
_, _, _, num_filters = inputs.get_shape().as_list()
if block_fn_cand == 'bottleneck':
use_projection = not (num_filters == (filters * 4) and strides == 1)
else:
use_projection = not (num_filters == filters and strides == 1)
x = block_fn(
filters=filters,
strides=strides,
use_projection=use_projection,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
for _ in range(1, block_repeats):
x = block_fn(
filters=filters,
strides=1,
use_projection=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
x)
return tf.identity(x, name=name)
def _build_stem(self, inputs):
"""Build SpineNet stem."""
x = layers.Conv2D(
filters=64,
kernel_size=7,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
net = []
# Build the initial level 2 blocks.
for i in range(self._num_init_blocks):
x = self._block_group(
inputs=x,
filters=int(FILTER_SIZE_MAP[2] * self._filter_size_scale),
strides=1,
block_fn_cand=self._init_block_fn,
block_repeats=self._block_repeats,
name='stem_block_{}'.format(i + 1))
net.append(x)
return net
def _build_scale_permuted_network(self,
net,
input_width,
weighted_fusion=False):
"""Build scale-permuted network."""
net_sizes = [int(math.ceil(input_width / 2**2))] * len(net)
net_block_fns = [self._init_block_fn] * len(net)
num_outgoing_connections = [0] * len(net)
endpoints = {}
for i, block_spec in enumerate(self._block_specs):
# Find out specs for the target block.
target_width = int(math.ceil(input_width / 2**block_spec.level))
target_num_filters = int(FILTER_SIZE_MAP[block_spec.level] *
self._filter_size_scale)
target_block_fn = block_spec.block_fn
# Resample then merge input0 and input1.
parents = []
input0 = block_spec.input_offsets[0]
input1 = block_spec.input_offsets[1]
x0 = self._resample_with_alpha(
inputs=net[input0],
input_width=net_sizes[input0],
input_block_fn=net_block_fns[input0],
target_width=target_width,
target_num_filters=target_num_filters,
target_block_fn=target_block_fn,
alpha=self._resample_alpha)
parents.append(x0)
num_outgoing_connections[input0] += 1
x1 = self._resample_with_alpha(
inputs=net[input1],
input_width=net_sizes[input1],
input_block_fn=net_block_fns[input1],
target_width=target_width,
target_num_filters=target_num_filters,
target_block_fn=target_block_fn,
alpha=self._resample_alpha)
parents.append(x1)
num_outgoing_connections[input1] += 1
# Merge 0 outdegree blocks to the output block.
if block_spec.is_output:
for j, (j_feat,
j_connections) in enumerate(zip(net, num_outgoing_connections)):
if j_connections == 0 and (j_feat.shape[2] == target_width and
j_feat.shape[3] == x0.shape[3]):
parents.append(j_feat)
num_outgoing_connections[j] += 1
# pylint: disable=g-direct-tensorflow-import
if weighted_fusion:
dtype = parents[0].dtype
parent_weights = [
tf.nn.relu(tf.cast(tf.Variable(1.0, name='block{}_fusion{}'.format(
i, j)), dtype=dtype)) for j in range(len(parents))]
weights_sum = tf.add_n(parent_weights)
parents = [
parents[i] * parent_weights[i] / (weights_sum + 0.0001)
for i in range(len(parents))
]
# Fuse all parent nodes then build a new block.
x = tf_utils.get_activation(self._activation)(tf.add_n(parents))
x = self._block_group(
inputs=x,
filters=target_num_filters,
strides=1,
block_fn_cand=target_block_fn,
block_repeats=self._block_repeats,
name='scale_permuted_block_{}'.format(i + 1))
net.append(x)
net_sizes.append(target_width)
net_block_fns.append(target_block_fn)
num_outgoing_connections.append(0)
# Save output feats.
if block_spec.is_output:
if block_spec.level in endpoints:
raise ValueError('Duplicate feats found for output level {}.'.format(
block_spec.level))
if (block_spec.level < self._min_level or
block_spec.level > self._max_level):
raise ValueError('Output level is out of range [{}, {}]'.format(
self._min_level, self._max_level))
endpoints[block_spec.level] = x
return endpoints
def _build_endpoints(self, net):
"""Match filter size for endpoints before sharing conv layers."""
endpoints = {}
for level in range(self._min_level, self._max_level + 1):
x = layers.Conv2D(
filters=self._endpoints_num_filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
net[level])
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
endpoints[level] = x
return endpoints
def _resample_with_alpha(self,
inputs,
input_width,
input_block_fn,
target_width,
target_num_filters,
target_block_fn,
alpha=0.5):
"""Match resolution and feature dimension."""
_, _, _, input_num_filters = inputs.get_shape().as_list()
if input_block_fn == 'bottleneck':
input_num_filters /= 4
new_num_filters = int(input_num_filters * alpha)
x = layers.Conv2D(
filters=new_num_filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
# Spatial resampling.
if input_width > target_width:
x = layers.Conv2D(
filters=new_num_filters,
kernel_size=3,
strides=2,
padding='SAME',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
input_width /= 2
while input_width > target_width:
x = layers.MaxPool2D(pool_size=3, strides=2, padding='SAME')(x)
input_width /= 2
elif input_width < target_width:
scale = target_width // input_width
x = layers.UpSampling2D(size=(scale, scale))(x)
# Last 1x1 conv to match filter size.
if target_block_fn == 'bottleneck':
target_num_filters *= 4
x = layers.Conv2D(
filters=target_num_filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
return x
class SpineNetBuilder(object):
"""SpineNet builder."""
def __init__(self,
model_id,
input_specs=tf.keras.layers.InputSpec(shape=[None, 640, 640, 3]),
min_level=3,
max_level=7,
block_specs=build_block_specs(),
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001):
if model_id not in SCALING_MAP:
raise ValueError(
'SpineNet {} is not a valid architecture.'.format(model_id))
scaling_params = SCALING_MAP[model_id]
self._input_specs = input_specs
self._min_level = min_level
self._max_level = max_level
self._block_specs = block_specs
self._endpoints_num_filters = scaling_params['endpoints_num_filters']
self._resample_alpha = scaling_params['resample_alpha']
self._block_repeats = scaling_params['block_repeats']
self._filter_size_scale = scaling_params['filter_size_scale']
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._activation = activation
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
def __call__(self, inputs, is_training=None):
with backend.get_graph().as_default():
model = SpineNet(
input_specs=self._input_specs,
min_level=self._min_level,
max_level=self._max_level,
block_specs=self._block_specs,
endpoints_num_filters=self._endpoints_num_filters,
resample_alpha=self._resample_alpha,
block_repeats=self._block_repeats,
filter_size_scale=self._filter_size_scale,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)
return model(inputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment