Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
469326b6
Commit
469326b6
authored
May 03, 2021
by
Xianzhi Du
Committed by
A. Unique TensorFlower
May 03, 2021
Browse files
Internal change
PiperOrigin-RevId: 371847148
parent
3d1f1135
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
212 additions
and
15 deletions
+212
-15
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet143_casrcnn_tpu.yaml
...gs/experiments/maskrcnn/coco_spinenet143_casrcnn_tpu.yaml
+57
-0
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet143_mrcnn_tpu.yaml
...figs/experiments/maskrcnn/coco_spinenet143_mrcnn_tpu.yaml
+0
-0
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet49_casrcnn_tpu.yaml
...igs/experiments/maskrcnn/coco_spinenet49_casrcnn_tpu.yaml
+56
-0
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet49_mrcnn_tpu.yaml
...nfigs/experiments/maskrcnn/coco_spinenet49_mrcnn_tpu.yaml
+0
-0
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet96_casrcnn_tpu.yaml
...igs/experiments/maskrcnn/coco_spinenet96_casrcnn_tpu.yaml
+56
-0
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet96_mrcnn_tpu.yaml
...nfigs/experiments/maskrcnn/coco_spinenet96_mrcnn_tpu.yaml
+0
-0
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+22
-1
official/vision/beta/modeling/maskrcnn_model.py
official/vision/beta/modeling/maskrcnn_model.py
+20
-13
official/vision/beta/modeling/maskrcnn_model_test.py
official/vision/beta/modeling/maskrcnn_model_test.py
+1
-1
No files found.
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet143_casrcnn_tpu.yaml
0 → 100644
View file @
469326b6
# Expect to reach: box mAP: 51.6%, mask mAP: 44.5% on COCO
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
init_checkpoint
:
null
train_data
:
global_batch_size
:
256
parser
:
aug_rand_hflip
:
true
aug_scale_min
:
0.1
aug_scale_max
:
2.5
losses
:
l2_weight_decay
:
0.00004
model
:
anchor
:
anchor_size
:
4.0
num_scales
:
3
min_level
:
3
max_level
:
7
input_size
:
[
1280
,
1280
,
3
]
backbone
:
spinenet
:
stochastic_depth_drop_rate
:
0.2
model_id
:
'
143'
type
:
'
spinenet'
decoder
:
type
:
'
identity'
detection_head
:
cascade_class_ensemble
:
true
class_agnostic_bbox_pred
:
true
rpn_head
:
num_convs
:
2
num_filters
:
256
roi_sampler
:
cascade_iou_thresholds
:
[
0.7
]
foreground_iou_threshold
:
0.6
norm_activation
:
norm_epsilon
:
0.001
norm_momentum
:
0.99
use_sync_bn
:
true
activation
:
'
swish'
detection_generator
:
pre_nms_top_k
:
1000
trainer
:
train_steps
:
162050
optimizer_config
:
learning_rate
:
type
:
'
stepwise'
stepwise
:
boundaries
:
[
148160
,
157420
]
values
:
[
0.32
,
0.032
,
0.0032
]
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
2000
warmup_learning_rate
:
0.0067
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet143_tpu.yaml
→
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet143_
mrcnn_
tpu.yaml
View file @
469326b6
File moved
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet49_casrcnn_tpu.yaml
0 → 100644
View file @
469326b6
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
init_checkpoint
:
null
train_data
:
global_batch_size
:
256
parser
:
aug_rand_hflip
:
true
aug_scale_min
:
0.1
aug_scale_max
:
2.0
losses
:
l2_weight_decay
:
0.00004
model
:
anchor
:
anchor_size
:
3.0
num_scales
:
3
min_level
:
3
max_level
:
7
input_size
:
[
640
,
640
,
3
]
backbone
:
spinenet
:
stochastic_depth_drop_rate
:
0.2
model_id
:
'
49'
type
:
'
spinenet'
decoder
:
type
:
'
identity'
detection_head
:
cascade_class_ensemble
:
true
class_agnostic_bbox_pred
:
true
rpn_head
:
num_convs
:
2
num_filters
:
256
roi_sampler
:
cascade_iou_thresholds
:
[
0.7
]
foreground_iou_threshold
:
0.6
norm_activation
:
norm_epsilon
:
0.001
norm_momentum
:
0.99
use_sync_bn
:
true
activation
:
'
swish'
detection_generator
:
pre_nms_top_k
:
1000
trainer
:
train_steps
:
231000
optimizer_config
:
learning_rate
:
type
:
'
stepwise'
stepwise
:
boundaries
:
[
219450
,
226380
]
values
:
[
0.28
,
0.028
,
0.0028
]
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
2000
warmup_learning_rate
:
0.0067
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet49_tpu.yaml
→
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet49_
mrcnn_
tpu.yaml
View file @
469326b6
File moved
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet96_casrcnn_tpu.yaml
0 → 100644
View file @
469326b6
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
init_checkpoint
:
null
train_data
:
global_batch_size
:
256
parser
:
aug_rand_hflip
:
true
aug_scale_min
:
0.1
aug_scale_max
:
2.0
losses
:
l2_weight_decay
:
0.00004
model
:
anchor
:
anchor_size
:
3.0
num_scales
:
3
min_level
:
3
max_level
:
7
input_size
:
[
1024
,
1024
,
3
]
backbone
:
spinenet
:
stochastic_depth_drop_rate
:
0.2
model_id
:
'
96'
type
:
'
spinenet'
decoder
:
type
:
'
identity'
detection_head
:
cascade_class_ensemble
:
true
class_agnostic_bbox_pred
:
true
rpn_head
:
num_convs
:
2
num_filters
:
256
roi_sampler
:
cascade_iou_thresholds
:
[
0.7
]
foreground_iou_threshold
:
0.6
norm_activation
:
norm_epsilon
:
0.001
norm_momentum
:
0.99
use_sync_bn
:
true
activation
:
'
swish'
detection_generator
:
pre_nms_top_k
:
1000
trainer
:
train_steps
:
231000
optimizer_config
:
learning_rate
:
type
:
'
stepwise'
stepwise
:
boundaries
:
[
219450
,
226380
]
values
:
[
0.32
,
0.032
,
0.0032
]
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
2000
warmup_learning_rate
:
0.0067
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet96_tpu.yaml
→
official/vision/beta/configs/experiments/maskrcnn/coco_spinenet96_
mrcnn_
tpu.yaml
View file @
469326b6
File moved
official/vision/beta/modeling/factory.py
View file @
469326b6
...
@@ -114,7 +114,28 @@ def build_maskrcnn(
...
@@ -114,7 +114,28 @@ def build_maskrcnn(
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
kernel_regularizer
=
l2_regularizer
,
name
=
'detection_head'
)
if
roi_sampler_config
.
cascade_iou_thresholds
:
detection_head_cascade
=
[
detection_head
]
for
cascade_num
in
range
(
len
(
roi_sampler_config
.
cascade_iou_thresholds
)):
detection_head
=
instance_heads
.
DetectionHead
(
num_classes
=
model_config
.
num_classes
,
num_convs
=
detection_head_config
.
num_convs
,
num_filters
=
detection_head_config
.
num_filters
,
use_separable_conv
=
detection_head_config
.
use_separable_conv
,
num_fcs
=
detection_head_config
.
num_fcs
,
fc_dims
=
detection_head_config
.
fc_dims
,
class_agnostic_bbox_pred
=
detection_head_config
.
class_agnostic_bbox_pred
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
,
name
=
'detection_head_{}'
.
format
(
cascade_num
+
1
))
detection_head_cascade
.
append
(
detection_head
)
detection_head
=
detection_head_cascade
roi_generator_obj
=
roi_generator
.
MultilevelROIGenerator
(
roi_generator_obj
=
roi_generator
.
MultilevelROIGenerator
(
pre_nms_top_k
=
roi_generator_config
.
pre_nms_top_k
,
pre_nms_top_k
=
roi_generator_config
.
pre_nms_top_k
,
...
...
official/vision/beta/modeling/maskrcnn_model.py
View file @
469326b6
...
@@ -31,7 +31,8 @@ class MaskRCNNModel(tf.keras.Model):
...
@@ -31,7 +31,8 @@ class MaskRCNNModel(tf.keras.Model):
backbone
:
tf
.
keras
.
Model
,
backbone
:
tf
.
keras
.
Model
,
decoder
:
tf
.
keras
.
Model
,
decoder
:
tf
.
keras
.
Model
,
rpn_head
:
tf
.
keras
.
layers
.
Layer
,
rpn_head
:
tf
.
keras
.
layers
.
Layer
,
detection_head
:
tf
.
keras
.
layers
.
Layer
,
detection_head
:
Union
[
tf
.
keras
.
layers
.
Layer
,
List
[
tf
.
keras
.
layers
.
Layer
]],
roi_generator
:
tf
.
keras
.
layers
.
Layer
,
roi_generator
:
tf
.
keras
.
layers
.
Layer
,
roi_sampler
:
Union
[
tf
.
keras
.
layers
.
Layer
,
roi_sampler
:
Union
[
tf
.
keras
.
layers
.
Layer
,
List
[
tf
.
keras
.
layers
.
Layer
]],
List
[
tf
.
keras
.
layers
.
Layer
]],
...
@@ -54,7 +55,7 @@ class MaskRCNNModel(tf.keras.Model):
...
@@ -54,7 +55,7 @@ class MaskRCNNModel(tf.keras.Model):
backbone: `tf.keras.Model`, the backbone network.
backbone: `tf.keras.Model`, the backbone network.
decoder: `tf.keras.Model`, the decoder network.
decoder: `tf.keras.Model`, the decoder network.
rpn_head: the RPN head.
rpn_head: the RPN head.
detection_head: the detection head.
detection_head: the detection head
or a list of heads
.
roi_generator: the ROI generator.
roi_generator: the ROI generator.
roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
detection heads.
detection heads.
...
@@ -104,6 +105,9 @@ class MaskRCNNModel(tf.keras.Model):
...
@@ -104,6 +105,9 @@ class MaskRCNNModel(tf.keras.Model):
self
.
backbone
=
backbone
self
.
backbone
=
backbone
self
.
decoder
=
decoder
self
.
decoder
=
decoder
self
.
rpn_head
=
rpn_head
self
.
rpn_head
=
rpn_head
if
not
isinstance
(
detection_head
,
(
list
,
tuple
)):
self
.
detection_head
=
[
detection_head
]
else
:
self
.
detection_head
=
detection_head
self
.
detection_head
=
detection_head
self
.
roi_generator
=
roi_generator
self
.
roi_generator
=
roi_generator
if
not
isinstance
(
roi_sampler
,
(
list
,
tuple
)):
if
not
isinstance
(
roi_sampler
,
(
list
,
tuple
)):
...
@@ -191,7 +195,7 @@ class MaskRCNNModel(tf.keras.Model):
...
@@ -191,7 +195,7 @@ class MaskRCNNModel(tf.keras.Model):
gt_classes
=
gt_classes
,
gt_classes
=
gt_classes
,
training
=
training
,
training
=
training
,
model_outputs
=
model_outputs
,
model_outputs
=
model_outputs
,
layer
_num
=
cascade_num
,
cascade
_num
=
cascade_num
,
regression_weights
=
regression_weights
)
regression_weights
=
regression_weights
)
all_class_outputs
.
append
(
class_outputs
)
all_class_outputs
.
append
(
class_outputs
)
...
@@ -266,7 +270,7 @@ class MaskRCNNModel(tf.keras.Model):
...
@@ -266,7 +270,7 @@ class MaskRCNNModel(tf.keras.Model):
return
model_outputs
return
model_outputs
def
_run_frcnn_head
(
self
,
features
,
rois
,
gt_boxes
,
gt_classes
,
training
,
def
_run_frcnn_head
(
self
,
features
,
rois
,
gt_boxes
,
gt_classes
,
training
,
model_outputs
,
layer
_num
,
regression_weights
):
model_outputs
,
cascade
_num
,
regression_weights
):
"""Runs the frcnn head that does both class and box prediction.
"""Runs the frcnn head that does both class and box prediction.
Args:
Args:
...
@@ -279,7 +283,7 @@ class MaskRCNNModel(tf.keras.Model):
...
@@ -279,7 +283,7 @@ class MaskRCNNModel(tf.keras.Model):
classes. It is padded with -1s to indicate the invalid classes.
classes. It is padded with -1s to indicate the invalid classes.
training: `bool`, if model is training or being evaluated.
training: `bool`, if model is training or being evaluated.
model_outputs: `dict`, used for storing outputs used for eval and losses.
model_outputs: `dict`, used for storing outputs used for eval and losses.
layer
_num: `int`, the current frcnn layer in the cascade.
cascade
_num: `int`, the current frcnn layer in the cascade.
regression_weights: `list`, weights used for l1 loss in bounding box
regression_weights: `list`, weights used for l1 loss in bounding box
regression.
regression.
...
@@ -305,7 +309,7 @@ class MaskRCNNModel(tf.keras.Model):
...
@@ -305,7 +309,7 @@ class MaskRCNNModel(tf.keras.Model):
if
training
and
gt_boxes
is
not
None
:
if
training
and
gt_boxes
is
not
None
:
rois
=
tf
.
stop_gradient
(
rois
)
rois
=
tf
.
stop_gradient
(
rois
)
current_roi_sampler
=
self
.
roi_sampler
[
layer
_num
]
current_roi_sampler
=
self
.
roi_sampler
[
cascade
_num
]
rois
,
matched_gt_boxes
,
matched_gt_classes
,
matched_gt_indices
=
(
rois
,
matched_gt_boxes
,
matched_gt_classes
,
matched_gt_indices
=
(
current_roi_sampler
(
rois
,
gt_boxes
,
gt_classes
))
current_roi_sampler
(
rois
,
gt_boxes
,
gt_classes
))
# Create bounding box training targets.
# Create bounding box training targets.
...
@@ -317,10 +321,11 @@ class MaskRCNNModel(tf.keras.Model):
...
@@ -317,10 +321,11 @@ class MaskRCNNModel(tf.keras.Model):
tf
.
expand_dims
(
tf
.
equal
(
matched_gt_classes
,
0
),
axis
=-
1
),
tf
.
expand_dims
(
tf
.
equal
(
matched_gt_classes
,
0
),
axis
=-
1
),
[
1
,
1
,
4
]),
tf
.
zeros_like
(
box_targets
),
box_targets
)
[
1
,
1
,
4
]),
tf
.
zeros_like
(
box_targets
),
box_targets
)
model_outputs
.
update
({
model_outputs
.
update
({
'class_targets_{}'
.
format
(
layer
_num
)
'class_targets_{}'
.
format
(
cascade
_num
)
if
layer
_num
else
'class_targets'
:
if
cascade
_num
else
'class_targets'
:
matched_gt_classes
,
matched_gt_classes
,
'box_targets_{}'
.
format
(
layer_num
)
if
layer_num
else
'box_targets'
:
'box_targets_{}'
.
format
(
cascade_num
)
if
cascade_num
else
'box_targets'
:
box_targets
,
box_targets
,
})
})
...
@@ -328,12 +333,14 @@ class MaskRCNNModel(tf.keras.Model):
...
@@ -328,12 +333,14 @@ class MaskRCNNModel(tf.keras.Model):
roi_features
=
self
.
roi_aligner
(
features
,
rois
)
roi_features
=
self
.
roi_aligner
(
features
,
rois
)
# Run frcnn head to get class and bbox predictions.
# Run frcnn head to get class and bbox predictions.
class_outputs
,
box_outputs
=
self
.
detection_head
(
roi_features
)
current_detection_head
=
self
.
detection_head
[
cascade_num
]
class_outputs
,
box_outputs
=
current_detection_head
(
roi_features
)
model_outputs
.
update
({
model_outputs
.
update
({
'class_outputs_{}'
.
format
(
layer_num
)
if
layer_num
else
'class_outputs'
:
'class_outputs_{}'
.
format
(
cascade_num
)
if
cascade_num
else
'class_outputs'
:
class_outputs
,
class_outputs
,
'box_outputs_{}'
.
format
(
layer
_num
)
if
layer
_num
else
'box_outputs'
:
'box_outputs_{}'
.
format
(
cascade
_num
)
if
cascade
_num
else
'box_outputs'
:
box_outputs
,
box_outputs
,
})
})
return
(
class_outputs
,
box_outputs
,
model_outputs
,
matched_gt_boxes
,
return
(
class_outputs
,
box_outputs
,
model_outputs
,
matched_gt_boxes
,
...
...
official/vision/beta/modeling/maskrcnn_model_test.py
View file @
469326b6
...
@@ -373,7 +373,7 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -373,7 +373,7 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
backbone
=
backbone
,
backbone
=
backbone
,
decoder
=
decoder
,
decoder
=
decoder
,
rpn_head
=
rpn_head
,
rpn_head
=
rpn_head
,
detection_head
=
detection_head
)
detection_head
=
[
detection_head
]
)
if
include_mask
:
if
include_mask
:
expect_checkpoint_items
[
'mask_head'
]
=
mask_head
expect_checkpoint_items
[
'mask_head'
]
=
mask_head
self
.
assertAllEqual
(
expect_checkpoint_items
,
model
.
checkpoint_items
)
self
.
assertAllEqual
(
expect_checkpoint_items
,
model
.
checkpoint_items
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment