Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0aa2134f
Commit
0aa2134f
authored
Aug 16, 2022
by
A. Unique TensorFlower
Browse files
Add data augmentation strategies to mitigate overfitting in ViTs.
PiperOrigin-RevId: 468079502
parent
19cbd05d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
52 additions
and
1 deletion
+52
-1
official/vision/configs/maskrcnn.py
official/vision/configs/maskrcnn.py
+2
-0
official/vision/dataloaders/maskrcnn_input.py
official/vision/dataloaders/maskrcnn_input.py
+31
-1
official/vision/ops/augment.py
official/vision/ops/augment.py
+17
-0
official/vision/ops/augment_test.py
official/vision/ops/augment_test.py
+1
-0
official/vision/tasks/maskrcnn.py
official/vision/tasks/maskrcnn.py
+1
-0
No files found.
official/vision/configs/maskrcnn.py
View file @
0aa2134f
...
@@ -36,6 +36,8 @@ class Parser(hyperparams.Config):
...
@@ -36,6 +36,8 @@ class Parser(hyperparams.Config):
aug_rand_hflip
:
bool
=
False
aug_rand_hflip
:
bool
=
False
aug_scale_min
:
float
=
1.0
aug_scale_min
:
float
=
1.0
aug_scale_max
:
float
=
1.0
aug_scale_max
:
float
=
1.0
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
skip_crowd_during_training
:
bool
=
True
skip_crowd_during_training
:
bool
=
True
max_num_instances
:
int
=
100
max_num_instances
:
int
=
100
rpn_match_threshold
:
float
=
0.7
rpn_match_threshold
:
float
=
0.7
...
...
official/vision/dataloaders/maskrcnn_input.py
View file @
0aa2134f
...
@@ -14,13 +14,16 @@
...
@@ -14,13 +14,16 @@
"""Data parser and processing for Mask R-CNN."""
"""Data parser and processing for Mask R-CNN."""
# Import libraries
from
typing
import
Optional
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.configs
import
common
from
official.vision.dataloaders
import
parser
from
official.vision.dataloaders
import
parser
from
official.vision.dataloaders
import
utils
from
official.vision.dataloaders
import
utils
from
official.vision.ops
import
anchor
from
official.vision.ops
import
anchor
from
official.vision.ops
import
augment
from
official.vision.ops
import
box_ops
from
official.vision.ops
import
box_ops
from
official.vision.ops
import
preprocess_ops
from
official.vision.ops
import
preprocess_ops
...
@@ -42,6 +45,7 @@ class Parser(parser.Parser):
...
@@ -42,6 +45,7 @@ class Parser(parser.Parser):
aug_rand_hflip
=
False
,
aug_rand_hflip
=
False
,
aug_scale_min
=
1.0
,
aug_scale_min
=
1.0
,
aug_scale_max
=
1.0
,
aug_scale_max
=
1.0
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
skip_crowd_during_training
=
True
,
skip_crowd_during_training
=
True
,
max_num_instances
=
100
,
max_num_instances
=
100
,
include_mask
=
False
,
include_mask
=
False
,
...
@@ -73,6 +77,9 @@ class Parser(parser.Parser):
...
@@ -73,6 +77,9 @@ class Parser(parser.Parser):
data augmentation during training.
data augmentation during training.
aug_scale_max: `float`, the maximum scale applied to `output_size` for
aug_scale_max: `float`, the maximum scale applied to `output_size` for
data augmentation during training.
data augmentation during training.
aug_type: An optional Augmentation object with params for AutoAugment.
The AutoAug policy should not use rotation/translation/shear.
Only in-place augmentations can be used.
skip_crowd_during_training: `bool`, if True, skip annotations labeled with
skip_crowd_during_training: `bool`, if True, skip annotations labeled with
`is_crowd` equals to 1.
`is_crowd` equals to 1.
max_num_instances: `int` number of maximum number of instances in an
max_num_instances: `int` number of maximum number of instances in an
...
@@ -104,6 +111,26 @@ class Parser(parser.Parser):
...
@@ -104,6 +111,26 @@ class Parser(parser.Parser):
self
.
_aug_scale_min
=
aug_scale_min
self
.
_aug_scale_min
=
aug_scale_min
self
.
_aug_scale_max
=
aug_scale_max
self
.
_aug_scale_max
=
aug_scale_max
if
aug_type
and
aug_type
.
type
:
if
aug_type
.
type
==
'autoaug'
:
self
.
_augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
aug_type
.
autoaug
.
augmentation_name
,
cutout_const
=
aug_type
.
autoaug
.
cutout_const
,
translate_const
=
aug_type
.
autoaug
.
translate_const
)
elif
aug_type
.
type
==
'randaug'
:
self
.
_augmenter
=
augment
.
RandAugment
(
num_layers
=
aug_type
.
randaug
.
num_layers
,
magnitude
=
aug_type
.
randaug
.
magnitude
,
cutout_const
=
aug_type
.
randaug
.
cutout_const
,
translate_const
=
aug_type
.
randaug
.
translate_const
,
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
,
exclude_ops
=
aug_type
.
randaug
.
exclude_ops
)
else
:
raise
ValueError
(
'Augmentation policy {} not supported.'
.
format
(
aug_type
.
type
))
else
:
self
.
_augmenter
=
None
# Mask.
# Mask.
self
.
_include_mask
=
include_mask
self
.
_include_mask
=
include_mask
self
.
_mask_crop_size
=
mask_crop_size
self
.
_mask_crop_size
=
mask_crop_size
...
@@ -167,6 +194,9 @@ class Parser(parser.Parser):
...
@@ -167,6 +194,9 @@ class Parser(parser.Parser):
# Gets original image and its size.
# Gets original image and its size.
image
=
data
[
'image'
]
image
=
data
[
'image'
]
if
self
.
_augmenter
is
not
None
:
image
=
self
.
_augmenter
.
distort
(
image
)
image_shape
=
tf
.
shape
(
image
)[
0
:
2
]
image_shape
=
tf
.
shape
(
image
)[
0
:
2
]
# Normalizes image with mean and std pixel values.
# Normalizes image with mean and std pixel values.
...
...
official/vision/ops/augment.py
View file @
0aa2134f
...
@@ -1623,6 +1623,7 @@ class AutoAugment(ImageAugment):
...
@@ -1623,6 +1623,7 @@ class AutoAugment(ImageAugment):
'svhn'
:
self
.
policy_svhn
(),
'svhn'
:
self
.
policy_svhn
(),
'reduced_imagenet'
:
self
.
policy_reduced_imagenet
(),
'reduced_imagenet'
:
self
.
policy_reduced_imagenet
(),
'panoptic_deeplab_policy'
:
self
.
panoptic_deeplab_policy
(),
'panoptic_deeplab_policy'
:
self
.
panoptic_deeplab_policy
(),
'vit'
:
self
.
vit
(),
}
}
if
not
policies
:
if
not
policies
:
...
@@ -1938,6 +1939,22 @@ class AutoAugment(ImageAugment):
...
@@ -1938,6 +1939,22 @@ class AutoAugment(ImageAugment):
[(
'Sharpness'
,
0.2
,
0.2
),
(
'Equalize'
,
0.2
,
1.4
)]]
[(
'Sharpness'
,
0.2
,
0.2
),
(
'Equalize'
,
0.2
,
1.4
)]]
return
policy
return
policy
@
staticmethod
def
vit
():
"""Autoaugment policy for a generic ViT."""
policy
=
[
[(
'Sharpness'
,
0.4
,
1.4
),
(
'Brightness'
,
0.2
,
2.0
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Equalize'
,
0.0
,
1.8
),
(
'Contrast'
,
0.2
,
2.0
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Sharpness'
,
0.2
,
1.8
),
(
'Color'
,
0.2
,
1.8
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Solarize'
,
0.2
,
1.4
),
(
'Equalize'
,
0.6
,
1.8
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Sharpness'
,
0.2
,
0.2
),
(
'Equalize'
,
0.2
,
1.4
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Sharpness'
,
0.4
,
7
),
(
'Invert'
,
0.6
,
8
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Invert'
,
0.6
,
4
),
(
'Equalize'
,
1.0
,
8
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Posterize'
,
0.6
,
7
),
(
'Posterize'
,
0.6
,
6
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Solarize'
,
0.6
,
5
),
(
'AutoContrast'
,
0.6
,
5
),
(
'Cutout'
,
0.8
,
8
)],
]
return
policy
@
staticmethod
@
staticmethod
def
policy_test
():
def
policy_test
():
"""Autoaugment test policy for debugging."""
"""Autoaugment test policy for debugging."""
...
...
official/vision/ops/augment_test.py
View file @
0aa2134f
...
@@ -96,6 +96,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -96,6 +96,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
'svhn'
,
'svhn'
,
'reduced_imagenet'
,
'reduced_imagenet'
,
'detection_v0'
,
'detection_v0'
,
'vit'
,
]
]
def
test_autoaugment
(
self
):
def
test_autoaugment
(
self
):
...
...
official/vision/tasks/maskrcnn.py
View file @
0aa2134f
...
@@ -155,6 +155,7 @@ class MaskRCNNTask(base_task.Task):
...
@@ -155,6 +155,7 @@ class MaskRCNNTask(base_task.Task):
aug_rand_hflip
=
params
.
parser
.
aug_rand_hflip
,
aug_rand_hflip
=
params
.
parser
.
aug_rand_hflip
,
aug_scale_min
=
params
.
parser
.
aug_scale_min
,
aug_scale_min
=
params
.
parser
.
aug_scale_min
,
aug_scale_max
=
params
.
parser
.
aug_scale_max
,
aug_scale_max
=
params
.
parser
.
aug_scale_max
,
aug_type
=
params
.
parser
.
aug_type
,
skip_crowd_during_training
=
params
.
parser
.
skip_crowd_during_training
,
skip_crowd_during_training
=
params
.
parser
.
skip_crowd_during_training
,
max_num_instances
=
params
.
parser
.
max_num_instances
,
max_num_instances
=
params
.
parser
.
max_num_instances
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment