Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5746b95d
Commit
5746b95d
authored
Oct 27, 2020
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Oct 27, 2020
Browse files
Internal change
PiperOrigin-RevId: 339275945
parent
b7cc196e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
16 deletions
+20
-16
official/vision/beta/configs/semantic_segmentation.py
official/vision/beta/configs/semantic_segmentation.py
+8
-8
official/vision/beta/configs/semantic_segmentation_test.py
official/vision/beta/configs/semantic_segmentation_test.py
+2
-2
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+1
-1
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+9
-5
No files found.
official/vision/beta/configs/semantic_segmentation.py
View file @
5746b95d
...
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Image
segmentation configuration definition."""
"""
Semantic
segmentation configuration definition."""
import
os
from
typing
import
List
,
Union
,
Optional
import
dataclasses
...
...
@@ -50,8 +50,8 @@ class SegmentationHead(hyperparams.Config):
@
dataclasses
.
dataclass
class
Image
SegmentationModel
(
hyperparams
.
Config
):
"""
Image
segmentation model config."""
class
Semantic
SegmentationModel
(
hyperparams
.
Config
):
"""
Semantic
segmentation model config."""
num_classes
:
int
=
0
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
min_level
:
int
=
3
...
...
@@ -73,9 +73,9 @@ class Losses(hyperparams.Config):
@
dataclasses
.
dataclass
class
Image
SegmentationTask
(
cfg
.
TaskConfig
):
class
Semantic
SegmentationTask
(
cfg
.
TaskConfig
):
"""The model config."""
model
:
Image
SegmentationModel
=
Image
SegmentationModel
()
model
:
Semantic
SegmentationModel
=
Semantic
SegmentationModel
()
train_data
:
DataConfig
=
DataConfig
(
is_training
=
True
)
validation_data
:
DataConfig
=
DataConfig
(
is_training
=
False
)
losses
:
Losses
=
Losses
()
...
...
@@ -89,7 +89,7 @@ class ImageSegmentationTask(cfg.TaskConfig):
def
semantic_segmentation
()
->
cfg
.
ExperimentConfig
:
"""Semantic segmentation general."""
return
cfg
.
ExperimentConfig
(
task
=
Image
SegmentationModel
(),
task
=
Semantic
SegmentationModel
(),
trainer
=
cfg
.
TrainerConfig
(),
restrictions
=
[
'task.train_data.is_training != None'
,
...
...
@@ -109,8 +109,8 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
eval_batch_size
=
8
steps_per_epoch
=
PASCAL_TRAIN_EXAMPLES
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
Image
SegmentationTask
(
model
=
Image
SegmentationModel
(
task
=
Semantic
SegmentationTask
(
model
=
Semantic
SegmentationModel
(
num_classes
=
21
,
# TODO(arashwan): test changing size to 513 to match deeplab.
input_size
=
[
512
,
512
,
3
],
...
...
official/vision/beta/configs/semantic_segmentation_test.py
View file @
5746b95d
...
...
@@ -31,9 +31,9 @@ class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
def
test_semantic_segmentation_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
Image
SegmentationTask
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
Semantic
SegmentationTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
Image
SegmentationModel
)
exp_cfg
.
Semantic
SegmentationModel
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
...
...
official/vision/beta/modeling/factory.py
View file @
5746b95d
...
...
@@ -241,7 +241,7 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
def
build_segmentation_model
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
segmentation_cfg
.
Image
SegmentationModel
,
model_config
:
segmentation_cfg
.
Semantic
SegmentationModel
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
"""Builds Segmentation model."""
backbone
=
backbones
.
factory
.
build_backbone
(
...
...
official/vision/beta/tasks/semantic_segmentation.py
View file @
5746b95d
...
...
@@ -28,9 +28,9 @@ from official.vision.beta.losses import segmentation_losses
from
official.vision.beta.modeling
import
factory
@
task_factory
.
register_task_cls
(
exp_cfg
.
Image
SegmentationTask
)
class
Image
SegmentationTask
(
base_task
.
Task
):
"""A task for
image
classification."""
@
task_factory
.
register_task_cls
(
exp_cfg
.
Semantic
SegmentationTask
)
class
Semantic
SegmentationTask
(
base_task
.
Task
):
"""A task for
semantic
classification."""
def
build_model
(
self
):
"""Builds classification model."""
...
...
@@ -219,8 +219,12 @@ class ImageSegmentationTask(base_task.Task):
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
if
self
.
task_config
.
validation_data
.
resize_eval_groundtruth
:
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
else
:
loss
=
0
logs
=
{
self
.
loss
:
loss
}
logs
.
update
({
self
.
miou_metric
.
name
:
(
labels
,
outputs
)})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment