Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
440e0eec
Unverified
Commit
440e0eec
authored
Feb 10, 2021
by
Stephen Wu
Committed by
GitHub
Feb 10, 2021
Browse files
Merge branch 'master' into RTESuperGLUE
parents
51364cdf
9815ea67
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
218 additions
and
42 deletions
+218
-42
official/vision/beta/dataloaders/tfds_detection_decoders.py
official/vision/beta/dataloaders/tfds_detection_decoders.py
+60
-0
official/vision/beta/dataloaders/tfds_segmentation_decoders.py
...ial/vision/beta/dataloaders/tfds_segmentation_decoders.py
+86
-0
official/vision/beta/evaluation/__init__.py
official/vision/beta/evaluation/__init__.py
+0
-0
official/vision/beta/evaluation/segmentation_metrics.py
official/vision/beta/evaluation/segmentation_metrics.py
+1
-2
official/vision/beta/losses/__init__.py
official/vision/beta/losses/__init__.py
+0
-0
official/vision/beta/modeling/__init__.py
official/vision/beta/modeling/__init__.py
+0
-0
official/vision/beta/modeling/factory_3d.py
official/vision/beta/modeling/factory_3d.py
+2
-1
official/vision/beta/modeling/heads/__init__.py
official/vision/beta/modeling/heads/__init__.py
+0
-0
official/vision/beta/modeling/layers/__init__.py
official/vision/beta/modeling/layers/__init__.py
+0
-0
official/vision/beta/modeling/layers/nn_layers.py
official/vision/beta/modeling/layers/nn_layers.py
+12
-2
official/vision/beta/modeling/video_classification_model.py
official/vision/beta/modeling/video_classification_model.py
+14
-8
official/vision/beta/modeling/video_classification_model_test.py
...l/vision/beta/modeling/video_classification_model_test.py
+1
-1
official/vision/beta/ops/__init__.py
official/vision/beta/ops/__init__.py
+0
-0
official/vision/beta/projects/__init__.py
official/vision/beta/projects/__init__.py
+0
-0
official/vision/beta/serving/__init__.py
official/vision/beta/serving/__init__.py
+0
-0
official/vision/beta/serving/detection.py
official/vision/beta/serving/detection.py
+2
-2
official/vision/beta/serving/detection_test.py
official/vision/beta/serving/detection_test.py
+23
-20
official/vision/beta/serving/export_base.py
official/vision/beta/serving/export_base.py
+1
-1
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+13
-3
official/vision/beta/tasks/maskrcnn.py
official/vision/beta/tasks/maskrcnn.py
+3
-2
No files found.
official/vision/beta/dataloaders/tfds_detection_decoders.py
0 → 100644
View file @
440e0eec
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFDS detection decoders."""
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
decoder
class
MSCOCODecoder
(
decoder
.
Decoder
):
"""A tf.Example decoder for tfds coco datasets."""
def
decode
(
self
,
serialized_example
):
"""Decode the serialized example.
Args:
serialized_example: a dictonary example produced by tfds.
Returns:
decoded_tensors: a dictionary of tensors with the following fields:
- source_id: a string scalar tensor.
- image: a uint8 tensor of shape [None, None, 3].
- height: an integer scalar tensor.
- width: an integer scalar tensor.
- groundtruth_classes: a int64 tensor of shape [None].
- groundtruth_is_crowd: a bool tensor of shape [None].
- groundtruth_area: a float32 tensor of shape [None].
- groundtruth_boxes: a float32 tensor of shape [None, 4].
"""
decoded_tensors
=
{
'source_id'
:
tf
.
strings
.
as_string
(
serialized_example
[
'image/id'
]),
'image'
:
serialized_example
[
'image'
],
'height'
:
tf
.
cast
(
tf
.
shape
(
serialized_example
[
'image'
])[
0
],
tf
.
int64
),
'width'
:
tf
.
cast
(
tf
.
shape
(
serialized_example
[
'image'
])[
1
],
tf
.
int64
),
'groundtruth_classes'
:
serialized_example
[
'objects'
][
'label'
],
'groundtruth_is_crowd'
:
serialized_example
[
'objects'
][
'is_crowd'
],
'groundtruth_area'
:
tf
.
cast
(
serialized_example
[
'objects'
][
'area'
],
tf
.
float32
),
'groundtruth_boxes'
:
serialized_example
[
'objects'
][
'bbox'
],
}
return
decoded_tensors
TFDS_ID_TO_DECODER_MAP
=
{
'coco/2017'
:
MSCOCODecoder
,
'coco/2014'
:
MSCOCODecoder
,
'coco'
:
MSCOCODecoder
}
official/vision/beta/dataloaders/tfds_segmentation_decoders.py
0 → 100644
View file @
440e0eec
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFDS Semantic Segmentation decoders."""
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
decoder
class
CityScapesDecorder
(
decoder
.
Decoder
):
"""A tf.Example decoder for tfds cityscapes datasets."""
def
__init__
(
self
):
# Original labels to trainable labels map, 255 is the ignore class.
self
.
_label_map
=
{
-
1
:
255
,
0
:
255
,
1
:
255
,
2
:
255
,
3
:
255
,
4
:
255
,
5
:
255
,
6
:
255
,
7
:
0
,
8
:
1
,
9
:
255
,
10
:
255
,
11
:
2
,
12
:
3
,
13
:
4
,
14
:
255
,
15
:
255
,
16
:
255
,
17
:
5
,
18
:
255
,
19
:
6
,
20
:
7
,
21
:
8
,
22
:
9
,
23
:
10
,
24
:
11
,
25
:
12
,
26
:
13
,
27
:
14
,
28
:
15
,
29
:
255
,
30
:
255
,
31
:
16
,
32
:
17
,
33
:
18
,
}
def
decode
(
self
,
serialized_example
):
# Convert labels according to the self._label_map
label
=
serialized_example
[
'segmentation_label'
]
for
original_label
in
self
.
_label_map
:
label
=
tf
.
where
(
label
==
original_label
,
self
.
_label_map
[
original_label
]
*
tf
.
ones_like
(
label
),
label
)
sample_dict
=
{
'image/encoded'
:
tf
.
io
.
encode_jpeg
(
serialized_example
[
'image_left'
],
quality
=
100
),
'image/height'
:
serialized_example
[
'image_left'
].
shape
[
0
],
'image/width'
:
serialized_example
[
'image_left'
].
shape
[
1
],
'image/segmentation/class/encoded'
:
tf
.
io
.
encode_png
(
label
),
}
return
sample_dict
TFDS_ID_TO_DECODER_MAP
=
{
'cityscapes'
:
CityScapesDecorder
,
'cityscapes/semantic_segmentation'
:
CityScapesDecorder
,
'cityscapes/semantic_segmentation_extra'
:
CityScapesDecorder
,
}
official/vision/beta/evaluation/__init__.py
0 → 100644
View file @
440e0eec
official/vision/beta/evaluation/segmentation_metrics.py
View file @
440e0eec
...
...
@@ -44,7 +44,7 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
num_classes
=
num_classes
,
name
=
name
,
dtype
=
dtype
)
def
update_state
(
self
,
y_true
,
y_pred
):
"""Updates metic state.
"""Updates met
r
ic state.
Args:
y_true: `dict`, dictionary with the following name, and key values.
...
...
@@ -122,4 +122,3 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
super
(
MeanIoU
,
self
).
update_state
(
flatten_masks
,
flatten_predictions
,
tf
.
cast
(
flatten_valid_masks
,
tf
.
float32
))
official/vision/beta/losses/__init__.py
0 → 100644
View file @
440e0eec
official/vision/beta/modeling/__init__.py
0 → 100644
View file @
440e0eec
official/vision/beta/modeling/factory_3d.py
View file @
440e0eec
...
...
@@ -83,6 +83,7 @@ def build_video_classification_model(
num_classes
:
int
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
"""Builds the video classification model."""
input_specs_dict
=
{
'image'
:
input_specs
}
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
model_config
=
model_config
,
...
...
@@ -91,7 +92,7 @@ def build_video_classification_model(
model
=
video_classification_model
.
VideoClassificationModel
(
backbone
=
backbone
,
num_classes
=
num_classes
,
input_specs
=
input_specs
,
input_specs
=
input_specs
_dict
,
dropout_rate
=
model_config
.
dropout_rate
,
aggregate_endpoints
=
model_config
.
aggregate_endpoints
,
kernel_regularizer
=
l2_regularizer
)
...
...
official/vision/beta/modeling/heads/__init__.py
0 → 100644
View file @
440e0eec
official/vision/beta/modeling/layers/__init__.py
0 → 100644
View file @
440e0eec
official/vision/beta/modeling/layers/nn_layers.py
View file @
440e0eec
...
...
@@ -74,6 +74,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
out_filters
,
se_ratio
,
divisible_by
=
1
,
use_3d_input
=
False
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
...
...
@@ -89,6 +90,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
excitation layer.
divisible_by: `int` ensures all inner dimensions are divisible by this
number.
use_3d_input: `bool` 2D image or 3D input type.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
...
...
@@ -105,15 +107,22 @@ class SqueezeExcitation(tf.keras.layers.Layer):
self
.
_out_filters
=
out_filters
self
.
_se_ratio
=
se_ratio
self
.
_divisible_by
=
divisible_by
self
.
_use_3d_input
=
use_3d_input
self
.
_activation
=
activation
self
.
_gating_activation
=
gating_activation
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_spatial_axis
=
[
1
,
2
]
if
not
use_3d_input
:
self
.
_spatial_axis
=
[
1
,
2
]
else
:
self
.
_spatial_axis
=
[
1
,
2
,
3
]
else
:
self
.
_spatial_axis
=
[
2
,
3
]
if
not
use_3d_input
:
self
.
_spatial_axis
=
[
2
,
3
]
else
:
self
.
_spatial_axis
=
[
2
,
3
,
4
]
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
self
.
_gating_activation_fn
=
tf_utils
.
get_activation
(
gating_activation
)
...
...
@@ -150,6 +159,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
'out_filters'
:
self
.
_out_filters
,
'se_ratio'
:
self
.
_se_ratio
,
'divisible_by'
:
self
.
_divisible_by
,
'use_3d_input'
:
self
.
_use_3d_input
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
...
...
official/vision/beta/modeling/video_classification_model.py
View file @
440e0eec
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Build video classification models."""
# Import libraries
from
typing
import
Mapping
import
tensorflow
as
tf
layers
=
tf
.
keras
.
layers
...
...
@@ -24,11 +24,11 @@ class VideoClassificationModel(tf.keras.Model):
"""A video classification class builder."""
def
__init__
(
self
,
backbone
,
num_classes
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
,
3
]),
dropout_rate
=
0.0
,
aggregate_endpoints
=
False
,
backbone
:
tf
.
keras
.
Model
,
num_classes
:
int
,
input_specs
:
Mapping
[
str
,
tf
.
keras
.
layers
.
InputSpec
]
=
None
,
dropout_rate
:
float
=
0.0
,
aggregate_endpoints
:
bool
=
False
,
kernel_initializer
=
'random_uniform'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
...
...
@@ -49,6 +49,10 @@ class VideoClassificationModel(tf.keras.Model):
None.
**kwargs: keyword arguments to be passed.
"""
if
not
input_specs
:
input_specs
=
{
'image'
:
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
,
3
])
}
self
.
_self_setattr_tracking
=
False
self
.
_config_dict
=
{
'backbone'
:
backbone
,
...
...
@@ -65,8 +69,10 @@ class VideoClassificationModel(tf.keras.Model):
self
.
_bias_regularizer
=
bias_regularizer
self
.
_backbone
=
backbone
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
endpoints
=
backbone
(
inputs
)
inputs
=
{
k
:
tf
.
keras
.
Input
(
shape
=
v
.
shape
[
1
:])
for
k
,
v
in
input_specs
.
items
()
}
endpoints
=
backbone
(
inputs
[
'image'
])
if
aggregate_endpoints
:
pooled_feats
=
[]
...
...
official/vision/beta/modeling/video_classification_model_test.py
View file @
440e0eec
...
...
@@ -53,7 +53,7 @@ class VideoClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
model
=
video_classification_model
.
VideoClassificationModel
(
backbone
=
backbone
,
num_classes
=
num_classes
,
input_specs
=
input_specs
,
input_specs
=
{
'image'
:
input_specs
}
,
dropout_rate
=
0.2
,
aggregate_endpoints
=
aggregate_endpoints
,
)
...
...
official/vision/beta/ops/__init__.py
0 → 100644
View file @
440e0eec
official/vision/beta/projects/__init__.py
0 → 100644
View file @
440e0eec
official/vision/beta/serving/__init__.py
0 → 100644
View file @
440e0eec
official/vision/beta/serving/detection.py
View file @
440e0eec
...
...
@@ -55,7 +55,7 @@ class DetectionModule(export_base.ExportModule):
return
self
.
_model
def
_build_inputs
(
self
,
image
):
"""Builds
classifica
tion model inputs for serving."""
"""Builds
detec
tion model inputs for serving."""
model_params
=
self
.
_params
.
task
.
model
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
...
...
@@ -89,7 +89,7 @@ class DetectionModule(export_base.ExportModule):
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding
classifica
tion output logits.
Tensor holding
detec
tion output logits.
"""
model_params
=
self
.
_params
.
task
.
model
with
tf
.
device
(
'cpu:0'
):
...
...
official/vision/beta/serving/detection_test.py
View file @
440e0eec
...
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for image
classifica
tion export lib."""
"""Test for image
detec
tion export lib."""
import
io
import
os
...
...
@@ -41,7 +41,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def
_export_from_module
(
self
,
module
,
input_type
,
batch_size
,
save_directory
):
if
input_type
==
'image_tensor'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
batch_size
,
640
,
640
,
3
],
dtype
=
tf
.
uint8
)
shape
=
[
batch_size
,
None
,
None
,
3
],
dtype
=
tf
.
uint8
)
signatures
=
{
'serving_default'
:
module
.
inference_from_image_tensors
.
get_concrete_function
(
...
...
@@ -68,18 +68,19 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
save_directory
,
signatures
=
signatures
)
def
_get_dummy_input
(
self
,
input_type
,
batch_size
):
def
_get_dummy_input
(
self
,
input_type
,
batch_size
,
image_size
):
"""Get dummy input for the given input type."""
h
,
w
=
image_size
if
input_type
==
'image_tensor'
:
return
tf
.
zeros
((
batch_size
,
640
,
640
,
3
),
dtype
=
np
.
uint8
)
return
tf
.
zeros
((
batch_size
,
h
,
w
,
3
),
dtype
=
np
.
uint8
)
elif
input_type
==
'image_bytes'
:
image
=
Image
.
fromarray
(
np
.
zeros
((
640
,
640
,
3
),
dtype
=
np
.
uint8
))
image
=
Image
.
fromarray
(
np
.
zeros
((
h
,
w
,
3
),
dtype
=
np
.
uint8
))
byte_io
=
io
.
BytesIO
()
image
.
save
(
byte_io
,
'PNG'
)
return
[
byte_io
.
getvalue
()
for
b
in
range
(
batch_size
)]
elif
input_type
==
'tf_example'
:
image_tensor
=
tf
.
zeros
((
640
,
640
,
3
),
dtype
=
tf
.
uint8
)
image_tensor
=
tf
.
zeros
((
h
,
w
,
3
),
dtype
=
tf
.
uint8
)
encoded_jpeg
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
)).
numpy
()
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
...
...
@@ -91,21 +92,23 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
return
[
example
for
b
in
range
(
batch_size
)]
@
parameterized
.
parameters
(
(
'image_tensor'
,
'fasterrcnn_resnetfpn_coco'
),
(
'image_bytes'
,
'fasterrcnn_resnetfpn_coco'
),
(
'tf_example'
,
'fasterrcnn_resnetfpn_coco'
),
(
'image_tensor'
,
'maskrcnn_resnetfpn_coco'
),
(
'image_bytes'
,
'maskrcnn_resnetfpn_coco'
),
(
'tf_example'
,
'maskrcnn_resnetfpn_coco'
),
(
'image_tensor'
,
'retinanet_resnetfpn_coco'
),
(
'image_bytes'
,
'retinanet_resnetfpn_coco'
),
(
'tf_example'
,
'retinanet_resnetfpn_coco'
),
(
'image_tensor'
,
'fasterrcnn_resnetfpn_coco'
,
[
384
,
384
]),
(
'image_bytes'
,
'fasterrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'tf_example'
,
'fasterrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_tensor'
,
'maskrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_bytes'
,
'maskrcnn_resnetfpn_coco'
,
[
640
,
384
]),
(
'tf_example'
,
'maskrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_tensor'
,
'retinanet_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_bytes'
,
'retinanet_resnetfpn_coco'
,
[
640
,
640
]),
(
'tf_example'
,
'retinanet_resnetfpn_coco'
,
[
384
,
640
]),
(
'image_tensor'
,
'retinanet_resnetfpn_coco'
,
[
384
,
384
]),
(
'image_bytes'
,
'retinanet_spinenet_coco'
,
[
640
,
640
]),
(
'tf_example'
,
'retinanet_spinenet_coco'
,
[
640
,
384
]),
)
def
test_export
(
self
,
input_type
,
experiment_name
):
def
test_export
(
self
,
input_type
,
experiment_name
,
image_size
):
tmp_dir
=
self
.
get_temp_dir
()
batch_size
=
1
experiment_name
=
'fasterrcnn_resnetfpn_coco'
module
=
self
.
_get_detection_module
(
experiment_name
)
model
=
module
.
build_model
()
...
...
@@ -118,9 +121,9 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
tmp_dir
)
classifica
tion_fn
=
imported
.
signatures
[
'serving_default'
]
detec
tion_fn
=
imported
.
signatures
[
'serving_default'
]
images
=
self
.
_get_dummy_input
(
input_type
,
batch_size
)
images
=
self
.
_get_dummy_input
(
input_type
,
batch_size
,
image_size
)
processed_images
,
anchor_boxes
,
image_shape
=
module
.
_build_inputs
(
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
))
...
...
@@ -134,7 +137,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
image_shape
=
image_shape
,
anchor_boxes
=
anchor_boxes
,
training
=
False
)
outputs
=
classifica
tion_fn
(
tf
.
constant
(
images
))
outputs
=
detec
tion_fn
(
tf
.
constant
(
images
))
self
.
assertAllClose
(
outputs
[
'num_detections'
].
numpy
(),
expected_outputs
[
'num_detections'
].
numpy
())
...
...
official/vision/beta/serving/export_base.py
View file @
440e0eec
...
...
@@ -73,7 +73,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
_decode_image
,
elems
=
input_tensor
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
uint8
),
shape
=
[
None
,
None
,
3
],
dtype
=
tf
.
uint8
),
parallel_iterations
=
32
))
images
=
tf
.
stack
(
images
)
return
self
.
_run_inference_on_image_tensors
(
images
)
...
...
official/vision/beta/tasks/image_classification.py
View file @
440e0eec
...
...
@@ -16,13 +16,14 @@
"""Image classification task definition."""
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
base_task
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.modeling
import
tf_utils
from
official.vision.beta.configs
import
image_classification
as
exp_cfg
from
official.vision.beta.dataloaders
import
classification_input
from
official.vision.beta.dataloaders
import
dataset_fn
from
official.vision.beta.dataloaders
import
tfds_classification_decoders
from
official.vision.beta.modeling
import
factory
...
...
@@ -67,7 +68,8 @@ class ImageClassificationTask(base_task.Task):
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
assert
"Only 'all' or 'backbone' can be used to initialize the model."
raise
ValueError
(
"Only 'all' or 'backbone' can be used to initialize the model."
)
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
...
...
@@ -78,7 +80,15 @@ class ImageClassificationTask(base_task.Task):
num_classes
=
self
.
task_config
.
model
.
num_classes
input_size
=
self
.
task_config
.
model
.
input_size
decoder
=
classification_input
.
Decoder
()
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
[
params
.
tfds_name
]()
else
:
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
else
:
decoder
=
classification_input
.
Decoder
()
parser
=
classification_input
.
Parser
(
output_size
=
input_size
[:
2
],
num_classes
=
num_classes
,
...
...
official/vision/beta/tasks/maskrcnn.py
View file @
440e0eec
...
...
@@ -17,13 +17,13 @@
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
base_task
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.vision.beta.configs
import
maskrcnn
as
exp_cfg
from
official.vision.beta.dataloaders
import
maskrcnn_input
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
dataset_fn
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.losses
import
maskrcnn_losses
...
...
@@ -100,7 +100,8 @@ class MaskRCNNTask(base_task.Task):
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
assert
"Only 'all' or 'backbone' can be used to initialize the model."
raise
ValueError
(
"Only 'all' or 'backbone' can be used to initialize the model."
)
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment