Commit ac8d0651 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Add SpaghettiNet Feature Extractor

PiperOrigin-RevId: 402944074
parent fccb57b1
...@@ -73,6 +73,23 @@ documentation of the Object Detection API: ...@@ -73,6 +73,23 @@ documentation of the Object Detection API:
## Whats New ## Whats New
### SpaghettiNet for Edge TPU
We have released SpaghettiNet models optimized for the Edge TPU in the [Google Tensor SoC](https://blog.google/products/pixel/google-tensor-debuts-new-pixel-6-fall/).
SpaghettiNet models are derived from a TuNAS search space that incorporates
group convolution based [Inverted Bottleneck](https://arxiv.org/abs/1801.04381) blocks.
The backbone and detection head are connected through [MnasFPN](https://arxiv.org/abs/1912.01106)-style feature map
merging and searched jointly.
When compared to MobileDet-EdgeTPU, SpaghettiNet models achieve +2.2% mAP
(absolute) on COCO17 at the same latency. They also consume <70% of the energy
used by MobileDet-EdgeTPU to achieve the same accuracy.
Sample config available [here](configs/tf1/ssd_spaghettinet_edgetpu_320x320_coco17_sync_4x4.config).
<b>Thanks to contributors</b>: Marie White, Hao Xu, Hanxiao Liu and Suyog Gupta.
### DeepMAC architecture ### DeepMAC architecture
We have released our new architecture, **DeepMAC**, designed for partially We have released our new architecture, **DeepMAC**, designed for partially
......
...@@ -93,6 +93,7 @@ if tf_version.is_tf1(): ...@@ -93,6 +93,7 @@ if tf_version.is_tf1():
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetDSPFeatureExtractor from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetDSPFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetEdgeTPUFeatureExtractor from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetEdgeTPUFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetGPUFeatureExtractor from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetGPUFeatureExtractor
from object_detection.models.ssd_spaghettinet_feature_extractor import SSDSpaghettinetFeatureExtractor
from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor
from object_detection.predictors import rfcn_box_predictor from object_detection.predictors import rfcn_box_predictor
# pylint: enable=g-import-not-at-top # pylint: enable=g-import-not-at-top
...@@ -229,6 +230,8 @@ if tf_version.is_tf1(): ...@@ -229,6 +230,8 @@ if tf_version.is_tf1():
SSDMobileDetEdgeTPUFeatureExtractor, SSDMobileDetEdgeTPUFeatureExtractor,
'ssd_mobiledet_gpu': 'ssd_mobiledet_gpu':
SSDMobileDetGPUFeatureExtractor, SSDMobileDetGPUFeatureExtractor,
'ssd_spaghettinet':
SSDSpaghettinetFeatureExtractor,
} }
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = { FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
...@@ -350,6 +353,12 @@ def _build_ssd_feature_extractor(feature_extractor_config, ...@@ -350,6 +353,12 @@ def _build_ssd_feature_extractor(feature_extractor_config,
}) })
if feature_extractor_config.HasField('spaghettinet_arch_name'):
kwargs.update({
'spaghettinet_arch_name':
feature_extractor_config.spaghettinet_arch_name,
})
if feature_extractor_config.HasField('fpn'): if feature_extractor_config.HasField('fpn'):
kwargs.update({ kwargs.update({
'fpn_min_level': 'fpn_min_level':
......
# SpaghettiNet Feature Extractor optimized for EdgeTPU.
# Trained on COCO17 from scratch.
#
# spaghettinet_edgetpu_s
# Achieves 26.2% mAP on COCO17 at 400k steps.
# 1.31ms Edge TPU latency at 1 billion MACs, 3.4 million params.
#
# spaghettinet_edgetpu_m
# Achieves 27.4% mAP on COCO17 at 400k steps.
# 1.55ms Edge TPU latency at 1.25 billion MACs, 4.1 million params.
#
# spaghettinet_edgetpu_l
# Achieves 28.02% mAP on COCO17 at 400k steps.
# 1.75ms Edge TPU latency at 1.57 billion MACs, 5.7 million params.
#
# TPU-compatible.
model {
ssd {
inplace_batchnorm_update: true
freeze_batchnorm: false
num_classes: 90
box_coder {
faster_rcnn_box_coder {
y_scale: 10.0
x_scale: 10.0
height_scale: 5.0
width_scale: 5.0
}
}
matcher {
argmax_matcher {
matched_threshold: 0.5
unmatched_threshold: 0.5
ignore_thresholds: false
negatives_lower_than_unmatched: true
force_match_for_each_row: true
use_matmul_gather: true
}
}
similarity_calculator {
iou_similarity {
}
}
encode_background_as_zeros: true
anchor_generator {
ssd_anchor_generator {
num_layers: 5
min_scale: 0.2
max_scale: 0.95
aspect_ratios: 1.0
aspect_ratios: 2.0
aspect_ratios: 0.5
aspect_ratios: 3.0
aspect_ratios: 0.3333333
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
convolutional_box_predictor {
min_depth: 0
max_depth: 0
num_layers_before_predictor: 0
use_dropout: false
dropout_keep_probability: 0.8
kernel_size: 3
use_depthwise: true
box_code_size: 4
apply_sigmoid_to_scores: false
class_prediction_bias_init: -4.6
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
random_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.97,
epsilon: 0.001,
}
}
}
}
feature_extractor {
type: 'ssd_spaghettinet'
# 3 architectures are supported and performance for each is listed at the top of this config file.
#spaghettinet_arch_name: 'spaghettinet_edgetpu_s'
spaghettinet_arch_name: 'spaghettinet_edgetpu_m'
#spaghettinet_arch_name: 'spaghettinet_edgetpu_l'
use_explicit_padding: false
}
loss {
classification_loss {
weighted_sigmoid_focal {
alpha: 0.75,
gamma: 2.0
}
}
localization_loss {
weighted_smooth_l1 {
delta: 1.0
}
}
classification_weight: 1.0
localization_weight: 1.0
}
normalize_loss_by_num_matches: true
normalize_loc_loss_by_codesize: true
post_processing {
batch_non_max_suppression {
score_threshold: 1e-8
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 100
use_static_shapes: true
}
score_converter: SIGMOID
}
}
}
train_config: {
batch_size: 512
sync_replicas: true
startup_delay_steps: 0
replicas_to_aggregate: 32
num_steps: 400000
data_augmentation_options {
random_horizontal_flip {
}
}
data_augmentation_options {
ssd_random_crop {
}
}
optimizer {
momentum_optimizer: {
learning_rate: {
cosine_decay_learning_rate {
learning_rate_base: 0.8
total_steps: 400000
warmup_learning_rate: 0.13333
warmup_steps: 2000
}
}
momentum_optimizer_value: 0.9
}
use_moving_average: false
}
max_number_of_boxes: 100
unpad_groundtruth_tensors: false
}
train_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/label_map.txt"
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/train2017-?????-of-00256.tfrecord"
}
}
eval_config: {
metrics_set: "coco_detection_metrics"
use_moving_averages: false
}
eval_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/label_map.txt"
shuffle: false
num_epochs: 1
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/val2017-?????-of-00032.tfrecord"
}
}
graph_rewriter {
quantization {
delay: 40000
weight_bits: 8
activation_bits: 8
}
}
...@@ -173,10 +173,19 @@ Model name ...@@ -173,10 +173,19 @@ Model name
[faster_rcnn_resnet101_snapshot_serengeti](http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_snapshot_serengeti_2020_06_10.tar.gz) | 38 | Boxes [faster_rcnn_resnet101_snapshot_serengeti](http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_snapshot_serengeti_2020_06_10.tar.gz) | 38 | Boxes
[context_rcnn_resnet101_snapshot_serengeti](http://download.tensorflow.org/models/object_detection/context_rcnn_resnet101_snapshot_serengeti_2020_06_10.tar.gz) | 56 | Boxes [context_rcnn_resnet101_snapshot_serengeti](http://download.tensorflow.org/models/object_detection/context_rcnn_resnet101_snapshot_serengeti_2020_06_10.tar.gz) | 56 | Boxes
## Pixel 6 Edge TPU models
Model name | Pixel 6 Edge TPU Speed (ms) | Pixel 6 Speed with Post-processing on CPU (ms) | COCO 2017 mAP (uint8) | Outputs
----------------------------------------------------------------------------------------------------------------------------- | :-------------------------: | :--------------------------------------------: | :-------------------: | :-----:
[spaghettinet_edgetpu_s](http://download.tensorflow.org/models/object_detection/tf1/spaghettinet_edgetpu_s_2021_10_13.tar.gz) | 1.3 | 1.8 | 26.3 | Boxes
[spaghettinet_edgetpu_m](http://download.tensorflow.org/models/object_detection/tf1/spaghettinet_edgetpu_m_2021_10_13.tar.gz) | 1.4 | 1.9 | 27.4 | Boxes
[spaghettinet_edgetpu_l](http://download.tensorflow.org/models/object_detection/tf1/spaghettinet_edgetpu_l_2021_10_13.tar.gz) | 1.7 | 2.1 | 28.0 | Boxes
[^1]: See [MSCOCO evaluation protocol](http://cocodataset.org/#detections-eval). [^1]: See [MSCOCO evaluation protocol](http://cocodataset.org/#detections-eval).
The COCO mAP numbers here are evaluated on COCO 14 minival set (note that The COCO mAP numbers, with the exception of the Pixel 6 Edge TPU models,
our split is different from COCO 17 Val). A full list of image ids used in are evaluated on COCO 14 minival set (note that our split is different
our split could be fould from COCO 17 Val). A full list of image ids used in our split could be
found
[here](https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_minival_ids.txt). [here](https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_minival_ids.txt).
[^2]: This is PASCAL mAP with a slightly different way of true positives [^2]: This is PASCAL mAP with a slightly different way of true positives
computation: see computation: see
......
"""SpaghettiNet Feature Extractor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import tensorflow.compat.v1 as tf
import tf_slim as slim
from tensorflow.python.training import moving_averages
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.utils import ops
from object_detection.utils import shape_utils
IbnOp = collections.namedtuple(
'IbnOp', ['kernel_size', 'expansion_rate', 'stride', 'has_residual'])
SepConvOp = collections.namedtuple('SepConvOp',
['kernel_size', 'stride', 'has_residual'])
IbnFusedGrouped = collections.namedtuple(
'IbnFusedGrouped',
['kernel_size', 'expansion_rate', 'stride', 'groups', 'has_residual'])
SpaghettiStemNode = collections.namedtuple('SpaghettiStemNode',
['kernel_size', 'num_filters'])
SpaghettiNode = collections.namedtuple(
'SpaghettiNode', ['layers', 'num_filters', 'edges', 'level'])
SpaghettiResampleEdge = collections.namedtuple('SpaghettiResampleEdge',
['input'])
SpaghettiPassthroughEdge = collections.namedtuple('SpaghettiPassthroughEdge',
['input'])
SpaghettiNodeSpecs = collections.namedtuple('SpaghettiNodeSpecs',
['nodes', 'outputs'])
class SpaghettiNet():
"""SpaghettiNet."""
def __init__(self,
node_specs,
is_training=False,
use_native_resize_op=False,
use_explicit_padding=False,
activation_fn=tf.nn.relu6,
normalization_fn=slim.batch_norm,
name='spaghetti_node'):
self._node_specs = node_specs
self._is_training = is_training
self._use_native_resize_op = use_native_resize_op
self._use_explicit_padding = use_explicit_padding
self._activation_fn = activation_fn
self._normalization_fn = normalization_fn
self._name = name
self._nodes = {}
def _quant_var(self,
name,
initializer_val,
vars_collection=tf.GraphKeys.MOVING_AVERAGE_VARIABLES):
"""Create an var for storing the min/max quantization range."""
return slim.model_variable(
name,
shape=[],
initializer=tf.constant_initializer(initializer_val),
collections=[vars_collection],
trainable=False)
def _quantizable_concat(self,
inputs,
axis,
is_training,
is_quantized=True,
default_min=0,
default_max=6,
ema_decay=0.999,
scope='quantized_concat'):
"""Concat replacement with quantization option.
Allows concat inputs to share the same min max ranges,
from experimental/gazelle/synthetic/model/tpu/utils.py.
Args:
inputs: list of tensors to concatenate.
axis: dimension along which to concatenate.
is_training: true if the graph is a training graph.
is_quantized: flag to enable/disable quantization.
default_min: default min value for fake quant op.
default_max: default max value for fake quant op.
ema_decay: the moving average decay for the quantization variables.
scope: Optional scope for variable_scope.
Returns:
Tensor resulting from concatenation of input tensors
"""
if is_quantized:
with tf.variable_scope(scope):
min_var = self._quant_var('min', default_min)
max_var = self._quant_var('max', default_max)
if not is_training:
# If we are building an eval graph just use the values in the
# variables.
quant_inputs = [
tf.fake_quant_with_min_max_vars(t, min_var, max_var)
for t in inputs
]
else:
concat_tensors = tf.concat(inputs, axis=axis)
tf.logging.info('concat_tensors: {}'.format(concat_tensors))
# TFLite requires that 0.0 is always in the [min; max] range.
range_min = tf.minimum(
tf.reduce_min(concat_tensors), 0.0, name='SafeQuantRangeMin')
range_max = tf.maximum(
tf.reduce_max(concat_tensors), 0.0, name='SafeQuantRangeMax')
# Otherwise we need to keep track of the moving averages of the min
# and of the elements of the input tensor max.
min_val = moving_averages.assign_moving_average(
min_var, range_min, ema_decay, name='AssignMinEma')
max_val = moving_averages.assign_moving_average(
max_var, range_max, ema_decay, name='AssignMaxEma')
quant_inputs = [
tf.fake_quant_with_min_max_vars(t, min_val, max_val)
for t in inputs
]
outputs = tf.concat(quant_inputs, axis=axis)
else:
outputs = tf.concat(inputs, axis=axis)
return outputs
def _expanded_conv(self, net, num_filters, expansion_rates, kernel_size,
stride, scope):
"""Expanded convolution."""
expanded_num_filters = num_filters * expansion_rates
add_fixed_padding = self._use_explicit_padding and stride > 1
padding = 'VALID' if add_fixed_padding else 'SAME'
net = slim.conv2d(
net,
expanded_num_filters, [1, 1],
activation_fn=self._activation_fn,
normalizer_fn=self._normalization_fn,
padding=padding,
scope=scope + '/expansion')
net = slim.separable_conv2d(
ops.fixed_padding(net, kernel_size) if add_fixed_padding else net,
num_outputs=None,
kernel_size=kernel_size,
activation_fn=self._activation_fn,
normalizer_fn=self._normalization_fn,
stride=stride,
padding=padding,
scope=scope + '/depthwise')
net = slim.conv2d(
net,
num_filters, [1, 1],
activation_fn=tf.identity,
normalizer_fn=self._normalization_fn,
padding=padding,
scope=scope + '/projection')
return net
def _slice_shape_along_axis(self, shape, axis, groups):
"""Returns the shape after slicing into groups along the axis."""
if isinstance(shape, tf.TensorShape):
shape_as_list = shape.as_list()
if shape_as_list[axis] % groups != 0:
raise ValueError('Dimension {} must be divisible by {} groups'.format(
shape_as_list[axis], groups))
shape_as_list[axis] = shape_as_list[axis] // groups
return tf.TensorShape(shape_as_list)
elif isinstance(shape, tf.Tensor) and shape.shape.rank == 1:
shape_as_list = tf.unstack(shape)
shape_as_list[axis] = shape_as_list[axis] // groups
return tf.stack(shape_as_list)
else:
raise ValueError(
'Shape should be a TensorShape or rank-1 Tensor, but got: {}'.format(
shape))
def _ibn_fused_grouped(self, net, num_filters, expansion_rates, kernel_size,
stride, groups, scope):
"""Fused grouped IBN convolution."""
add_fixed_padding = self._use_explicit_padding and stride > 1
padding = 'VALID' if add_fixed_padding else 'SAME'
slice_shape = self._slice_shape_along_axis(net.shape, -1, groups)
slice_begin = [0] * net.shape.rank
slice_outputs = []
output_filters_per_group = net.shape[-1] // groups
expanded_num_filters_per_group = output_filters_per_group * expansion_rates
for idx in range(groups):
slice_input = tf.slice(net, slice_begin, slice_shape)
if isinstance(slice_shape, tf.TensorShape):
slice_begin[-1] += slice_shape.as_list()[-1]
else:
slice_begin[-1] += slice_shape[-1]
slice_outputs.append(
slim.conv2d(
ops.fixed_padding(slice_input, kernel_size)
if add_fixed_padding else slice_input,
expanded_num_filters_per_group,
kernel_size,
activation_fn=self._activation_fn,
normalizer_fn=self._normalization_fn,
stride=stride,
padding=padding,
scope='{}/{}_{}'.format(scope, 'slice', idx)))
# Make inputs to the concat share the same quantization variables.
net = self._quantizable_concat(
slice_outputs,
-1,
self._is_training,
scope='{}/{}'.format(scope, 'concat'))
net = slim.conv2d(
net,
num_filters, [1, 1],
activation_fn=tf.identity,
normalizer_fn=self._normalization_fn,
padding=padding,
scope=scope + '/projection')
return net
def _sep_conv(self, net, num_filters, kernel_size, stride, scope):
"""Depthwise Separable convolution."""
add_fixed_padding = self._use_explicit_padding and stride > 1
padding = 'VALID' if add_fixed_padding else 'SAME'
net = slim.separable_conv2d(
ops.fixed_padding(net, kernel_size) if add_fixed_padding else net,
num_outputs=None,
kernel_size=kernel_size,
activation_fn=None,
normalizer_fn=None,
stride=stride,
padding=padding,
scope=scope + '/depthwise')
net = slim.conv2d(
net,
num_filters, [1, 1],
activation_fn=self._activation_fn,
normalizer_fn=self._normalization_fn,
padding=padding,
scope=scope + '/pointwise')
return net
def _upsample(self, net, num_filters, upsample_ratio, scope):
"""Perform 1x1 conv then nearest neighbor upsampling."""
node_pre_up = slim.conv2d(
net,
num_filters, [1, 1],
activation_fn=tf.identity,
normalizer_fn=self._normalization_fn,
padding='SAME',
scope=scope + '/1x1_before_upsample')
if self._use_native_resize_op:
with tf.name_scope(scope + '/nearest_neighbor_upsampling'):
input_shape = shape_utils.combined_static_and_dynamic_shape(node_pre_up)
node_up = tf.image.resize_nearest_neighbor(
node_pre_up,
[input_shape[1] * upsample_ratio, input_shape[2] * upsample_ratio])
else:
node_up = ops.nearest_neighbor_upsampling(
node_pre_up, scale=upsample_ratio)
return node_up
def _downsample(self, net, num_filters, downsample_ratio, scope):
"""Perform maxpool downsampling then 1x1 conv."""
add_fixed_padding = self._use_explicit_padding and downsample_ratio > 1
padding = 'VALID' if add_fixed_padding else 'SAME'
node_down = slim.max_pool2d(
ops.fixed_padding(net, downsample_ratio +
1) if add_fixed_padding else net,
[downsample_ratio + 1, downsample_ratio + 1],
stride=[downsample_ratio, downsample_ratio],
padding=padding,
scope=scope + '/maxpool_downsampling')
node_after_down = slim.conv2d(
node_down,
num_filters, [1, 1],
activation_fn=tf.identity,
normalizer_fn=self._normalization_fn,
padding=padding,
scope=scope + '/1x1_after_downsampling')
return node_after_down
def _no_resample(self, net, num_filters, scope):
return slim.conv2d(
net,
num_filters, [1, 1],
activation_fn=tf.identity,
normalizer_fn=self._normalization_fn,
padding='SAME',
scope=scope + '/1x1_no_resampling')
def _spaghetti_node(self, node, scope):
"""Spaghetti node."""
node_spec = self._node_specs.nodes[node]
# Make spaghetti edges
edge_outputs = []
edge_min_level = 100 # Currently we don't have any level over 7.
edge_output_shape = None
for edge in node_spec.edges:
if isinstance(edge, SpaghettiPassthroughEdge):
assert len(node_spec.edges) == 1, len(node_spec.edges)
edge_outputs.append(self._nodes[edge.input])
elif isinstance(edge, SpaghettiResampleEdge):
edge_outputs.append(
self._spaghetti_edge(node, edge.input,
'edge_{}_{}'.format(edge.input, node)))
if edge_min_level > self._node_specs.nodes[edge.input].level:
edge_min_level = self._node_specs.nodes[edge.input].level
edge_output_shape = tf.shape(edge_outputs[-1])
else:
raise ValueError('Unknown edge type {}'.format(edge))
if len(edge_outputs) == 1:
# When edge_outputs' length is 1, it is passthrough edge.
net = edge_outputs[-1]
else:
# When edge_outputs' length is over 1, need to crop and then add edges.
net = edge_outputs[0][:, :edge_output_shape[1], :edge_output_shape[2], :]
for edge_output in edge_outputs[1:]:
net += edge_output[:, :edge_output_shape[1], :edge_output_shape[2], :]
net = self._activation_fn(net)
# Make spaghetti node
for idx, layer_spec in enumerate(node_spec.layers):
if isinstance(layer_spec, IbnOp):
net_exp = self._expanded_conv(net, node_spec.num_filters,
layer_spec.expansion_rate,
layer_spec.kernel_size, layer_spec.stride,
'{}_{}'.format(scope, idx))
elif isinstance(layer_spec, IbnFusedGrouped):
net_exp = self._ibn_fused_grouped(net, node_spec.num_filters,
layer_spec.expansion_rate,
layer_spec.kernel_size,
layer_spec.stride, layer_spec.groups,
'{}_{}'.format(scope, idx))
elif isinstance(layer_spec, SepConvOp):
net_exp = self._sep_conv(net, node_spec.num_filters,
layer_spec.kernel_size, layer_spec.stride,
'{}_{}'.format(scope, idx))
else:
raise ValueError('Unsupported layer_spec: {}'.format(layer_spec))
# Skip connection for all layers other than the first in a node.
net = net_exp + net if layer_spec.has_residual else net_exp
self._nodes[node] = net
def _spaghetti_edge(self, curr_node, prev_node, scope):
"""Create an edge between curr_node and prev_node."""
curr_spec = self._node_specs.nodes[curr_node]
prev_spec = self._node_specs.nodes[prev_node]
if curr_spec.level < prev_spec.level:
# upsample
output = self._upsample(self._nodes[prev_node], curr_spec.num_filters,
2**(prev_spec.level - curr_spec.level), scope)
elif curr_spec.level > prev_spec.level:
# downsample
output = self._downsample(self._nodes[prev_node], curr_spec.num_filters,
2**(curr_spec.level - prev_spec.level), scope)
else:
# 1x1
output = self._no_resample(self._nodes[prev_node], curr_spec.num_filters,
scope)
return output
def _spaghetti_stem_node(self, net, node, scope):
stem_spec = self._node_specs.nodes[node]
kernel_size = stem_spec.kernel_size
padding = 'VALID' if self._use_explicit_padding else 'SAME'
self._nodes[node] = slim.conv2d(
ops.fixed_padding(net, kernel_size)
if self._use_explicit_padding else net,
stem_spec.num_filters, [kernel_size, kernel_size],
stride=2,
activation_fn=self._activation_fn,
normalizer_fn=self._normalization_fn,
padding=padding,
scope=scope + '/stem')
def apply(self, net, scope='spaghetti_net'):
"""Apply the SpaghettiNet to the input and return nodes in outputs."""
for node, node_spec in self._node_specs.nodes.items():
if isinstance(node_spec, SpaghettiStemNode):
self._spaghetti_stem_node(net, node, '{}/stem_node'.format(scope))
elif isinstance(node_spec, SpaghettiNode):
self._spaghetti_node(node, '{}/{}'.format(scope, node))
else:
raise ValueError('Unknown node {}: {}'.format(node, node_spec))
return [self._nodes[x] for x in self._node_specs.outputs]
def _spaghettinet_edgetpu_s():
"""Architecture definition for SpaghettiNet-EdgeTPU-S."""
nodes = collections.OrderedDict()
outputs = ['c0n1', 'c0n2', 'c0n3', 'c0n4', 'c0n5']
nodes['s0'] = SpaghettiStemNode(kernel_size=5, num_filters=24)
nodes['n0'] = SpaghettiNode(
num_filters=48,
level=2,
layers=[
IbnFusedGrouped(3, 8, 2, 3, False),
],
edges=[SpaghettiPassthroughEdge(input='s0')])
nodes['n1'] = SpaghettiNode(
num_filters=64,
level=3,
layers=[
IbnFusedGrouped(3, 4, 2, 4, False),
IbnFusedGrouped(3, 4, 1, 4, True),
IbnFusedGrouped(3, 4, 1, 4, True),
],
edges=[SpaghettiPassthroughEdge(input='n0')])
nodes['n2'] = SpaghettiNode(
num_filters=72,
level=4,
layers=[
IbnOp(3, 8, 2, False),
IbnFusedGrouped(3, 8, 1, 4, True),
IbnOp(3, 8, 1, True),
IbnOp(3, 4, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n1')])
nodes['n3'] = SpaghettiNode(
num_filters=88,
level=5,
layers=[
IbnOp(3, 8, 2, False),
IbnOp(3, 8, 1, True),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n2')])
nodes['n4'] = SpaghettiNode(
num_filters=88,
level=6,
layers=[
IbnOp(3, 8, 2, False),
SepConvOp(5, 1, True),
SepConvOp(5, 1, True),
SepConvOp(5, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n3')])
nodes['n5'] = SpaghettiNode(
num_filters=88,
level=7,
layers=[
SepConvOp(5, 2, False),
SepConvOp(3, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n4')])
nodes['c0n0'] = SpaghettiNode(
num_filters=144,
level=5,
layers=[
IbnOp(3, 4, 1, False),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
],
edges=[
SpaghettiResampleEdge(input='n3'),
SpaghettiResampleEdge(input='n4')
])
nodes['c0n1'] = SpaghettiNode(
num_filters=120,
level=4,
layers=[
IbnOp(3, 8, 1, False),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
],
edges=[
SpaghettiResampleEdge(input='n2'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n2'] = SpaghettiNode(
num_filters=168,
level=5,
layers=[
IbnOp(3, 4, 1, False),
],
edges=[
SpaghettiResampleEdge(input='c0n1'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n3'] = SpaghettiNode(
num_filters=136,
level=6,
layers=[
IbnOp(3, 4, 1, False),
SepConvOp(3, 1, True),
],
edges=[
SpaghettiResampleEdge(input='n5'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n4'] = SpaghettiNode(
num_filters=136,
level=7,
layers=[
IbnOp(3, 4, 1, False),
],
edges=[
SpaghettiResampleEdge(input='n5'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n5'] = SpaghettiNode(
num_filters=64,
level=8,
layers=[
SepConvOp(3, 1, False),
SepConvOp(3, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='c0n4')])
node_specs = SpaghettiNodeSpecs(nodes=nodes, outputs=outputs)
return node_specs
def _spaghettinet_edgetpu_m():
"""Architecture definition for SpaghettiNet-EdgeTPU-M."""
nodes = collections.OrderedDict()
outputs = ['c0n1', 'c0n2', 'c0n3', 'c0n4', 'c0n5']
nodes['s0'] = SpaghettiStemNode(kernel_size=5, num_filters=24)
nodes['n0'] = SpaghettiNode(
num_filters=48,
level=2,
layers=[
IbnFusedGrouped(3, 8, 2, 3, False),
],
edges=[SpaghettiPassthroughEdge(input='s0')])
nodes['n1'] = SpaghettiNode(
num_filters=64,
level=3,
layers=[
IbnFusedGrouped(3, 8, 2, 4, False),
IbnFusedGrouped(3, 4, 1, 4, True),
IbnFusedGrouped(3, 4, 1, 4, True),
IbnFusedGrouped(3, 4, 1, 4, True),
],
edges=[SpaghettiPassthroughEdge(input='n0')])
nodes['n2'] = SpaghettiNode(
num_filters=72,
level=4,
layers=[
IbnOp(3, 8, 2, False),
IbnFusedGrouped(3, 8, 1, 4, True),
IbnOp(3, 8, 1, True),
IbnOp(3, 8, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n1')])
nodes['n3'] = SpaghettiNode(
num_filters=96,
level=5,
layers=[
IbnOp(3, 8, 2, False),
IbnOp(3, 8, 1, True),
IbnOp(3, 8, 1, True),
IbnOp(3, 4, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n2')])
nodes['n4'] = SpaghettiNode(
num_filters=104,
level=6,
layers=[
IbnOp(3, 8, 2, False),
IbnOp(3, 4, 1, True),
SepConvOp(5, 1, True),
IbnOp(3, 4, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n3')])
nodes['n5'] = SpaghettiNode(
num_filters=56,
level=7,
layers=[
SepConvOp(5, 2, False),
SepConvOp(3, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n4')])
nodes['c0n0'] = SpaghettiNode(
num_filters=152,
level=5,
layers=[
IbnOp(3, 8, 1, False),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
],
edges=[
SpaghettiResampleEdge(input='n3'),
SpaghettiResampleEdge(input='n4')
])
nodes['c0n1'] = SpaghettiNode(
num_filters=120,
level=4,
layers=[
IbnOp(3, 8, 1, False),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
],
edges=[
SpaghettiResampleEdge(input='n2'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n2'] = SpaghettiNode(
num_filters=168,
level=5,
layers=[
IbnOp(3, 4, 1, False),
SepConvOp(3, 1, True),
],
edges=[
SpaghettiResampleEdge(input='c0n1'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n3'] = SpaghettiNode(
num_filters=136,
level=6,
layers=[
SepConvOp(3, 1, False),
SepConvOp(3, 1, True),
SepConvOp(3, 1, True),
],
edges=[
SpaghettiResampleEdge(input='n5'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n4'] = SpaghettiNode(
num_filters=136,
level=7,
layers=[
IbnOp(3, 4, 1, False),
SepConvOp(5, 1, True),
],
edges=[
SpaghettiResampleEdge(input='n5'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n5'] = SpaghettiNode(
num_filters=64,
level=8,
layers=[
SepConvOp(3, 1, False),
SepConvOp(3, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='c0n4')])
node_specs = SpaghettiNodeSpecs(nodes=nodes, outputs=outputs)
return node_specs
def _spaghettinet_edgetpu_l():
"""Architecture definition for SpaghettiNet-EdgeTPU-L."""
nodes = collections.OrderedDict()
outputs = ['c0n1', 'c0n2', 'c0n3', 'c0n4', 'c0n5']
nodes['s0'] = SpaghettiStemNode(kernel_size=5, num_filters=24)
nodes['n0'] = SpaghettiNode(
num_filters=48,
level=2,
layers=[
IbnFusedGrouped(3, 8, 2, 3, False),
],
edges=[SpaghettiPassthroughEdge(input='s0')])
nodes['n1'] = SpaghettiNode(
num_filters=64,
level=3,
layers=[
IbnFusedGrouped(3, 8, 2, 4, False),
IbnFusedGrouped(3, 8, 1, 4, True),
IbnFusedGrouped(3, 8, 1, 4, True),
IbnFusedGrouped(3, 4, 1, 4, True),
],
edges=[SpaghettiPassthroughEdge(input='n0')])
nodes['n2'] = SpaghettiNode(
num_filters=80,
level=4,
layers=[
IbnOp(3, 8, 2, False),
IbnOp(3, 8, 1, True),
IbnOp(3, 8, 1, True),
IbnOp(3, 4, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n1')])
nodes['n3'] = SpaghettiNode(
num_filters=104,
level=5,
layers=[
IbnOp(3, 8, 2, False),
IbnOp(3, 8, 1, True),
IbnOp(3, 8, 1, True),
IbnOp(3, 8, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n2')])
nodes['n4'] = SpaghettiNode(
num_filters=88,
level=6,
layers=[
IbnOp(3, 8, 2, False),
IbnOp(5, 4, 1, True),
IbnOp(3, 4, 1, True),
IbnOp(3, 8, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n3')])
nodes['n5'] = SpaghettiNode(
num_filters=56,
level=7,
layers=[
IbnOp(5, 4, 2, False),
SepConvOp(5, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='n4')])
nodes['c0n0'] = SpaghettiNode(
num_filters=160,
level=5,
layers=[
IbnOp(3, 8, 1, False),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
],
edges=[
SpaghettiResampleEdge(input='n3'),
SpaghettiResampleEdge(input='n4')
])
nodes['c0n1'] = SpaghettiNode(
num_filters=120,
level=4,
layers=[
IbnOp(3, 8, 1, False),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
IbnOp(3, 8, 1, True),
],
edges=[
SpaghettiResampleEdge(input='n2'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n2'] = SpaghettiNode(
num_filters=168,
level=5,
layers=[
IbnOp(3, 4, 1, False),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
IbnOp(3, 4, 1, True),
],
edges=[
SpaghettiResampleEdge(input='c0n1'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n3'] = SpaghettiNode(
num_filters=112,
level=6,
layers=[
IbnOp(3, 8, 1, False),
IbnOp(3, 4, 1, True),
SepConvOp(3, 1, True),
],
edges=[
SpaghettiResampleEdge(input='n5'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n4'] = SpaghettiNode(
num_filters=128,
level=7,
layers=[
IbnOp(3, 4, 1, False),
IbnOp(3, 4, 1, True),
],
edges=[
SpaghettiResampleEdge(input='n5'),
SpaghettiResampleEdge(input='c0n0')
])
nodes['c0n5'] = SpaghettiNode(
num_filters=64,
level=8,
layers=[
SepConvOp(5, 1, False),
SepConvOp(5, 1, True),
],
edges=[SpaghettiPassthroughEdge(input='c0n4')])
node_specs = SpaghettiNodeSpecs(nodes=nodes, outputs=outputs)
return node_specs
def lookup_spaghetti_arch(arch):
"""Lookup table for the nodes structure for spaghetti nets."""
if arch == 'spaghettinet_edgetpu_s':
return _spaghettinet_edgetpu_s()
elif arch == 'spaghettinet_edgetpu_m':
return _spaghettinet_edgetpu_m()
elif arch == 'spaghettinet_edgetpu_l':
return _spaghettinet_edgetpu_l()
else:
raise ValueError('Unknown architecture {}'.format(arch))
class SSDSpaghettinetFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
"""SSD Feature Extractor using Custom Architecture."""
def __init__(
self,
is_training,
depth_multiplier,
min_depth,
pad_to_multiple,
conv_hyperparams_fn,
spaghettinet_arch_name='spaghettinet_edgetpu_m',
use_explicit_padding=False,
reuse_weights=False,
use_depthwise=False,
override_base_feature_extractor_hyperparams=False,
):
"""SSD FPN feature extractor based on Mobilenet v2 architecture.
Args:
is_training: whether the network is in training mode.
depth_multiplier: Not used in SpaghettiNet.
min_depth: Not used in SpaghettiNet.
pad_to_multiple: Not used in SpaghettiNet.
conv_hyperparams_fn: Not used in SpaghettiNet.
spaghettinet_arch_name: name of the specific architecture.
use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False.
reuse_weights: Not used in SpaghettiNet.
use_depthwise: Not used in SpaghettiNet.
override_base_feature_extractor_hyperparams: Not used in SpaghettiNet.
"""
super(SSDSpaghettinetFeatureExtractor, self).__init__(
is_training=is_training,
use_explicit_padding=use_explicit_padding,
depth_multiplier=depth_multiplier,
min_depth=min_depth,
pad_to_multiple=pad_to_multiple,
conv_hyperparams_fn=conv_hyperparams_fn,
reuse_weights=reuse_weights,
use_depthwise=use_depthwise,
override_base_feature_extractor_hyperparams=override_base_feature_extractor_hyperparams
)
self._spaghettinet_arch_name = spaghettinet_arch_name
self._use_native_resize_op = False if is_training else True
def preprocess(self, resized_inputs):
"""SSD preprocessing.
Maps pixel values to the range [-1, 1].
Args:
resized_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
Returns:
preprocessed_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
"""
return (2.0 / 255.0) * resized_inputs - 1.0
def extract_features(self, preprocessed_inputs):
"""Extract features from preprocessed inputs.
Args:
preprocessed_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
Returns:
feature_maps: a list of tensors where the ith tensor has shape
[batch, height_i, width_i, depth_i]
"""
preprocessed_inputs = shape_utils.check_min_image_dim(
33, preprocessed_inputs)
nodes_dict = lookup_spaghetti_arch(self._spaghettinet_arch_name)
with tf.variable_scope(
self._spaghettinet_arch_name, reuse=self._reuse_weights):
with slim.arg_scope([slim.conv2d],
weights_initializer=tf.truncated_normal_initializer(
mean=0.0, stddev=0.03),
weights_regularizer=slim.l2_regularizer(1e-5)):
with slim.arg_scope([slim.separable_conv2d],
weights_initializer=tf.truncated_normal_initializer(
mean=0.0, stddev=0.03),
weights_regularizer=slim.l2_regularizer(1e-5)):
with slim.arg_scope([slim.batch_norm],
is_training=self._is_training,
epsilon=0.001,
decay=0.97,
center=True,
scale=True):
spaghetti_net = SpaghettiNet(
node_specs=nodes_dict,
is_training=self._is_training,
use_native_resize_op=self._use_native_resize_op,
use_explicit_padding=self._use_explicit_padding,
name=self._spaghettinet_arch_name)
feature_maps = spaghetti_net.apply(preprocessed_inputs)
return feature_maps
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ssd_spaghettinet_feature_extractor."""
import unittest
import tensorflow.compat.v1 as tf
from object_detection.models import ssd_feature_extractor_test
from object_detection.models import ssd_spaghettinet_feature_extractor
from object_detection.utils import tf_version
try:
from tensorflow.contrib import quantize as contrib_quantize # pylint: disable=g-import-not-at-top
except: # pylint: disable=bare-except
pass
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class SSDSpaghettiNetFeatureExtractorTest(
ssd_feature_extractor_test.SsdFeatureExtractorTestBase):
def _create_feature_extractor(self, arch_name, is_training=True):
return ssd_spaghettinet_feature_extractor.SSDSpaghettinetFeatureExtractor(
is_training=is_training,
spaghettinet_arch_name=arch_name,
depth_multiplier=1.0,
min_depth=4,
pad_to_multiple=1,
conv_hyperparams_fn=self.conv_hyperparams_fn)
def _test_spaghettinet_returns_correct_shapes(self, arch_name,
expected_feature_map_shapes):
image = tf.random.normal((1, 320, 320, 3))
feature_extractor = self._create_feature_extractor(arch_name)
feature_maps = feature_extractor.extract_features(image)
self.assertEqual(len(expected_feature_map_shapes), len(feature_maps))
for expected_shape, x in zip(expected_feature_map_shapes, feature_maps):
self.assertTrue(x.shape.is_compatible_with(expected_shape))
def test_spaghettinet_edgetpu_s(self):
expected_feature_map_shapes = [(1, 20, 20, 120), (1, 10, 10, 168),
(1, 5, 5, 136), (1, 3, 3, 136),
(1, 3, 3, 64)]
self._test_spaghettinet_returns_correct_shapes('spaghettinet_edgetpu_s',
expected_feature_map_shapes)
def test_spaghettinet_edgetpu_m(self):
expected_feature_map_shapes = [(1, 20, 20, 120), (1, 10, 10, 168),
(1, 5, 5, 136), (1, 3, 3, 136),
(1, 3, 3, 64)]
self._test_spaghettinet_returns_correct_shapes('spaghettinet_edgetpu_m',
expected_feature_map_shapes)
def test_spaghettinet_edgetpu_l(self):
expected_feature_map_shapes = [(1, 20, 20, 120), (1, 10, 10, 168),
(1, 5, 5, 112), (1, 3, 3, 128),
(1, 3, 3, 64)]
self._test_spaghettinet_returns_correct_shapes('spaghettinet_edgetpu_l',
expected_feature_map_shapes)
def _check_quantization(self, model_fn):
checkpoint_dir = self.get_temp_dir()
with tf.Graph().as_default() as training_graph:
model_fn(is_training=True)
contrib_quantize.experimental_create_training_graph(training_graph)
with self.session(graph=training_graph) as sess:
sess.run(tf.global_variables_initializer())
tf.train.Saver().save(sess, checkpoint_dir)
with tf.Graph().as_default() as eval_graph:
model_fn(is_training=False)
contrib_quantize.experimental_create_eval_graph(eval_graph)
with self.session(graph=eval_graph) as sess:
tf.train.Saver().restore(sess, checkpoint_dir)
def _test_spaghettinet_quantization(self, arch_name):
def model_fn(is_training):
image = tf.random.normal((1, 320, 320, 3))
feature_extractor = self._create_feature_extractor(
arch_name, is_training=is_training)
feature_extractor.extract_features(image)
self._check_quantization(model_fn)
def test_spaghettinet_edgetpu_s_quantization(self):
self._test_spaghettinet_quantization('spaghettinet_edgetpu_s')
def test_spaghettinet_edgetpu_m_quantization(self):
self._test_spaghettinet_quantization('spaghettinet_edgetpu_m')
def test_spaghettinet_edgetpu_l_quantization(self):
self._test_spaghettinet_quantization('spaghettinet_edgetpu_l')
if __name__ == '__main__':
tf.test.main()
...@@ -5,13 +5,13 @@ package object_detection.protos; ...@@ -5,13 +5,13 @@ package object_detection.protos;
import "object_detection/protos/anchor_generator.proto"; import "object_detection/protos/anchor_generator.proto";
import "object_detection/protos/box_coder.proto"; import "object_detection/protos/box_coder.proto";
import "object_detection/protos/box_predictor.proto"; import "object_detection/protos/box_predictor.proto";
import "object_detection/protos/fpn.proto";
import "object_detection/protos/hyperparams.proto"; import "object_detection/protos/hyperparams.proto";
import "object_detection/protos/image_resizer.proto"; import "object_detection/protos/image_resizer.proto";
import "object_detection/protos/losses.proto"; import "object_detection/protos/losses.proto";
import "object_detection/protos/matcher.proto"; import "object_detection/protos/matcher.proto";
import "object_detection/protos/post_processing.proto"; import "object_detection/protos/post_processing.proto";
import "object_detection/protos/region_similarity_calculator.proto"; import "object_detection/protos/region_similarity_calculator.proto";
import "object_detection/protos/fpn.proto";
// Configuration for Single Shot Detection (SSD) models. // Configuration for Single Shot Detection (SSD) models.
// Next id: 27 // Next id: 27
...@@ -146,7 +146,7 @@ message Ssd { ...@@ -146,7 +146,7 @@ message Ssd {
optional MaskHead mask_head_config = 25; optional MaskHead mask_head_config = 25;
} }
// Next id: 20. // Next id: 21.
message SsdFeatureExtractor { message SsdFeatureExtractor {
reserved 6; reserved 6;
...@@ -202,5 +202,8 @@ message SsdFeatureExtractor { ...@@ -202,5 +202,8 @@ message SsdFeatureExtractor {
// The number of SSD layers. // The number of SSD layers.
optional int32 num_layers = 12 [default = 6]; optional int32 num_layers = 12 [default = 6];
// The SpaghettiNet architecture name.
optional string spaghettinet_arch_name = 20;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment