Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3f1ca33a
Commit
3f1ca33a
authored
Aug 16, 2022
by
A. Unique TensorFlower
Browse files
Add data augmentation strategies to mitigate overfitting in ViTs.
PiperOrigin-RevId: 468079502
parent
aa04c2de
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
52 additions
and
1 deletion
+52
-1
official/vision/configs/maskrcnn.py
official/vision/configs/maskrcnn.py
+2
-0
official/vision/dataloaders/maskrcnn_input.py
official/vision/dataloaders/maskrcnn_input.py
+31
-1
official/vision/ops/augment.py
official/vision/ops/augment.py
+17
-0
official/vision/ops/augment_test.py
official/vision/ops/augment_test.py
+1
-0
official/vision/tasks/maskrcnn.py
official/vision/tasks/maskrcnn.py
+1
-0
No files found.
official/vision/configs/maskrcnn.py
View file @
3f1ca33a
...
...
@@ -36,6 +36,8 @@ class Parser(hyperparams.Config):
aug_rand_hflip
:
bool
=
False
aug_scale_min
:
float
=
1.0
aug_scale_max
:
float
=
1.0
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
skip_crowd_during_training
:
bool
=
True
max_num_instances
:
int
=
100
rpn_match_threshold
:
float
=
0.7
...
...
official/vision/dataloaders/maskrcnn_input.py
View file @
3f1ca33a
...
...
@@ -14,13 +14,16 @@
"""Data parser and processing for Mask R-CNN."""
# Import libraries
from
typing
import
Optional
# Import libraries
import
tensorflow
as
tf
from
official.vision.configs
import
common
from
official.vision.dataloaders
import
parser
from
official.vision.dataloaders
import
utils
from
official.vision.ops
import
anchor
from
official.vision.ops
import
augment
from
official.vision.ops
import
box_ops
from
official.vision.ops
import
preprocess_ops
...
...
@@ -42,6 +45,7 @@ class Parser(parser.Parser):
aug_rand_hflip
=
False
,
aug_scale_min
=
1.0
,
aug_scale_max
=
1.0
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
skip_crowd_during_training
=
True
,
max_num_instances
=
100
,
include_mask
=
False
,
...
...
@@ -73,6 +77,9 @@ class Parser(parser.Parser):
data augmentation during training.
aug_scale_max: `float`, the maximum scale applied to `output_size` for
data augmentation during training.
aug_type: An optional Augmentation object with params for AutoAugment.
The AutoAug policy should not use rotation/translation/shear.
Only in-place augmentations can be used.
skip_crowd_during_training: `bool`, if True, skip annotations labeled with
`is_crowd` equals to 1.
max_num_instances: `int` number of maximum number of instances in an
...
...
@@ -104,6 +111,26 @@ class Parser(parser.Parser):
self
.
_aug_scale_min
=
aug_scale_min
self
.
_aug_scale_max
=
aug_scale_max
if
aug_type
and
aug_type
.
type
:
if
aug_type
.
type
==
'autoaug'
:
self
.
_augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
aug_type
.
autoaug
.
augmentation_name
,
cutout_const
=
aug_type
.
autoaug
.
cutout_const
,
translate_const
=
aug_type
.
autoaug
.
translate_const
)
elif
aug_type
.
type
==
'randaug'
:
self
.
_augmenter
=
augment
.
RandAugment
(
num_layers
=
aug_type
.
randaug
.
num_layers
,
magnitude
=
aug_type
.
randaug
.
magnitude
,
cutout_const
=
aug_type
.
randaug
.
cutout_const
,
translate_const
=
aug_type
.
randaug
.
translate_const
,
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
,
exclude_ops
=
aug_type
.
randaug
.
exclude_ops
)
else
:
raise
ValueError
(
'Augmentation policy {} not supported.'
.
format
(
aug_type
.
type
))
else
:
self
.
_augmenter
=
None
# Mask.
self
.
_include_mask
=
include_mask
self
.
_mask_crop_size
=
mask_crop_size
...
...
@@ -167,6 +194,9 @@ class Parser(parser.Parser):
# Gets original image and its size.
image
=
data
[
'image'
]
if
self
.
_augmenter
is
not
None
:
image
=
self
.
_augmenter
.
distort
(
image
)
image_shape
=
tf
.
shape
(
image
)[
0
:
2
]
# Normalizes image with mean and std pixel values.
...
...
official/vision/ops/augment.py
View file @
3f1ca33a
...
...
@@ -1623,6 +1623,7 @@ class AutoAugment(ImageAugment):
'svhn'
:
self
.
policy_svhn
(),
'reduced_imagenet'
:
self
.
policy_reduced_imagenet
(),
'panoptic_deeplab_policy'
:
self
.
panoptic_deeplab_policy
(),
'vit'
:
self
.
vit
(),
}
if
not
policies
:
...
...
@@ -1938,6 +1939,22 @@ class AutoAugment(ImageAugment):
[(
'Sharpness'
,
0.2
,
0.2
),
(
'Equalize'
,
0.2
,
1.4
)]]
return
policy
@
staticmethod
def
vit
():
"""Autoaugment policy for a generic ViT."""
policy
=
[
[(
'Sharpness'
,
0.4
,
1.4
),
(
'Brightness'
,
0.2
,
2.0
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Equalize'
,
0.0
,
1.8
),
(
'Contrast'
,
0.2
,
2.0
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Sharpness'
,
0.2
,
1.8
),
(
'Color'
,
0.2
,
1.8
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Solarize'
,
0.2
,
1.4
),
(
'Equalize'
,
0.6
,
1.8
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Sharpness'
,
0.2
,
0.2
),
(
'Equalize'
,
0.2
,
1.4
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Sharpness'
,
0.4
,
7
),
(
'Invert'
,
0.6
,
8
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Invert'
,
0.6
,
4
),
(
'Equalize'
,
1.0
,
8
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Posterize'
,
0.6
,
7
),
(
'Posterize'
,
0.6
,
6
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Solarize'
,
0.6
,
5
),
(
'AutoContrast'
,
0.6
,
5
),
(
'Cutout'
,
0.8
,
8
)],
]
return
policy
@
staticmethod
def
policy_test
():
"""Autoaugment test policy for debugging."""
...
...
official/vision/ops/augment_test.py
View file @
3f1ca33a
...
...
@@ -96,6 +96,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
'svhn'
,
'reduced_imagenet'
,
'detection_v0'
,
'vit'
,
]
def
test_autoaugment
(
self
):
...
...
official/vision/tasks/maskrcnn.py
View file @
3f1ca33a
...
...
@@ -155,6 +155,7 @@ class MaskRCNNTask(base_task.Task):
aug_rand_hflip
=
params
.
parser
.
aug_rand_hflip
,
aug_scale_min
=
params
.
parser
.
aug_scale_min
,
aug_scale_max
=
params
.
parser
.
aug_scale_max
,
aug_type
=
params
.
parser
.
aug_type
,
skip_crowd_during_training
=
params
.
parser
.
skip_crowd_during_training
,
max_num_instances
=
params
.
parser
.
max_num_instances
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment