Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b7af5b2d
Commit
b7af5b2d
authored
Apr 08, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Apr 08, 2021
Browse files
Internal change.
PiperOrigin-RevId: 367522154
parent
da8a5778
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
13 deletions
+32
-13
official/vision/beta/configs/common.py
official/vision/beta/configs/common.py
+3
-1
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+4
-0
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+19
-11
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+6
-1
No files found.
official/vision/beta/configs/common.py
View file @
b7af5b2d
...
@@ -16,8 +16,10 @@
...
@@ -16,8 +16,10 @@
"""Common configurations."""
"""Common configurations."""
# Import libraries
# Import libraries
import
dataclasses
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
...
@@ -30,7 +32,7 @@ class NormActivation(hyperparams.Config):
...
@@ -30,7 +32,7 @@ class NormActivation(hyperparams.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
PseudoLabelDataConfig
(
hyperparams
.
Config
):
class
PseudoLabelDataConfig
(
cfg
.
Data
Config
):
"""Psuedo Label input config for training."""
"""Psuedo Label input config for training."""
input_path
:
str
=
''
input_path
:
str
=
''
data_ratio
:
float
=
1.0
# Per-batch ratio of pseudo-labeled to labeled data
data_ratio
:
float
=
1.0
# Per-batch ratio of pseudo-labeled to labeled data
...
...
official/vision/beta/configs/image_classification.py
View file @
b7af5b2d
...
@@ -37,6 +37,8 @@ class DataConfig(cfg.DataConfig):
...
@@ -37,6 +37,8 @@ class DataConfig(cfg.DataConfig):
aug_policy
:
Optional
[
str
]
=
None
# None, 'autoaug', or 'randaug'
aug_policy
:
Optional
[
str
]
=
None
# None, 'autoaug', or 'randaug'
randaug_magnitude
:
Optional
[
int
]
=
10
randaug_magnitude
:
Optional
[
int
]
=
10
file_type
:
str
=
'tfrecord'
file_type
:
str
=
'tfrecord'
image_field_key
:
str
=
'image/encoded'
label_field_key
:
str
=
'image/class/label'
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -75,6 +77,8 @@ class ImageClassificationTask(cfg.TaskConfig):
...
@@ -75,6 +77,8 @@ class ImageClassificationTask(cfg.TaskConfig):
evaluation
:
Evaluation
=
Evaluation
()
evaluation
:
Evaluation
=
Evaluation
()
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
str
=
'all'
# all or backbone
init_checkpoint_modules
:
str
=
'all'
# all or backbone
model_output_keys
:
Optional
[
List
[
int
]]
=
dataclasses
.
field
(
default_factory
=
list
)
@
exp_factory
.
register_config_factory
(
'image_classification'
)
@
exp_factory
.
register_config_factory
(
'image_classification'
)
...
...
official/vision/beta/dataloaders/classification_input.py
View file @
b7af5b2d
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
"""Classification decoder and parser."""
"""Classification decoder and parser."""
from
typing
import
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -29,14 +29,16 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
...
@@ -29,14 +29,16 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class
Decoder
(
decoder
.
Decoder
):
class
Decoder
(
decoder
.
Decoder
):
"""A tf.Example decoder for classification task."""
"""A tf.Example decoder for classification task."""
def
__init__
(
self
):
def
__init__
(
self
,
image_field_key
:
str
=
'image/encoded'
,
label_field_key
:
str
=
'image/class/label'
):
self
.
_keys_to_features
=
{
self
.
_keys_to_features
=
{
'image/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
image_field_key
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/class/label'
:
(
label_field_key
:
(
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=-
1
))
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=-
1
))
}
}
def
decode
(
self
,
serialized_example
):
def
decode
(
self
,
serialized_example
:
tf
.
train
.
Example
)
->
Dict
[
str
,
tf
.
Tensor
]:
return
tf
.
io
.
parse_single_example
(
return
tf
.
io
.
parse_single_example
(
serialized_example
,
self
.
_keys_to_features
)
serialized_example
,
self
.
_keys_to_features
)
...
@@ -47,6 +49,8 @@ class Parser(parser.Parser):
...
@@ -47,6 +49,8 @@ class Parser(parser.Parser):
def
__init__
(
self
,
def
__init__
(
self
,
output_size
:
List
[
int
],
output_size
:
List
[
int
],
num_classes
:
float
,
num_classes
:
float
,
image_field_key
:
str
=
'image/encoded'
,
label_field_key
:
str
=
'image/class/label'
,
aug_rand_hflip
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_policy
:
Optional
[
str
]
=
None
,
aug_policy
:
Optional
[
str
]
=
None
,
randaug_magnitude
:
Optional
[
int
]
=
10
,
randaug_magnitude
:
Optional
[
int
]
=
10
,
...
@@ -57,6 +61,8 @@ class Parser(parser.Parser):
...
@@ -57,6 +61,8 @@ class Parser(parser.Parser):
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
output_size should be divided by the largest feature stride 2^max_level.
num_classes: `float`, number of classes.
num_classes: `float`, number of classes.
image_field_key: A `str` of the key name to encoded image in TFExample.
label_field_key: A `str` of the key name to label in TFExample.
aug_rand_hflip: `bool`, if True, augment training with random
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
horizontal flip.
aug_policy: `str`, augmentation policies. None, 'autoaug', or 'randaug'.
aug_policy: `str`, augmentation policies. None, 'autoaug', or 'randaug'.
...
@@ -67,6 +73,9 @@ class Parser(parser.Parser):
...
@@ -67,6 +73,9 @@ class Parser(parser.Parser):
self
.
_output_size
=
output_size
self
.
_output_size
=
output_size
self
.
_aug_rand_hflip
=
aug_rand_hflip
self
.
_aug_rand_hflip
=
aug_rand_hflip
self
.
_num_classes
=
num_classes
self
.
_num_classes
=
num_classes
self
.
_image_field_key
=
image_field_key
self
.
_label_field_key
=
label_field_key
if
dtype
==
'float32'
:
if
dtype
==
'float32'
:
self
.
_dtype
=
tf
.
float32
self
.
_dtype
=
tf
.
float32
elif
dtype
==
'float16'
:
elif
dtype
==
'float16'
:
...
@@ -89,9 +98,8 @@ class Parser(parser.Parser):
...
@@ -89,9 +98,8 @@ class Parser(parser.Parser):
def
_parse_train_data
(
self
,
decoded_tensors
):
def
_parse_train_data
(
self
,
decoded_tensors
):
"""Parses data for training."""
"""Parses data for training."""
label
=
tf
.
cast
(
decoded_tensors
[
'image/class/label'
],
dtype
=
tf
.
int32
)
label
=
tf
.
cast
(
decoded_tensors
[
self
.
_label_field_key
],
dtype
=
tf
.
int32
)
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
image_bytes
=
decoded_tensors
[
'image/encoded'
]
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
# Crops image.
# Crops image.
...
@@ -126,8 +134,8 @@ class Parser(parser.Parser):
...
@@ -126,8 +134,8 @@ class Parser(parser.Parser):
def
_parse_eval_data
(
self
,
decoded_tensors
):
def
_parse_eval_data
(
self
,
decoded_tensors
):
"""Parses data for evaluation."""
"""Parses data for evaluation."""
label
=
tf
.
cast
(
decoded_tensors
[
'image/class/label'
],
dtype
=
tf
.
int32
)
label
=
tf
.
cast
(
decoded_tensors
[
self
.
_label_field_key
],
dtype
=
tf
.
int32
)
image_bytes
=
decoded_tensors
[
'image/encoded'
]
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
# Center crops and resizes image.
# Center crops and resizes image.
...
...
official/vision/beta/tasks/image_classification.py
View file @
b7af5b2d
...
@@ -80,6 +80,8 @@ class ImageClassificationTask(base_task.Task):
...
@@ -80,6 +80,8 @@ class ImageClassificationTask(base_task.Task):
num_classes
=
self
.
task_config
.
model
.
num_classes
num_classes
=
self
.
task_config
.
model
.
num_classes
input_size
=
self
.
task_config
.
model
.
input_size
input_size
=
self
.
task_config
.
model
.
input_size
image_field_key
=
self
.
task_config
.
train_data
.
image_field_key
label_field_key
=
self
.
task_config
.
train_data
.
label_field_key
if
params
.
tfds_name
:
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
:
if
params
.
tfds_name
in
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
:
...
@@ -88,11 +90,14 @@ class ImageClassificationTask(base_task.Task):
...
@@ -88,11 +90,14 @@ class ImageClassificationTask(base_task.Task):
else
:
else
:
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
else
:
else
:
decoder
=
classification_input
.
Decoder
()
decoder
=
classification_input
.
Decoder
(
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
)
parser
=
classification_input
.
Parser
(
parser
=
classification_input
.
Parser
(
output_size
=
input_size
[:
2
],
output_size
=
input_size
[:
2
],
num_classes
=
num_classes
,
num_classes
=
num_classes
,
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
,
aug_policy
=
params
.
aug_policy
,
aug_policy
=
params
.
aug_policy
,
randaug_magnitude
=
params
.
randaug_magnitude
,
randaug_magnitude
=
params
.
randaug_magnitude
,
dtype
=
params
.
dtype
)
dtype
=
params
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment