"pytorch/vscode:/vscode.git/clone" did not exist on "2565f2fabb00706ef3c181cbb27a1f293d135131"
Commit 4323d37c authored by Soroosh Yazdani's avatar Soroosh Yazdani Committed by TF Object Detection Team
Browse files

Add the option of using "num_additional_channels" to ssd_mobilenet_v2_fpn_keras.

PiperOrigin-RevId: 461877203
parent 2575940c
......@@ -131,7 +131,8 @@ class SsdFeatureExtractorTestBase(test_case.TestCase):
use_explicit_padding=False,
num_layers=6,
use_keras=False,
use_depthwise=False):
use_depthwise=False,
num_channels=3):
with test_utils.GraphContextOrNone() as g:
feature_extractor = self._create_features(
depth_multiplier,
......@@ -148,7 +149,7 @@ class SsdFeatureExtractorTestBase(test_case.TestCase):
use_keras=use_keras)
image_tensor = np.random.rand(batch_size, image_height, image_width,
3).astype(np.float32)
num_channels).astype(np.float32)
feature_maps = self.execute(graph_fn, [image_tensor], graph=g)
for feature_map, expected_shape in zip(
feature_maps, expected_feature_map_shapes):
......
......@@ -135,6 +135,40 @@ class SsdMobilenetV2FpnFeatureExtractorTest(
use_keras=use_keras,
use_depthwise=use_depthwise)
def test_extract_features_returns_correct_shapes_4_channels(self,
use_depthwise):
use_keras = False
image_height = 320
image_width = 320
num_channels = 4
depth_multiplier = 1.0
pad_to_multiple = 1
expected_feature_map_shape = [(2, 40, 40, 256), (2, 20, 20, 256),
(2, 10, 10, 256), (2, 5, 5, 256),
(2, 3, 3, 256)]
self.check_extract_features_returns_correct_shape(
2,
image_height,
image_width,
depth_multiplier,
pad_to_multiple,
expected_feature_map_shape,
use_explicit_padding=False,
use_keras=use_keras,
use_depthwise=use_depthwise,
num_channels=num_channels)
self.check_extract_features_returns_correct_shape(
2,
image_height,
image_width,
depth_multiplier,
pad_to_multiple,
expected_feature_map_shape,
use_explicit_padding=True,
use_keras=use_keras,
use_depthwise=use_depthwise,
num_channels=num_channels)
def test_extract_features_with_dynamic_image_shape(self,
use_depthwise):
use_keras = False
......
......@@ -141,6 +141,40 @@ class SsdMobilenetV2FpnFeatureExtractorTest(
use_keras=use_keras,
use_depthwise=use_depthwise)
def test_extract_features_returns_correct_shapes_4_channels(self,
use_depthwise):
use_keras = True
image_height = 320
image_width = 320
num_channels = 4
depth_multiplier = 1.0
pad_to_multiple = 1
expected_feature_map_shape = [(2, 40, 40, 256), (2, 20, 20, 256),
(2, 10, 10, 256), (2, 5, 5, 256),
(2, 3, 3, 256)]
self.check_extract_features_returns_correct_shape(
2,
image_height,
image_width,
depth_multiplier,
pad_to_multiple,
expected_feature_map_shape,
use_explicit_padding=False,
use_keras=use_keras,
use_depthwise=use_depthwise,
num_channels=num_channels)
self.check_extract_features_returns_correct_shape(
2,
image_height,
image_width,
depth_multiplier,
pad_to_multiple,
expected_feature_map_shape,
use_explicit_padding=True,
use_keras=use_keras,
use_depthwise=use_depthwise,
num_channels=num_channels)
def test_extract_features_with_dynamic_image_shape(self,
use_depthwise):
use_keras = True
......
......@@ -136,7 +136,8 @@ class SSDMobileNetV2FpnKerasFeatureExtractor(
use_explicit_padding=self._use_explicit_padding,
alpha=self._depth_multiplier,
min_depth=self._min_depth,
include_top=False)
include_top=False,
input_shape=(None, None, input_shape[-1]))
layer_names = [layer.name for layer in full_mobilenet_v2.layers]
outputs = []
for layer_idx in [4, 7, 14]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment