Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8bc64372
Commit
8bc64372
authored
Nov 23, 2020
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Nov 23, 2020
Browse files
Internal change
PiperOrigin-RevId: 343888044
parent
50efd367
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
197 additions
and
109 deletions
+197
-109
official/vision/beta/configs/backbones.py
official/vision/beta/configs/backbones.py
+4
-1
official/vision/beta/configs/decoders.py
official/vision/beta/configs/decoders.py
+1
-0
official/vision/beta/configs/experiments/image_classification/imagenet_resnet101_deeplab_tpu.yaml
.../image_classification/imagenet_resnet101_deeplab_tpu.yaml
+4
-0
official/vision/beta/configs/experiments/semantic_segmentation/deeplab_resnet101_pascal_tpu.yaml
...s/semantic_segmentation/deeplab_resnet101_pascal_tpu.yaml
+0
-15
official/vision/beta/configs/experiments/semantic_segmentation/deeplab_resnet50_pascal_tpu.yaml
...ts/semantic_segmentation/deeplab_resnet50_pascal_tpu.yaml
+0
-14
official/vision/beta/configs/experiments/semantic_segmentation/deeplabv3plus_resnet101_pascal_tpu.yaml
...ntic_segmentation/deeplabv3plus_resnet101_pascal_tpu.yaml
+0
-17
official/vision/beta/configs/experiments/semantic_segmentation/deeplabv3plus_resnet50_pascal_tpu.yaml
...antic_segmentation/deeplabv3plus_resnet50_pascal_tpu.yaml
+0
-16
official/vision/beta/configs/semantic_segmentation.py
official/vision/beta/configs/semantic_segmentation.py
+26
-12
official/vision/beta/dataloaders/segmentation_input.py
official/vision/beta/dataloaders/segmentation_input.py
+21
-2
official/vision/beta/losses/segmentation_losses.py
official/vision/beta/losses/segmentation_losses.py
+17
-3
official/vision/beta/modeling/backbones/resnet_deeplab.py
official/vision/beta/modeling/backbones/resnet_deeplab.py
+90
-15
official/vision/beta/modeling/decoders/aspp.py
official/vision/beta/modeling/decoders/aspp.py
+12
-0
official/vision/beta/modeling/decoders/aspp_test.py
official/vision/beta/modeling/decoders/aspp_test.py
+1
-0
official/vision/beta/modeling/decoders/factory.py
official/vision/beta/modeling/decoders/factory.py
+1
-0
official/vision/beta/modeling/heads/segmentation_heads.py
official/vision/beta/modeling/heads/segmentation_heads.py
+8
-3
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+12
-11
No files found.
official/vision/beta/configs/backbones.py
View file @
8bc64372
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Backbones configurations."""
"""Backbones configurations."""
from
typing
import
Optional
from
typing
import
Optional
,
List
# Import libraries
# Import libraries
import
dataclasses
import
dataclasses
...
@@ -36,6 +36,9 @@ class DilatedResNet(hyperparams.Config):
...
@@ -36,6 +36,9 @@ class DilatedResNet(hyperparams.Config):
"""DilatedResNet config."""
"""DilatedResNet config."""
model_id
:
int
=
50
model_id
:
int
=
50
output_stride
:
int
=
16
output_stride
:
int
=
16
multigrid
:
Optional
[
List
[
int
]]
=
None
stem_type
:
str
=
'v0'
last_stage_repeats
:
int
=
1
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/configs/decoders.py
View file @
8bc64372
...
@@ -50,6 +50,7 @@ class ASPP(hyperparams.Config):
...
@@ -50,6 +50,7 @@ class ASPP(hyperparams.Config):
dilation_rates
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
dilation_rates
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
dropout_rate
:
float
=
0.0
dropout_rate
:
float
=
0.0
num_filters
:
int
=
256
num_filters
:
int
=
256
pool_kernel_size
:
Optional
[
List
[
int
]]
=
None
# Use global average pooling.
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnet101_deeplab_tpu.yaml
View file @
8bc64372
# Top1 accuracy 80.36%
runtime
:
runtime
:
distribution_strategy
:
'
tpu'
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
mixed_precision_dtype
:
'
bfloat16'
...
@@ -10,6 +11,9 @@ task:
...
@@ -10,6 +11,9 @@ task:
dilated_resnet
:
dilated_resnet
:
model_id
:
101
model_id
:
101
output_stride
:
16
output_stride
:
16
stem_type
:
'
v1'
multigrid
:
[
1
,
2
,
4
]
last_stage_repeats
:
1
norm_activation
:
norm_activation
:
activation
:
'
swish'
activation
:
'
swish'
losses
:
losses
:
...
...
official/vision/beta/configs/experiments/semantic_segmentation/deeplab_resnet101_pascal_tpu.yaml
deleted
100644 → 0
View file @
50efd367
# Dilated ResNet-101 Pascal segmentation. 80.89 mean IOU.
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
float32'
task
:
model
:
backbone
:
type
:
'
dilated_resnet'
dilated_resnet
:
model_id
:
101
output_stride
:
8
norm_activation
:
activation
:
'
swish'
init_checkpoint
:
'
gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400'
init_checkpoint_modules
:
'
backbone'
official/vision/beta/configs/experiments/semantic_segmentation/deeplab_resnet50_pascal_tpu.yaml
deleted
100644 → 0
View file @
50efd367
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
float32'
task
:
model
:
backbone
:
type
:
'
dilated_resnet'
dilated_resnet
:
model_id
:
50
output_stride
:
8
norm_activation
:
activation
:
'
swish'
init_checkpoint
:
'
gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400'
init_checkpoint_modules
:
'
backbone'
official/vision/beta/configs/experiments/semantic_segmentation/deeplabv3plus_resnet101_pascal_tpu.yaml
deleted
100644 → 0
View file @
50efd367
# Dilated ResNet-101 Pascal segmentation. 80.83 mean IOU with output stride of 16.
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
float32'
task
:
model
:
backbone
:
type
:
'
dilated_resnet'
dilated_resnet
:
model_id
:
101
output_stride
:
16
head
:
feature_fusion
:
'
deeplabv3plus'
low_level
:
2
low_level_num_filters
:
48
init_checkpoint
:
'
gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400'
init_checkpoint_modules
:
'
backbone'
official/vision/beta/configs/experiments/semantic_segmentation/deeplabv3plus_resnet50_pascal_tpu.yaml
deleted
100644 → 0
View file @
50efd367
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
float32'
task
:
model
:
backbone
:
type
:
'
dilated_resnet'
dilated_resnet
:
model_id
:
50
output_stride
:
16
head
:
feature_fusion
:
'
deeplabv3plus'
low_level
:
2
low_level_num_filters
:
48
init_checkpoint
:
'
gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400'
init_checkpoint_modules
:
'
backbone'
official/vision/beta/configs/semantic_segmentation.py
View file @
8bc64372
...
@@ -32,6 +32,8 @@ from official.vision.beta.configs import decoders
...
@@ -32,6 +32,8 @@ from official.vision.beta.configs import decoders
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
DataConfig
(
cfg
.
DataConfig
):
class
DataConfig
(
cfg
.
DataConfig
):
"""Input config for training."""
"""Input config for training."""
output_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
train_on_crops
:
bool
=
False
input_path
:
str
=
''
input_path
:
str
=
''
global_batch_size
:
int
=
0
global_batch_size
:
int
=
0
is_training
:
bool
=
True
is_training
:
bool
=
True
...
@@ -42,6 +44,7 @@ class DataConfig(cfg.DataConfig):
...
@@ -42,6 +44,7 @@ class DataConfig(cfg.DataConfig):
groundtruth_padded_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
groundtruth_padded_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
aug_scale_min
:
float
=
1.0
aug_scale_min
:
float
=
1.0
aug_scale_max
:
float
=
1.0
aug_scale_max
:
float
=
1.0
aug_rand_hflip
:
bool
=
True
drop_remainder
:
bool
=
True
drop_remainder
:
bool
=
True
...
@@ -73,11 +76,12 @@ class SemanticSegmentationModel(hyperparams.Config):
...
@@ -73,11 +76,12 @@ class SemanticSegmentationModel(hyperparams.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
class
Losses
(
hyperparams
.
Config
):
label_smoothing
:
float
=
0.
1
label_smoothing
:
float
=
0.
0
ignore_label
:
int
=
255
ignore_label
:
int
=
255
class_weights
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
class_weights
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
l2_weight_decay
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
use_groundtruth_dimension
:
bool
=
True
use_groundtruth_dimension
:
bool
=
True
top_k_percent_pixels
:
float
=
1.0
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -115,18 +119,20 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
...
@@ -115,18 +119,20 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
train_batch_size
=
16
train_batch_size
=
16
eval_batch_size
=
8
eval_batch_size
=
8
steps_per_epoch
=
PASCAL_TRAIN_EXAMPLES
//
train_batch_size
steps_per_epoch
=
PASCAL_TRAIN_EXAMPLES
//
train_batch_size
output_stride
=
8
output_stride
=
16
aspp_dilation_rates
=
[
12
,
24
,
36
]
# [6, 12, 18] if output_stride = 16
aspp_dilation_rates
=
[
12
,
24
,
36
]
# [6, 12, 18] if output_stride = 16
multigrid
=
[
1
,
2
,
4
]
stem_type
=
'v1'
level
=
int
(
np
.
math
.
log2
(
output_stride
))
level
=
int
(
np
.
math
.
log2
(
output_stride
))
config
=
cfg
.
ExperimentConfig
(
config
=
cfg
.
ExperimentConfig
(
task
=
SemanticSegmentationTask
(
task
=
SemanticSegmentationTask
(
model
=
SemanticSegmentationModel
(
model
=
SemanticSegmentationModel
(
num_classes
=
21
,
num_classes
=
21
,
# TODO(arashwan): test changing size to 513 to match deeplab.
input_size
=
[
None
,
None
,
3
],
input_size
=
[
512
,
512
,
3
],
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
model_id
=
50
,
output_stride
=
output_stride
)),
model_id
=
101
,
output_stride
=
output_stride
,
multigrid
=
multigrid
,
stem_type
=
stem_type
)),
decoder
=
decoders
.
Decoder
(
decoder
=
decoders
.
Decoder
(
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
level
=
level
,
dilation_rates
=
aspp_dilation_rates
)),
level
=
level
,
dilation_rates
=
aspp_dilation_rates
)),
...
@@ -139,19 +145,22 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
...
@@ -139,19 +145,22 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
train_data
=
DataConfig
(
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
PASCAL_INPUT_PATH_BASE
,
'train_aug*'
),
input_path
=
os
.
path
.
join
(
PASCAL_INPUT_PATH_BASE
,
'train_aug*'
),
# TODO(arashwan): test changing size to 513 to match deeplab.
output_size
=
[
512
,
512
],
is_training
=
True
,
is_training
=
True
,
global_batch_size
=
train_batch_size
,
global_batch_size
=
train_batch_size
,
aug_scale_min
=
0.5
,
aug_scale_min
=
0.5
,
aug_scale_max
=
2.0
),
aug_scale_max
=
2.0
),
validation_data
=
DataConfig
(
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
PASCAL_INPUT_PATH_BASE
,
'val*'
),
input_path
=
os
.
path
.
join
(
PASCAL_INPUT_PATH_BASE
,
'val*'
),
output_size
=
[
512
,
512
],
is_training
=
False
,
is_training
=
False
,
global_batch_size
=
eval_batch_size
,
global_batch_size
=
eval_batch_size
,
resize_eval_groundtruth
=
False
,
resize_eval_groundtruth
=
False
,
groundtruth_padded_size
=
[
512
,
512
],
groundtruth_padded_size
=
[
512
,
512
],
drop_remainder
=
False
),
drop_remainder
=
False
),
# resnet
50
# resnet
101
init_checkpoint
=
'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet
50
_imagenet/ckpt-62400'
,
init_checkpoint
=
'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet
101
_imagenet/ckpt-62400'
,
init_checkpoint_modules
=
'backbone'
),
init_checkpoint_modules
=
'backbone'
),
trainer
=
cfg
.
TrainerConfig
(
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
steps_per_loop
=
steps_per_epoch
,
...
@@ -199,16 +208,19 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
...
@@ -199,16 +208,19 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
eval_batch_size
=
8
eval_batch_size
=
8
steps_per_epoch
=
PASCAL_TRAIN_EXAMPLES
//
train_batch_size
steps_per_epoch
=
PASCAL_TRAIN_EXAMPLES
//
train_batch_size
output_stride
=
16
output_stride
=
16
aspp_dilation_rates
=
[
6
,
12
,
18
]
# [12, 24, 36] if output_stride = 8
aspp_dilation_rates
=
[
6
,
12
,
18
]
multigrid
=
[
1
,
2
,
4
]
stem_type
=
'v1'
level
=
int
(
np
.
math
.
log2
(
output_stride
))
level
=
int
(
np
.
math
.
log2
(
output_stride
))
config
=
cfg
.
ExperimentConfig
(
config
=
cfg
.
ExperimentConfig
(
task
=
SemanticSegmentationTask
(
task
=
SemanticSegmentationTask
(
model
=
SemanticSegmentationModel
(
model
=
SemanticSegmentationModel
(
num_classes
=
21
,
num_classes
=
21
,
input_size
=
[
512
,
512
,
3
],
input_size
=
[
None
,
None
,
3
],
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
model_id
=
50
,
output_stride
=
output_stride
)),
model_id
=
101
,
output_stride
=
output_stride
,
stem_type
=
stem_type
,
multigrid
=
multigrid
)),
decoder
=
decoders
.
Decoder
(
decoder
=
decoders
.
Decoder
(
type
=
'aspp'
,
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
aspp
=
decoders
.
ASPP
(
...
@@ -227,19 +239,21 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
...
@@ -227,19 +239,21 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
train_data
=
DataConfig
(
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
PASCAL_INPUT_PATH_BASE
,
'train_aug*'
),
input_path
=
os
.
path
.
join
(
PASCAL_INPUT_PATH_BASE
,
'train_aug*'
),
output_size
=
[
512
,
512
],
is_training
=
True
,
is_training
=
True
,
global_batch_size
=
train_batch_size
,
global_batch_size
=
train_batch_size
,
aug_scale_min
=
0.5
,
aug_scale_min
=
0.5
,
aug_scale_max
=
2.0
),
aug_scale_max
=
2.0
),
validation_data
=
DataConfig
(
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
PASCAL_INPUT_PATH_BASE
,
'val*'
),
input_path
=
os
.
path
.
join
(
PASCAL_INPUT_PATH_BASE
,
'val*'
),
output_size
=
[
512
,
512
],
is_training
=
False
,
is_training
=
False
,
global_batch_size
=
eval_batch_size
,
global_batch_size
=
eval_batch_size
,
resize_eval_groundtruth
=
False
,
resize_eval_groundtruth
=
False
,
groundtruth_padded_size
=
[
512
,
512
],
groundtruth_padded_size
=
[
512
,
512
],
drop_remainder
=
False
),
drop_remainder
=
False
),
# resnet
50
# resnet
101
init_checkpoint
=
'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet
50
_imagenet/ckpt-62400'
,
init_checkpoint
=
'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet
101
_imagenet/ckpt-62400'
,
init_checkpoint_modules
=
'backbone'
),
init_checkpoint_modules
=
'backbone'
),
trainer
=
cfg
.
TrainerConfig
(
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
steps_per_loop
=
steps_per_epoch
,
...
...
official/vision/beta/dataloaders/segmentation_input.py
View file @
8bc64372
...
@@ -38,10 +38,12 @@ class Decoder(decoder.Decoder):
...
@@ -38,10 +38,12 @@ class Decoder(decoder.Decoder):
class
Parser
(
parser
.
Parser
):
class
Parser
(
parser
.
Parser
):
"""Parser to parse an image and its annotations into a dictionary of tensors."""
"""Parser to parse an image and its annotations into a dictionary of tensors.
"""
def
__init__
(
self
,
def
__init__
(
self
,
output_size
,
output_size
,
train_on_crops
=
False
,
resize_eval_groundtruth
=
True
,
resize_eval_groundtruth
=
True
,
groundtruth_padded_size
=
None
,
groundtruth_padded_size
=
None
,
ignore_label
=
255
,
ignore_label
=
255
,
...
@@ -54,6 +56,9 @@ class Parser(parser.Parser):
...
@@ -54,6 +56,9 @@ class Parser(parser.Parser):
Args:
Args:
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
output_size should be divided by the largest feature stride 2^max_level.
train_on_crops: `bool`, if True, a training crop of size output_size
is returned. This is useful for cropping original images during training
while evaluating on original image sizes.
resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are
resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are
resized to output_size.
resized to output_size.
groundtruth_padded_size: `Tensor` or `list` for [height, width]. When
groundtruth_padded_size: `Tensor` or `list` for [height, width]. When
...
@@ -70,6 +75,7 @@ class Parser(parser.Parser):
...
@@ -70,6 +75,7 @@ class Parser(parser.Parser):
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
"""
"""
self
.
_output_size
=
output_size
self
.
_output_size
=
output_size
self
.
_train_on_crops
=
train_on_crops
self
.
_resize_eval_groundtruth
=
resize_eval_groundtruth
self
.
_resize_eval_groundtruth
=
resize_eval_groundtruth
if
(
not
resize_eval_groundtruth
)
and
(
groundtruth_padded_size
is
None
):
if
(
not
resize_eval_groundtruth
)
and
(
groundtruth_padded_size
is
None
):
raise
ValueError
(
'groundtruth_padded_size ([height, width]) needs to be'
raise
ValueError
(
'groundtruth_padded_size ([height, width]) needs to be'
...
@@ -104,9 +110,22 @@ class Parser(parser.Parser):
...
@@ -104,9 +110,22 @@ class Parser(parser.Parser):
"""Parses data for training and evaluation."""
"""Parses data for training and evaluation."""
image
,
label
=
self
.
_prepare_image_and_label
(
data
)
image
,
label
=
self
.
_prepare_image_and_label
(
data
)
if
self
.
_train_on_crops
:
if
data
[
'image/height'
]
<
self
.
_output_size
[
0
]
or
data
[
'image/width'
]
<
self
.
_output_size
[
1
]:
raise
ValueError
(
'Image size has to be larger than crop size (output_size)'
)
label
=
tf
.
reshape
(
label
,
[
data
[
'image/height'
],
data
[
'image/width'
],
1
])
image_mask
=
tf
.
concat
([
image
,
label
],
axis
=
2
)
image_mask_crop
=
tf
.
image
.
random_crop
(
image_mask
,
self
.
_output_size
+
[
4
])
image
=
image_mask_crop
[:,
:,
:
-
1
]
label
=
tf
.
reshape
(
image_mask_crop
[:,
:,
-
1
],
[
1
]
+
self
.
_output_size
)
# Flips image randomly during training.
# Flips image randomly during training.
if
self
.
_aug_rand_hflip
:
if
self
.
_aug_rand_hflip
:
image
,
label
=
preprocess_ops
.
random_horizontal_flip
(
image
,
masks
=
label
)
image
,
_
,
label
=
preprocess_ops
.
random_horizontal_flip
(
image
,
masks
=
label
)
# Resizes and crops image.
# Resizes and crops image.
image
,
image_info
=
preprocess_ops
.
resize_and_crop_image
(
image
,
image_info
=
preprocess_ops
.
resize_and_crop_image
(
...
...
official/vision/beta/losses/segmentation_losses.py
View file @
8bc64372
...
@@ -23,8 +23,9 @@ EPSILON = 1e-5
...
@@ -23,8 +23,9 @@ EPSILON = 1e-5
class
SegmentationLoss
:
class
SegmentationLoss
:
"""Semantic segmentation loss."""
"""Semantic segmentation loss."""
def
__init__
(
self
,
label_smoothing
,
class_weights
,
def
__init__
(
self
,
label_smoothing
,
class_weights
,
ignore_label
,
ignore_label
,
use_groundtruth_dimension
):
use_groundtruth_dimension
,
top_k_percent_pixels
=
1.0
):
self
.
_top_k_percent_pixels
=
top_k_percent_pixels
self
.
_class_weights
=
class_weights
self
.
_class_weights
=
class_weights
self
.
_ignore_label
=
ignore_label
self
.
_ignore_label
=
ignore_label
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
...
@@ -71,5 +72,18 @@ class SegmentationLoss:
...
@@ -71,5 +72,18 @@ class SegmentationLoss:
tf
.
constant
(
class_weights
,
tf
.
float32
))
tf
.
constant
(
class_weights
,
tf
.
float32
))
valid_mask
*=
weight_mask
valid_mask
*=
weight_mask
cross_entropy_loss
*=
tf
.
cast
(
valid_mask
,
tf
.
float32
)
cross_entropy_loss
*=
tf
.
cast
(
valid_mask
,
tf
.
float32
)
loss
=
tf
.
reduce_sum
(
cross_entropy_loss
)
/
normalizer
if
self
.
_top_k_percent_pixels
>=
1.0
:
loss
=
tf
.
reduce_sum
(
cross_entropy_loss
)
/
normalizer
else
:
cross_entropy_loss
=
tf
.
reshape
(
cross_entropy_loss
,
shape
=
[
-
1
])
top_k_pixels
=
tf
.
cast
(
self
.
_top_k_percent_pixels
*
tf
.
cast
(
tf
.
size
(
cross_entropy_loss
),
tf
.
float32
),
tf
.
int32
)
top_k_losses
,
_
=
tf
.
math
.
top_k
(
cross_entropy_loss
,
k
=
top_k_pixels
,
sorted
=
True
)
normalizer
=
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
not_equal
(
top_k_losses
,
0.0
),
tf
.
float32
)
+
EPSILON
)
loss
=
tf
.
reduce_sum
(
top_k_losses
)
/
normalizer
return
loss
return
loss
official/vision/beta/modeling/backbones/resnet_deeplab.py
View file @
8bc64372
...
@@ -56,6 +56,9 @@ class DilatedResNet(tf.keras.Model):
...
@@ -56,6 +56,9 @@ class DilatedResNet(tf.keras.Model):
model_id
,
model_id
,
output_stride
,
output_stride
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
stem_type
=
'v0'
,
multigrid
=
None
,
last_stage_repeats
=
1
,
activation
=
'relu'
,
activation
=
'relu'
,
use_sync_bn
=
False
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_momentum
=
0.99
,
...
@@ -70,6 +73,11 @@ class DilatedResNet(tf.keras.Model):
...
@@ -70,6 +73,11 @@ class DilatedResNet(tf.keras.Model):
model_id: `int` depth of ResNet backbone model.
model_id: `int` depth of ResNet backbone model.
output_stride: `int` output stride, ratio of input to output resolution.
output_stride: `int` output stride, ratio of input to output resolution.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
stem_type: `standard` or `deeplab`, deeplab replaces 7x7 conv by 3 3x3
convs.
multigrid: `Tuple` of the same length as the number of blocks in the last
resnet stage.
last_stage_repeats: `int`, how many times last stage is repeated.
activation: `str` name of the activation function.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_momentum: `float` normalization omentum for the moving average.
...
@@ -96,6 +104,7 @@ class DilatedResNet(tf.keras.Model):
...
@@ -96,6 +104,7 @@ class DilatedResNet(tf.keras.Model):
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_stem_type
=
stem_type
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
bn_axis
=
-
1
...
@@ -105,16 +114,67 @@ class DilatedResNet(tf.keras.Model):
...
@@ -105,16 +114,67 @@ class DilatedResNet(tf.keras.Model):
# Build ResNet.
# Build ResNet.
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
x
=
layers
.
Conv2D
(
if
stem_type
==
'v0'
:
filters
=
64
,
kernel_size
=
7
,
strides
=
2
,
use_bias
=
False
,
padding
=
'same'
,
x
=
layers
.
Conv2D
(
kernel_initializer
=
self
.
_kernel_initializer
,
filters
=
64
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_size
=
7
,
bias_regularizer
=
self
.
_bias_regularizer
)(
strides
=
2
,
inputs
)
use_bias
=
False
,
x
=
self
.
_norm
(
padding
=
'same'
,
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
kernel_initializer
=
self
.
_kernel_initializer
,
x
)
kernel_regularizer
=
self
.
_kernel_regularizer
,
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
elif
stem_type
==
'v1'
:
x
=
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
3
,
strides
=
2
,
use_bias
=
False
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
x
=
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
3
,
strides
=
1
,
use_bias
=
False
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
x
=
layers
.
Conv2D
(
filters
=
128
,
kernel_size
=
3
,
strides
=
1
,
use_bias
=
False
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
else
:
raise
ValueError
(
'Stem type {} not supported.'
.
format
(
stem_type
))
x
=
layers
.
MaxPool2D
(
pool_size
=
3
,
strides
=
2
,
padding
=
'same'
)(
x
)
x
=
layers
.
MaxPool2D
(
pool_size
=
3
,
strides
=
2
,
padding
=
'same'
)(
x
)
normal_resnet_stage
=
int
(
np
.
math
.
log2
(
self
.
_output_stride
))
-
2
normal_resnet_stage
=
int
(
np
.
math
.
log2
(
self
.
_output_stride
))
-
2
...
@@ -137,7 +197,7 @@ class DilatedResNet(tf.keras.Model):
...
@@ -137,7 +197,7 @@ class DilatedResNet(tf.keras.Model):
endpoints
[
str
(
i
+
2
)]
=
x
endpoints
[
str
(
i
+
2
)]
=
x
dilation_rate
=
2
dilation_rate
=
2
for
i
in
range
(
normal_resnet_stage
+
1
,
7
):
for
i
in
range
(
normal_resnet_stage
+
1
,
3
+
last_stage_repeats
):
spec
=
RESNET_SPECS
[
model_id
][
i
]
if
i
<
3
else
RESNET_SPECS
[
model_id
][
-
1
]
spec
=
RESNET_SPECS
[
model_id
][
i
]
if
i
<
3
else
RESNET_SPECS
[
model_id
][
-
1
]
if
spec
[
0
]
==
'bottleneck'
:
if
spec
[
0
]
==
'bottleneck'
:
block_fn
=
nn_blocks
.
BottleneckBlock
block_fn
=
nn_blocks
.
BottleneckBlock
...
@@ -150,6 +210,7 @@ class DilatedResNet(tf.keras.Model):
...
@@ -150,6 +210,7 @@ class DilatedResNet(tf.keras.Model):
dilation_rate
=
dilation_rate
,
dilation_rate
=
dilation_rate
,
block_fn
=
block_fn
,
block_fn
=
block_fn
,
block_repeats
=
spec
[
2
],
block_repeats
=
spec
[
2
],
multigrid
=
multigrid
if
i
>=
3
else
None
,
name
=
'block_group_l{}'
.
format
(
i
+
2
))
name
=
'block_group_l{}'
.
format
(
i
+
2
))
dilation_rate
*=
2
dilation_rate
*=
2
...
@@ -167,9 +228,12 @@ class DilatedResNet(tf.keras.Model):
...
@@ -167,9 +228,12 @@ class DilatedResNet(tf.keras.Model):
dilation_rate
,
dilation_rate
,
block_fn
,
block_fn
,
block_repeats
=
1
,
block_repeats
=
1
,
multigrid
=
None
,
name
=
'block_group'
):
name
=
'block_group'
):
"""Creates one group of blocks for the ResNet model.
"""Creates one group of blocks for the ResNet model.
Deeplab applies strides at the last block.
Args:
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
filters: `int` number of filters for the first convolution of the layer.
...
@@ -178,15 +242,24 @@ class DilatedResNet(tf.keras.Model):
...
@@ -178,15 +242,24 @@ class DilatedResNet(tf.keras.Model):
dilation_rate: `int`, diluted convolution rates.
dilation_rate: `int`, diluted convolution rates.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: `int` number of blocks contained in the layer.
block_repeats: `int` number of blocks contained in the layer.
multigrid: List of ints or None, if specified, dilation rates for each
block is scaled up by its corresponding factor in the multigrid.
name: `str`name for the block.
name: `str`name for the block.
Returns:
Returns:
The output `Tensor` of the block layer.
The output `Tensor` of the block layer.
"""
"""
if
multigrid
is
not
None
and
len
(
multigrid
)
!=
block_repeats
:
raise
ValueError
(
'multigrid has to match number of block_repeats'
)
if
multigrid
is
None
:
multigrid
=
[
1
]
*
block_repeats
# TODO(arashwan): move striding at the of the block.
x
=
block_fn
(
x
=
block_fn
(
filters
=
filters
,
filters
=
filters
,
strides
=
strides
,
strides
=
strides
,
dilation_rate
=
dilation_rate
,
dilation_rate
=
dilation_rate
*
multigrid
[
0
]
,
use_projection
=
True
,
use_projection
=
True
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
...
@@ -196,12 +269,11 @@ class DilatedResNet(tf.keras.Model):
...
@@ -196,12 +269,11 @@ class DilatedResNet(tf.keras.Model):
norm_momentum
=
self
.
_norm_momentum
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
norm_epsilon
=
self
.
_norm_epsilon
)(
inputs
)
inputs
)
for
i
in
range
(
1
,
block_repeats
):
for
_
in
range
(
1
,
block_repeats
):
x
=
block_fn
(
x
=
block_fn
(
filters
=
filters
,
filters
=
filters
,
strides
=
1
,
strides
=
1
,
dilation_rate
=
dilation_rate
,
dilation_rate
=
dilation_rate
*
multigrid
[
i
]
,
use_projection
=
False
,
use_projection
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
...
@@ -254,6 +326,9 @@ def build_dilated_resnet(
...
@@ -254,6 +326,9 @@ def build_dilated_resnet(
model_id
=
backbone_cfg
.
model_id
,
model_id
=
backbone_cfg
.
model_id
,
output_stride
=
backbone_cfg
.
output_stride
,
output_stride
=
backbone_cfg
.
output_stride
,
input_specs
=
input_specs
,
input_specs
=
input_specs
,
multigrid
=
backbone_cfg
.
multigrid
,
last_stage_repeats
=
backbone_cfg
.
last_stage_repeats
,
stem_type
=
backbone_cfg
.
stem_type
,
activation
=
norm_activation_config
.
activation
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
...
...
official/vision/beta/modeling/decoders/aspp.py
View file @
8bc64372
...
@@ -28,6 +28,7 @@ class ASPP(tf.keras.layers.Layer):
...
@@ -28,6 +28,7 @@ class ASPP(tf.keras.layers.Layer):
level
,
level
,
dilation_rates
,
dilation_rates
,
num_filters
=
256
,
num_filters
=
256
,
pool_kernel_size
=
None
,
use_sync_bn
=
False
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
norm_epsilon
=
0.001
,
...
@@ -43,6 +44,9 @@ class ASPP(tf.keras.layers.Layer):
...
@@ -43,6 +44,9 @@ class ASPP(tf.keras.layers.Layer):
level: `int` level to apply ASPP.
level: `int` level to apply ASPP.
dilation_rates: `list` of dilation rates.
dilation_rates: `list` of dilation rates.
num_filters: `int` number of output filters in ASPP.
num_filters: `int` number of output filters in ASPP.
pool_kernel_size: `list` of [height, width] of pooling kernel size or
None. Pooling size is with respect to original image size, it will be
scaled down by 2**level. If None, global average pooling is used.
use_sync_bn: if True, use synchronized batch normalization.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
norm_epsilon: `float` small float added to variance to avoid dividing by
...
@@ -60,6 +64,7 @@ class ASPP(tf.keras.layers.Layer):
...
@@ -60,6 +64,7 @@ class ASPP(tf.keras.layers.Layer):
'level'
:
level
,
'level'
:
level
,
'dilation_rates'
:
dilation_rates
,
'dilation_rates'
:
dilation_rates
,
'num_filters'
:
num_filters
,
'num_filters'
:
num_filters
,
'pool_kernel_size'
:
pool_kernel_size
,
'use_sync_bn'
:
use_sync_bn
,
'use_sync_bn'
:
use_sync_bn
,
'norm_momentum'
:
norm_momentum
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
'norm_epsilon'
:
norm_epsilon
,
...
@@ -71,9 +76,16 @@ class ASPP(tf.keras.layers.Layer):
...
@@ -71,9 +76,16 @@ class ASPP(tf.keras.layers.Layer):
}
}
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
pool_kernel_size
=
None
if
self
.
_config_dict
[
'pool_kernel_size'
]:
pool_kernel_size
=
[
int
(
p_size
//
2
**
self
.
_config_dict
[
'level'
])
for
p_size
in
self
.
_config_dict
[
'pool_kernel_size'
]
]
self
.
aspp
=
keras_cv
.
layers
.
SpatialPyramidPooling
(
self
.
aspp
=
keras_cv
.
layers
.
SpatialPyramidPooling
(
output_channels
=
self
.
_config_dict
[
'num_filters'
],
output_channels
=
self
.
_config_dict
[
'num_filters'
],
dilation_rates
=
self
.
_config_dict
[
'dilation_rates'
],
dilation_rates
=
self
.
_config_dict
[
'dilation_rates'
],
pool_kernel_size
=
pool_kernel_size
,
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
batchnorm_momentum
=
self
.
_config_dict
[
'norm_momentum'
],
batchnorm_momentum
=
self
.
_config_dict
[
'norm_momentum'
],
batchnorm_epsilon
=
self
.
_config_dict
[
'norm_epsilon'
],
batchnorm_epsilon
=
self
.
_config_dict
[
'norm_epsilon'
],
...
...
official/vision/beta/modeling/decoders/aspp_test.py
View file @
8bc64372
...
@@ -61,6 +61,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -61,6 +61,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
level
=
3
,
level
=
3
,
dilation_rates
=
[
6
,
12
],
dilation_rates
=
[
6
,
12
],
num_filters
=
256
,
num_filters
=
256
,
pool_kernel_size
=
None
,
use_sync_bn
=
False
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
norm_epsilon
=
0.001
,
...
...
official/vision/beta/modeling/decoders/factory.py
View file @
8bc64372
...
@@ -70,6 +70,7 @@ def build_decoder(input_specs,
...
@@ -70,6 +70,7 @@ def build_decoder(input_specs,
level
=
decoder_cfg
.
level
,
level
=
decoder_cfg
.
level
,
dilation_rates
=
decoder_cfg
.
dilation_rates
,
dilation_rates
=
decoder_cfg
.
dilation_rates
,
num_filters
=
decoder_cfg
.
num_filters
,
num_filters
=
decoder_cfg
.
num_filters
,
pool_kernel_size
=
decoder_cfg
.
pool_kernel_size
,
dropout_rate
=
decoder_cfg
.
dropout_rate
,
dropout_rate
=
decoder_cfg
.
dropout_rate
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
...
...
official/vision/beta/modeling/heads/segmentation_heads.py
View file @
8bc64372
...
@@ -102,7 +102,7 @@ class SegmentationHead(tf.keras.layers.Layer):
...
@@ -102,7 +102,7 @@ class SegmentationHead(tf.keras.layers.Layer):
conv_kwargs
=
{
conv_kwargs
=
{
'kernel_size'
:
3
,
'kernel_size'
:
3
,
'padding'
:
'same'
,
'padding'
:
'same'
,
'
bias_initializer'
:
tf
.
zeros_initializer
()
,
'
use_bias'
:
False
,
'kernel_initializer'
:
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.01
),
'kernel_initializer'
:
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.01
),
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
}
}
...
@@ -120,7 +120,7 @@ class SegmentationHead(tf.keras.layers.Layer):
...
@@ -120,7 +120,7 @@ class SegmentationHead(tf.keras.layers.Layer):
self
.
_dlv3p_conv
=
conv_op
(
self
.
_dlv3p_conv
=
conv_op
(
kernel_size
=
1
,
kernel_size
=
1
,
padding
=
'same'
,
padding
=
'same'
,
bias_initializer
=
tf
.
zeros_initializer
()
,
use_bias
=
False
,
kernel_initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.01
),
kernel_initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.01
),
kernel_regularizer
=
self
.
_config_dict
[
'kernel_regularizer'
],
kernel_regularizer
=
self
.
_config_dict
[
'kernel_regularizer'
],
name
=
'segmentation_head_deeplabv3p_fusion_conv'
,
name
=
'segmentation_head_deeplabv3p_fusion_conv'
,
...
@@ -145,7 +145,12 @@ class SegmentationHead(tf.keras.layers.Layer):
...
@@ -145,7 +145,12 @@ class SegmentationHead(tf.keras.layers.Layer):
self
.
_classifier
=
conv_op
(
self
.
_classifier
=
conv_op
(
name
=
'segmentation_output'
,
name
=
'segmentation_output'
,
filters
=
self
.
_config_dict
[
'num_classes'
],
filters
=
self
.
_config_dict
[
'num_classes'
],
**
conv_kwargs
)
kernel_size
=
1
,
padding
=
'same'
,
bias_initializer
=
tf
.
zeros_initializer
(),
kernel_initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.01
),
kernel_regularizer
=
self
.
_config_dict
[
'kernel_regularizer'
],
bias_regularizer
=
self
.
_config_dict
[
'bias_regularizer'
])
super
(
SegmentationHead
,
self
).
build
(
input_shape
)
super
(
SegmentationHead
,
self
).
build
(
input_shape
)
...
...
official/vision/beta/tasks/semantic_segmentation.py
View file @
8bc64372
...
@@ -81,17 +81,18 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -81,17 +81,18 @@ class SemanticSegmentationTask(base_task.Task):
def
build_inputs
(
self
,
params
,
input_context
=
None
):
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Builds classification input."""
"""Builds classification input."""
input_size
=
self
.
task_config
.
model
.
input_size
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
decoder
=
segmentation_input
.
Decoder
()
decoder
=
segmentation_input
.
Decoder
()
parser
=
segmentation_input
.
Parser
(
parser
=
segmentation_input
.
Parser
(
output_size
=
input_size
[:
2
],
output_size
=
params
.
output_size
,
train_on_crops
=
params
.
train_on_crops
,
ignore_label
=
ignore_label
,
ignore_label
=
ignore_label
,
resize_eval_groundtruth
=
params
.
resize_eval_groundtruth
,
resize_eval_groundtruth
=
params
.
resize_eval_groundtruth
,
groundtruth_padded_size
=
params
.
groundtruth_padded_size
,
groundtruth_padded_size
=
params
.
groundtruth_padded_size
,
aug_scale_min
=
params
.
aug_scale_min
,
aug_scale_min
=
params
.
aug_scale_min
,
aug_scale_max
=
params
.
aug_scale_max
,
aug_scale_max
=
params
.
aug_scale_max
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
dtype
=
params
.
dtype
)
dtype
=
params
.
dtype
)
reader
=
input_reader
.
InputReader
(
reader
=
input_reader
.
InputReader
(
...
@@ -120,7 +121,8 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -120,7 +121,8 @@ class SemanticSegmentationTask(base_task.Task):
loss_params
.
label_smoothing
,
loss_params
.
label_smoothing
,
loss_params
.
class_weights
,
loss_params
.
class_weights
,
loss_params
.
ignore_label
,
loss_params
.
ignore_label
,
use_groundtruth_dimension
=
loss_params
.
use_groundtruth_dimension
)
use_groundtruth_dimension
=
loss_params
.
use_groundtruth_dimension
,
top_k_percent_pixels
=
loss_params
.
top_k_percent_pixels
)
total_loss
=
segmentation_loss_fn
(
model_outputs
,
labels
[
'masks'
])
total_loss
=
segmentation_loss_fn
(
model_outputs
,
labels
[
'masks'
])
...
@@ -133,19 +135,18 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -133,19 +135,18 @@ class SemanticSegmentationTask(base_task.Task):
"""Gets streaming metrics for training/validation."""
"""Gets streaming metrics for training/validation."""
metrics
=
[]
metrics
=
[]
if
training
:
if
training
:
# TODO(arashwan): make MeanIoU tpu friendly.
metrics
.
append
(
segmentation_metrics
.
MeanIoU
(
if
not
isinstance
(
tf
.
distribute
.
get_strategy
(),
name
=
'mean_iou'
,
tf
.
distribute
.
TPUStrategy
):
num_classes
=
self
.
task_config
.
model
.
num_classes
,
metrics
.
append
(
segmentation_metrics
.
MeanIoU
(
rescale_predictions
=
False
,
name
=
'mean_iou'
,
dtype
=
tf
.
float32
))
num_classes
=
self
.
task_config
.
model
.
num_classes
,
rescale_predictions
=
False
))
else
:
else
:
self
.
miou_metric
=
segmentation_metrics
.
MeanIoU
(
self
.
miou_metric
=
segmentation_metrics
.
MeanIoU
(
name
=
'val_mean_iou'
,
name
=
'val_mean_iou'
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
rescale_predictions
=
not
self
.
task_config
.
validation_data
rescale_predictions
=
not
self
.
task_config
.
validation_data
.
resize_eval_groundtruth
)
.
resize_eval_groundtruth
,
dtype
=
tf
.
float32
)
return
metrics
return
metrics
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment