"llama/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "7b95087b9d26f371f78cd8af3144d9f5575e890c"
Commit 7c2ff1af authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 386344633
parent 7bfd14d1
...@@ -592,8 +592,9 @@ class MobileNet(tf.keras.Model): ...@@ -592,8 +592,9 @@ class MobileNet(tf.keras.Model):
x, endpoints, next_endpoint_level = self._mobilenet_base(inputs=inputs) x, endpoints, next_endpoint_level = self._mobilenet_base(inputs=inputs)
endpoints[str(next_endpoint_level)] = x
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
# Don't include the final layer in `self._output_specs` to support decoders.
endpoints[str(next_endpoint_level)] = x
super(MobileNet, self).__init__( super(MobileNet, self).__init__(
inputs=inputs, outputs=endpoints, **kwargs) inputs=inputs, outputs=endpoints, **kwargs)
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling.backbones import mobilenet
from official.vision.beta.modeling.backbones import resnet from official.vision.beta.modeling.backbones import resnet
from official.vision.beta.modeling.decoders import fpn from official.vision.beta.modeling.decoders import fpn
...@@ -52,6 +53,33 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase): ...@@ -52,6 +53,33 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
[1, input_size // 2**level, input_size // 2**level, 256], [1, input_size // 2**level, input_size // 2**level, 256],
feats[str(level)].shape.as_list()) feats[str(level)].shape.as_list())
@parameterized.parameters(
(256, 3, 7, False),
(256, 3, 7, True),
)
def test_network_creation_with_mobilenet(self, input_size, min_level,
max_level, use_separable_conv):
"""Test creation of FPN with mobilenet backbone."""
tf.keras.backend.set_image_data_format('channels_last')
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
backbone = mobilenet.MobileNet(model_id='MobileNetV2')
network = fpn.FPN(
input_specs=backbone.output_specs,
min_level=min_level,
max_level=max_level,
use_separable_conv=use_separable_conv)
endpoints = backbone(inputs)
feats = network(endpoints)
for level in range(min_level, max_level + 1):
self.assertIn(str(level), feats)
self.assertAllEqual(
[1, input_size // 2**level, input_size // 2**level, 256],
feats[str(level)].shape.as_list())
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
kwargs = dict( kwargs = dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment