Unverified Commit ca552843 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-segmentation

parents 7e2f7a35 6b90e134
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for metrics.py."""
from absl.testing import parameterized
import tensorflow as tf
from official.projects.basnet.evaluation import metrics
class BASNetMetricTest(parameterized.TestCase, tf.test.TestCase):
def test_mae(self):
input_size = 224
inputs = (tf.random.uniform([2, input_size, input_size, 1]),)
labels = (tf.random.uniform([2, input_size, input_size, 1]),)
mae_obj = metrics.MAE()
mae_obj.reset_states()
mae_obj.update_state(labels, inputs)
output = mae_obj.result()
mae_tf = tf.keras.metrics.MeanAbsoluteError()
mae_tf.reset_state()
mae_tf.update_state(labels[0], inputs[0])
compare = mae_tf.result().numpy()
self.assertAlmostEqual(output, compare, places=4)
def test_max_f(self):
input_size = 224
beta = 0.3
inputs = (tf.random.uniform([2, input_size, input_size, 1]),)
labels = (tf.random.uniform([2, input_size, input_size, 1]),)
max_f_obj = metrics.MaxFscore()
max_f_obj.reset_states()
max_f_obj.update_state(labels, inputs)
output = max_f_obj.result()
pre_tf = tf.keras.metrics.Precision(thresholds=0.78)
rec_tf = tf.keras.metrics.Recall(thresholds=0.78)
pre_tf.reset_state()
rec_tf.reset_state()
pre_tf.update_state(labels[0], inputs[0])
rec_tf.update_state(labels[0], inputs[0])
pre_out_tf = pre_tf.result().numpy()
rec_out_tf = rec_tf.result().numpy()
compare = (1+beta)*pre_out_tf*rec_out_tf/(beta*pre_out_tf+rec_out_tf+1e-8)
self.assertAlmostEqual(output, compare, places=1)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Losses used for BASNet models."""
import tensorflow as tf
EPSILON = 1e-5
class BASNetLoss:
"""BASNet hybrid loss."""
def __init__(self):
self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=False)
self._ssim = tf.image.ssim
def __call__(self, sigmoids, labels):
levels = sorted(sigmoids.keys())
labels_bce = tf.squeeze(labels, axis=-1)
labels = tf.cast(labels, tf.float32)
bce_losses = []
ssim_losses = []
iou_losses = []
for level in levels:
bce_losses.append(
self._binary_crossentropy(labels_bce, sigmoids[level]))
ssim_losses.append(
1 - self._ssim(sigmoids[level], labels, max_val=1.0))
iou_losses.append(
self._iou_loss(sigmoids[level], labels))
total_bce_loss = tf.math.add_n(bce_losses)
total_ssim_loss = tf.math.add_n(ssim_losses)
total_iou_loss = tf.math.add_n(iou_losses)
total_loss = total_bce_loss + total_ssim_loss + total_iou_loss
total_loss = total_loss / len(levels)
return total_loss
def _iou_loss(self, sigmoids, labels):
total_iou_loss = 0
intersection = tf.reduce_sum(sigmoids[:, :, :, :] * labels[:, :, :, :])
union = tf.reduce_sum(sigmoids[:, :, :, :]) + tf.reduce_sum(
labels[:, :, :, :]) - intersection
iou = intersection / union
total_iou_loss += 1-iou
return total_iou_loss
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build BASNet models."""
from typing import Mapping
import tensorflow as tf
from official.modeling import tf_utils
from official.projects.basnet.modeling import nn_blocks
from official.vision.beta.modeling.backbones import factory
# Specifications for BASNet encoder.
# Each element in the block configuration is in the following format:
# (num_filters, stride, block_repeats, maxpool)
BASNET_ENCODER_SPECS = [
(64, 1, 3, 0), # ResNet-34,
(128, 2, 4, 0), # ResNet-34,
(256, 2, 6, 0), # ResNet-34,
(512, 2, 3, 1), # ResNet-34,
(512, 1, 3, 1), # BASNet,
(512, 1, 3, 0), # BASNet,
]
# Specifications for BASNet decoder.
# Each element in the block configuration is in the following format:
# (conv1_nf, conv1_dr, convm_nf, convm_dr, conv2_nf, conv2_dr, scale_factor)
# nf : num_filters, dr : dilation_rate
BASNET_BRIDGE_SPECS = [
(512, 2, 512, 2, 512, 2, 32), # Sup0, Bridge
]
BASNET_DECODER_SPECS = [
(512, 1, 512, 2, 512, 2, 32), # Sup1, stage6d
(512, 1, 512, 1, 512, 1, 16), # Sup2, stage5d
(512, 1, 512, 1, 256, 1, 8), # Sup3, stage4d
(256, 1, 256, 1, 128, 1, 4), # Sup4, stage3d
(128, 1, 128, 1, 64, 1, 2), # Sup5, stage2d
(64, 1, 64, 1, 64, 1, 1) # Sup6, stage1d
]
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNetModel(tf.keras.Model):
"""A BASNet model.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
Input images are passed through backbone first. Decoder network is then
applied, and finally, refinement module is applied on the output of the
decoder network.
"""
def __init__(self,
backbone,
decoder,
refinement=None,
**kwargs):
"""BASNet initialization function.
Args:
backbone: a backbone network. basnet_encoder.
decoder: a decoder network. basnet_decoder.
refinement: a module for salient map refinement.
**kwargs: keyword arguments to be passed.
"""
super(BASNetModel, self).__init__(**kwargs)
self._config_dict = {
'backbone': backbone,
'decoder': decoder,
'refinement': refinement,
}
self.backbone = backbone
self.decoder = decoder
self.refinement = refinement
def call(self, inputs, training=None):
features = self.backbone(inputs)
if self.decoder:
features = self.decoder(features)
levels = sorted(features.keys())
new_key = str(len(levels))
if self.refinement:
features[new_key] = self.refinement(features[levels[-1]])
return features
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(backbone=self.backbone)
if self.decoder is not None:
items.update(decoder=self.decoder)
if self.refinement is not None:
items.update(refinement=self.refinement)
return items
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNetEncoder(tf.keras.Model):
"""BASNet encoder."""
def __init__(
self,
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""BASNet encoder initialization function.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
self._input_specs = input_specs
self._use_sync_bn = use_sync_bn
self._use_bias = use_bias
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
# Build BASNet Encoder.
inputs = tf.keras.Input(shape=input_specs.shape[1:])
x = tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=1,
use_bias=self._use_bias, padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
endpoints = {}
for i, spec in enumerate(BASNET_ENCODER_SPECS):
x = self._block_group(
inputs=x,
filters=spec[0],
strides=spec[1],
block_repeats=spec[2],
name='block_group_l{}'.format(i + 2))
endpoints[str(i)] = x
if spec[3]:
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='same')(x)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(BASNetEncoder, self).__init__(
inputs=inputs, outputs=endpoints, **kwargs)
def _block_group(self,
inputs,
filters,
strides,
block_repeats=1,
name='block_group'):
"""Creates one group of residual blocks for the BASNet encoder model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
block_repeats: `int` number of blocks contained in the layer.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
x = nn_blocks.ResBlock(
filters=filters,
strides=strides,
use_projection=True,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
for _ in range(1, block_repeats):
x = nn_blocks.ResBlock(
filters=filters,
strides=1,
use_projection=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
x)
return tf.identity(x, name=name)
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('basnet_encoder')
def build_basnet_encoder(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds BASNet Encoder backbone from a config."""
backbone_type = model_config.backbone.type
norm_activation_config = model_config.norm_activation
assert backbone_type == 'basnet_encoder', (f'Inconsistent backbone type '
f'{backbone_type}')
return BASNetEncoder(
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=norm_activation_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNetDecoder(tf.keras.layers.Layer):
"""BASNet decoder."""
def __init__(self,
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""BASNet decoder initialization function.
Args:
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed.
"""
super(BASNetDecoder, self).__init__(**kwargs)
self._config_dict = {
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
self._activation = tf_utils.get_activation(activation)
self._concat = tf.keras.layers.Concatenate(axis=-1)
self._sigmoid = tf.keras.layers.Activation(activation='sigmoid')
def build(self, input_shape):
"""Creates the variables of the BASNet decoder."""
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'kernel_size': 3,
'strides': 1,
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._out_convs = []
self._out_usmps = []
# Bridge layers.
self._bdg_convs = []
for spec in BASNET_BRIDGE_SPECS:
blocks = []
for j in range(3):
blocks.append(nn_blocks.ConvBlock(
filters=spec[2*j],
dilation_rate=spec[2*j+1],
activation='relu',
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=0.99,
norm_epsilon=0.001,
**conv_kwargs))
self._bdg_convs.append(blocks)
self._out_convs.append(conv_op(
filters=1,
padding='same',
**conv_kwargs))
self._out_usmps.append(tf.keras.layers.UpSampling2D(
size=spec[6],
interpolation='bilinear'
))
# Decoder layers.
self._dec_convs = []
for spec in BASNET_DECODER_SPECS:
blocks = []
for j in range(3):
blocks.append(nn_blocks.ConvBlock(
filters=spec[2*j],
dilation_rate=spec[2*j+1],
activation='relu',
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=0.99,
norm_epsilon=0.001,
**conv_kwargs))
self._dec_convs.append(blocks)
self._out_convs.append(conv_op(
filters=1,
padding='same',
**conv_kwargs))
self._out_usmps.append(tf.keras.layers.UpSampling2D(
size=spec[6],
interpolation='bilinear'
))
def call(self, backbone_output: Mapping[str, tf.Tensor]):
"""Forward pass of the BASNet decoder.
Args:
backbone_output: A `dict` of tensors
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
Returns:
sup: A `dict` of tensors
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
"""
levels = sorted(backbone_output.keys(), reverse=True)
sup = {}
x = backbone_output[levels[0]]
for blocks in self._bdg_convs:
for block in blocks:
x = block(x)
sup['0'] = x
for i, blocks in enumerate(self._dec_convs):
x = self._concat([x, backbone_output[levels[i]]])
for block in blocks:
x = block(x)
sup[str(i+1)] = x
x = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear'
)(x)
for i, (conv, usmp) in enumerate(zip(self._out_convs, self._out_usmps)):
sup[str(i)] = self._sigmoid(usmp(conv(sup[str(i)])))
self._output_specs = {
str(order): sup[str(order)].get_shape()
for order in range(0, len(BASNET_DECODER_SPECS))
}
return sup
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {order: TensorShape} pairs for the model output."""
return self._output_specs
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for basnet network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.projects.basnet.modeling import basnet_model
from official.projects.basnet.modeling import refunet
class BASNetNetworkTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(256),
(512),
)
def test_basnet_network_creation(
self, input_size):
"""Test for creation of a segmentation network."""
inputs = np.random.rand(2, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = basnet_model.BASNetEncoder()
decoder = basnet_model.BASNetDecoder()
refinement = refunet.RefUnet()
model = basnet_model.BASNetModel(
backbone=backbone,
decoder=decoder,
refinement=refinement
)
sigmoids = model(inputs)
levels = sorted(sigmoids.keys())
self.assertAllEqual(
[2, input_size, input_size, 1],
sigmoids[levels[-1]].numpy().shape)
def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized."""
backbone = basnet_model.BASNetEncoder()
decoder = basnet_model.BASNetDecoder()
refinement = refunet.RefUnet()
model = basnet_model.BASNetModel(
backbone=backbone,
decoder=decoder,
refinement=refinement
)
config = model.get_config()
new_model = basnet_model.BASNetModel.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for BasNet model."""
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Vision')
class ConvBlock(tf.keras.layers.Layer):
"""A (Conv+BN+Activation) block."""
def __init__(self,
filters,
strides,
dilation_rate=1,
kernel_size=3,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_bias=False,
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A vgg block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
dilation_rate: `int`, dilation rate for conv layers.
kernel_size: `int`, kernel size of conv layers.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_bias: `bool`, whether or not use bias in conv layers.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(ConvBlock, self).__init__(**kwargs)
self._config_dict = {
'filters': filters,
'kernel_size': kernel_size,
'strides': strides,
'dilation_rate': dilation_rate,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon
}
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
conv_kwargs = {
'padding': 'same',
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._conv0 = tf.keras.layers.Conv2D(
filters=self._config_dict['filters'],
kernel_size=self._config_dict['kernel_size'],
strides=self._config_dict['strides'],
dilation_rate=self._config_dict['dilation_rate'],
**conv_kwargs)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
super(ConvBlock, self).build(input_shape)
def get_config(self):
return self._config_dict
def call(self, inputs, training=None):
x = self._conv0(inputs)
x = self._norm0(x)
x = self._activation_fn(x)
return x
@tf.keras.utils.register_keras_serializable(package='Vision')
class ResBlock(tf.keras.layers.Layer):
"""A residual block."""
def __init__(self,
filters,
strides,
use_projection=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
use_bias=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""Initializes a residual block with BN after convolutions.
Args:
filters: An `int` number of filters for the first two convolutions. Note
that the third and final convolution will use 4 times as many filters.
strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input.
use_projection: A `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: A `str` of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2d.
Default to None.
activation: A `str` name of the activation function.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
use_bias: A `bool`. If True, use bias in conv2d.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: Additional keyword arguments to be passed.
"""
super(ResBlock, self).__init__(**kwargs)
self._config_dict = {
'filters': filters,
'strides': strides,
'use_projection': use_projection,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon
}
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
conv_kwargs = {
'filters': self._config_dict['filters'],
'padding': 'same',
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
if self._config_dict['use_projection']:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._config_dict['filters'],
kernel_size=1,
strides=self._config_dict['strides'],
use_bias=self._config_dict['use_bias'],
kernel_initializer=self._config_dict['kernel_initializer'],
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'])
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
self._conv1 = tf.keras.layers.Conv2D(
kernel_size=3,
strides=self._config_dict['strides'],
**conv_kwargs)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
self._conv2 = tf.keras.layers.Conv2D(
kernel_size=3,
strides=1,
**conv_kwargs)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
super(ResBlock, self).build(input_shape)
def get_config(self):
return self._config_dict
def call(self, inputs, training=None):
shortcut = inputs
if self._config_dict['use_projection']:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
return self._activation_fn(x + shortcut)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RefUNet model."""
import tensorflow as tf
from official.projects.basnet.modeling import nn_blocks
@tf.keras.utils.register_keras_serializable(package='Vision')
class RefUnet(tf.keras.layers.Layer):
"""Residual Refinement Module of BASNet.
Boundary-Aware network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def __init__(self,
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""Residual Refinement Module of BASNet.
Args:
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
super(RefUnet, self).__init__(**kwargs)
self._config_dict = {
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
self._concat = tf.keras.layers.Concatenate(axis=-1)
self._sigmoid = tf.keras.layers.Activation(activation='sigmoid')
self._maxpool = tf.keras.layers.MaxPool2D(
pool_size=2,
strides=2,
padding='valid')
self._upsample = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear')
def build(self, input_shape):
"""Creates the variables of the BASNet decoder."""
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'kernel_size': 3,
'strides': 1,
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._in_conv = conv_op(
filters=64,
padding='same',
**conv_kwargs)
self._en_convs = []
for _ in range(4):
self._en_convs.append(nn_blocks.ConvBlock(
filters=64,
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._bridge_convs = []
for _ in range(1):
self._bridge_convs.append(nn_blocks.ConvBlock(
filters=64,
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._de_convs = []
for _ in range(4):
self._de_convs.append(nn_blocks.ConvBlock(
filters=64,
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._out_conv = conv_op(
filters=1,
padding='same',
**conv_kwargs)
def call(self, inputs):
endpoints = {}
residual = inputs
x = self._in_conv(inputs)
# Top-down
for i, block in enumerate(self._en_convs):
x = block(x)
endpoints[str(i)] = x
x = self._maxpool(x)
# Bridge
for i, block in enumerate(self._bridge_convs):
x = block(x)
# Bottom-up
for i, block in enumerate(self._de_convs):
dtype = x.dtype
x = tf.cast(x, tf.float32)
x = self._upsample(x)
x = tf.cast(x, dtype)
x = self._concat([endpoints[str(3-i)], x])
x = block(x)
x = self._out_conv(x)
residual = tf.cast(residual, dtype=x.dtype)
output = self._sigmoid(x + residual)
self._output_specs = output.get_shape()
return output
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
return self._output_specs
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export module for BASNet."""
import tensorflow as tf
from official.projects.basnet.tasks import basnet
from official.vision.beta.serving import semantic_segmentation
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class BASNetModule(semantic_segmentation.SegmentationModule):
"""BASNet Module."""
def _build_model(self):
input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3])
return basnet.build_basnet_model(
input_specs=input_specs,
model_config=self.params.task.model,
l2_regularizer=None)
def serve(self, images):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs, elems=images,
fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32
)
)
masks = self.inference_step(images)
keys = sorted(masks.keys())
output = tf.image.resize(
masks[keys[-1]],
self._input_image_size, method='bilinear')
return dict(predicted_masks=output)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Export binary for BASNet.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.core import exp_factory
from official.modeling import hyperparams
from official.projects.basnet.serving import basnet
from official.vision.beta.serving import export_saved_model_lib
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer(
'batch_size', None, 'The batch size.')
flags.DEFINE_string(
'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example`.')
flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
export_saved_model_lib.export_inference_graph(
input_type=FLAGS.input_type,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir,
export_module=basnet.BASNetModule(
params=params,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')]),
export_checkpoint_subdir='checkpoint',
export_saved_model_subdir='saved_model')
if __name__ == '__main__':
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BASNet task definition."""
from typing import Optional
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.projects.basnet.configs import basnet as exp_cfg
from official.projects.basnet.evaluation import metrics as basnet_metrics
from official.projects.basnet.losses import basnet_losses
from official.projects.basnet.modeling import basnet_model
from official.projects.basnet.modeling import refunet
from official.vision.beta.dataloaders import segmentation_input
def build_basnet_model(
input_specs: tf.keras.layers.InputSpec,
model_config: exp_cfg.BASNetModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds BASNet model."""
norm_activation_config = model_config.norm_activation
backbone = basnet_model.BASNetEncoder(
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=model_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
decoder = basnet_model.BASNetDecoder(
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=model_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
refinement = refunet.RefUnet(
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=model_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
model = basnet_model.BASNetModel(backbone, decoder, refinement)
return model
@task_factory.register_task_cls(exp_cfg.BASNetTask)
class BASNetTask(base_task.Task):
"""A task for basnet."""
def build_model(self):
"""Builds basnet model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = build_basnet_model(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds BASNet input."""
ignore_label = self.task_config.losses.ignore_label
decoder = segmentation_input.Decoder()
parser = segmentation_input.Parser(
output_size=params.output_size,
crop_size=params.crop_size,
ignore_label=ignore_label,
aug_rand_hflip=params.aug_rand_hflip,
dtype=params.dtype)
reader = input_reader.InputReader(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self, label, model_outputs, aux_losses=None):
"""Hybrid loss proposed in BASNet.
Args:
label: label.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
basnet_loss_fn = basnet_losses.BASNetLoss()
total_loss = basnet_loss_fn(model_outputs, label['masks'])
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_metrics(self, training=False):
"""Gets streaming metrics for training/validation."""
evaluations = []
if training:
evaluations = []
else:
self.mae_metric = basnet_metrics.MAE()
self.maxf_metric = basnet_metrics.MaxFscore()
self.relaxf_metric = basnet_metrics.RelaxedFscore()
return evaluations
def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs, label=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping.
if self.task_config.gradient_clip_norm > 0:
grads, _ = tf.clip_by_global_norm(
grads, self.task_config.gradient_clip_norm)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
return logs
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = 0
logs = {self.loss: loss}
levels = sorted(outputs.keys())
logs.update(
{self.mae_metric.name: (labels['masks'], outputs[levels[-1]])})
logs.update(
{self.maxf_metric.name: (labels['masks'], outputs[levels[-1]])})
logs.update(
{self.relaxf_metric.name: (labels['masks'], outputs[levels[-1]])})
return logs
def inference_step(self, inputs, model):
"""Performs the forward step."""
return model(inputs, training=False)
def aggregate_logs(self, state=None, step_outputs=None):
if state is None:
self.mae_metric.reset_states()
self.maxf_metric.reset_states()
self.relaxf_metric.reset_states()
state = self.mae_metric
self.mae_metric.update_state(
step_outputs[self.mae_metric.name][0],
step_outputs[self.mae_metric.name][1])
self.maxf_metric.update_state(
step_outputs[self.maxf_metric.name][0],
step_outputs[self.maxf_metric.name][1])
self.relaxf_metric.update_state(
step_outputs[self.relaxf_metric.name][0],
step_outputs[self.relaxf_metric.name][1])
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
result = {}
result['MAE'] = self.mae_metric.result()
result['maxF'] = self.maxf_metric.result()
result['relaxF'] = self.relaxf_metric.result()
return result
......@@ -12,37 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM continuous finetuning+eval training driver."""
# Lint as: python3
"""TensorFlow Model Garden Vision training driver."""
from absl import app
from absl import flags
import gin
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import train_utils
from official.nlp import continuous_finetune_lib
FLAGS = flags.FLAGS
flags.DEFINE_integer(
'pretrain_steps',
default=None,
help='The number of total training steps for the pretraining job.')
def main(_):
# TODO(b/177863554): consolidate to nlp/train.py
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
train_utils.serialize_config(params, model_dir)
continuous_finetune_lib.run_continuous_finetune(
FLAGS.mode, params, model_dir, pretrain_steps=FLAGS.pretrain_steps)
train_utils.save_gin_config(FLAGS.mode, model_dir)
from official.projects.basnet.configs import basnet as basnet_cfg
from official.projects.basnet.modeling import basnet_model
from official.projects.basnet.modeling import refunet
from official.projects.basnet.tasks import basnet as basenet_task
from official.vision.beta import train
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
app.run(train.main)
......@@ -68,6 +68,9 @@ Note that the dataset is large (~1TB).
### Preprocess the data
Follow the instructions in [Data Preprocessing](data/preprocessing) to
preprocess the Criteo Terabyte dataset.
Data preprocessing steps are summarized below.
Integer feature processing steps, sequentially:
......@@ -93,9 +96,9 @@ Training and eval datasets are expected to be saved in many tab-separated values
(TSV) files in the following format: numberical fetures, categorical features
and label.
On each row of the TSV file first `num_dense_features` inputs are numerical
features, then `vocab_sizes` categorical features and the last one is the label
(either 0 or 1). Each i-th categorical feature is expected to be an integer in
On each row of the TSV file, the first one is the label
(either 0 or 1), the next `num_dense_features` inputs are numerical
features, then `vocab_sizes` categorical features. Each i-th categorical feature is expected to be an integer in
the range of `[0, vocab_sizes[i])`.
## Train and Evaluate
......
## Download and preprocess Criteo TB dataset
[Apache Beam](https://beam.apache.org) enables distributed preprocessing of the
dataset and can be run on
[Google Cloud Dataflow](https://cloud.google.com/dataflow/). The preprocessing
scripts can be run locally via DirectRunner provided that the local host has
enough CPU/Memory/Storage.
Install required packages.
```bash
python3 setup.py install
```
Set up the following environment variables, replacing bucket-name with the name
of your Cloud Storage bucket and project name with your GCP project name.
```bash
export STORAGE_BUCKET=gs://bucket-name
export PROJECT=my-gcp-project
export REGION=us-central1
```
Note: If running locally above environment variables won't be needed and instead
of gs://bucket-name a local path can be used, also consider passing smaller
`max_vocab_size` argument.
1. Download raw
[Criteo TB dataset](https://labs.criteo.com/2013/12/download-terabyte-click-logs/)
to a GCS bucket.
Organize the data in the following way:
* The files day_0.gz, day_1.gz, ..., day_22.gz in
${STORAGE_BUCKET}/criteo_raw/train/
* The file day_23.gz in ${STORAGE_BUCKET}/criteo_raw/test/
2. Shard the raw training/test data into multiple files.
```bash
python3 shard_rebalancer.py \
--input_path "${STORAGE_BUCKET}/criteo_raw/train/*" \
--output_path "${STORAGE_BUCKET}/criteo_raw_sharded/train/train" \
--num_output_files 1024 --filetype csv --runner DataflowRunner \
--project ${PROJECT} --region ${REGION}
```
```bash
python3 shard_rebalancer.py \
--input_path "${STORAGE_BUCKET}/criteo_raw/test/*" \
--output_path "${STORAGE_BUCKET}/criteo_raw_sharded/test/test" \
--num_output_files 64 --filetype csv --runner DataflowRunner \
--project ${PROJECT} --region ${REGION}
```
3. Generate vocabulary and preprocess the data.
Generate vocabulary:
```bash
python3 criteo_preprocess.py \
--input_path "${STORAGE_BUCKET}/criteo_raw_sharded/*/*" \
--output_path "${STORAGE_BUCKET}/criteo/" \
--temp_dir "${STORAGE_BUCKET}/criteo_vocab/" \
--vocab_gen_mode --runner DataflowRunner --max_vocab_size 5000000 \
--project ${PROJECT} --region ${REGION}
```
Preprocess training and test data:
```bash
python3 criteo_preprocess.py \
--input_path "${STORAGE_BUCKET}/criteo_raw_sharded/train/*" \
--output_path "${STORAGE_BUCKET}/criteo/train/train" \
--temp_dir "${STORAGE_BUCKET}/criteo_vocab/" \
--runner DataflowRunner --max_vocab_size 5000000 \
--project ${PROJECT} --region ${REGION}
```
```bash
python3 criteo_preprocess.py \
--input_path "${STORAGE_BUCKET}/criteo_raw_sharded/test/*" \
--output_path "${STORAGE_BUCKET}/criteo/test/test" \
--temp_dir "${STORAGE_BUCKET}/criteo_vocab/" \
--runner DataflowRunner --max_vocab_size 5000000 \
--project ${PROJECT} --region ${REGION}
```
4. (Optional) Re-balance the dataset.
```bash
python3 shard_rebalancer.py \
--input_path "${STORAGE_BUCKET}/criteo/train/*" \
--output_path "${STORAGE_BUCKET}/criteo_balanced/train/train" \
--num_output_files 8192 --filetype csv --runner DataflowRunner \
--project ${PROJECT} --region ${REGION}
```
```bash
python3 shard_rebalancer.py \
--input_path "${STORAGE_BUCKET}/criteo/test/*" \
--output_path "${STORAGE_BUCKET}/criteo_balanced/test/test" \
--num_output_files 1024 --filetype csv --runner DataflowRunner \
--project ${PROJECT} --region ${REGION}
```
At this point training and test data are in the buckets:
* `${STORAGE_BUCKET}/criteo_balanced/train/`
* `${STORAGE_BUCKET}/criteo_balanced/test/`
All other buckets can be removed.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFX beam preprocessing pipeline for Criteo data.
Preprocessing util for criteo data. Transformations:
1. Fill missing features with zeros.
2. Set negative integer features to zeros.
3. Normalize integer features using log(x+1).
4. For categorical features (hex), convert to integer and take value modulus the
max_vocab_size value.
Usage:
For raw Criteo data, this script should be run twice.
First run should set vocab_gen_mode to true. This run is used to generate
vocabulary files in the temp_dir location.
Second run should set vocab_gen_mode to false. It is necessary to point to the
same temp_dir used during the first run.
"""
import argparse
import datetime
import os
from absl import logging
import apache_beam as beam
import numpy as np
import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils
from tfx_bsl.public import tfxio
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_path",
default=None,
required=True,
help="Input path. Be sure to set this to cover all data, to ensure "
"that sparse vocabs are complete.")
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Output path.")
parser.add_argument(
"--temp_dir",
default=None,
required=True,
help="Directory to store temporary metadata. Important because vocab "
"dictionaries will be stored here. Co-located with data, ideally.")
parser.add_argument(
"--csv_delimeter",
default="\t",
help="Delimeter string for input and output.")
parser.add_argument(
"--vocab_gen_mode",
action="store_true",
default=False,
help="If it is set, process full dataset and do not write CSV output. In "
"this mode, See temp_dir for vocab files. input_path should cover all "
"data, e.g. train, test, eval.")
parser.add_argument(
"--runner",
help="Runner for Apache Beam, needs to be one of {DirectRunner, "
"DataflowRunner}.",
default="DirectRunner")
parser.add_argument(
"--project",
default=None,
help="ID of your project. Ignored by DirectRunner.")
parser.add_argument(
"--region",
default=None,
help="Region. Ignored by DirectRunner.")
parser.add_argument(
"--max_vocab_size",
type=int,
default=10_000_000,
help="Max index range, categorical features convert to integer and take "
"value modulus the max_vocab_size")
args = parser.parse_args()
NUM_NUMERIC_FEATURES = 13
NUMERIC_FEATURE_KEYS = [
f"int-feature-{x + 1}" for x in range(NUM_NUMERIC_FEATURES)]
CATEGORICAL_FEATURE_KEYS = [
"categorical-feature-%d" % x for x in range(NUM_NUMERIC_FEATURES + 1, 40)]
LABEL_KEY = "clicked"
# Data is first preprocessed in pure Apache Beam using numpy.
# This removes missing values and hexadecimal-encoded values.
# For the TF schema, we can thus specify the schema as FixedLenFeature
# for TensorFlow Transform.
FEATURE_SPEC = dict([(name, tf.io.FixedLenFeature([], dtype=tf.int64))
for name in CATEGORICAL_FEATURE_KEYS] +
[(name, tf.io.FixedLenFeature([], dtype=tf.float32))
for name in NUMERIC_FEATURE_KEYS] +
[(LABEL_KEY, tf.io.FixedLenFeature([], tf.float32))])
INPUT_METADATA = dataset_metadata.DatasetMetadata(
schema_utils.schema_from_feature_spec(FEATURE_SPEC))
def apply_vocab_fn(inputs):
"""Preprocessing fn for sparse features.
Applies vocab to bucketize sparse features. This function operates using
previously-created vocab files.
Pre-condition: Full vocab has been materialized.
Args:
inputs: Input features to transform.
Returns:
Output dict with transformed features.
"""
outputs = {}
outputs[LABEL_KEY] = inputs[LABEL_KEY]
for key in NUMERIC_FEATURE_KEYS:
outputs[key] = inputs[key]
for idx, key in enumerate(CATEGORICAL_FEATURE_KEYS):
vocab_fn = os.path.join(
args.temp_dir, "tftransform_tmp", "feature_{}_vocab".format(idx))
outputs[key] = tft.apply_vocabulary(inputs[key], vocab_fn)
return outputs
def compute_vocab_fn(inputs):
"""Preprocessing fn for sparse features.
This function computes unique IDs for the sparse features. We rely on implicit
behavior which writes the vocab files to the vocab_filename specified in
tft.compute_and_apply_vocabulary.
Pre-condition: Sparse features have been converted to integer and mod'ed with
args.max_vocab_size.
Args:
inputs: Input features to transform.
Returns:
Output dict with transformed features.
"""
outputs = {}
outputs[LABEL_KEY] = inputs[LABEL_KEY]
for key in NUMERIC_FEATURE_KEYS:
outputs[key] = inputs[key]
for idx, key in enumerate(CATEGORICAL_FEATURE_KEYS):
outputs[key] = tft.compute_and_apply_vocabulary(
x=inputs[key],
vocab_filename="feature_{}_vocab".format(idx))
return outputs
class FillMissing(beam.DoFn):
"""Fills missing elements with zero string value."""
def process(self, element):
elem_list = element.split(args.csv_delimeter)
out_list = []
for val in elem_list:
new_val = "0" if not val else val
out_list.append(new_val)
yield (args.csv_delimeter).join(out_list)
class NegsToZeroLog(beam.DoFn):
"""For int features, sets negative values to zero and takes log(x+1)."""
def process(self, element):
elem_list = element.split(args.csv_delimeter)
out_list = []
for i, val in enumerate(elem_list):
if i > 0 and i <= NUM_NUMERIC_FEATURES:
new_val = "0" if int(val) < 0 else val
new_val = np.log(int(new_val) + 1)
new_val = str(new_val)
else:
new_val = val
out_list.append(new_val)
yield (args.csv_delimeter).join(out_list)
class HexToIntModRange(beam.DoFn):
"""For categorical features, takes decimal value and mods with max value."""
def process(self, element):
elem_list = element.split(args.csv_delimeter)
out_list = []
for i, val in enumerate(elem_list):
if i > NUM_NUMERIC_FEATURES:
new_val = int(val, 16) % args.max_vocab_size
else:
new_val = val
out_list.append(str(new_val))
yield str.encode((args.csv_delimeter).join(out_list))
def transform_data(data_path, output_path):
"""Preprocesses Criteo data.
Two processing modes are supported. Raw data will require two passes.
If full vocab files already exist, only one pass is necessary.
Args:
data_path: File(s) to read.
output_path: Path to which output CSVs are written, if necessary.
"""
preprocessing_fn = compute_vocab_fn if args.vocab_gen_mode else apply_vocab_fn
gcp_project = args.project
region = args.region
job_name = (f"criteo-preprocessing-"
f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}")
# set up Beam pipeline.
pipeline_options = None
if args.runner == "DataflowRunner":
options = {
"staging_location": os.path.join(output_path, "tmp", "staging"),
"temp_location": os.path.join(output_path, "tmp"),
"job_name": job_name,
"project": gcp_project,
"save_main_session": True,
"region": region,
"setup_file": "./setup.py",
}
pipeline_options = beam.pipeline.PipelineOptions(flags=[], **options)
elif args.runner == "DirectRunner":
pipeline_options = beam.options.pipeline_options.DirectOptions(
direct_num_workers=os.cpu_count(),
direct_running_mode="multi_threading")
with beam.Pipeline(args.runner, options=pipeline_options) as pipeline:
with tft_beam.Context(temp_dir=args.temp_dir):
processed_lines = (
pipeline
# Read in TSV data.
| beam.io.ReadFromText(data_path, coder=beam.coders.StrUtf8Coder())
# Fill in missing elements with the defaults (zeros).
| "FillMissing" >> beam.ParDo(FillMissing())
# For numerical features, set negatives to zero. Then take log(x+1).
| "NegsToZeroLog" >> beam.ParDo(NegsToZeroLog())
# For categorical features, mod the values with vocab size.
| "HexToIntModRange" >> beam.ParDo(HexToIntModRange()))
# CSV reader: List the cols in order, as dataset schema is not ordered.
ordered_columns = [LABEL_KEY
] + NUMERIC_FEATURE_KEYS + CATEGORICAL_FEATURE_KEYS
csv_tfxio = tfxio.BeamRecordCsvTFXIO(
physical_format="text",
column_names=ordered_columns,
delimiter=args.csv_delimeter,
schema=INPUT_METADATA.schema)
converted_data = (
processed_lines
| "DecodeData" >> csv_tfxio.BeamSource())
raw_dataset = (converted_data, csv_tfxio.TensorAdapterConfig())
# The TFXIO output format is chosen for improved performance.
transformed_dataset, _ = (
raw_dataset | tft_beam.AnalyzeAndTransformDataset(
preprocessing_fn, output_record_batches=False))
# Transformed metadata is not necessary for encoding.
transformed_data, transformed_metadata = transformed_dataset
if not args.vocab_gen_mode:
# Write to CSV.
transformed_csv_coder = tft.coders.CsvCoder(
ordered_columns, transformed_metadata.schema,
delimiter=args.csv_delimeter)
_ = (
transformed_data
| "EncodeDataCsv" >> beam.Map(transformed_csv_coder.encode)
| "WriteDataCsv" >> beam.io.WriteToText(output_path))
if __name__ == "__main__":
logging.set_verbosity(logging.INFO)
transform_data(data_path=args.input_path,
output_path=args.output_path)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup configuration for criteo dataset preprocessing.
This is used while running Tensorflow transform on Cloud Dataflow.
"""
import setuptools
version = "0.1.0"
if __name__ == "__main__":
setuptools.setup(
name="criteo_preprocessing",
version=version,
install_requires=["tensorflow-transform"],
packages=setuptools.find_packages(),
)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rebalance a set of CSV/TFRecord shards to a target number of files.
"""
import argparse
import datetime
import os
import apache_beam as beam
import tensorflow as tf
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_path",
default=None,
required=True,
help="Input path.")
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Output path.")
parser.add_argument(
"--num_output_files",
type=int,
default=256,
help="Number of output file shards.")
parser.add_argument(
"--filetype",
default="tfrecord",
help="File type, needs to be one of {tfrecord, csv}.")
parser.add_argument(
"--project",
default=None,
help="ID (not name) of your project. Ignored by DirectRunner")
parser.add_argument(
"--runner",
help="Runner for Apache Beam, needs to be one of "
"{DirectRunner, DataflowRunner}.",
default="DirectRunner")
parser.add_argument(
"--region",
default=None,
help="region")
args = parser.parse_args()
def rebalance_data_shards():
"""Rebalances data shards."""
def csv_pipeline(pipeline: beam.Pipeline):
"""Rebalances CSV dataset.
Args:
pipeline: Beam pipeline object.
"""
_ = (
pipeline
| beam.io.ReadFromText(args.input_path)
| beam.io.WriteToText(args.output_path,
num_shards=args.num_output_files))
def tfrecord_pipeline(pipeline: beam.Pipeline):
"""Rebalances TFRecords dataset.
Args:
pipeline: Beam pipeline object.
"""
example_coder = beam.coders.ProtoCoder(tf.train.Example)
_ = (
pipeline
| beam.io.ReadFromTFRecord(args.input_path, coder=example_coder)
| beam.io.WriteToTFRecord(args.output_path, file_name_suffix="tfrecord",
coder=example_coder,
num_shards=args.num_output_files))
job_name = (
f"shard-rebalancer-{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}")
# set up Beam pipeline.
options = {
"staging_location": os.path.join(args.output_path, "tmp", "staging"),
"temp_location": os.path.join(args.output_path, "tmp"),
"job_name": job_name,
"project": args.project,
"save_main_session": True,
"region": args.region,
}
opts = beam.pipeline.PipelineOptions(flags=[], **options)
with beam.Pipeline(args.runner, options=opts) as pipeline:
if args.filetype == "tfrecord":
tfrecord_pipeline(pipeline)
elif args.filetype == "csv":
csv_pipeline(pipeline)
if __name__ == "__main__":
rebalance_data_shards()
......@@ -43,11 +43,6 @@ class RankingTrainer(base_trainer.Trainer):
def train_loop_end(self) -> Dict[str, float]:
"""See base class."""
self.join()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if self._recovery:
self._recovery.maybe_recover(self.train_loss.result().numpy(),
self.global_step.numpy())
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
......
......@@ -46,16 +46,13 @@ flags.DEFINE_bool('search_hints', True,
flags.DEFINE_string('site_path', '/api_docs/python',
'Path prefix in the _toc.yaml')
flags.DEFINE_bool('gen_report', False,
'Generate an API report containing the health of the '
'docstrings of the public API.')
PROJECT_SHORT_NAME = 'tfnlp'
PROJECT_FULL_NAME = 'TensorFlow Official Models - NLP Modeling Library'
def gen_api_docs(code_url_prefix, site_path, output_dir, gen_report,
project_short_name, project_full_name, search_hints):
def gen_api_docs(code_url_prefix, site_path, output_dir, project_short_name,
project_full_name, search_hints):
"""Generates api docs for the tensorflow docs package."""
build_api_docs_lib.hide_module_model_and_layer_methods()
del tfnlp.layers.MultiHeadAttention
......@@ -68,7 +65,6 @@ def gen_api_docs(code_url_prefix, site_path, output_dir, gen_report,
code_url_prefix=code_url_prefix,
search_hints=search_hints,
site_path=site_path,
gen_report=gen_report,
callbacks=[public_api.explicit_package_contents_filter],
)
......@@ -84,7 +80,6 @@ def main(argv):
code_url_prefix=FLAGS.code_url_prefix,
site_path=FLAGS.site_path,
output_dir=FLAGS.output_dir,
gen_report=FLAGS.gen_report,
project_short_name=PROJECT_SHORT_NAME,
project_full_name=PROJECT_FULL_NAME,
search_hints=FLAGS.search_hints)
......
......@@ -46,16 +46,12 @@ flags.DEFINE_bool('search_hints', True,
flags.DEFINE_string('site_path', 'tfvision/api_docs/python',
'Path prefix in the _toc.yaml')
flags.DEFINE_bool('gen_report', False,
'Generate an API report containing the health of the '
'docstrings of the public API.')
PROJECT_SHORT_NAME = 'tfvision'
PROJECT_FULL_NAME = 'TensorFlow Official Models - Vision Modeling Library'
def gen_api_docs(code_url_prefix, site_path, output_dir, gen_report,
project_short_name, project_full_name, search_hints):
def gen_api_docs(code_url_prefix, site_path, output_dir, project_short_name,
project_full_name, search_hints):
"""Generates api docs for the tensorflow docs package."""
build_api_docs_lib.hide_module_model_and_layer_methods()
......@@ -66,7 +62,6 @@ def gen_api_docs(code_url_prefix, site_path, output_dir, gen_report,
code_url_prefix=code_url_prefix,
search_hints=search_hints,
site_path=site_path,
gen_report=gen_report,
callbacks=[public_api.explicit_package_contents_filter],
)
......@@ -82,7 +77,6 @@ def main(argv):
code_url_prefix=FLAGS.code_url_prefix,
site_path=FLAGS.site_path,
output_dir=FLAGS.output_dir,
gen_report=FLAGS.gen_report,
project_short_name=PROJECT_SHORT_NAME,
project_full_name=PROJECT_FULL_NAME,
search_hints=FLAGS.search_hints)
......
......@@ -46,14 +46,8 @@ py_test() {
return "${exit_code}"
}
py2_test() {
local PY_BINARY=$(which python2)
py_test "$PY_BINARY"
return $?
}
py3_test() {
local PY_BINARY=$(which python3)
local PY_BINARY=python3.9
py_test "$PY_BINARY"
return $?
}
......@@ -61,7 +55,7 @@ py3_test() {
test_result=0
if [ "$#" -eq 0 ]; then
TESTS="lint py2_test py3_test"
TESTS="lint py3_test"
else
TESTS="$@"
fi
......
......@@ -54,9 +54,12 @@ depth, label smoothing and dropout.
### Common Settings and Notes
* We provide models based on two detection frameworks, [RetinaNet](https://arxiv.org/abs/1708.02002)
or [Mask R-CNN](https://arxiv.org/abs/1703.06870), and two backbones, [ResNet-FPN](https://arxiv.org/abs/1612.03144)
or [SpineNet](https://arxiv.org/abs/1912.05027).
* We provide models adopting [ResNet-FPN](https://arxiv.org/abs/1612.03144) and
[SpineNet](https://arxiv.org/abs/1912.05027) backbones based on detection frameworks:
* [RetinaNet](https://arxiv.org/abs/1708.02002) and [RetinaNet-RS](https://arxiv.org/abs/2107.00057)
* [Mask R-CNN](https://arxiv.org/abs/1703.06870)
* [Cascade RCNN](https://arxiv.org/abs/1712.00726) and [Cascade RCNN-RS](https://arxiv.org/abs/2107.00057)
* Models are all trained on COCO train2017 and evaluated on COCO val2017.
* Training details:
* Models finetuned from ImageNet pretrained checkpoints adopt the 12 or 36
......@@ -99,13 +102,22 @@ depth, label smoothing and dropout.
### Instance Segmentation Baselines
#### Mask R-CNN (ImageNet pretrained)
#### Mask R-CNN (Trained from scratch)
| Backbone | Resolution | Epochs | FLOPs (B) | Params (M) | Box AP | Mask AP | Download |
| ------------ |:-------------:| -------:|-----------:|-----------:|-------:|--------:|---------:|
| SpineNet-49 | 640x640 | 350 | 215.7 | 40.8 | 42.6 | 37.9 | config |
ResNet50-FPN | 640x640 | 350 | 227.7 | 46.3 | 42.3 | 37.6 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/r50fpn_640_coco_scratch_tpu4x4.yaml) |
| SpineNet-49 | 640x640 | 350 | 215.7 | 40.8 | 42.6 | 37.9 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet49_mrcnn_tpu.yaml) |
SpineNet-96 | 1024x1024 | 500 | 315.0 | 55.2 | 48.1 | 42.4 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet96_mrcnn_tpu.yaml) |
SpineNet-143 | 1280x1280 | 500 | 498.8 | 79.2 | 49.3 | 43.4 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet143_mrcnn_tpu.yaml) |
#### Cascade RCNN-RS (Trained from scratch)
backbone | resolution | epochs | params (M) | box AP | mask AP | download
------------ | :--------: | -----: | ---------: | -----: | ------: | -------:
SpineNet-49 | 640x640 | 500 | 56.4 | 46.4 | 40.0 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet49_cascadercnn_tpu.yaml)|
SpineNet-143 | 1280x1280 | 500 | 94.9 | 51.9 | 45.0 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet143_cascadercnn_tpu.yaml)|
## Semantic Segmentation
......@@ -131,7 +143,7 @@ depth, label smoothing and dropout.
### Common Settings and Notes
* We provide models for video classification with two backbones:
* We provide models for video classification with two backbones:
[SlowOnly](https://arxiv.org/abs/1812.03982) and 3D-ResNet (R3D) used in
[Spatiotemporal Contrastive Video Representation Learning](https://arxiv.org/abs/2008.03800).
* Training and evaluation details:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment