Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7a5be20b
Commit
7a5be20b
authored
Jan 14, 2021
by
A. Unique TensorFlower
Browse files
Implement thresholding instance masks with a configurable value.
PiperOrigin-RevId: 351890707
parent
374cff58
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
16 additions
and
5 deletions
+16
-5
official/vision/beta/configs/maskrcnn.py
official/vision/beta/configs/maskrcnn.py
+2
-0
official/vision/beta/dataloaders/tf_example_decoder.py
official/vision/beta/dataloaders/tf_example_decoder.py
+6
-1
official/vision/beta/dataloaders/tf_example_label_map_decoder.py
...l/vision/beta/dataloaders/tf_example_label_map_decoder.py
+4
-2
official/vision/beta/tasks/maskrcnn.py
official/vision/beta/tasks/maskrcnn.py
+4
-2
No files found.
official/vision/beta/configs/maskrcnn.py
View file @
7a5be20b
...
@@ -31,11 +31,13 @@ from official.vision.beta.configs import decoders
...
@@ -31,11 +31,13 @@ from official.vision.beta.configs import decoders
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TfExampleDecoder
(
hyperparams
.
Config
):
class
TfExampleDecoder
(
hyperparams
.
Config
):
regenerate_source_id
:
bool
=
False
regenerate_source_id
:
bool
=
False
mask_binarize_threshold
:
Optional
[
float
]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TfExampleDecoderLabelMap
(
hyperparams
.
Config
):
class
TfExampleDecoderLabelMap
(
hyperparams
.
Config
):
regenerate_source_id
:
bool
=
False
regenerate_source_id
:
bool
=
False
mask_binarize_threshold
:
Optional
[
float
]
=
None
label_map
:
str
=
''
label_map
:
str
=
''
...
...
official/vision/beta/dataloaders/tf_example_decoder.py
View file @
7a5be20b
...
@@ -34,7 +34,8 @@ class TfExampleDecoder(decoder.Decoder):
...
@@ -34,7 +34,8 @@ class TfExampleDecoder(decoder.Decoder):
def
__init__
(
self
,
def
__init__
(
self
,
include_mask
=
False
,
include_mask
=
False
,
regenerate_source_id
=
False
):
regenerate_source_id
=
False
,
mask_binarize_threshold
=
None
):
self
.
_include_mask
=
include_mask
self
.
_include_mask
=
include_mask
self
.
_regenerate_source_id
=
regenerate_source_id
self
.
_regenerate_source_id
=
regenerate_source_id
self
.
_keys_to_features
=
{
self
.
_keys_to_features
=
{
...
@@ -50,6 +51,7 @@ class TfExampleDecoder(decoder.Decoder):
...
@@ -50,6 +51,7 @@ class TfExampleDecoder(decoder.Decoder):
'image/object/area'
:
tf
.
io
.
VarLenFeature
(
tf
.
float32
),
'image/object/area'
:
tf
.
io
.
VarLenFeature
(
tf
.
float32
),
'image/object/is_crowd'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'image/object/is_crowd'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
}
}
self
.
_mask_binarize_threshold
=
mask_binarize_threshold
if
include_mask
:
if
include_mask
:
self
.
_keys_to_features
.
update
({
self
.
_keys_to_features
.
update
({
'image/object/mask'
:
tf
.
io
.
VarLenFeature
(
tf
.
string
),
'image/object/mask'
:
tf
.
io
.
VarLenFeature
(
tf
.
string
),
...
@@ -151,6 +153,9 @@ class TfExampleDecoder(decoder.Decoder):
...
@@ -151,6 +153,9 @@ class TfExampleDecoder(decoder.Decoder):
if
self
.
_include_mask
:
if
self
.
_include_mask
:
masks
=
self
.
_decode_masks
(
parsed_tensors
)
masks
=
self
.
_decode_masks
(
parsed_tensors
)
if
self
.
_mask_binarize_threshold
is
not
None
:
masks
=
tf
.
cast
(
masks
>
self
.
_mask_binarize_threshold
,
tf
.
float32
)
decoded_tensors
=
{
decoded_tensors
=
{
'source_id'
:
source_id
,
'source_id'
:
source_id
,
'image'
:
image
,
'image'
:
image
,
...
...
official/vision/beta/dataloaders/tf_example_label_map_decoder.py
View file @
7a5be20b
...
@@ -27,9 +27,11 @@ from official.vision.beta.dataloaders import tf_example_decoder
...
@@ -27,9 +27,11 @@ from official.vision.beta.dataloaders import tf_example_decoder
class
TfExampleDecoderLabelMap
(
tf_example_decoder
.
TfExampleDecoder
):
class
TfExampleDecoderLabelMap
(
tf_example_decoder
.
TfExampleDecoder
):
"""Tensorflow Example proto decoder."""
"""Tensorflow Example proto decoder."""
def
__init__
(
self
,
label_map
,
include_mask
=
False
,
regenerate_source_id
=
False
):
def
__init__
(
self
,
label_map
,
include_mask
=
False
,
regenerate_source_id
=
False
,
mask_binarize_threshold
=
None
):
super
(
TfExampleDecoderLabelMap
,
self
).
__init__
(
super
(
TfExampleDecoderLabelMap
,
self
).
__init__
(
include_mask
=
include_mask
,
regenerate_source_id
=
regenerate_source_id
)
include_mask
=
include_mask
,
regenerate_source_id
=
regenerate_source_id
,
mask_binarize_threshold
=
mask_binarize_threshold
)
self
.
_keys_to_features
.
update
({
self
.
_keys_to_features
.
update
({
'image/object/class/text'
:
tf
.
io
.
VarLenFeature
(
tf
.
string
),
'image/object/class/text'
:
tf
.
io
.
VarLenFeature
(
tf
.
string
),
})
})
...
...
official/vision/beta/tasks/maskrcnn.py
View file @
7a5be20b
...
@@ -110,12 +110,14 @@ class MaskRCNNTask(base_task.Task):
...
@@ -110,12 +110,14 @@ class MaskRCNNTask(base_task.Task):
if
params
.
decoder
.
type
==
'simple_decoder'
:
if
params
.
decoder
.
type
==
'simple_decoder'
:
decoder
=
tf_example_decoder
.
TfExampleDecoder
(
decoder
=
tf_example_decoder
.
TfExampleDecoder
(
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
,
mask_binarize_threshold
=
decoder_cfg
.
mask_binarize_threshold
)
elif
params
.
decoder
.
type
==
'label_map_decoder'
:
elif
params
.
decoder
.
type
==
'label_map_decoder'
:
decoder
=
tf_example_label_map_decoder
.
TfExampleDecoderLabelMap
(
decoder
=
tf_example_label_map_decoder
.
TfExampleDecoderLabelMap
(
label_map
=
decoder_cfg
.
label_map
,
label_map
=
decoder_cfg
.
label_map
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
,
mask_binarize_threshold
=
decoder_cfg
.
mask_binarize_threshold
)
else
:
else
:
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment