Commit ea3542c1 authored by Adrian Boguszewski's avatar Adrian Boguszewski Committed by aquariusjay
Browse files

Fixed improper HNASNet architecture (#6419)

parent 47d6c66e
......@@ -156,8 +156,8 @@ def _build_nas_base(images,
stride = 2
filter_scaling *= hparams.filter_scaling_rate
elif backbone[cell_num] == backbone[cell_num - 1] - 1:
scaled_height = scale_dimension(tf.shape(net)[1], 2)
scaled_width = scale_dimension(tf.shape(net)[2], 2)
scaled_height = scale_dimension(net.shape[1].value, 2)
scaled_width = scale_dimension(net.shape[2].value, 2)
net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype)
filter_scaling /= hparams.filter_scaling_rate
net = cell(
......
......@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
......@@ -47,8 +46,9 @@ def create_test_input(batch, height, width, channels):
class NASNetworkTest(tf.test.TestCase):
"""Tests with complete small NAS networks."""
def _pnasnet_small(self,
def _pnasnet(self,
images,
backbone,
num_classes,
is_training=True,
output_stride=16,
......@@ -63,7 +63,6 @@ class NASNetworkTest(tf.test.TestCase):
if not is_training:
hparams.set_hparam('drop_path_keep_prob', 1.0)
backbone = [1, 2, 2]
cell = nas_genotypes.PNASCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob,
len(backbone),
......@@ -81,16 +80,26 @@ class NASNetworkTest(tf.test.TestCase):
def testFullyConvolutionalEndpointShapes(self):
num_classes = 10
inputs = create_test_input(2, 321, 321, 3)
backbone = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1]
inputs = create_test_input(None, 321, 321, 3)
with slim.arg_scope(nas_network.nas_arg_scope()):
_, end_points = self._pnasnet_small(inputs,
num_classes)
_, end_points = self._pnasnet(inputs, backbone, num_classes)
endpoint_to_shape = {
'Stem': [2, 81, 81, 128],
'Cell_0': [2, 41, 41, 100],
'Cell_1': [2, 21, 21, 200],
'Cell_2': [2, 21, 21, 200]}
for endpoint, shape in endpoint_to_shape.iteritems():
'Stem': [None, 81, 81, 128],
'Cell_0': [None, 81, 81, 50],
'Cell_1': [None, 81, 81, 50],
'Cell_2': [None, 81, 81, 50],
'Cell_3': [None, 41, 41, 100],
'Cell_4': [None, 21, 21, 200],
'Cell_5': [None, 41, 41, 100],
'Cell_6': [None, 21, 21, 200],
'Cell_7': [None, 21, 21, 200],
'Cell_8': [None, 11, 11, 400],
'Cell_9': [None, 11, 11, 400],
'Cell_10': [None, 21, 21, 200],
'Cell_11': [None, 41, 41, 100]
}
for endpoint, shape in endpoint_to_shape.items():
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment