Commit 2de518be authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 387004613
parent 4afec6ba
...@@ -130,7 +130,7 @@ class SpineNet(tf.keras.Model): ...@@ -130,7 +130,7 @@ class SpineNet(tf.keras.Model):
def __init__( def __init__(
self, self,
input_specs: tf.keras.layers.InputSpec = tf.keras.layers.InputSpec( input_specs: tf.keras.layers.InputSpec = tf.keras.layers.InputSpec(
shape=[None, 640, 640, 3]), shape=[None, None, None, 3]),
min_level: int = 3, min_level: int = 3,
max_level: int = 7, max_level: int = 7,
block_specs: List[BlockSpec] = build_block_specs(), block_specs: List[BlockSpec] = build_block_specs(),
...@@ -214,8 +214,11 @@ class SpineNet(tf.keras.Model): ...@@ -214,8 +214,11 @@ class SpineNet(tf.keras.Model):
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
net = self._build_stem(inputs=inputs) net = self._build_stem(inputs=inputs)
net = self._build_scale_permuted_network( input_width = input_specs.shape[2]
net=net, input_width=input_specs.shape[2]) if input_width is None:
max_stride = max(map(lambda b: b.level, block_specs))
input_width = 2 ** max_stride
net = self._build_scale_permuted_network(net=net, input_width=input_width)
endpoints = self._build_endpoints(net=net) endpoints = self._build_endpoints(net=net)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
......
...@@ -135,7 +135,7 @@ class SpineNetMobile(tf.keras.Model): ...@@ -135,7 +135,7 @@ class SpineNetMobile(tf.keras.Model):
def __init__( def __init__(
self, self,
input_specs: tf.keras.layers.InputSpec = tf.keras.layers.InputSpec( input_specs: tf.keras.layers.InputSpec = tf.keras.layers.InputSpec(
shape=[None, 512, 512, 3]), shape=[None, None, None, 3]),
min_level: int = 3, min_level: int = 3,
max_level: int = 7, max_level: int = 7,
block_specs: List[BlockSpec] = build_block_specs(), block_specs: List[BlockSpec] = build_block_specs(),
...@@ -219,8 +219,11 @@ class SpineNetMobile(tf.keras.Model): ...@@ -219,8 +219,11 @@ class SpineNetMobile(tf.keras.Model):
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
net = self._build_stem(inputs=inputs) net = self._build_stem(inputs=inputs)
net = self._build_scale_permuted_network( input_width = input_specs.shape[2]
net=net, input_width=input_specs.shape[2]) if input_width is None:
max_stride = max(map(lambda b: b.level, block_specs))
input_width = 2 ** max_stride
net = self._build_scale_permuted_network(net=net, input_width=input_width)
endpoints = self._build_endpoints(net=net) endpoints = self._build_endpoints(net=net)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment