Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8d9a16ce
Commit
8d9a16ce
authored
Dec 16, 2019
by
Yeqing Li
Committed by
A. Unique TensorFlower
Dec 16, 2019
Browse files
Internal change
PiperOrigin-RevId: 285844156
parent
913640d4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
26 additions
and
19 deletions
+26
-19
official/vision/detection/configs/base_config.py
official/vision/detection/configs/base_config.py
+4
-3
official/vision/detection/configs/maskrcnn_config.py
official/vision/detection/configs/maskrcnn_config.py
+1
-0
official/vision/detection/dataloader/maskrcnn_parser.py
official/vision/detection/dataloader/maskrcnn_parser.py
+17
-11
official/vision/detection/modeling/base_model.py
official/vision/detection/modeling/base_model.py
+1
-0
official/vision/detection/modeling/losses.py
official/vision/detection/modeling/losses.py
+2
-4
official/vision/detection/modeling/retinanet_model.py
official/vision/detection/modeling/retinanet_model.py
+1
-1
No files found.
official/vision/detection/configs/base_config.py
View file @
8d9a16ce
...
@@ -30,10 +30,11 @@ REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
...
@@ -30,10 +30,11 @@ REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
BASE_CFG
=
{
BASE_CFG
=
{
'model_dir'
:
''
,
'model_dir'
:
''
,
'use_tpu'
:
True
,
'use_tpu'
:
True
,
'strategy_type'
:
'tpu'
,
'isolate_session_state'
:
False
,
'isolate_session_state'
:
False
,
'train'
:
{
'train'
:
{
'iterations_per_loop'
:
100
,
'iterations_per_loop'
:
100
,
'
train_
batch_size'
:
64
,
'batch_size'
:
64
,
'total_steps'
:
22500
,
'total_steps'
:
22500
,
'num_cores_per_replica'
:
None
,
'num_cores_per_replica'
:
None
,
'input_partition_dims'
:
None
,
'input_partition_dims'
:
None
,
...
@@ -57,13 +58,13 @@ BASE_CFG = {
...
@@ -57,13 +58,13 @@ BASE_CFG = {
'frozen_variable_prefix'
:
RESNET_FROZEN_VAR_PREFIX
,
'frozen_variable_prefix'
:
RESNET_FROZEN_VAR_PREFIX
,
'train_file_pattern'
:
''
,
'train_file_pattern'
:
''
,
'train_dataset_type'
:
'tfrecord'
,
'train_dataset_type'
:
'tfrecord'
,
'transpose_input'
:
Tru
e
,
'transpose_input'
:
Fals
e
,
'regularization_variable_regex'
:
REGULARIZATION_VAR_REGEX
,
'regularization_variable_regex'
:
REGULARIZATION_VAR_REGEX
,
'l2_weight_decay'
:
0.0001
,
'l2_weight_decay'
:
0.0001
,
'gradient_clip_norm'
:
0.0
,
'gradient_clip_norm'
:
0.0
,
},
},
'eval'
:
{
'eval'
:
{
'
eval_
batch_size'
:
8
,
'batch_size'
:
8
,
'eval_samples'
:
5000
,
'eval_samples'
:
5000
,
'min_eval_interval'
:
180
,
'min_eval_interval'
:
180
,
'eval_timeout'
:
None
,
'eval_timeout'
:
None
,
...
...
official/vision/detection/configs/maskrcnn_config.py
View file @
8d9a16ce
...
@@ -34,6 +34,7 @@ MASKRCNN_CFG.override({
...
@@ -34,6 +34,7 @@ MASKRCNN_CFG.override({
'maskrcnn_parser'
:
{
'maskrcnn_parser'
:
{
'use_bfloat16'
:
True
,
'use_bfloat16'
:
True
,
'output_size'
:
[
1024
,
1024
],
'output_size'
:
[
1024
,
1024
],
'num_channels'
:
3
,
'rpn_match_threshold'
:
0.7
,
'rpn_match_threshold'
:
0.7
,
'rpn_unmatched_threshold'
:
0.3
,
'rpn_unmatched_threshold'
:
0.3
,
'rpn_batch_size_per_im'
:
256
,
'rpn_batch_size_per_im'
:
256
,
...
...
official/vision/detection/dataloader/maskrcnn_parser.py
View file @
8d9a16ce
...
@@ -275,6 +275,10 @@ class Parser(object):
...
@@ -275,6 +275,10 @@ class Parser(object):
if
self
.
_use_bfloat16
:
if
self
.
_use_bfloat16
:
image
=
tf
.
cast
(
image
,
dtype
=
tf
.
bfloat16
)
image
=
tf
.
cast
(
image
,
dtype
=
tf
.
bfloat16
)
inputs
=
{
'image'
:
image
,
'image_info'
:
image_info
,
}
# Packs labels for model_fn outputs.
# Packs labels for model_fn outputs.
labels
=
{
labels
=
{
'anchor_boxes'
:
input_anchor
.
multilevel_boxes
,
'anchor_boxes'
:
input_anchor
.
multilevel_boxes
,
...
@@ -282,15 +286,16 @@ class Parser(object):
...
@@ -282,15 +286,16 @@ class Parser(object):
'rpn_score_targets'
:
rpn_score_targets
,
'rpn_score_targets'
:
rpn_score_targets
,
'rpn_box_targets'
:
rpn_box_targets
,
'rpn_box_targets'
:
rpn_box_targets
,
}
}
labels
[
'gt_boxes'
]
=
input_utils
.
pad_to_fixed_size
(
inputs
[
'gt_boxes'
]
=
input_utils
.
pad_to_fixed_size
(
boxes
,
boxes
,
self
.
_max_num_instances
,
-
1
)
self
.
_max_num_instances
,
labels
[
'gt_classes'
]
=
input_utils
.
pad_to_fixed_size
(
-
1
)
inputs
[
'gt_classes'
]
=
input_utils
.
pad_to_fixed_size
(
classes
,
self
.
_max_num_instances
,
-
1
)
classes
,
self
.
_max_num_instances
,
-
1
)
if
self
.
_include_mask
:
if
self
.
_include_mask
:
label
s
[
'gt_masks'
]
=
input_utils
.
pad_to_fixed_size
(
input
s
[
'gt_masks'
]
=
input_utils
.
pad_to_fixed_size
(
masks
,
self
.
_max_num_instances
,
-
1
)
masks
,
self
.
_max_num_instances
,
-
1
)
return
i
mage
,
labels
return
i
nputs
,
labels
def
_parse_eval_data
(
self
,
data
):
def
_parse_eval_data
(
self
,
data
):
"""Parses data for evaluation."""
"""Parses data for evaluation."""
...
@@ -348,11 +353,7 @@ class Parser(object):
...
@@ -348,11 +353,7 @@ class Parser(object):
self
.
_anchor_size
,
self
.
_anchor_size
,
(
image_height
,
image_width
))
(
image_height
,
image_width
))
labels
=
{
labels
=
{}
'source_id'
:
dataloader_utils
.
process_source_id
(
data
[
'source_id'
]),
'anchor_boxes'
:
input_anchor
.
multilevel_boxes
,
'image_info'
:
image_info
,
}
if
self
.
_mode
==
ModeKeys
.
PREDICT_WITH_GT
:
if
self
.
_mode
==
ModeKeys
.
PREDICT_WITH_GT
:
# Converts boxes from normalized coordinates to pixel coordinates.
# Converts boxes from normalized coordinates to pixel coordinates.
...
@@ -372,6 +373,11 @@ class Parser(object):
...
@@ -372,6 +373,11 @@ class Parser(object):
groundtruths
[
'source_id'
])
groundtruths
[
'source_id'
])
groundtruths
=
dataloader_utils
.
pad_groundtruths_to_fixed_size
(
groundtruths
=
dataloader_utils
.
pad_groundtruths_to_fixed_size
(
groundtruths
,
self
.
_max_num_instances
)
groundtruths
,
self
.
_max_num_instances
)
# TODO(yeqing): Remove the `groundtrtuh` layer key (no longer needed).
labels
[
'groundtruths'
]
=
groundtruths
labels
[
'groundtruths'
]
=
groundtruths
inputs
=
{
'image'
:
image
,
'image_info'
:
image_info
,
}
return
i
mage
,
labels
return
i
nputs
,
labels
official/vision/detection/modeling/base_model.py
View file @
8d9a16ce
...
@@ -99,6 +99,7 @@ class Model(object):
...
@@ -99,6 +99,7 @@ class Model(object):
params
.
train
.
learning_rate
)
params
.
train
.
learning_rate
)
self
.
_frozen_variable_prefix
=
params
.
train
.
frozen_variable_prefix
self
.
_frozen_variable_prefix
=
params
.
train
.
frozen_variable_prefix
self
.
_l2_weight_decay
=
params
.
train
.
l2_weight_decay
# Checkpoint restoration.
# Checkpoint restoration.
self
.
_checkpoint
=
params
.
train
.
checkpoint
.
as_dict
()
self
.
_checkpoint
=
params
.
train
.
checkpoint
.
as_dict
()
...
...
official/vision/detection/modeling/losses.py
View file @
8d9a16ce
...
@@ -147,6 +147,7 @@ class RpnBoxLoss(object):
...
@@ -147,6 +147,7 @@ class RpnBoxLoss(object):
"""Region Proposal Network box regression loss function."""
"""Region Proposal Network box regression loss function."""
def
__init__
(
self
,
params
):
def
__init__
(
self
,
params
):
self
.
_delta
=
params
.
huber_loss_delta
self
.
_huber_loss
=
tf
.
keras
.
losses
.
Huber
(
self
.
_huber_loss
=
tf
.
keras
.
losses
.
Huber
(
delta
=
params
.
huber_loss_delta
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
delta
=
params
.
huber_loss_delta
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
...
@@ -212,7 +213,7 @@ class FastrcnnClassLoss(object):
...
@@ -212,7 +213,7 @@ class FastrcnnClassLoss(object):
a scalar tensor representing total class loss.
a scalar tensor representing total class loss.
"""
"""
with
tf
.
name_scope
(
'fast_rcnn_loss'
):
with
tf
.
name_scope
(
'fast_rcnn_loss'
):
_
,
_
,
_
,
num_classes
=
class_outputs
.
get_shape
().
as_list
()
_
,
_
,
num_classes
=
class_outputs
.
get_shape
().
as_list
()
class_targets
=
tf
.
cast
(
class_targets
,
dtype
=
tf
.
int32
)
class_targets
=
tf
.
cast
(
class_targets
,
dtype
=
tf
.
int32
)
class_targets_one_hot
=
tf
.
one_hot
(
class_targets
,
num_classes
)
class_targets_one_hot
=
tf
.
one_hot
(
class_targets
,
num_classes
)
return
self
.
_fast_rcnn_class_loss
(
class_outputs
,
class_targets_one_hot
)
return
self
.
_fast_rcnn_class_loss
(
class_outputs
,
class_targets_one_hot
)
...
@@ -320,9 +321,6 @@ class FastrcnnBoxLoss(object):
...
@@ -320,9 +321,6 @@ class FastrcnnBoxLoss(object):
class
MaskrcnnLoss
(
object
):
class
MaskrcnnLoss
(
object
):
"""Mask R-CNN instance segmentation mask loss function."""
"""Mask R-CNN instance segmentation mask loss function."""
def
__init__
(
self
):
raise
ValueError
(
'Not TF 2.0 ready.'
)
def
__call__
(
self
,
mask_outputs
,
mask_targets
,
select_class_targets
):
def
__call__
(
self
,
mask_outputs
,
mask_targets
,
select_class_targets
):
"""Computes the mask loss of Mask-RCNN.
"""Computes the mask loss of Mask-RCNN.
...
...
official/vision/detection/modeling/retinanet_model.py
View file @
8d9a16ce
...
@@ -56,7 +56,6 @@ class RetinanetModel(base_model.Model):
...
@@ -56,7 +56,6 @@ class RetinanetModel(base_model.Model):
self
.
_generate_detections_fn
=
postprocess_ops
.
MultilevelDetectionGenerator
(
self
.
_generate_detections_fn
=
postprocess_ops
.
MultilevelDetectionGenerator
(
params
.
postprocess
)
params
.
postprocess
)
self
.
_l2_weight_decay
=
params
.
train
.
l2_weight_decay
self
.
_transpose_input
=
params
.
train
.
transpose_input
self
.
_transpose_input
=
params
.
train
.
transpose_input
assert
not
self
.
_transpose_input
,
'Transpose input is not supportted.'
assert
not
self
.
_transpose_input
,
'Transpose input is not supportted.'
# Input layer.
# Input layer.
...
@@ -134,6 +133,7 @@ class RetinanetModel(base_model.Model):
...
@@ -134,6 +133,7 @@ class RetinanetModel(base_model.Model):
return
self
.
_keras_model
return
self
.
_keras_model
def
post_processing
(
self
,
labels
,
outputs
):
def
post_processing
(
self
,
labels
,
outputs
):
# TODO(yeqing): Moves the output related part into build_outputs.
required_output_fields
=
[
'cls_outputs'
,
'box_outputs'
]
required_output_fields
=
[
'cls_outputs'
,
'box_outputs'
]
for
field
in
required_output_fields
:
for
field
in
required_output_fields
:
if
field
not
in
outputs
:
if
field
not
in
outputs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment