Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8fa62b84
Commit
8fa62b84
authored
Apr 16, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Apr 16, 2021
Browse files
Internal change to image classification.
PiperOrigin-RevId: 368957441
parent
c2e19c97
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
112 additions
and
44 deletions
+112
-44
official/vision/beta/configs/common.py
official/vision/beta/configs/common.py
+36
-1
official/vision/beta/configs/experiments/image_classification/imagenet_resnet50_gpu.yaml
...periments/image_classification/imagenet_resnet50_gpu.yaml
+4
-4
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs101_i160.yaml
...ments/image_classification/imagenet_resnetrs101_i160.yaml
+4
-2
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs101_i192.yaml
...ments/image_classification/imagenet_resnetrs101_i192.yaml
+4
-2
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs152_i192.yaml
...ments/image_classification/imagenet_resnetrs152_i192.yaml
+4
-2
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs152_i224.yaml
...ments/image_classification/imagenet_resnetrs152_i224.yaml
+4
-2
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs152_i256.yaml
...ments/image_classification/imagenet_resnetrs152_i256.yaml
+4
-2
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs200_i256.yaml
...ments/image_classification/imagenet_resnetrs200_i256.yaml
+4
-2
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs270_i256.yaml
...ments/image_classification/imagenet_resnetrs270_i256.yaml
+4
-2
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs350_i256.yaml
...ments/image_classification/imagenet_resnetrs350_i256.yaml
+4
-2
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs350_i320.yaml
...ments/image_classification/imagenet_resnetrs350_i320.yaml
+4
-2
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs420_i320.yaml
...ments/image_classification/imagenet_resnetrs420_i320.yaml
+4
-2
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs50_i160.yaml
...iments/image_classification/imagenet_resnetrs50_i160.yaml
+4
-2
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+9
-4
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+17
-11
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+2
-2
No files found.
official/vision/beta/configs/common.py
View file @
8fa62b84
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
# Lint as: python3
# Lint as: python3
"""Common configurations."""
"""Common configurations."""
from
typing
import
Optional
# Import libraries
# Import libraries
import
dataclasses
import
dataclasses
...
@@ -23,6 +24,37 @@ from official.core import config_definitions as cfg
...
@@ -23,6 +24,37 @@ from official.core import config_definitions as cfg
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
@
dataclasses
.
dataclass
class
RandAugment
(
hyperparams
.
Config
):
"""Configuration for RandAugment."""
num_layers
:
int
=
2
magnitude
:
float
=
10
cutout_const
:
float
=
40
translate_const
:
float
=
10
@
dataclasses
.
dataclass
class
AutoAugment
(
hyperparams
.
Config
):
"""Configuration for AutoAugment."""
augmentation_name
:
str
=
'v0'
cutout_const
:
float
=
100
translate_const
:
float
=
250
@
dataclasses
.
dataclass
class
Augmentation
(
hyperparams
.
OneOfConfig
):
"""Configuration for input data augmentation.
Attributes:
type: 'str', type of augmentation be used, one of the fields below.
randaug: RandAugment config.
autoaug: AutoAugment config.
"""
type
:
Optional
[
str
]
=
None
randaug
:
RandAugment
=
RandAugment
()
autoaug
:
AutoAugment
=
AutoAugment
()
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
NormActivation
(
hyperparams
.
Config
):
class
NormActivation
(
hyperparams
.
Config
):
activation
:
str
=
'relu'
activation
:
str
=
'relu'
...
@@ -35,5 +67,8 @@ class NormActivation(hyperparams.Config):
...
@@ -35,5 +67,8 @@ class NormActivation(hyperparams.Config):
class
PseudoLabelDataConfig
(
cfg
.
DataConfig
):
class
PseudoLabelDataConfig
(
cfg
.
DataConfig
):
"""Psuedo Label input config for training."""
"""Psuedo Label input config for training."""
input_path
:
str
=
''
input_path
:
str
=
''
data_ratio
:
float
=
1.0
# Per-batch ratio of pseudo-labeled to labeled data
data_ratio
:
float
=
1.0
# Per-batch ratio of pseudo-labeled to labeled data.
aug_rand_hflip
:
bool
=
True
aug_type
:
Optional
[
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
file_type
:
str
=
'tfrecord'
file_type
:
str
=
'tfrecord'
official/vision/beta/configs/experiments/image_classification/imagenet_resnet50_gpu.yaml
View file @
8fa62b84
...
@@ -12,19 +12,19 @@ task:
...
@@ -12,19 +12,19 @@ task:
model_id
:
50
model_id
:
50
losses
:
losses
:
l2_weight_decay
:
0.0001
l2_weight_decay
:
0.0001
one_hot
:
T
rue
one_hot
:
t
rue
label_smoothing
:
0.1
label_smoothing
:
0.1
train_data
:
train_data
:
input_path
:
'
imagenet-2012-tfrecord/train*'
input_path
:
'
imagenet-2012-tfrecord/train*'
is_training
:
T
rue
is_training
:
t
rue
global_batch_size
:
2048
global_batch_size
:
2048
dtype
:
'
float16'
dtype
:
'
float16'
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
F
alse
is_training
:
f
alse
global_batch_size
:
2048
global_batch_size
:
2048
dtype
:
'
float16'
dtype
:
'
float16'
drop_remainder
:
F
alse
drop_remainder
:
f
alse
trainer
:
trainer
:
train_steps
:
56160
train_steps
:
56160
validation_steps
:
25
validation_steps
:
25
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs101_i160.yaml
View file @
8fa62b84
...
@@ -29,8 +29,10 @@ task:
...
@@ -29,8 +29,10 @@ task:
is_training
:
true
is_training
:
true
global_batch_size
:
4096
global_batch_size
:
4096
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
aug_type
:
randaug_magnitude
:
15
type
:
'
randaug'
randaug
:
magnitude
:
15
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
is_training
:
false
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs101_i192.yaml
View file @
8fa62b84
...
@@ -29,8 +29,10 @@ task:
...
@@ -29,8 +29,10 @@ task:
is_training
:
true
is_training
:
true
global_batch_size
:
4096
global_batch_size
:
4096
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
aug_type
:
randaug_magnitude
:
15
type
:
'
randaug'
randaug
:
magnitude
:
15
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
is_training
:
false
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs152_i192.yaml
View file @
8fa62b84
...
@@ -29,8 +29,10 @@ task:
...
@@ -29,8 +29,10 @@ task:
is_training
:
true
is_training
:
true
global_batch_size
:
4096
global_batch_size
:
4096
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
aug_type
:
randaug_magnitude
:
15
type
:
'
randaug'
randaug
:
magnitude
:
15
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
is_training
:
false
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs152_i224.yaml
View file @
8fa62b84
...
@@ -29,8 +29,10 @@ task:
...
@@ -29,8 +29,10 @@ task:
is_training
:
true
is_training
:
true
global_batch_size
:
4096
global_batch_size
:
4096
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
aug_type
:
randaug_magnitude
:
15
type
:
'
randaug'
randaug
:
magnitude
:
15
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
is_training
:
false
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs152_i256.yaml
View file @
8fa62b84
...
@@ -29,8 +29,10 @@ task:
...
@@ -29,8 +29,10 @@ task:
is_training
:
true
is_training
:
true
global_batch_size
:
4096
global_batch_size
:
4096
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
aug_type
:
randaug_magnitude
:
15
type
:
'
randaug'
randaug
:
magnitude
:
15
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
is_training
:
false
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs200_i256.yaml
View file @
8fa62b84
...
@@ -29,8 +29,10 @@ task:
...
@@ -29,8 +29,10 @@ task:
is_training
:
true
is_training
:
true
global_batch_size
:
4096
global_batch_size
:
4096
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
aug_type
:
randaug_magnitude
:
15
type
:
'
randaug'
randaug
:
magnitude
:
15
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
is_training
:
false
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs270_i256.yaml
View file @
8fa62b84
...
@@ -29,8 +29,10 @@ task:
...
@@ -29,8 +29,10 @@ task:
is_training
:
true
is_training
:
true
global_batch_size
:
4096
global_batch_size
:
4096
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
aug_type
:
randaug_magnitude
:
15
type
:
'
randaug'
randaug
:
magnitude
:
15
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
is_training
:
false
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs350_i256.yaml
View file @
8fa62b84
...
@@ -29,8 +29,10 @@ task:
...
@@ -29,8 +29,10 @@ task:
is_training
:
true
is_training
:
true
global_batch_size
:
4096
global_batch_size
:
4096
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
aug_type
:
randaug_magnitude
:
15
type
:
'
randaug'
randaug
:
magnitude
:
15
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
is_training
:
false
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs350_i320.yaml
View file @
8fa62b84
...
@@ -29,8 +29,10 @@ task:
...
@@ -29,8 +29,10 @@ task:
is_training
:
true
is_training
:
true
global_batch_size
:
4096
global_batch_size
:
4096
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
aug_type
:
randaug_magnitude
:
15
type
:
'
randaug'
randaug
:
magnitude
:
15
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
is_training
:
false
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs420_i320.yaml
View file @
8fa62b84
...
@@ -28,8 +28,10 @@ task:
...
@@ -28,8 +28,10 @@ task:
is_training
:
true
is_training
:
true
global_batch_size
:
4096
global_batch_size
:
4096
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
aug_type
:
randaug_magnitude
:
15
type
:
'
randaug'
randaug
:
magnitude
:
15
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
is_training
:
false
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs50_i160.yaml
View file @
8fa62b84
...
@@ -29,8 +29,10 @@ task:
...
@@ -29,8 +29,10 @@ task:
is_training
:
true
is_training
:
true
global_batch_size
:
4096
global_batch_size
:
4096
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
aug_type
:
randaug_magnitude
:
10
type
:
'
randaug'
randaug
:
magnitude
:
10
validation_data
:
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
is_training
:
false
...
...
official/vision/beta/configs/image_classification.py
View file @
8fa62b84
...
@@ -34,12 +34,17 @@ class DataConfig(cfg.DataConfig):
...
@@ -34,12 +34,17 @@ class DataConfig(cfg.DataConfig):
dtype
:
str
=
'float32'
dtype
:
str
=
'float32'
shuffle_buffer_size
:
int
=
10000
shuffle_buffer_size
:
int
=
10000
cycle_length
:
int
=
10
cycle_length
:
int
=
10
aug_policy
:
Optional
[
str
]
=
None
# None, 'autoaug', or 'randaug'
aug_rand_hflip
:
bool
=
True
randaug_magnitude
:
Optional
[
int
]
=
10
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
file_type
:
str
=
'tfrecord'
file_type
:
str
=
'tfrecord'
image_field_key
:
str
=
'image/encoded'
image_field_key
:
str
=
'image/encoded'
label_field_key
:
str
=
'image/class/label'
label_field_key
:
str
=
'image/class/label'
# Keep for backward compatibility.
aug_policy
:
Optional
[
str
]
=
None
# None, 'autoaug', or 'randaug'.
randaug_magnitude
:
Optional
[
int
]
=
10
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ImageClassificationModel
(
hyperparams
.
Config
):
class
ImageClassificationModel
(
hyperparams
.
Config
):
...
@@ -198,8 +203,8 @@ def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig:
...
@@ -198,8 +203,8 @@ def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig:
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
is_training
=
True
,
global_batch_size
=
train_batch_size
,
global_batch_size
=
train_batch_size
,
aug_
policy
=
'randaug'
,
aug_
type
=
common
.
Augmentation
(
randaug_
magnitude
=
10
),
type
=
'randaug'
,
randaug
=
common
.
RandAugment
(
magnitude
=
10
)
))
,
validation_data
=
DataConfig
(
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
is_training
=
False
,
...
...
official/vision/beta/dataloaders/classification_input.py
View file @
8fa62b84
...
@@ -17,6 +17,7 @@ from typing import Dict, List, Optional
...
@@ -17,6 +17,7 @@ from typing import Dict, List, Optional
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.configs
import
common
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.ops
import
augment
from
official.vision.beta.ops
import
augment
...
@@ -52,8 +53,7 @@ class Parser(parser.Parser):
...
@@ -52,8 +53,7 @@ class Parser(parser.Parser):
image_field_key
:
str
=
'image/encoded'
,
image_field_key
:
str
=
'image/encoded'
,
label_field_key
:
str
=
'image/class/label'
,
label_field_key
:
str
=
'image/class/label'
,
aug_rand_hflip
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_policy
:
Optional
[
str
]
=
None
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
randaug_magnitude
:
Optional
[
int
]
=
10
,
dtype
:
str
=
'float32'
):
dtype
:
str
=
'float32'
):
"""Initializes parameters for parsing annotations in the dataset.
"""Initializes parameters for parsing annotations in the dataset.
...
@@ -65,8 +65,8 @@ class Parser(parser.Parser):
...
@@ -65,8 +65,8 @@ class Parser(parser.Parser):
label_field_key: A `str` of the key name to label in TFExample.
label_field_key: A `str` of the key name to label in TFExample.
aug_rand_hflip: `bool`, if True, augment training with random
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
horizontal flip.
aug_
policy: `str`, augmentation policies. None, 'a
uto
a
ug
', or 'randaug'.
aug_
type: An optional Augmentation object to choose from A
uto
A
ug
ment and
randaug_magnitude: `int`, magnitude of the r
and
a
ugment
policy
.
R
and
A
ugment.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
or 'bfloat16'.
"""
"""
...
@@ -84,15 +84,21 @@ class Parser(parser.Parser):
...
@@ -84,15 +84,21 @@ class Parser(parser.Parser):
self
.
_dtype
=
tf
.
bfloat16
self
.
_dtype
=
tf
.
bfloat16
else
:
else
:
raise
ValueError
(
'dtype {!r} is not supported!'
.
format
(
dtype
))
raise
ValueError
(
'dtype {!r} is not supported!'
.
format
(
dtype
))
if
aug_policy
:
if
aug_type
:
if
aug_policy
==
'autoaug'
:
if
aug_type
.
type
==
'autoaug'
:
self
.
_augmenter
=
augment
.
AutoAugment
()
self
.
_augmenter
=
augment
.
AutoAugment
(
elif
aug_policy
==
'randaug'
:
augmentation_name
=
aug_type
.
autoaug
.
augmentation_name
,
cutout_const
=
aug_type
.
autoaug
.
cutout_const
,
translate_const
=
aug_type
.
autoaug
.
translate_const
)
elif
aug_type
.
type
==
'randaug'
:
self
.
_augmenter
=
augment
.
RandAugment
(
self
.
_augmenter
=
augment
.
RandAugment
(
num_layers
=
2
,
magnitude
=
randaug_magnitude
)
num_layers
=
aug_type
.
randaug
.
num_layers
,
magnitude
=
aug_type
.
randaug
.
magnitude
,
cutout_const
=
aug_type
.
randaug
.
cutout_const
,
translate_const
=
aug_type
.
randaug
.
translate_const
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
'Augmentation policy {} not supported.'
.
format
(
'Augmentation policy {} not supported.'
.
format
(
aug_policy
))
aug_type
.
type
))
else
:
else
:
self
.
_augmenter
=
None
self
.
_augmenter
=
None
...
...
official/vision/beta/tasks/image_classification.py
View file @
8fa62b84
...
@@ -100,8 +100,8 @@ class ImageClassificationTask(base_task.Task):
...
@@ -100,8 +100,8 @@ class ImageClassificationTask(base_task.Task):
num_classes
=
num_classes
,
num_classes
=
num_classes
,
image_field_key
=
image_field_key
,
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
,
label_field_key
=
label_field_key
,
aug_
policy
=
params
.
aug_
policy
,
aug_
rand_hflip
=
params
.
aug_
rand_hflip
,
randaug_magnitude
=
params
.
randaug_magnitud
e
,
aug_type
=
params
.
aug_typ
e
,
dtype
=
params
.
dtype
)
dtype
=
params
.
dtype
)
reader
=
input_reader_factory
.
input_reader_generator
(
reader
=
input_reader_factory
.
input_reader_generator
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment