Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8ea058b9
Commit
8ea058b9
authored
May 28, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 313693539
parent
980b27d5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
18 additions
and
6 deletions
+18
-6
official/vision/detection/configs/factory.py
official/vision/detection/configs/factory.py
+5
-1
official/vision/detection/dataloader/factory.py
official/vision/detection/dataloader/factory.py
+1
-0
official/vision/detection/dataloader/shapemask_parser.py
official/vision/detection/dataloader/shapemask_parser.py
+1
-0
official/vision/detection/main.py
official/vision/detection/main.py
+3
-3
official/vision/detection/modeling/factory.py
official/vision/detection/modeling/factory.py
+3
-0
official/vision/detection/modeling/losses.py
official/vision/detection/modeling/losses.py
+4
-2
official/vision/detection/ops/postprocess_ops.py
official/vision/detection/ops/postprocess_ops.py
+1
-0
No files found.
official/vision/detection/configs/factory.py
View file @
8ea058b9
...
@@ -14,9 +14,10 @@
...
@@ -14,9 +14,10 @@
# ==============================================================================
# ==============================================================================
"""Factory to provide model configs."""
"""Factory to provide model configs."""
from
official.modeling.hyperparams
import
params_dict
from
official.vision.detection.configs
import
maskrcnn_config
from
official.vision.detection.configs
import
maskrcnn_config
from
official.vision.detection.configs
import
retinanet_config
from
official.vision.detection.configs
import
retinanet_config
from
official.
modeling.hyperparams
import
params_dict
from
official.
vision.detection.configs
import
shapemask_config
def
config_generator
(
model
):
def
config_generator
(
model
):
...
@@ -27,6 +28,9 @@ def config_generator(model):
...
@@ -27,6 +28,9 @@ def config_generator(model):
elif
model
==
'mask_rcnn'
:
elif
model
==
'mask_rcnn'
:
default_config
=
maskrcnn_config
.
MASKRCNN_CFG
default_config
=
maskrcnn_config
.
MASKRCNN_CFG
restrictions
=
maskrcnn_config
.
MASKRCNN_RESTRICTIONS
restrictions
=
maskrcnn_config
.
MASKRCNN_RESTRICTIONS
elif
model
==
'shapemask'
:
default_config
=
shapemask_config
.
SHAPEMASK_CFG
restrictions
=
shapemask_config
.
SHAPEMASK_RESTRICTIONS
else
:
else
:
raise
ValueError
(
'Model %s is not supported.'
%
model
)
raise
ValueError
(
'Model %s is not supported.'
%
model
)
...
...
official/vision/detection/dataloader/factory.py
View file @
8ea058b9
...
@@ -22,6 +22,7 @@ from official.vision.detection.dataloader import maskrcnn_parser
...
@@ -22,6 +22,7 @@ from official.vision.detection.dataloader import maskrcnn_parser
from
official.vision.detection.dataloader
import
retinanet_parser
from
official.vision.detection.dataloader
import
retinanet_parser
from
official.vision.detection.dataloader
import
shapemask_parser
from
official.vision.detection.dataloader
import
shapemask_parser
def
parser_generator
(
params
,
mode
):
def
parser_generator
(
params
,
mode
):
"""Generator function for various dataset parser."""
"""Generator function for various dataset parser."""
if
params
.
architecture
.
parser
==
'retinanet_parser'
:
if
params
.
architecture
.
parser
==
'retinanet_parser'
:
...
...
official/vision/detection/dataloader/shapemask_parser.py
View file @
8ea058b9
...
@@ -419,6 +419,7 @@ class Parser(object):
...
@@ -419,6 +419,7 @@ class Parser(object):
inputs
=
{
inputs
=
{
'image'
:
image
,
'image'
:
image
,
'image_info'
:
image_info
,
'mask_boxes'
:
sampled_boxes
,
'mask_boxes'
:
sampled_boxes
,
'mask_outer_boxes'
:
mask_outer_boxes
,
'mask_outer_boxes'
:
mask_outer_boxes
,
'mask_classes'
:
sampled_classes
,
'mask_classes'
:
sampled_classes
,
...
...
official/vision/detection/main.py
View file @
8ea058b9
...
@@ -54,7 +54,7 @@ flags.DEFINE_string(
...
@@ -54,7 +54,7 @@ flags.DEFINE_string(
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
'model'
,
default
=
'retinanet'
,
'model'
,
default
=
'retinanet'
,
help
=
'Model to run: `retinanet`
or
`mask_rcnn`.'
)
help
=
'Model to run: `retinanet`
,
`mask_rcnn`
or `shapemask`
.'
)
flags
.
DEFINE_string
(
'training_file_pattern'
,
None
,
flags
.
DEFINE_string
(
'training_file_pattern'
,
None
,
'Location of the train data.'
)
'Location of the train data.'
)
...
@@ -75,7 +75,7 @@ def run_executor(params,
...
@@ -75,7 +75,7 @@ def run_executor(params,
eval_input_fn
=
None
,
eval_input_fn
=
None
,
callbacks
=
None
,
callbacks
=
None
,
prebuilt_strategy
=
None
):
prebuilt_strategy
=
None
):
"""Runs
Retinanet
model on distribution strategy defined by the user."""
"""Runs
the object detection
model on distribution strategy defined by the user."""
if
params
.
architecture
.
use_bfloat16
:
if
params
.
architecture
.
use_bfloat16
:
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
...
@@ -203,7 +203,7 @@ def run(callbacks=None):
...
@@ -203,7 +203,7 @@ def run(callbacks=None):
params
.
lock
()
params
.
lock
()
pp
=
pprint
.
PrettyPrinter
()
pp
=
pprint
.
PrettyPrinter
()
params_str
=
pp
.
pformat
(
params
.
as_dict
())
params_str
=
pp
.
pformat
(
params
.
as_dict
())
logging
.
info
(
'Model Parameters:
{}'
.
format
(
params_str
)
)
logging
.
info
(
'Model Parameters:
%s'
,
params_str
)
train_input_fn
=
None
train_input_fn
=
None
eval_input_fn
=
None
eval_input_fn
=
None
...
...
official/vision/detection/modeling/factory.py
View file @
8ea058b9
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
from
official.vision.detection.modeling
import
maskrcnn_model
from
official.vision.detection.modeling
import
maskrcnn_model
from
official.vision.detection.modeling
import
retinanet_model
from
official.vision.detection.modeling
import
retinanet_model
from
official.vision.detection.modeling
import
shapemask_model
def
model_generator
(
params
):
def
model_generator
(
params
):
...
@@ -25,6 +26,8 @@ def model_generator(params):
...
@@ -25,6 +26,8 @@ def model_generator(params):
model_fn
=
retinanet_model
.
RetinanetModel
(
params
)
model_fn
=
retinanet_model
.
RetinanetModel
(
params
)
elif
params
.
type
==
'mask_rcnn'
:
elif
params
.
type
==
'mask_rcnn'
:
model_fn
=
maskrcnn_model
.
MaskrcnnModel
(
params
)
model_fn
=
maskrcnn_model
.
MaskrcnnModel
(
params
)
elif
params
.
type
==
'shapemask'
:
model_fn
=
shapemask_model
.
ShapeMaskModel
(
params
)
else
:
else
:
raise
ValueError
(
'Model %s is not supported.'
%
params
.
type
)
raise
ValueError
(
'Model %s is not supported.'
%
params
.
type
)
...
...
official/vision/detection/modeling/losses.py
View file @
8ea058b9
...
@@ -411,8 +411,10 @@ class RetinanetClassLoss(object):
...
@@ -411,8 +411,10 @@ class RetinanetClassLoss(object):
bs
,
height
,
width
,
_
,
_
=
cls_targets_one_hot
.
get_shape
().
as_list
()
bs
,
height
,
width
,
_
,
_
=
cls_targets_one_hot
.
get_shape
().
as_list
()
cls_targets_one_hot
=
tf
.
reshape
(
cls_targets_one_hot
,
cls_targets_one_hot
=
tf
.
reshape
(
cls_targets_one_hot
,
[
bs
,
height
,
width
,
-
1
])
[
bs
,
height
,
width
,
-
1
])
loss
=
focal_loss
(
cls_outputs
,
cls_targets_one_hot
,
loss
=
focal_loss
(
tf
.
cast
(
cls_outputs
,
dtype
=
tf
.
float32
),
self
.
_focal_loss_alpha
,
self
.
_focal_loss_gamma
,
tf
.
cast
(
cls_targets_one_hot
,
dtype
=
tf
.
float32
),
self
.
_focal_loss_alpha
,
self
.
_focal_loss_gamma
,
num_positives
)
num_positives
)
ignore_loss
=
tf
.
where
(
ignore_loss
=
tf
.
where
(
...
...
official/vision/detection/ops/postprocess_ops.py
View file @
8ea058b9
...
@@ -288,6 +288,7 @@ def _generate_detections_batched(boxes,
...
@@ -288,6 +288,7 @@ def _generate_detections_batched(boxes,
pad_per_class
=
False
,)
pad_per_class
=
False
,)
# De-normalizes box cooridinates.
# De-normalizes box cooridinates.
nmsed_boxes
*=
normalizer
nmsed_boxes
*=
normalizer
nmsed_classes
=
tf
.
cast
(
nmsed_classes
,
tf
.
int32
)
return
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
return
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment