Commit bab70e6b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by saberkun
Browse files

Open source DeepMAC architecture.

PiperOrigin-RevId: 366382220
parent 2e1a0602
# Mask R-CNN with deep mask heads
This project brings insights from the DeepMAC model into the Mask-RCNN
architecture. Please see the paper
[The surprising impact of mask-head architecture on novel class segmentation](https://arxiv.org/abs/2104.00613)
for more details.
## Code structure
* This folder contains forks of a few Mask R-CNN files and repurposes them to
support deep mask heads.
* To see the benefits of using deep mask heads, it is important to train the
mask head with only groundtruth boxes. This is configured via the
`task.model.use_gt_boxes_for_masks` flag.
* Architecture of the mask head can be changed via the config value
`task.model.mask_head.convnet_variant`. Supported values are `"default"`,
`"hourglass20"`, `"hourglass52"`, and `"hourglass100"`.
* The flag `task.model.mask_head.class_agnostic` trains the model in class
agnostic mode and `task.allowed_mask_class_ids` controls which classes are
allowed to have masks during training.
* Majority of experiments and ablations from the paper are perfomed with the
[DeepMAC model]() in the Object Detection API code base.
## Prerequisites
### Prepare dataset
Use [create_coco_tf_record.py](../../data/create_coco_tf_record.py) to create
the COCO dataset. The data needs to be store in a
[Google cloud storage bucket](https://cloud.google.com/storage/docs/creating-buckets)
so that it can be accessed by the TPU.
### Start a TPU v3-32 instance
See [TPU Quickstart](https://cloud.google.com/tpu/docs/quickstart) for
instructions. An example command would look like:
```shell
ctpu up --name <tpu-name> --zone <zone> --tpu-size=v3-32 --tf-version nightly
```
This model requires TF version `>= 2.5`. Currently, that is only available via a
`nightly` build on Cloud.
### Install requirements
SSH into the TPU host with `gcloud compute ssh <tpu-name>` and execute the
following.
```shell
$ git clone https://github.com/tensorflow/models.git
$ cd models
$ pip3 install -r official/requirements.txt
```
## Training Models
The configurations can be found in the `configs/experiments` directory. You can
launch a training job by executing.
```shell
$ export CONFIG=./official/vision/beta/projects/deepmac_maskrcnn/configs/experiments/deep_mask_head_rcnn_voc_r50.yaml
$ export MODEL_DIR="gs://<path-for-checkpoints>"
$ export ANNOTAION_FILE="gs://<path-to-coco-annotation-json>"
$ export TRAIN_DATA="gs://<path-to-train-data>"
$ export EVAL_DATA="gs://<path-to-eval-data>"
# Overrides to access data. These can also be changed in the config file.
$ export OVERRIDES="task.validation_data.input_path=${EVAL_DATA},\
task.train_data.input_path=${TRAIN_DATA},\
task.annotation_file=${ANNOTAION_FILE},\
runtime.distribution_strategy=tpu"
$ python3 -m official.vision.beta.projects.deepmac_maskrcnn.train \
--logtostderr \
--mode=train_and_eval \
--experiment=deep_mask_head_rcnn_resnetfpn_coco \
--model_dir=$MODEL_DIR \
--config_file=$CONFIG \
--params_override=$OVERRIDES\
--tpu=<tpu-name>
```
`CONFIG_FILE` can be any file in the `configs/experiments` directory.
**Note:** The default eval batch size of 32 discards some samples during
validation. For accurate vaidation statistics, launch a dedicated eval job on
TPU `v3-8` and set batch size to 8.
## Configurations
In the following table, we report the Mask mAP of our models on the non-VOC
classes when only training with masks for the VOC calsses. Performance is
measured on the `coco-val2017` set.
Backbone | Mask head | Config name | Mask mAP
:--------- | :----------- | :--------------------------------------- | -------:
ResNet-50 | Default | `deep_mask_head_rcnn_voc_r50.yaml` | 25.9
ResNet-50 | Hourglass-52 | `deep_mask_head_rcnn_voc_r50_hg52.yaml` | 33.1
ResNet-101 | Hourglass-52 | `deep_mask_head_rcnn_voc_r101_hg52.yaml` | 34.4
## See also
* [DeepMAC model](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/deepmac.md)
in the Object Detection API code base.
* Project website - [git.io/deepmac](https://git.io/deepmac)
## Citation
```
@misc{birodkar2021surprising,
title={The surprising impact of mask-head architecture on novel class segmentation},
author={Vighnesh Birodkar and Zhichao Lu and Siyang Li and Vivek Rathod and Jonathan Huang},
year={2021},
eprint={2104.00613},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Imports to configure Mask R-CNN with deep mask heads."""
# pylint: disable=unused-import
from official.vision.beta.projects.deepmac_maskrcnn.tasks import deep_mask_head_rcnn
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for Mask R-CNN with deep mask heads."""
import os
from typing import Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.vision.beta.configs import maskrcnn as maskrcnn_config
from official.vision.beta.configs import retinanet as retinanet_config
@dataclasses.dataclass
class DeepMaskHead(maskrcnn_config.MaskHead):
convnet_variant: str = 'default'
@dataclasses.dataclass
class DeepMaskHeadRCNN(maskrcnn_config.MaskRCNN):
mask_head: Optional[DeepMaskHead] = DeepMaskHead()
use_gt_boxes_for_masks: bool = False
@dataclasses.dataclass
class DeepMaskHeadRCNNTask(maskrcnn_config.MaskRCNNTask):
"""Configuration for the deep mask head R-CNN task."""
model: DeepMaskHeadRCNN = DeepMaskHeadRCNN()
@exp_factory.register_config_factory('deep_mask_head_rcnn_resnetfpn_coco')
def deep_mask_head_rcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
"""COCO object detection with Mask R-CNN with deep mask heads."""
global_batch_size = 64
steps_per_epoch = int(retinanet_config.COCO_TRAIN_EXAMPLES /
global_batch_size)
coco_val_samples = 5000
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=DeepMaskHeadRCNNTask(
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080',
init_checkpoint_modules='backbone',
annotation_file=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
model=DeepMaskHeadRCNN(
num_classes=91,
input_size=[1024, 1024, 3],
include_mask=True), # pytype: disable=wrong-keyword-args
losses=maskrcnn_config.Losses(l2_weight_decay=0.00004),
train_data=maskrcnn_config.DataConfig(
input_path=os.path.join(
maskrcnn_config.COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=global_batch_size,
parser=maskrcnn_config.Parser(
aug_rand_hflip=True, aug_scale_min=0.8, aug_scale_max=1.25)),
validation_data=maskrcnn_config.DataConfig(
input_path=os.path.join(
maskrcnn_config.COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=8)), # pytype: disable=wrong-keyword-args
trainer=cfg.TrainerConfig(
train_steps=22500,
validation_steps=coco_val_samples // 8,
validation_interval=steps_per_epoch,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [15000, 20000],
'values': [0.12, 0.012, 0.0012],
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 500,
'warmup_learning_rate': 0.0067
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Check that the config is set correctly."""
import tensorflow as tf
from official.vision.beta.projects.deepmac_maskrcnn.configs import deep_mask_head_rcnn
class DeepMaskHeadRcnnConfigTest(tf.test.TestCase):
def test_config(self):
config = deep_mask_head_rcnn.deep_mask_head_rcnn_resnetfpn_coco()
self.assertIsInstance(config.task, deep_mask_head_rcnn.DeepMaskHeadRCNNTask)
if __name__ == '__main__':
tf.test.main()
task:
# VOC class taken from
# models/official/vision/detection/utils/class_utils.py
allowed_mask_class_ids: [1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72]
per_category_metrics: true
model:
mask_head:
class_agnostic: true
convnet_variant: 'hourglass52'
num_filters: 64
mask_roi_aligner:
crop_size: 32
use_gt_boxes_for_masks: true
backbone:
type: 'resnet'
resnet:
model_id: 101
init_checkpoint: 'gs://tf_model_garden/official/resnet101_imagenet/ckpt-62400'
train_data:
global_batch_size: 64
validation_data:
global_batch_size: 32
trainer:
optimizer_config:
learning_rate:
stepwise:
boundaries: [50000, 65000]
type: 'stepwise'
train_steps: 70000
validation_steps: 156 # 5000 / 32
task:
# VOC class taken from
# models/official/vision/detection/utils/class_utils.py
allowed_mask_class_ids: [1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72]
per_category_metrics: true
model:
mask_head:
class_agnostic: true
use_gt_boxes_for_masks: true
backbone:
type: 'resnet'
resnet:
model_id: 50
init_checkpoint: 'gs://tf_model_garden/official/resnet50_imagenet/ckpt-28080'
train_data:
global_batch_size: 64
validation_data:
global_batch_size: 32
trainer:
optimizer_config:
learning_rate:
stepwise:
boundaries: [50000, 65000]
type: 'stepwise'
train_steps: 70000
validation_steps: 156 # 5000 / 32
task:
# VOC class taken from
# models/official/vision/detection/utils/class_utils.py
allowed_mask_class_ids: [1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72]
per_category_metrics: true
model:
mask_head:
class_agnostic: true
convnet_variant: 'hourglass52'
num_filters: 64
mask_roi_aligner:
crop_size: 32
use_gt_boxes_for_masks: true
backbone:
type: 'resnet'
resnet:
model_id: 50
init_checkpoint: 'gs://tf_model_garden/official/resnet50_imagenet/ckpt-28080'
train_data:
global_batch_size: 64
validation_data:
global_batch_size: 32
trainer:
optimizer_config:
learning_rate:
stepwise:
boundaries: [50000, 65000]
type: 'stepwise'
train_steps: 70000
validation_steps: 156 # 5000 / 32
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Instance prediction heads."""
# Import libraries
from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.projects.deepmac_maskrcnn.modeling.heads import hourglass_network
@tf.keras.utils.register_keras_serializable(package='Vision')
class DeepMaskHead(tf.keras.layers.Layer):
"""Creates a mask head."""
def __init__(self,
num_classes,
upsample_factor=2,
num_convs=4,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
class_agnostic=False,
convnet_variant='default',
**kwargs):
"""Initializes a mask head.
Args:
num_classes: An `int` of the number of classes.
upsample_factor: An `int` that indicates the upsample factor to generate
the final predicted masks. It should be >= 1.
num_convs: An `int` number that represents the number of the intermediate
convolution layers before the mask prediction layers.
num_filters: An `int` number that represents the number of filters of the
intermediate convolution layers.
use_separable_conv: A `bool` that indicates whether the separable
convolution layers is used.
activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: A `bool` that indicates whether to use synchronized batch
normalization across different replicas.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default is None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
class_agnostic: A `bool`. If set, we use a single channel mask head that
is shared between all classes.
convnet_variant: A `str` denoting the architecture of network used in the
head. Supported options are 'default', 'hourglass20', 'hourglass52'
and 'hourglass100'.
**kwargs: Additional keyword arguments to be passed.
"""
super(DeepMaskHead, self).__init__(**kwargs)
self._config_dict = {
'num_classes': num_classes,
'upsample_factor': upsample_factor,
'num_convs': num_convs,
'num_filters': num_filters,
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'class_agnostic': class_agnostic,
'convnet_variant': convnet_variant,
}
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def _get_conv_op_and_kwargs(self):
conv_op = (tf.keras.layers.SeparableConv2D
if self._config_dict['use_separable_conv']
else tf.keras.layers.Conv2D)
conv_kwargs = {
'filters': self._config_dict['num_filters'],
'kernel_size': 3,
'padding': 'same',
}
if self._config_dict['use_separable_conv']:
conv_kwargs.update({
'depthwise_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'pointwise_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'depthwise_regularizer': self._config_dict['kernel_regularizer'],
'pointwise_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
else:
conv_kwargs.update({
'kernel_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
return conv_op, conv_kwargs
def _get_bn_op_and_kwargs(self):
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
}
return bn_op, bn_kwargs
def build(self, input_shape):
"""Creates the variables of the head."""
conv_op, conv_kwargs = self._get_conv_op_and_kwargs()
self._build_convnet_variant()
self._deconv = tf.keras.layers.Conv2DTranspose(
filters=self._config_dict['num_filters'],
kernel_size=self._config_dict['upsample_factor'],
strides=self._config_dict['upsample_factor'],
padding='valid',
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name='mask-upsampling')
bn_op, bn_kwargs = self._get_bn_op_and_kwargs()
self._deconv_bn = bn_op(name='mask-deconv-bn', **bn_kwargs)
if self._config_dict['class_agnostic']:
num_filters = 1
else:
num_filters = self._config_dict['num_classes']
conv_kwargs = {
'filters': num_filters,
'kernel_size': 1,
'padding': 'valid',
}
if self._config_dict['use_separable_conv']:
conv_kwargs.update({
'depthwise_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'pointwise_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'depthwise_regularizer': self._config_dict['kernel_regularizer'],
'pointwise_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
else:
conv_kwargs.update({
'kernel_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
self._mask_regressor = conv_op(name='mask-logits', **conv_kwargs)
super(DeepMaskHead, self).build(input_shape)
def call(self, inputs, training=None):
"""Forward pass of mask branch for the Mask-RCNN model.
Args:
inputs: A `list` of two tensors where
inputs[0]: A `tf.Tensor` of shape [batch_size, num_instances,
roi_height, roi_width, roi_channels], representing the ROI features.
inputs[1]: A `tf.Tensor` of shape [batch_size, num_instances],
representing the classes of the ROIs.
training: A `bool` indicating whether it is in `training` mode.
Returns:
mask_outputs: A `tf.Tensor` of shape
[batch_size, num_instances, roi_height * upsample_factor,
roi_width * upsample_factor], representing the mask predictions.
"""
roi_features, roi_classes = inputs
batch_size, num_rois, height, width, filters = (
roi_features.get_shape().as_list())
if batch_size is None:
batch_size = tf.shape(roi_features)[0]
x = tf.reshape(roi_features, [-1, height, width, filters])
x = self._call_convnet_variant(x)
x = self._deconv(x)
x = self._deconv_bn(x)
x = self._activation(x)
logits = self._mask_regressor(x)
mask_height = height * self._config_dict['upsample_factor']
mask_width = width * self._config_dict['upsample_factor']
if self._config_dict['class_agnostic']:
logits = tf.reshape(logits, [-1, num_rois, mask_height, mask_width, 1])
else:
logits = tf.reshape(
logits,
[-1, num_rois, mask_height, mask_width,
self._config_dict['num_classes']])
batch_indices = tf.tile(
tf.expand_dims(tf.range(batch_size), axis=1), [1, num_rois])
mask_indices = tf.tile(
tf.expand_dims(tf.range(num_rois), axis=0), [batch_size, 1])
if self._config_dict['class_agnostic']:
class_gather_indices = tf.zeros_like(roi_classes, dtype=tf.int32)
else:
class_gather_indices = tf.cast(roi_classes, dtype=tf.int32)
gather_indices = tf.stack(
[batch_indices, mask_indices, class_gather_indices],
axis=2)
mask_outputs = tf.gather_nd(
tf.transpose(logits, [0, 1, 4, 2, 3]), gather_indices)
return mask_outputs
def _build_convnet_variant(self):
variant = self._config_dict['convnet_variant']
if variant == 'default':
conv_op, conv_kwargs = self._get_conv_op_and_kwargs()
bn_op, bn_kwargs = self._get_bn_op_and_kwargs()
self._convs = []
self._conv_norms = []
for i in range(self._config_dict['num_convs']):
conv_name = 'mask-conv_{}'.format(i)
self._convs.append(conv_op(name=conv_name, **conv_kwargs))
bn_name = 'mask-conv-bn_{}'.format(i)
self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs))
elif variant == 'hourglass20':
logging.info('Using hourglass 20 network.')
self._hourglass = hourglass_network.hourglass_20(
self._config_dict['num_filters'], initial_downsample=False)
elif variant == 'hourglass52':
logging.info('Using hourglass 52 network.')
self._hourglass = hourglass_network.hourglass_52(
self._config_dict['num_filters'], initial_downsample=False)
elif variant == 'hourglass100':
logging.info('Using hourglass 100 network.')
self._hourglass = hourglass_network.hourglass_100(
self._config_dict['num_filters'], initial_downsample=False)
else:
raise ValueError('Unknown ConvNet variant - {}'.format(variant))
def _call_convnet_variant(self, x):
variant = self._config_dict['convnet_variant']
if variant == 'default':
for conv, bn in zip(self._convs, self._conv_norms):
x = conv(x)
x = bn(x)
x = self._activation(x)
return x
elif variant == 'hourglass20':
return self._hourglass(x)[-1]
elif variant == 'hourglass52':
return self._hourglass(x)[-1]
elif variant == 'hourglass100':
return self._hourglass(x)[-1]
else:
raise ValueError('Unknown ConvNet variant - {}'.format(variant))
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for instance_heads.py."""
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.deepmac_maskrcnn.modeling.heads import instance_heads as deep_instance_heads
class MaskHeadTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(1, 1, False),
(1, 2, False),
(2, 1, False),
(2, 2, False),
)
def test_forward(self, upsample_factor, num_convs, use_sync_bn):
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=3,
upsample_factor=upsample_factor,
num_convs=num_convs,
num_filters=16,
use_separable_conv=False,
activation='relu',
use_sync_bn=use_sync_bn,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
)
roi_features = np.random.rand(2, 10, 14, 14, 16)
roi_classes = np.zeros((2, 10))
masks = mask_head([roi_features, roi_classes])
self.assertAllEqual(
masks.numpy().shape,
[2, 10, 14 * upsample_factor, 14 * upsample_factor])
def test_serialize_deserialize(self):
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=3,
upsample_factor=2,
num_convs=1,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
)
config = mask_head.get_config()
new_mask_head = deep_instance_heads.DeepMaskHead.from_config(config)
self.assertAllEqual(
mask_head.get_config(), new_mask_head.get_config())
def test_forward_class_agnostic(self):
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=3,
class_agnostic=True
)
roi_features = np.random.rand(2, 10, 14, 14, 16)
roi_classes = np.zeros((2, 10))
masks = mask_head([roi_features, roi_classes])
self.assertAllEqual(masks.numpy().shape, [2, 10, 28, 28])
def test_instance_head_hourglass(self):
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=3,
class_agnostic=True,
convnet_variant='hourglass20',
num_filters=32,
upsample_factor=2
)
roi_features = np.random.rand(2, 10, 16, 16, 16)
roi_classes = np.zeros((2, 10))
masks = mask_head([roi_features, roi_classes])
self.assertAllEqual(masks.numpy().shape, [2, 10, 32, 32])
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask R-CNN model."""
# Import libraries
from absl import logging
import tensorflow as tf
from official.vision.beta.ops import box_ops
def resize_as(source, size):
source = tf.transpose(source, (0, 2, 3, 1))
source = tf.image.resize(source, (size, size))
return tf.transpose(source, (0, 3, 1, 2))
@tf.keras.utils.register_keras_serializable(package='Vision')
class DeepMaskRCNNModel(tf.keras.Model):
"""The Mask R-CNN model."""
def __init__(self,
backbone,
decoder,
rpn_head,
detection_head,
roi_generator,
roi_sampler,
roi_aligner,
detection_generator,
mask_head=None,
mask_sampler=None,
mask_roi_aligner=None,
use_gt_boxes_for_masks=False,
**kwargs):
"""Initializes the Mask R-CNN model.
Args:
backbone: `tf.keras.Model`, the backbone network.
decoder: `tf.keras.Model`, the decoder network.
rpn_head: the RPN head.
detection_head: the detection head.
roi_generator: the ROI generator.
roi_sampler: the ROI sampler.
roi_aligner: the ROI aligner.
detection_generator: the detection generator.
mask_head: the mask head.
mask_sampler: the mask sampler.
mask_roi_aligner: the ROI alginer for mask prediction.
use_gt_boxes_for_masks: bool, if set, crop using groundtruth boxes
instead of proposals for training mask head
**kwargs: keyword arguments to be passed.
"""
super(DeepMaskRCNNModel, self).__init__(**kwargs)
self._config_dict = {
'backbone': backbone,
'decoder': decoder,
'rpn_head': rpn_head,
'detection_head': detection_head,
'roi_generator': roi_generator,
'roi_sampler': roi_sampler,
'roi_aligner': roi_aligner,
'detection_generator': detection_generator,
'mask_head': mask_head,
'mask_sampler': mask_sampler,
'mask_roi_aligner': mask_roi_aligner,
'use_gt_boxes_for_masks': use_gt_boxes_for_masks
}
self.backbone = backbone
self.decoder = decoder
self.rpn_head = rpn_head
self.detection_head = detection_head
self.roi_generator = roi_generator
self.roi_sampler = roi_sampler
self.roi_aligner = roi_aligner
self.detection_generator = detection_generator
self._include_mask = mask_head is not None
self.mask_head = mask_head
if self._include_mask and mask_sampler is None:
raise ValueError('`mask_sampler` is not provided in Mask R-CNN.')
self.mask_sampler = mask_sampler
if self._include_mask and mask_roi_aligner is None:
raise ValueError('`mask_roi_aligner` is not provided in Mask R-CNN.')
self.mask_roi_aligner = mask_roi_aligner
def call(self,
images,
image_shape,
anchor_boxes=None,
gt_boxes=None,
gt_classes=None,
gt_masks=None,
training=None):
model_outputs = {}
# Feature extraction.
features = self.backbone(images)
if self.decoder:
features = self.decoder(features)
# Region proposal network.
rpn_scores, rpn_boxes = self.rpn_head(features)
model_outputs.update({
'rpn_boxes': rpn_boxes,
'rpn_scores': rpn_scores
})
# Generate RoIs.
rois, _ = self.roi_generator(
rpn_boxes, rpn_scores, anchor_boxes, image_shape, training)
if training:
rois = tf.stop_gradient(rois)
rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
self.roi_sampler(rois, gt_boxes, gt_classes))
# Assign target for the 2nd stage classification.
box_targets = box_ops.encode_boxes(
matched_gt_boxes, rois, weights=[10.0, 10.0, 5.0, 5.0])
# If the target is background, the box target is set to all 0s.
box_targets = tf.where(
tf.tile(
tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
[1, 1, 4]),
tf.zeros_like(box_targets),
box_targets)
model_outputs.update({
'class_targets': matched_gt_classes,
'box_targets': box_targets,
})
# RoI align.
roi_features = self.roi_aligner(features, rois)
# Detection head.
raw_scores, raw_boxes = self.detection_head(roi_features)
if training:
model_outputs.update({
'class_outputs': raw_scores,
'box_outputs': raw_boxes,
})
else:
# Post-processing.
detections = self.detection_generator(
raw_boxes, raw_scores, rois, image_shape)
model_outputs.update({
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
'detection_classes': detections['detection_classes'],
'num_detections': detections['num_detections'],
})
if not self._include_mask:
return model_outputs
if training:
if self._config_dict['use_gt_boxes_for_masks']:
mask_size = (
self.mask_roi_aligner._config_dict['crop_size'] * # pylint:disable=protected-access
self.mask_head._config_dict['upsample_factor'] # pylint:disable=protected-access
)
gt_masks = resize_as(source=gt_masks, size=mask_size)
logging.info('Using GT class and mask targets.')
model_outputs.update({
'mask_class_targets': gt_classes,
'mask_targets': gt_masks,
})
else:
rois, roi_classes, roi_masks = self.mask_sampler(
rois,
matched_gt_boxes,
matched_gt_classes,
matched_gt_indices,
gt_masks)
roi_masks = tf.stop_gradient(roi_masks)
model_outputs.update({
'mask_class_targets': roi_classes,
'mask_targets': roi_masks,
})
else:
rois = model_outputs['detection_boxes']
roi_classes = model_outputs['detection_classes']
# Mask RoI align.
if training and self._config_dict['use_gt_boxes_for_masks']:
logging.info('Using GT mask roi features.')
mask_roi_features = self.mask_roi_aligner(features, gt_boxes)
raw_masks = self.mask_head([mask_roi_features, gt_classes])
else:
mask_roi_features = self.mask_roi_aligner(features, rois)
raw_masks = self.mask_head([mask_roi_features, roi_classes])
# Mask head.
if training:
model_outputs.update({
'mask_outputs': raw_masks,
})
else:
model_outputs.update({
'detection_masks': tf.math.sigmoid(raw_masks),
})
return model_outputs
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(
backbone=self.backbone,
rpn_head=self.rpn_head,
detection_head=self.detection_head)
if self.decoder is not None:
items.update(decoder=self.decoder)
if self._include_mask:
items.update(mask_head=self.mask_head)
return items
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for maskrcnn_model.py."""
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.modeling.backbones import resnet
from official.vision.beta.modeling.decoders import fpn
from official.vision.beta.modeling.heads import dense_prediction_heads
from official.vision.beta.modeling.heads import instance_heads
from official.vision.beta.modeling.layers import detection_generator
from official.vision.beta.modeling.layers import mask_sampler
from official.vision.beta.modeling.layers import roi_aligner
from official.vision.beta.modeling.layers import roi_generator
from official.vision.beta.modeling.layers import roi_sampler
from official.vision.beta.ops import anchor
from official.vision.beta.projects.deepmac_maskrcnn.modeling import maskrcnn_model
from official.vision.beta.projects.deepmac_maskrcnn.modeling.heads import instance_heads as deep_instance_heads
class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(False, False,),
(False, True,),
(True, False,),
(True, True,),
)
def test_forward(self, use_gt_boxes_for_masks, training):
num_classes = 3
min_level = 3
max_level = 4
num_scales = 3
aspect_ratios = [1.0]
image_size = (256, 256)
images = np.random.rand(2, image_size[0], image_size[1], 3)
image_shape = np.array([[224, 100], [100, 224]])
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=3,
image_size=image_size).multilevel_boxes
num_anchors_per_location = len(aspect_ratios) * num_scales
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=min_level,
max_level=max_level,
input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=min_level,
max_level=max_level,
num_anchors_per_location=num_anchors_per_location)
detection_head = instance_heads.DetectionHead(
num_classes=num_classes)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
model = maskrcnn_model.DeepMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
use_gt_boxes_for_masks=use_gt_boxes_for_masks)
gt_boxes = tf.zeros((2, 16, 4), dtype=tf.float32)
gt_masks = tf.zeros((2, 16, 32, 32))
gt_classes = tf.zeros((2, 16), dtype=tf.int32)
results = model(images,
image_shape,
anchor_boxes,
gt_boxes,
gt_classes,
gt_masks,
training=training)
self.assertIn('rpn_boxes', results)
self.assertIn('rpn_scores', results)
if training:
self.assertIn('class_targets', results)
self.assertIn('box_targets', results)
self.assertIn('class_outputs', results)
self.assertIn('box_outputs', results)
self.assertIn('mask_outputs', results)
self.assertEqual(results['mask_targets'].shape,
results['mask_outputs'].shape)
else:
self.assertIn('detection_boxes', results)
self.assertIn('detection_scores', results)
self.assertIn('detection_classes', results)
self.assertIn('num_detections', results)
self.assertIn('detection_masks', results)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask R-CNN variant with support for deep mask heads."""
import tensorflow as tf
from official.core import task_factory
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import dense_prediction_heads
from official.vision.beta.modeling.heads import instance_heads
from official.vision.beta.modeling.layers import detection_generator
from official.vision.beta.modeling.layers import mask_sampler
from official.vision.beta.modeling.layers import roi_aligner
from official.vision.beta.modeling.layers import roi_generator
from official.vision.beta.modeling.layers import roi_sampler
from official.vision.beta.projects.deepmac_maskrcnn.configs import deep_mask_head_rcnn as deep_mask_head_rcnn_config
from official.vision.beta.projects.deepmac_maskrcnn.modeling import maskrcnn_model as deep_maskrcnn_model
from official.vision.beta.projects.deepmac_maskrcnn.modeling.heads import instance_heads as deep_instance_heads
from official.vision.beta.tasks import maskrcnn
# Taken from modeling/factory.py
def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
model_config: deep_mask_head_rcnn_config.DeepMaskHeadRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds Mask R-CNN model."""
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder(
input_specs=backbone.output_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
rpn_head_config = model_config.rpn_head
roi_generator_config = model_config.roi_generator
roi_sampler_config = model_config.roi_sampler
roi_aligner_config = model_config.roi_aligner
detection_head_config = model_config.detection_head
generator_config = model_config.detection_generator
norm_activation_config = model_config.norm_activation
num_anchors_per_location = (
len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)
rpn_head = dense_prediction_heads.RPNHead(
min_level=model_config.min_level,
max_level=model_config.max_level,
num_anchors_per_location=num_anchors_per_location,
num_convs=rpn_head_config.num_convs,
num_filters=rpn_head_config.num_filters,
use_separable_conv=rpn_head_config.use_separable_conv,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
detection_head = instance_heads.DetectionHead(
num_classes=model_config.num_classes,
num_convs=detection_head_config.num_convs,
num_filters=detection_head_config.num_filters,
use_separable_conv=detection_head_config.use_separable_conv,
num_fcs=detection_head_config.num_fcs,
fc_dims=detection_head_config.fc_dims,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
roi_generator_obj = roi_generator.MultilevelROIGenerator(
pre_nms_top_k=roi_generator_config.pre_nms_top_k,
pre_nms_score_threshold=roi_generator_config.pre_nms_score_threshold,
pre_nms_min_size_threshold=(
roi_generator_config.pre_nms_min_size_threshold),
nms_iou_threshold=roi_generator_config.nms_iou_threshold,
num_proposals=roi_generator_config.num_proposals,
test_pre_nms_top_k=roi_generator_config.test_pre_nms_top_k,
test_pre_nms_score_threshold=(
roi_generator_config.test_pre_nms_score_threshold),
test_pre_nms_min_size_threshold=(
roi_generator_config.test_pre_nms_min_size_threshold),
test_nms_iou_threshold=roi_generator_config.test_nms_iou_threshold,
test_num_proposals=roi_generator_config.test_num_proposals,
use_batched_nms=roi_generator_config.use_batched_nms)
roi_sampler_obj = roi_sampler.ROISampler(
mix_gt_boxes=roi_sampler_config.mix_gt_boxes,
num_sampled_rois=roi_sampler_config.num_sampled_rois,
foreground_fraction=roi_sampler_config.foreground_fraction,
foreground_iou_threshold=roi_sampler_config.foreground_iou_threshold,
background_iou_high_threshold=(
roi_sampler_config.background_iou_high_threshold),
background_iou_low_threshold=(
roi_sampler_config.background_iou_low_threshold))
roi_aligner_obj = roi_aligner.MultilevelROIAligner(
crop_size=roi_aligner_config.crop_size,
sample_offset=roi_aligner_config.sample_offset)
detection_generator_obj = detection_generator.DetectionGenerator(
apply_nms=True,
pre_nms_top_k=generator_config.pre_nms_top_k,
pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold,
max_num_detections=generator_config.max_num_detections,
use_batched_nms=generator_config.use_batched_nms)
if model_config.include_mask:
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=model_config.num_classes,
upsample_factor=model_config.mask_head.upsample_factor,
num_convs=model_config.mask_head.num_convs,
num_filters=model_config.mask_head.num_filters,
use_separable_conv=model_config.mask_head.use_separable_conv,
activation=model_config.norm_activation.activation,
norm_momentum=model_config.norm_activation.norm_momentum,
norm_epsilon=model_config.norm_activation.norm_epsilon,
kernel_regularizer=l2_regularizer,
class_agnostic=model_config.mask_head.class_agnostic,
convnet_variant=model_config.mask_head.convnet_variant)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=(
model_config.mask_roi_aligner.crop_size *
model_config.mask_head.upsample_factor),
num_sampled_masks=model_config.mask_sampler.num_sampled_masks)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(
crop_size=model_config.mask_roi_aligner.crop_size,
sample_offset=model_config.mask_roi_aligner.sample_offset)
else:
mask_head = None
mask_sampler_obj = None
mask_roi_aligner_obj = None
model = deep_maskrcnn_model.DeepMaskRCNNModel(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=detection_head,
roi_generator=roi_generator_obj,
roi_sampler=roi_sampler_obj,
roi_aligner=roi_aligner_obj,
detection_generator=detection_generator_obj,
mask_head=mask_head,
mask_sampler=mask_sampler_obj,
mask_roi_aligner=mask_roi_aligner_obj,
use_gt_boxes_for_masks=model_config.use_gt_boxes_for_masks)
return model
@task_factory.register_task_cls(deep_mask_head_rcnn_config.DeepMaskHeadRCNNTask)
class DeepMaskHeadRCNNTask(maskrcnn.MaskRCNNTask):
"""Mask R-CNN with support for deep mask heads."""
def build_model(self):
"""Build Mask R-CNN model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = build_maskrcnn(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
return model
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver."""
from absl import app
from absl import flags
from absl import logging
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# pylint: disable=unused-import
from official.vision.beta.projects.deepmac_maskrcnn.common import registry_imports
# pylint: enable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
logging.info('Training with task %s', task)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment