Unverified Commit 440e0eec authored by Stephen Wu's avatar Stephen Wu Committed by GitHub
Browse files

Merge branch 'master' into RTESuperGLUE

parents 51364cdf 9815ea67
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFDS detection decoders."""
import tensorflow as tf
from official.vision.beta.dataloaders import decoder
class MSCOCODecoder(decoder.Decoder):
"""A tf.Example decoder for tfds coco datasets."""
def decode(self, serialized_example):
"""Decode the serialized example.
Args:
serialized_example: a dictonary example produced by tfds.
Returns:
decoded_tensors: a dictionary of tensors with the following fields:
- source_id: a string scalar tensor.
- image: a uint8 tensor of shape [None, None, 3].
- height: an integer scalar tensor.
- width: an integer scalar tensor.
- groundtruth_classes: a int64 tensor of shape [None].
- groundtruth_is_crowd: a bool tensor of shape [None].
- groundtruth_area: a float32 tensor of shape [None].
- groundtruth_boxes: a float32 tensor of shape [None, 4].
"""
decoded_tensors = {
'source_id': tf.strings.as_string(serialized_example['image/id']),
'image': serialized_example['image'],
'height': tf.cast(tf.shape(serialized_example['image'])[0], tf.int64),
'width': tf.cast(tf.shape(serialized_example['image'])[1], tf.int64),
'groundtruth_classes': serialized_example['objects']['label'],
'groundtruth_is_crowd': serialized_example['objects']['is_crowd'],
'groundtruth_area': tf.cast(
serialized_example['objects']['area'], tf.float32),
'groundtruth_boxes': serialized_example['objects']['bbox'],
}
return decoded_tensors
TFDS_ID_TO_DECODER_MAP = {
'coco/2017': MSCOCODecoder,
'coco/2014': MSCOCODecoder,
'coco': MSCOCODecoder
}
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFDS Semantic Segmentation decoders."""
import tensorflow as tf
from official.vision.beta.dataloaders import decoder
class CityScapesDecorder(decoder.Decoder):
"""A tf.Example decoder for tfds cityscapes datasets."""
def __init__(self):
# Original labels to trainable labels map, 255 is the ignore class.
self._label_map = {
-1: 255,
0: 255,
1: 255,
2: 255,
3: 255,
4: 255,
5: 255,
6: 255,
7: 0,
8: 1,
9: 255,
10: 255,
11: 2,
12: 3,
13: 4,
14: 255,
15: 255,
16: 255,
17: 5,
18: 255,
19: 6,
20: 7,
21: 8,
22: 9,
23: 10,
24: 11,
25: 12,
26: 13,
27: 14,
28: 15,
29: 255,
30: 255,
31: 16,
32: 17,
33: 18,
}
def decode(self, serialized_example):
# Convert labels according to the self._label_map
label = serialized_example['segmentation_label']
for original_label in self._label_map:
label = tf.where(label == original_label,
self._label_map[original_label] * tf.ones_like(label),
label)
sample_dict = {
'image/encoded':
tf.io.encode_jpeg(serialized_example['image_left'], quality=100),
'image/height': serialized_example['image_left'].shape[0],
'image/width': serialized_example['image_left'].shape[1],
'image/segmentation/class/encoded':
tf.io.encode_png(label),
}
return sample_dict
TFDS_ID_TO_DECODER_MAP = {
'cityscapes': CityScapesDecorder,
'cityscapes/semantic_segmentation': CityScapesDecorder,
'cityscapes/semantic_segmentation_extra': CityScapesDecorder,
}
...@@ -44,7 +44,7 @@ class MeanIoU(tf.keras.metrics.MeanIoU): ...@@ -44,7 +44,7 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
num_classes=num_classes, name=name, dtype=dtype) num_classes=num_classes, name=name, dtype=dtype)
def update_state(self, y_true, y_pred): def update_state(self, y_true, y_pred):
"""Updates metic state. """Updates metric state.
Args: Args:
y_true: `dict`, dictionary with the following name, and key values. y_true: `dict`, dictionary with the following name, and key values.
...@@ -122,4 +122,3 @@ class MeanIoU(tf.keras.metrics.MeanIoU): ...@@ -122,4 +122,3 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
super(MeanIoU, self).update_state( super(MeanIoU, self).update_state(
flatten_masks, flatten_predictions, flatten_masks, flatten_predictions,
tf.cast(flatten_valid_masks, tf.float32)) tf.cast(flatten_valid_masks, tf.float32))
...@@ -83,6 +83,7 @@ def build_video_classification_model( ...@@ -83,6 +83,7 @@ def build_video_classification_model(
num_classes: int, num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds the video classification model.""" """Builds the video classification model."""
input_specs_dict = {'image': input_specs}
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, model_config=model_config,
...@@ -91,7 +92,7 @@ def build_video_classification_model( ...@@ -91,7 +92,7 @@ def build_video_classification_model(
model = video_classification_model.VideoClassificationModel( model = video_classification_model.VideoClassificationModel(
backbone=backbone, backbone=backbone,
num_classes=num_classes, num_classes=num_classes,
input_specs=input_specs, input_specs=input_specs_dict,
dropout_rate=model_config.dropout_rate, dropout_rate=model_config.dropout_rate,
aggregate_endpoints=model_config.aggregate_endpoints, aggregate_endpoints=model_config.aggregate_endpoints,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
......
...@@ -74,6 +74,7 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -74,6 +74,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
out_filters, out_filters,
se_ratio, se_ratio,
divisible_by=1, divisible_by=1,
use_3d_input=False,
kernel_initializer='VarianceScaling', kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
...@@ -89,6 +90,7 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -89,6 +90,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
excitation layer. excitation layer.
divisible_by: `int` ensures all inner dimensions are divisible by this divisible_by: `int` ensures all inner dimensions are divisible by this
number. number.
use_3d_input: `bool` 2D image or 3D input type.
kernel_initializer: kernel_initializer for convolutional layers. kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None. Default to None.
...@@ -105,15 +107,22 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -105,15 +107,22 @@ class SqueezeExcitation(tf.keras.layers.Layer):
self._out_filters = out_filters self._out_filters = out_filters
self._se_ratio = se_ratio self._se_ratio = se_ratio
self._divisible_by = divisible_by self._divisible_by = divisible_by
self._use_3d_input = use_3d_input
self._activation = activation self._activation = activation
self._gating_activation = gating_activation self._gating_activation = gating_activation
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._spatial_axis = [1, 2] if not use_3d_input:
self._spatial_axis = [1, 2]
else:
self._spatial_axis = [1, 2, 3]
else: else:
self._spatial_axis = [2, 3] if not use_3d_input:
self._spatial_axis = [2, 3]
else:
self._spatial_axis = [2, 3, 4]
self._activation_fn = tf_utils.get_activation(activation) self._activation_fn = tf_utils.get_activation(activation)
self._gating_activation_fn = tf_utils.get_activation(gating_activation) self._gating_activation_fn = tf_utils.get_activation(gating_activation)
...@@ -150,6 +159,7 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -150,6 +159,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
'out_filters': self._out_filters, 'out_filters': self._out_filters,
'se_ratio': self._se_ratio, 'se_ratio': self._se_ratio,
'divisible_by': self._divisible_by, 'divisible_by': self._divisible_by,
'use_3d_input': self._use_3d_input,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer, 'bias_regularizer': self._bias_regularizer,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Build video classification models.""" """Build video classification models."""
# Import libraries from typing import Mapping
import tensorflow as tf import tensorflow as tf
layers = tf.keras.layers layers = tf.keras.layers
...@@ -24,11 +24,11 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -24,11 +24,11 @@ class VideoClassificationModel(tf.keras.Model):
"""A video classification class builder.""" """A video classification class builder."""
def __init__(self, def __init__(self,
backbone, backbone: tf.keras.Model,
num_classes, num_classes: int,
input_specs=layers.InputSpec(shape=[None, None, None, None, 3]), input_specs: Mapping[str, tf.keras.layers.InputSpec] = None,
dropout_rate=0.0, dropout_rate: float = 0.0,
aggregate_endpoints=False, aggregate_endpoints: bool = False,
kernel_initializer='random_uniform', kernel_initializer='random_uniform',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
...@@ -49,6 +49,10 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -49,6 +49,10 @@ class VideoClassificationModel(tf.keras.Model):
None. None.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
if not input_specs:
input_specs = {
'image': layers.InputSpec(shape=[None, None, None, None, 3])
}
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._config_dict = { self._config_dict = {
'backbone': backbone, 'backbone': backbone,
...@@ -65,8 +69,10 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -65,8 +69,10 @@ class VideoClassificationModel(tf.keras.Model):
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._backbone = backbone self._backbone = backbone
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = {
endpoints = backbone(inputs) k: tf.keras.Input(shape=v.shape[1:]) for k, v in input_specs.items()
}
endpoints = backbone(inputs['image'])
if aggregate_endpoints: if aggregate_endpoints:
pooled_feats = [] pooled_feats = []
......
...@@ -53,7 +53,7 @@ class VideoClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -53,7 +53,7 @@ class VideoClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
model = video_classification_model.VideoClassificationModel( model = video_classification_model.VideoClassificationModel(
backbone=backbone, backbone=backbone,
num_classes=num_classes, num_classes=num_classes,
input_specs=input_specs, input_specs={'image': input_specs},
dropout_rate=0.2, dropout_rate=0.2,
aggregate_endpoints=aggregate_endpoints, aggregate_endpoints=aggregate_endpoints,
) )
......
...@@ -55,7 +55,7 @@ class DetectionModule(export_base.ExportModule): ...@@ -55,7 +55,7 @@ class DetectionModule(export_base.ExportModule):
return self._model return self._model
def _build_inputs(self, image): def _build_inputs(self, image):
"""Builds classification model inputs for serving.""" """Builds detection model inputs for serving."""
model_params = self._params.task.model model_params = self._params.task.model
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image, image = preprocess_ops.normalize_image(image,
...@@ -89,7 +89,7 @@ class DetectionModule(export_base.ExportModule): ...@@ -89,7 +89,7 @@ class DetectionModule(export_base.ExportModule):
Args: Args:
images: uint8 Tensor of shape [batch_size, None, None, 3] images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns: Returns:
Tensor holding classification output logits. Tensor holding detection output logits.
""" """
model_params = self._params.task.model model_params = self._params.task.model
with tf.device('cpu:0'): with tf.device('cpu:0'):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Test for image classification export lib.""" """Test for image detection export lib."""
import io import io
import os import os
...@@ -41,7 +41,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -41,7 +41,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def _export_from_module(self, module, input_type, batch_size, save_directory): def _export_from_module(self, module, input_type, batch_size, save_directory):
if input_type == 'image_tensor': if input_type == 'image_tensor':
input_signature = tf.TensorSpec( input_signature = tf.TensorSpec(
shape=[batch_size, 640, 640, 3], dtype=tf.uint8) shape=[batch_size, None, None, 3], dtype=tf.uint8)
signatures = { signatures = {
'serving_default': 'serving_default':
module.inference_from_image_tensors.get_concrete_function( module.inference_from_image_tensors.get_concrete_function(
...@@ -68,18 +68,19 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -68,18 +68,19 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
save_directory, save_directory,
signatures=signatures) signatures=signatures)
def _get_dummy_input(self, input_type, batch_size): def _get_dummy_input(self, input_type, batch_size, image_size):
"""Get dummy input for the given input type.""" """Get dummy input for the given input type."""
h, w = image_size
if input_type == 'image_tensor': if input_type == 'image_tensor':
return tf.zeros((batch_size, 640, 640, 3), dtype=np.uint8) return tf.zeros((batch_size, h, w, 3), dtype=np.uint8)
elif input_type == 'image_bytes': elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((640, 640, 3), dtype=np.uint8)) image = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8))
byte_io = io.BytesIO() byte_io = io.BytesIO()
image.save(byte_io, 'PNG') image.save(byte_io, 'PNG')
return [byte_io.getvalue() for b in range(batch_size)] return [byte_io.getvalue() for b in range(batch_size)]
elif input_type == 'tf_example': elif input_type == 'tf_example':
image_tensor = tf.zeros((640, 640, 3), dtype=tf.uint8) image_tensor = tf.zeros((h, w, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy() encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example( example = tf.train.Example(
features=tf.train.Features( features=tf.train.Features(
...@@ -91,21 +92,23 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -91,21 +92,23 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
return [example for b in range(batch_size)] return [example for b in range(batch_size)]
@parameterized.parameters( @parameterized.parameters(
('image_tensor', 'fasterrcnn_resnetfpn_coco'), ('image_tensor', 'fasterrcnn_resnetfpn_coco', [384, 384]),
('image_bytes', 'fasterrcnn_resnetfpn_coco'), ('image_bytes', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('tf_example', 'fasterrcnn_resnetfpn_coco'), ('tf_example', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'maskrcnn_resnetfpn_coco'), ('image_tensor', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_bytes', 'maskrcnn_resnetfpn_coco'), ('image_bytes', 'maskrcnn_resnetfpn_coco', [640, 384]),
('tf_example', 'maskrcnn_resnetfpn_coco'), ('tf_example', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'retinanet_resnetfpn_coco'), ('image_tensor', 'retinanet_resnetfpn_coco', [640, 640]),
('image_bytes', 'retinanet_resnetfpn_coco'), ('image_bytes', 'retinanet_resnetfpn_coco', [640, 640]),
('tf_example', 'retinanet_resnetfpn_coco'), ('tf_example', 'retinanet_resnetfpn_coco', [384, 640]),
('image_tensor', 'retinanet_resnetfpn_coco', [384, 384]),
('image_bytes', 'retinanet_spinenet_coco', [640, 640]),
('tf_example', 'retinanet_spinenet_coco', [640, 384]),
) )
def test_export(self, input_type, experiment_name): def test_export(self, input_type, experiment_name, image_size):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
batch_size = 1 batch_size = 1
experiment_name = 'fasterrcnn_resnetfpn_coco'
module = self._get_detection_module(experiment_name) module = self._get_detection_module(experiment_name)
model = module.build_model() model = module.build_model()
...@@ -118,9 +121,9 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -118,9 +121,9 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001'))) os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir) imported = tf.saved_model.load(tmp_dir)
classification_fn = imported.signatures['serving_default'] detection_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type, batch_size) images = self._get_dummy_input(input_type, batch_size, image_size)
processed_images, anchor_boxes, image_shape = module._build_inputs( processed_images, anchor_boxes, image_shape = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8)) tf.zeros((224, 224, 3), dtype=tf.uint8))
...@@ -134,7 +137,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -134,7 +137,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
image_shape=image_shape, image_shape=image_shape,
anchor_boxes=anchor_boxes, anchor_boxes=anchor_boxes,
training=False) training=False)
outputs = classification_fn(tf.constant(images)) outputs = detection_fn(tf.constant(images))
self.assertAllClose(outputs['num_detections'].numpy(), self.assertAllClose(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy()) expected_outputs['num_detections'].numpy())
......
...@@ -73,7 +73,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -73,7 +73,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
_decode_image, _decode_image,
elems=input_tensor, elems=input_tensor,
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.uint8), shape=[None, None, 3], dtype=tf.uint8),
parallel_iterations=32)) parallel_iterations=32))
images = tf.stack(images) images = tf.stack(images)
return self._run_inference_on_image_tensors(images) return self._run_inference_on_image_tensors(images)
......
...@@ -16,13 +16,14 @@ ...@@ -16,13 +16,14 @@
"""Image classification task definition.""" """Image classification task definition."""
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task from official.core import base_task
from official.core import input_reader from official.core import input_reader
from official.core import task_factory from official.core import task_factory
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.configs import image_classification as exp_cfg from official.vision.beta.configs import image_classification as exp_cfg
from official.vision.beta.dataloaders import classification_input from official.vision.beta.dataloaders import classification_input
from official.vision.beta.dataloaders import dataset_fn from official.vision.beta.dataloaders import tfds_classification_decoders
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -67,7 +68,8 @@ class ImageClassificationTask(base_task.Task): ...@@ -67,7 +68,8 @@ class ImageClassificationTask(base_task.Task):
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
assert "Only 'all' or 'backbone' can be used to initialize the model." raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
...@@ -78,7 +80,15 @@ class ImageClassificationTask(base_task.Task): ...@@ -78,7 +80,15 @@ class ImageClassificationTask(base_task.Task):
num_classes = self.task_config.model.num_classes num_classes = self.task_config.model.num_classes
input_size = self.task_config.model.input_size input_size = self.task_config.model.input_size
decoder = classification_input.Decoder() if params.tfds_name:
if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else:
decoder = classification_input.Decoder()
parser = classification_input.Parser( parser = classification_input.Parser(
output_size=input_size[:2], output_size=input_size[:2],
num_classes=num_classes, num_classes=num_classes,
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task from official.core import base_task
from official.core import input_reader from official.core import input_reader
from official.core import task_factory from official.core import task_factory
from official.vision.beta.configs import maskrcnn as exp_cfg from official.vision.beta.configs import maskrcnn as exp_cfg
from official.vision.beta.dataloaders import maskrcnn_input from official.vision.beta.dataloaders import maskrcnn_input
from official.vision.beta.dataloaders import tf_example_decoder from official.vision.beta.dataloaders import tf_example_decoder
from official.vision.beta.dataloaders import dataset_fn
from official.vision.beta.dataloaders import tf_example_label_map_decoder from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.evaluation import coco_evaluator from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.losses import maskrcnn_losses from official.vision.beta.losses import maskrcnn_losses
...@@ -100,7 +100,8 @@ class MaskRCNNTask(base_task.Task): ...@@ -100,7 +100,8 @@ class MaskRCNNTask(base_task.Task):
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
assert "Only 'all' or 'backbone' can be used to initialize the model." raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment