Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
26cfa492
Commit
26cfa492
authored
Oct 28, 2019
by
Yeqing Li
Committed by
A. Unique TensorFlower
Oct 28, 2019
Browse files
Update dataloaders.
PiperOrigin-RevId: 277104025
parent
d03efbf3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
2 deletions
+55
-2
official/vision/detection/dataloader/factory.py
official/vision/detection/dataloader/factory.py
+53
-1
official/vision/detection/utils/input_utils.py
official/vision/detection/utils/input_utils.py
+2
-1
No files found.
official/vision/detection/dataloader/factory.py
View file @
26cfa492
...
@@ -18,8 +18,9 @@ from __future__ import absolute_import
...
@@ -18,8 +18,9 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
official.vision.detection.dataloader
import
maskrcnn_parser
from
official.vision.detection.dataloader
import
retinanet_parser
from
official.vision.detection.dataloader
import
retinanet_parser
from
official.vision.detection.dataloader
import
shapemask_parser
def
parser_generator
(
params
,
mode
):
def
parser_generator
(
params
,
mode
):
"""Generator function for various dataset parser."""
"""Generator function for various dataset parser."""
...
@@ -44,6 +45,57 @@ def parser_generator(params, mode):
...
@@ -44,6 +45,57 @@ def parser_generator(params, mode):
max_num_instances
=
parser_params
.
max_num_instances
,
max_num_instances
=
parser_params
.
max_num_instances
,
use_bfloat16
=
parser_params
.
use_bfloat16
,
use_bfloat16
=
parser_params
.
use_bfloat16
,
mode
=
mode
)
mode
=
mode
)
elif
params
.
architecture
.
parser
==
'maskrcnn_parser'
:
anchor_params
=
params
.
anchor
parser_params
=
params
.
maskrcnn_parser
parser_fn
=
maskrcnn_parser
.
Parser
(
output_size
=
parser_params
.
output_size
,
min_level
=
anchor_params
.
min_level
,
max_level
=
anchor_params
.
max_level
,
num_scales
=
anchor_params
.
num_scales
,
aspect_ratios
=
anchor_params
.
aspect_ratios
,
anchor_size
=
anchor_params
.
anchor_size
,
rpn_match_threshold
=
parser_params
.
rpn_match_threshold
,
rpn_unmatched_threshold
=
parser_params
.
rpn_unmatched_threshold
,
rpn_batch_size_per_im
=
parser_params
.
rpn_batch_size_per_im
,
rpn_fg_fraction
=
parser_params
.
rpn_fg_fraction
,
aug_rand_hflip
=
parser_params
.
aug_rand_hflip
,
aug_scale_min
=
parser_params
.
aug_scale_min
,
aug_scale_max
=
parser_params
.
aug_scale_max
,
skip_crowd_during_training
=
parser_params
.
skip_crowd_during_training
,
max_num_instances
=
parser_params
.
max_num_instances
,
include_mask
=
parser_params
.
include_mask
,
mask_crop_size
=
parser_params
.
mask_crop_size
,
use_bfloat16
=
parser_params
.
use_bfloat16
,
mode
=
mode
)
elif
params
.
architecture
.
parser
==
'shapemask_parser'
:
anchor_params
=
params
.
anchor
parser_params
=
params
.
shapemask_parser
parser_fn
=
shapemask_parser
.
Parser
(
output_size
=
parser_params
.
output_size
,
min_level
=
anchor_params
.
min_level
,
max_level
=
anchor_params
.
max_level
,
num_scales
=
anchor_params
.
num_scales
,
aspect_ratios
=
anchor_params
.
aspect_ratios
,
anchor_size
=
anchor_params
.
anchor_size
,
use_category
=
parser_params
.
use_category
,
outer_box_scale
=
parser_params
.
outer_box_scale
,
box_jitter_scale
=
parser_params
.
box_jitter_scale
,
num_sampled_masks
=
parser_params
.
num_sampled_masks
,
mask_crop_size
=
parser_params
.
mask_crop_size
,
mask_min_level
=
parser_params
.
mask_min_level
,
mask_max_level
=
parser_params
.
mask_max_level
,
upsample_factor
=
parser_params
.
upsample_factor
,
match_threshold
=
parser_params
.
match_threshold
,
unmatched_threshold
=
parser_params
.
unmatched_threshold
,
aug_rand_hflip
=
parser_params
.
aug_rand_hflip
,
aug_scale_min
=
parser_params
.
aug_scale_min
,
aug_scale_max
=
parser_params
.
aug_scale_max
,
skip_crowd_during_training
=
parser_params
.
skip_crowd_during_training
,
max_num_instances
=
parser_params
.
max_num_instances
,
use_bfloat16
=
parser_params
.
use_bfloat16
,
mask_train_class
=
parser_params
.
mask_train_class
,
mode
=
mode
)
else
:
else
:
raise
ValueError
(
'Parser %s is not supported.'
%
params
.
architecture
.
parser
)
raise
ValueError
(
'Parser %s is not supported.'
%
params
.
architecture
.
parser
)
...
...
official/vision/detection/utils/input_utils.py
View file @
26cfa492
...
@@ -346,7 +346,8 @@ def resize_and_crop_masks(masks,
...
@@ -346,7 +346,8 @@ def resize_and_crop_masks(masks,
masks: `Tensor` of shape [N, H, W, 1] representing the scaled masks.
masks: `Tensor` of shape [N, H, W, 1] representing the scaled masks.
"""
"""
mask_size
=
tf
.
shape
(
input
=
masks
)[
1
:
3
]
mask_size
=
tf
.
shape
(
input
=
masks
)[
1
:
3
]
scaled_size
=
tf
.
cast
(
image_scale
*
mask_size
,
tf
.
int32
)
scaled_size
=
tf
.
cast
(
image_scale
*
tf
.
cast
(
mask_size
,
image_scale
.
dtype
),
tf
.
int32
)
scaled_masks
=
tf
.
image
.
resize
(
scaled_masks
=
tf
.
image
.
resize
(
masks
,
scaled_size
,
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
masks
,
scaled_size
,
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
offset
=
tf
.
cast
(
offset
,
tf
.
int32
)
offset
=
tf
.
cast
(
offset
,
tf
.
int32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment