Commit ea3542c1 authored by Adrian Boguszewski's avatar Adrian Boguszewski Committed by aquariusjay
Browse files

Fixed improper HNASNet architecture (#6419)

parent 47d6c66e
...@@ -156,8 +156,8 @@ def _build_nas_base(images, ...@@ -156,8 +156,8 @@ def _build_nas_base(images,
stride = 2 stride = 2
filter_scaling *= hparams.filter_scaling_rate filter_scaling *= hparams.filter_scaling_rate
elif backbone[cell_num] == backbone[cell_num - 1] - 1: elif backbone[cell_num] == backbone[cell_num - 1] - 1:
scaled_height = scale_dimension(tf.shape(net)[1], 2) scaled_height = scale_dimension(net.shape[1].value, 2)
scaled_width = scale_dimension(tf.shape(net)[2], 2) scaled_width = scale_dimension(net.shape[2].value, 2)
net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype) net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype)
filter_scaling /= hparams.filter_scaling_rate filter_scaling /= hparams.filter_scaling_rate
net = cell( net = cell(
......
...@@ -19,7 +19,6 @@ from __future__ import absolute_import ...@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -47,12 +46,13 @@ def create_test_input(batch, height, width, channels): ...@@ -47,12 +46,13 @@ def create_test_input(batch, height, width, channels):
class NASNetworkTest(tf.test.TestCase): class NASNetworkTest(tf.test.TestCase):
"""Tests with complete small NAS networks.""" """Tests with complete small NAS networks."""
def _pnasnet_small(self, def _pnasnet(self,
images, images,
num_classes, backbone,
is_training=True, num_classes,
output_stride=16, is_training=True,
final_endpoint=None): output_stride=16,
final_endpoint=None):
"""Build PNASNet model backbone.""" """Build PNASNet model backbone."""
hparams = tf.contrib.training.HParams( hparams = tf.contrib.training.HParams(
filter_scaling_rate=2.0, filter_scaling_rate=2.0,
...@@ -63,7 +63,6 @@ class NASNetworkTest(tf.test.TestCase): ...@@ -63,7 +63,6 @@ class NASNetworkTest(tf.test.TestCase):
if not is_training: if not is_training:
hparams.set_hparam('drop_path_keep_prob', 1.0) hparams.set_hparam('drop_path_keep_prob', 1.0)
backbone = [1, 2, 2]
cell = nas_genotypes.PNASCell(hparams.num_conv_filters, cell = nas_genotypes.PNASCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob, hparams.drop_path_keep_prob,
len(backbone), len(backbone),
...@@ -81,16 +80,26 @@ class NASNetworkTest(tf.test.TestCase): ...@@ -81,16 +80,26 @@ class NASNetworkTest(tf.test.TestCase):
def testFullyConvolutionalEndpointShapes(self): def testFullyConvolutionalEndpointShapes(self):
num_classes = 10 num_classes = 10
inputs = create_test_input(2, 321, 321, 3) backbone = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1]
inputs = create_test_input(None, 321, 321, 3)
with slim.arg_scope(nas_network.nas_arg_scope()): with slim.arg_scope(nas_network.nas_arg_scope()):
_, end_points = self._pnasnet_small(inputs, _, end_points = self._pnasnet(inputs, backbone, num_classes)
num_classes)
endpoint_to_shape = { endpoint_to_shape = {
'Stem': [2, 81, 81, 128], 'Stem': [None, 81, 81, 128],
'Cell_0': [2, 41, 41, 100], 'Cell_0': [None, 81, 81, 50],
'Cell_1': [2, 21, 21, 200], 'Cell_1': [None, 81, 81, 50],
'Cell_2': [2, 21, 21, 200]} 'Cell_2': [None, 81, 81, 50],
for endpoint, shape in endpoint_to_shape.iteritems(): 'Cell_3': [None, 41, 41, 100],
'Cell_4': [None, 21, 21, 200],
'Cell_5': [None, 41, 41, 100],
'Cell_6': [None, 21, 21, 200],
'Cell_7': [None, 21, 21, 200],
'Cell_8': [None, 11, 11, 400],
'Cell_9': [None, 11, 11, 400],
'Cell_10': [None, 21, 21, 200],
'Cell_11': [None, 41, 41, 100]
}
for endpoint, shape in endpoint_to_shape.items():
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape) self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment