Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
eeff8d87
Commit
eeff8d87
authored
Jun 30, 2021
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Jun 30, 2021
Browse files
Internal change
PiperOrigin-RevId: 382422745
parent
f24c95ae
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
155 additions
and
18 deletions
+155
-18
official/vision/beta/dataloaders/tfds_factory.py
official/vision/beta/dataloaders/tfds_factory.py
+71
-0
official/vision/beta/dataloaders/tfds_factory_test.py
official/vision/beta/dataloaders/tfds_factory_test.py
+78
-0
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+2
-6
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+2
-6
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+2
-6
No files found.
official/vision/beta/dataloaders/tfds_factory.py
0 → 100644
View file @
eeff8d87
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFDS factory functions."""
from
official.vision.beta.dataloaders
import
decoder
as
base_decoder
from
official.vision.beta.dataloaders
import
tfds_detection_decoders
from
official.vision.beta.dataloaders
import
tfds_segmentation_decoders
from
official.vision.beta.dataloaders
import
tfds_classification_decoders
def
get_classification_decoder
(
tfds_name
:
str
)
->
base_decoder
.
Decoder
:
"""Gets classification decoder.
Args:
tfds_name: `str`, name of the tfds classification decoder.
Returns:
`base_decoder.Decoder` instance.
Raises:
ValueError if the tfds_name doesn't exist in the available decoders.
"""
if
tfds_name
in
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
[
tfds_name
]()
else
:
raise
ValueError
(
f
'TFDS Classification
{
tfds_name
}
is not supported'
)
return
decoder
def
get_detection_decoder
(
tfds_name
:
str
)
->
base_decoder
.
Decoder
:
"""Gets detection decoder.
Args:
tfds_name: `str`, name of the tfds detection decoder.
Returns:
`base_decoder.Decoder` instance.
Raises:
ValueError if the tfds_name doesn't exist in the available decoders.
"""
if
tfds_name
in
tfds_detection_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_detection_decoders
.
TFDS_ID_TO_DECODER_MAP
[
tfds_name
]()
else
:
raise
ValueError
(
f
'TFDS Detection
{
tfds_name
}
is not supported'
)
return
decoder
def
get_segmentation_decoder
(
tfds_name
:
str
)
->
base_decoder
.
Decoder
:
"""Gets segmentation decoder.
Args:
tfds_name: `str`, name of the tfds segmentation decoder.
Returns:
`base_decoder.Decoder` instance.
Raises:
ValueError if the tfds_name doesn't exist in the available decoders.
"""
if
tfds_name
in
tfds_segmentation_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_segmentation_decoders
.
TFDS_ID_TO_DECODER_MAP
[
tfds_name
]()
else
:
raise
ValueError
(
f
'TFDS Segmentation
{
tfds_name
}
is not supported'
)
return
decoder
official/vision/beta/dataloaders/tfds_factory_test.py
0 → 100644
View file @
eeff8d87
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfds factory functions."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
decoder
as
base_decoder
from
official.vision.beta.dataloaders
import
tfds_factory
class
TFDSFactoryTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
(
'imagenet2012'
),
(
'cifar10'
),
(
'cifar100'
),
)
def
test_classification_decoder
(
self
,
tfds_name
):
decoder
=
tfds_factory
.
get_classification_decoder
(
tfds_name
)
self
.
assertIsInstance
(
decoder
,
base_decoder
.
Decoder
)
@
parameterized
.
parameters
(
(
'flowers'
),
(
'coco'
),
)
def
test_doesnt_exit_classification_decoder
(
self
,
tfds_name
):
with
self
.
assertRaises
(
ValueError
):
_
=
tfds_factory
.
get_classification_decoder
(
tfds_name
)
@
parameterized
.
parameters
(
(
'coco'
),
(
'coco/2014'
),
(
'coco/2017'
),
)
def
test_detection_decoder
(
self
,
tfds_name
):
decoder
=
tfds_factory
.
get_detection_decoder
(
tfds_name
)
self
.
assertIsInstance
(
decoder
,
base_decoder
.
Decoder
)
@
parameterized
.
parameters
(
(
'pascal'
),
(
'cityscapes'
),
)
def
test_doesnt_exit_detection_decoder
(
self
,
tfds_name
):
with
self
.
assertRaises
(
ValueError
):
_
=
tfds_factory
.
get_detection_decoder
(
tfds_name
)
@
parameterized
.
parameters
(
(
'cityscapes'
),
(
'cityscapes/semantic_segmentation'
),
(
'cityscapes/semantic_segmentation_extra'
),
)
def
test_segmentation_decoder
(
self
,
tfds_name
):
decoder
=
tfds_factory
.
get_segmentation_decoder
(
tfds_name
)
self
.
assertIsInstance
(
decoder
,
base_decoder
.
Decoder
)
@
parameterized
.
parameters
(
(
'coco'
),
(
'imagenet'
),
)
def
test_doesnt_exit_segmentation_decoder
(
self
,
tfds_name
):
with
self
.
assertRaises
(
ValueError
):
_
=
tfds_factory
.
get_segmentation_decoder
(
tfds_name
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/tasks/image_classification.py
View file @
eeff8d87
...
@@ -24,7 +24,7 @@ from official.modeling import tf_utils
...
@@ -24,7 +24,7 @@ from official.modeling import tf_utils
from
official.vision.beta.configs
import
image_classification
as
exp_cfg
from
official.vision.beta.configs
import
image_classification
as
exp_cfg
from
official.vision.beta.dataloaders
import
classification_input
from
official.vision.beta.dataloaders
import
classification_input
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
tfds_
classification_decoders
from
official.vision.beta.dataloaders
import
tfds_
factory
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.modeling
import
factory
...
@@ -89,11 +89,7 @@ class ImageClassificationTask(base_task.Task):
...
@@ -89,11 +89,7 @@ class ImageClassificationTask(base_task.Task):
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
params
.
tfds_name
:
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_factory
.
get_classification_decoder
(
params
.
tfds_name
)
decoder
=
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
[
params
.
tfds_name
]()
else
:
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
else
:
else
:
decoder
=
classification_input
.
Decoder
(
decoder
=
classification_input
.
Decoder
(
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
,
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
,
...
...
official/vision/beta/tasks/retinanet.py
View file @
eeff8d87
...
@@ -25,7 +25,7 @@ from official.vision.beta.configs import retinanet as exp_cfg
...
@@ -25,7 +25,7 @@ from official.vision.beta.configs import retinanet as exp_cfg
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
retinanet_input
from
official.vision.beta.dataloaders
import
retinanet_input
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tfds_
detection_decoders
from
official.vision.beta.dataloaders
import
tfds_
factory
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.modeling
import
factory
...
@@ -90,11 +90,7 @@ class RetinaNetTask(base_task.Task):
...
@@ -90,11 +90,7 @@ class RetinaNetTask(base_task.Task):
"""Build input dataset."""
"""Build input dataset."""
if
params
.
tfds_name
:
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_detection_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_factory
.
get_detection_decoder
(
params
.
tfds_name
)
decoder
=
tfds_detection_decoders
.
TFDS_ID_TO_DECODER_MAP
[
params
.
tfds_name
]()
else
:
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
else
:
else
:
decoder_cfg
=
params
.
decoder
.
get
()
decoder_cfg
=
params
.
decoder
.
get
()
if
params
.
decoder
.
type
==
'simple_decoder'
:
if
params
.
decoder
.
type
==
'simple_decoder'
:
...
...
official/vision/beta/tasks/semantic_segmentation.py
View file @
eeff8d87
...
@@ -23,7 +23,7 @@ from official.core import task_factory
...
@@ -23,7 +23,7 @@ from official.core import task_factory
from
official.vision.beta.configs
import
semantic_segmentation
as
exp_cfg
from
official.vision.beta.configs
import
semantic_segmentation
as
exp_cfg
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
segmentation_input
from
official.vision.beta.dataloaders
import
segmentation_input
from
official.vision.beta.dataloaders
import
tfds_
segmentation_decoders
from
official.vision.beta.dataloaders
import
tfds_
factory
from
official.vision.beta.evaluation
import
segmentation_metrics
from
official.vision.beta.evaluation
import
segmentation_metrics
from
official.vision.beta.losses
import
segmentation_losses
from
official.vision.beta.losses
import
segmentation_losses
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.modeling
import
factory
...
@@ -87,11 +87,7 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -87,11 +87,7 @@ class SemanticSegmentationTask(base_task.Task):
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
if
params
.
tfds_name
:
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_segmentation_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_factory
.
get_segmentation_decoder
(
params
.
tfds_name
)
decoder
=
tfds_segmentation_decoders
.
TFDS_ID_TO_DECODER_MAP
[
params
.
tfds_name
]()
else
:
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
else
:
else
:
decoder
=
segmentation_input
.
Decoder
()
decoder
=
segmentation_input
.
Decoder
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment