"docs/en/vscode:/vscode.git/clone" did not exist on "2d23d70e7b5482afcdb5c3b4f0f7f249f608f242"
Commit 0f7580bd authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 343970720
parent 730548d4
......@@ -54,7 +54,7 @@ class SegmentationHead(hyperparams.Config):
num_convs: int = 2
num_filters: int = 256
upsample_factor: int = 1
feature_fusion: Optional[str] = None # None, or deeplabv3plus
feature_fusion: Optional[str] = None # None, deeplabv3plus, or pyramid_fusion
# deeplabv3plus feature fusion params
low_level: int = 2
low_level_num_filters: int = 48
......@@ -292,3 +292,77 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
])
return config
@exp_factory.register_config_factory('seg_resnetfpn_pascal')
def seg_resnetfpn_pascal() -> cfg.ExperimentConfig:
"""Image segmentation on imagenet with resnet-fpn."""
train_batch_size = 256
eval_batch_size = 32
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=SemanticSegmentationTask(
model=SemanticSegmentationModel(
num_classes=21,
input_size=[512, 512, 3],
min_level=3,
max_level=7,
backbone=backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)),
decoder=decoders.Decoder(type='fpn', fpn=decoders.FPN()),
head=SegmentationHead(level=3, num_convs=3),
norm_activation=common.NormActivation(
activation='swish',
use_sync_bn=True)),
losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'),
is_training=True,
global_batch_size=train_batch_size,
aug_scale_min=0.2,
aug_scale_max=1.5),
validation_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=False,
groundtruth_padded_size=[512, 512],
drop_remainder=False),
),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=450 * steps_per_epoch,
validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.007,
'decay_steps': 450 * steps_per_epoch,
'end_learning_rate': 0.0,
'power': 0.9
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
......@@ -17,6 +17,7 @@
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.ops import spatial_transform_ops
......@@ -52,9 +53,10 @@ class SegmentationHead(tf.keras.layers.Layer):
Default is 256.
upsample_factor: `int` number to specify the upsampling factor to generate
finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of `deeplabv3plus`, or None. If `deeplabv3plus`,
features from decoder_features[level] will be fused with
low level feature maps from backbone.
feature_fusion: One of `deeplabv3plus`, `pyramid_fusion`, or None. If
`deeplabv3plus`, features from decoder_features[level] will be fused
with low level feature maps from backbone. If `pyramid_fusion`,
multiscale features will be resized and fused at the target level.
low_level: `int`, backbone level to be used for feature fusion. This arg
is used when feature_fusion is set to deeplabv3plus.
low_level_num_filters: `int`, reduced number of filters for the low
......@@ -170,11 +172,9 @@ class SegmentationHead(tf.keras.layers.Layer):
segmentation prediction mask: `Tensor`, the segmentation mask scores
predicted from input feature.
"""
x = decoder_output[str(self._config_dict['level'])]
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
# deeplabv3+ feature fusion
x = decoder_output[str(self._config_dict['level'])]
y = backbone_output[str(
self._config_dict['low_level'])]
y = self._dlv3p_norm(self._dlv3p_conv(y))
......@@ -183,6 +183,11 @@ class SegmentationHead(tf.keras.layers.Layer):
x = tf.image.resize(
x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR)
x = tf.concat([x, y], axis=self._bn_axis)
elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
x = nn_layers.pyramid_feature_fusion(decoder_output,
self._config_dict['level'])
else:
x = decoder_output[str(self._config_dict['level'])]
for conv, norm in zip(self._convs, self._norms):
x = conv(x)
......@@ -198,4 +203,3 @@ class SegmentationHead(tf.keras.layers.Layer):
@classmethod
def from_config(cls, config):
return cls(**config)
......@@ -26,10 +26,12 @@ from official.vision.beta.modeling.heads import segmentation_heads
class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(3), (4),
(2, 'pyramid_fusion'),
(3, 'pyramid_fusion'),
)
def test_forward(self, level):
head = segmentation_heads.SegmentationHead(num_classes=10, level=level)
def test_forward(self, level, feature_fusion):
head = segmentation_heads.SegmentationHead(
num_classes=10, level=level, feature_fusion=feature_fusion)
backbone_features = {
'3': np.random.rand(2, 128, 128, 16),
'4': np.random.rand(2, 64, 64, 16),
......@@ -39,10 +41,12 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
'4': np.random.rand(2, 64, 64, 16),
}
logits = head(backbone_features, decoder_features)
self.assertAllEqual(
logits.numpy().shape,
[2, decoder_features[str(level)].shape[1],
decoder_features[str(level)].shape[2], 10])
if level in decoder_features:
self.assertAllEqual(logits.numpy().shape, [
2, decoder_features[str(level)].shape[1],
decoder_features[str(level)].shape[2], 10
])
def test_serialize_deserialize(self):
head = segmentation_heads.SegmentationHead(num_classes=10, level=3)
......
......@@ -223,3 +223,41 @@ class StochasticDepth(tf.keras.layers.Layer):
binary_tensor = tf.floor(random_tensor)
output = tf.math.divide(inputs, keep_prob) * binary_tensor
return output
@tf.keras.utils.register_keras_serializable(package='Vision')
def pyramid_feature_fusion(inputs, target_level):
"""Fuse all feature maps in the feature pyramid at the target level.
Args:
inputs: a dictionary containing the feature pyramid. The size of the input
tensor needs to be fixed.
target_level: `int` the target feature level for feature fusion.
Returns:
A float Tensor of shape [batch_size, feature_height, feature_width,
feature_channel].
"""
# Convert keys to int.
pyramid_feats = {int(k): v for k, v in inputs.items()}
min_level = min(pyramid_feats.keys())
max_level = max(pyramid_feats.keys())
resampled_feats = []
for l in range(min_level, max_level + 1):
if l == target_level:
resampled_feats.append(pyramid_feats[l])
else:
feat = pyramid_feats[l]
target_size = list(feat.shape[1:3])
target_size[0] *= 2**(l - target_level)
target_size[1] *= 2**(l - target_level)
# Casts feat to float32 so the resize op can be run on TPU.
feat = tf.cast(feat, tf.float32)
feat = tf.image.resize(
feat, size=target_size, method=tf.image.ResizeMethod.BILINEAR)
# Casts it back to be compatible with the rest opetations.
feat = tf.cast(feat, pyramid_feats[l].dtype)
resampled_feats.append(feat)
return tf.math.add_n(resampled_feats)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment