"torchvision/vscode:/vscode.git/clone" did not exist on "8096c1b5c9775aed3710fbd81d635e06924cc094"
Commit 001a2a61 authored by pkulzc's avatar pkulzc Committed by Sergio Guadarrama
Browse files

Internal changes for object detection. (#3656)

* Force cast of num_classes to integer

PiperOrigin-RevId: 188335318

* Updating config util to allow overwriting of cosine decay learning rates.

PiperOrigin-RevId: 188338852

* Make box_list_ops.py and box_list_ops_test.py work with C API enabled.

The C API has improved shape inference over the original Python
code. This causes some previously-working conds to fail. Switching to smart_cond fixes this.

Another effect of the improved shape inference is that one of the
failures tested gets caught earlier, so I modified the test to reflect
this.

PiperOrigin-RevId: 188409792

* Fix parallel event file writing issue.

Without this change, the event files might get corrupted when multiple evaluations are run in parallel.

PiperOrigin-RevId: 188502560

* Deprecating the boolean flag of from_detection_checkpoint.

Replace with a string field fine_tune_checkpoint_type to train_config to provide extensibility. The fine_tune_checkpoint_type can currently take value of `detection`, `classification`, or others when the restore_map is overwritten.

PiperOrigin-RevId: 188518685

* Automated g4 rollback of changelist 188502560

PiperOrigin-RevId: 188519969

* Introducing eval metrics specs for Coco Mask metrics. This allows metrics to be computed in tensorflow using the tf.learn Estimator.

PiperOrigin-RevId: 188528485

* Minor fix to make object_detection/metrics/coco_evaluation.py python3 compatible.

PiperOrigin-RevId: 188550683

* Updating eval_util to handle eval_metric_ops from multiple `DetectionEvaluator`s.

PiperOrigin-RevId: 188560474

* Allow tensor input for new_height and new_width for resize_image.

PiperOrigin-RevId: 188561908

* Fix typo in fine_tune_checkpoint_type name in trainer.

PiperOrigin-RevId: 188799033

* Adding mobilenet feature extractor to object detection.

PiperOrigin-RevId: 188916897

* Allow label maps to optionally contain an explicit background class with id zero.

PiperOrigin-RevId: 188951089

* Fix boundary conditions in random_pad_to_aspect_ratio to ensure that min_scale is always less than max_scale.

PiperOrigin-RevId: 189026868

* Fallback on from_detection_checkpoint option if fine_tune_checkpoint_type isn't set.

PiperOrigin-RevId: 189052833

* Add proper names for learning rate schedules so we don't see cryptic names on tensorboard.

PiperOrigin-RevId: 189069837

* Enforcing that all datasets are batched (and then unbatched in the model) with batch_size >= 1.

PiperOrigin-RevId: 189117178

* Adding regularization to total loss returned from DetectionModel.loss().

PiperOrigin-RevId: 189189123

* Standardize the names of loss scalars (for SSD, Faster R-CNN and R-FCN) in both training and eval so they can be compared on tensorboard.

Log localization and classification losses in evaluation.

PiperOrigin-RevId: 189189940

* Remove negative test from box list ops test.

PiperOrigin-RevId: 189229327

* Add an option to warmup learning rate in manual stepping schedule.

PiperOrigin-RevId: 189361039

* Replace tf.contrib.slim.tfexample_decoder.LookupTensor with object_detection.data_decoders.tf_example_decoder.LookupTensor.

PiperOrigin-RevId: 189388556

* Force regularization summary variables under specific family names.

PiperOrigin-RevId: 189393190

* Automated g4 rollback of changelist 188619139

PiperOrigin-RevId: 189396001

* Remove step 0 schedule since we do a hard check for it after cl/189361039

PiperOrigin-RevId: 189396697

* PiperOrigin-RevId: 189040463

* PiperOrigin-RevId: 189059229

* PiperOrigin-RevId: 189214402

* Force regularization summary variables under specific family names.

PiperOrigin-RevId: 189393190

* Automated g4 rollback of changelist 188619139

PiperOrigin-RevId: 189396001

* Make slim python3 compatible.

* Monir fixes.

* Add TargetAssignment summaries in a separate family.

PiperOrigin-RevId: 189407487

* 1. Setting `family` keyword arg prepends the summary names twice with the same name. Directly adding family suffix to the name gets rid of this problem.
2. Make sure the eval losses have the same name.

PiperOrigin-RevId: 189434618

* Minor fixes to make object detection tf 1.4 compatible.

PiperOrigin-RevId: 189437519

* Call the base of mobilenet_v1 feature extractor under the right arg scope and set batchnorm is_training based on the value passed in the constructor.

PiperOrigin-RevId: 189460890

* Automated g4 rollback of changelist 188409792

PiperOrigin-RevId: 189463882

* Update object detection syncing.

PiperOrigin-RevId: 189601955

* Add an option to warmup learning rate, hold it constant for a certain number of steps and cosine decay it.

PiperOrigin-RevId: 189606169

* Let the proposal feature extractor function in faster_rcnn meta architectures return the activations (end_points).

PiperOrigin-RevId: 189619301

* Fixed bug which caused masks to be mostly zeros (caused by detection_boxes being in absolute coordinates if scale_to_absolute=True.

PiperOrigin-RevId: 189641294

* Open sourcing Mobilenetv2 + SSDLite.

PiperOrigin-RevId: 189654520

* Remove unused files.
parent 2913cb24
...@@ -434,7 +434,6 @@ py_library( ...@@ -434,7 +434,6 @@ py_library(
srcs = glob(["nets/mobilenet/*.py"]), srcs = glob(["nets/mobilenet/*.py"]),
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//third_party/py/contextlib2",
# "//tensorflow", # "//tensorflow",
], ],
) )
......
...@@ -93,6 +93,7 @@ import sys ...@@ -93,6 +93,7 @@ import sys
import threading import threading
import numpy as np import numpy as np
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
tf.app.flags.DEFINE_string('train_directory', '/tmp/', tf.app.flags.DEFINE_string('train_directory', '/tmp/',
......
...@@ -52,6 +52,8 @@ import os ...@@ -52,6 +52,8 @@ import os
import os.path import os.path
import sys import sys
from six.moves import xrange
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) < 3: if len(sys.argv) < 3:
......
...@@ -86,6 +86,8 @@ import os.path ...@@ -86,6 +86,8 @@ import os.path
import sys import sys
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from six.moves import xrange
class BoundingBox(object): class BoundingBox(object):
pass pass
......
...@@ -230,9 +230,10 @@ def _gather_clone_loss(clone, num_clones, regularization_losses): ...@@ -230,9 +230,10 @@ def _gather_clone_loss(clone, num_clones, regularization_losses):
sum_loss = tf.add_n(all_losses) sum_loss = tf.add_n(all_losses)
# Add the summaries out of the clone device block. # Add the summaries out of the clone device block.
if clone_loss is not None: if clone_loss is not None:
tf.summary.scalar(clone.scope + '/clone_loss', clone_loss) tf.summary.scalar(clone.scope + '/clone_loss', clone_loss, family='Losses')
if regularization_loss is not None: if regularization_loss is not None:
tf.summary.scalar('regularization_loss', regularization_loss) tf.summary.scalar('regularization_loss', regularization_loss,
family='Losses')
return sum_loss return sum_loss
......
...@@ -18,7 +18,7 @@ from __future__ import division ...@@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
layers = tf.contrib.layers layers = tf.contrib.layers
......
...@@ -19,6 +19,8 @@ from __future__ import print_function ...@@ -19,6 +19,8 @@ from __future__ import print_function
from math import log from math import log
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
slim = tf.contrib.slim slim = tf.contrib.slim
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
from nets import dcgan from nets import dcgan
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Convolution blocks for mobilenet.""" """Convolution blocks for mobilenet."""
import contextlib
import functools import functools
import contextlib2
import tensorflow as tf import tensorflow as tf
...@@ -75,6 +75,19 @@ def _split_divisible(num, num_ways, divisible_by=8): ...@@ -75,6 +75,19 @@ def _split_divisible(num, num_ways, divisible_by=8):
return result return result
@contextlib.contextmanager
def _v1_compatible_scope_naming(scope):
if scope is None: # Create uniqified separable blocks.
with tf.variable_scope(None, default_name='separable') as s, \
tf.name_scope(s.original_name_scope):
yield ''
else:
# We use scope_depthwise, scope_pointwise for compatibility with V1 ckpts.
# which provide numbered scopes.
scope += '_'
yield scope
@slim.add_arg_scope @slim.add_arg_scope
def split_separable_conv2d(input_tensor, def split_separable_conv2d(input_tensor,
num_outputs, num_outputs,
...@@ -110,15 +123,7 @@ def split_separable_conv2d(input_tensor, ...@@ -110,15 +123,7 @@ def split_separable_conv2d(input_tensor,
output tesnor output tesnor
""" """
with contextlib2.ExitStack() as stack: with _v1_compatible_scope_naming(scope) as scope:
if scope is None: # Create uniqified separable blocks.
s = stack.enter_context(tf.variable_scope(None, default_name='separable'))
stack.enter_context(tf.name_scope(s.original_name_scope))
scope = ''
else:
# We use scope_depthwise, scope_pointwise for compatibility with V1 ckpts.
scope += '_'
dw_scope = scope + 'depthwise' dw_scope = scope + 'depthwise'
endpoints = endpoints if endpoints is not None else {} endpoints = endpoints if endpoints is not None else {}
kernel_size = [3, 3] kernel_size = [3, 3]
......
...@@ -22,8 +22,6 @@ import contextlib ...@@ -22,8 +22,6 @@ import contextlib
import copy import copy
import os import os
import contextlib2
import tensorflow as tf import tensorflow as tf
...@@ -76,17 +74,23 @@ def _set_arg_scope_defaults(defaults): ...@@ -76,17 +74,23 @@ def _set_arg_scope_defaults(defaults):
"""Sets arg scope defaults for all items present in defaults. """Sets arg scope defaults for all items present in defaults.
Args: Args:
defaults: dictionary mapping function to default_dict defaults: dictionary/list of pairs, containing a mapping from
function to a dictionary of default args.
Yields: Yields:
context manager context manager where all defaults are set.
""" """
with contextlib2.ExitStack() as stack: if hasattr(defaults, 'items'):
_ = [ items = defaults.items()
stack.enter_context(slim.arg_scope(func, **default_arg)) else:
for func, default_arg in defaults.items() items = defaults
] if not items:
yield yield
else:
func, default_arg = items[0]
with slim.arg_scope(func, **default_arg):
with _set_arg_scope_defaults(items[1:]):
yield
@slim.add_arg_scope @slim.add_arg_scope
......
...@@ -350,7 +350,7 @@ class MobilenetV1Test(tf.test.TestCase): ...@@ -350,7 +350,7 @@ class MobilenetV1Test(tf.test.TestCase):
mobilenet_v1.mobilenet_v1_base(inputs) mobilenet_v1.mobilenet_v1_base(inputs)
total_params, _ = slim.model_analyzer.analyze_vars( total_params, _ = slim.model_analyzer.analyze_vars(
slim.get_model_variables()) slim.get_model_variables())
self.assertAlmostEqual(3217920L, total_params) self.assertAlmostEqual(3217920, total_params)
def testBuildEndPointsWithDepthMultiplierLessThanOne(self): def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
batch_size = 5 batch_size = 5
......
...@@ -20,6 +20,7 @@ from __future__ import absolute_import ...@@ -20,6 +20,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy
import tensorflow as tf import tensorflow as tf
from nets.nasnet import nasnet_utils from nets.nasnet import nasnet_utils
...@@ -35,13 +36,12 @@ slim = tf.contrib.slim ...@@ -35,13 +36,12 @@ slim = tf.contrib.slim
# cosine (single period) learning rate decay # cosine (single period) learning rate decay
# auxiliary head loss weighting: 0.4 # auxiliary head loss weighting: 0.4
# clip global norm of all gradients by 5 # clip global norm of all gradients by 5
def _cifar_config(is_training=True, use_aux_head=True): def cifar_config():
drop_path_keep_prob = 1.0 if not is_training else 0.6
return tf.contrib.training.HParams( return tf.contrib.training.HParams(
stem_multiplier=3.0, stem_multiplier=3.0,
drop_path_keep_prob=drop_path_keep_prob, drop_path_keep_prob=0.6,
num_cells=18, num_cells=18,
use_aux_head=int(use_aux_head), use_aux_head=1,
num_conv_filters=32, num_conv_filters=32,
dense_dropout_keep_prob=1.0, dense_dropout_keep_prob=1.0,
filter_scaling_rate=2.0, filter_scaling_rate=2.0,
...@@ -65,16 +65,15 @@ def _cifar_config(is_training=True, use_aux_head=True): ...@@ -65,16 +65,15 @@ def _cifar_config(is_training=True, use_aux_head=True):
# auxiliary head loss weighting: 0.4 # auxiliary head loss weighting: 0.4
# label smoothing: 0.1 # label smoothing: 0.1
# clip global norm of all gradients by 10 # clip global norm of all gradients by 10
def _large_imagenet_config(is_training=True, use_aux_head=True): def large_imagenet_config():
drop_path_keep_prob = 1.0 if not is_training else 0.7
return tf.contrib.training.HParams( return tf.contrib.training.HParams(
stem_multiplier=3.0, stem_multiplier=3.0,
dense_dropout_keep_prob=0.5, dense_dropout_keep_prob=0.5,
num_cells=18, num_cells=18,
filter_scaling_rate=2.0, filter_scaling_rate=2.0,
num_conv_filters=168, num_conv_filters=168,
drop_path_keep_prob=drop_path_keep_prob, drop_path_keep_prob=0.7,
use_aux_head=int(use_aux_head), use_aux_head=1,
num_reduction_layers=2, num_reduction_layers=2,
data_format='NHWC', data_format='NHWC',
skip_reduction_layer_input=1, skip_reduction_layer_input=1,
...@@ -92,7 +91,7 @@ def _large_imagenet_config(is_training=True, use_aux_head=True): ...@@ -92,7 +91,7 @@ def _large_imagenet_config(is_training=True, use_aux_head=True):
# auxiliary head weighting: 0.4 # auxiliary head weighting: 0.4
# label smoothing: 0.1 # label smoothing: 0.1
# clip global norm of all gradients by 10 # clip global norm of all gradients by 10
def _mobile_imagenet_config(use_aux_head=True): def mobile_imagenet_config():
return tf.contrib.training.HParams( return tf.contrib.training.HParams(
stem_multiplier=1.0, stem_multiplier=1.0,
dense_dropout_keep_prob=0.5, dense_dropout_keep_prob=0.5,
...@@ -100,7 +99,7 @@ def _mobile_imagenet_config(use_aux_head=True): ...@@ -100,7 +99,7 @@ def _mobile_imagenet_config(use_aux_head=True):
filter_scaling_rate=2.0, filter_scaling_rate=2.0,
drop_path_keep_prob=1.0, drop_path_keep_prob=1.0,
num_conv_filters=44, num_conv_filters=44,
use_aux_head=int(use_aux_head), use_aux_head=1,
num_reduction_layers=2, num_reduction_layers=2,
data_format='NHWC', data_format='NHWC',
skip_reduction_layer_input=0, skip_reduction_layer_input=0,
...@@ -108,6 +107,12 @@ def _mobile_imagenet_config(use_aux_head=True): ...@@ -108,6 +107,12 @@ def _mobile_imagenet_config(use_aux_head=True):
) )
def _update_hparams(hparams, is_training):
"""Update hparams for given is_training option."""
if not is_training:
hparams.set_hparam('drop_path_keep_prob', 1.0)
def nasnet_cifar_arg_scope(weight_decay=5e-4, def nasnet_cifar_arg_scope(weight_decay=5e-4,
batch_norm_decay=0.9, batch_norm_decay=0.9,
batch_norm_epsilon=1e-5): batch_norm_epsilon=1e-5):
...@@ -279,10 +284,12 @@ def _cifar_stem(inputs, hparams): ...@@ -279,10 +284,12 @@ def _cifar_stem(inputs, hparams):
return net, [None, net] return net, [None, net]
def build_nasnet_cifar( def build_nasnet_cifar(images, num_classes,
images, num_classes, is_training=True, use_aux_head=True): is_training=True,
config=None):
"""Build NASNet model for the Cifar Dataset.""" """Build NASNet model for the Cifar Dataset."""
hparams = _cifar_config(is_training=is_training, use_aux_head=use_aux_head) hparams = cifar_config() if config is None else copy.deepcopy(config)
_update_hparams(hparams, is_training)
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
tf.logging.info('A GPU is available on the machine, consider using NCHW ' tf.logging.info('A GPU is available on the machine, consider using NCHW '
...@@ -326,9 +333,11 @@ build_nasnet_cifar.default_image_size = 32 ...@@ -326,9 +333,11 @@ build_nasnet_cifar.default_image_size = 32
def build_nasnet_mobile(images, num_classes, def build_nasnet_mobile(images, num_classes,
is_training=True, is_training=True,
final_endpoint=None, final_endpoint=None,
use_aux_head=True): config=None):
"""Build NASNet Mobile model for the ImageNet Dataset.""" """Build NASNet Mobile model for the ImageNet Dataset."""
hparams = _mobile_imagenet_config(use_aux_head=use_aux_head) hparams = (mobile_imagenet_config() if config is None
else copy.deepcopy(config))
_update_hparams(hparams, is_training)
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
tf.logging.info('A GPU is available on the machine, consider using NCHW ' tf.logging.info('A GPU is available on the machine, consider using NCHW '
...@@ -375,10 +384,11 @@ build_nasnet_mobile.default_image_size = 224 ...@@ -375,10 +384,11 @@ build_nasnet_mobile.default_image_size = 224
def build_nasnet_large(images, num_classes, def build_nasnet_large(images, num_classes,
is_training=True, is_training=True,
final_endpoint=None, final_endpoint=None,
use_aux_head=True): config=None):
"""Build NASNet Large model for the ImageNet Dataset.""" """Build NASNet Large model for the ImageNet Dataset."""
hparams = _large_imagenet_config(is_training=is_training, hparams = (large_imagenet_config() if config is None
use_aux_head=use_aux_head) else copy.deepcopy(config))
_update_hparams(hparams, is_training)
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
tf.logging.info('A GPU is available on the machine, consider using NCHW ' tf.logging.info('A GPU is available on the machine, consider using NCHW '
......
...@@ -166,9 +166,11 @@ class NASNetTest(tf.test.TestCase): ...@@ -166,9 +166,11 @@ class NASNetTest(tf.test.TestCase):
tf.reset_default_graph() tf.reset_default_graph()
inputs = tf.random_uniform((batch_size, height, width, 3)) inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step() tf.train.create_global_step()
config = nasnet.cifar_config()
config.set_hparam('use_aux_head', int(use_aux_head))
with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()): with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
_, end_points = nasnet.build_nasnet_cifar(inputs, num_classes, _, end_points = nasnet.build_nasnet_cifar(inputs, num_classes,
use_aux_head=use_aux_head) config=config)
self.assertEqual('AuxLogits' in end_points, use_aux_head) self.assertEqual('AuxLogits' in end_points, use_aux_head)
def testAllEndPointsShapesMobileModel(self): def testAllEndPointsShapesMobileModel(self):
...@@ -215,9 +217,11 @@ class NASNetTest(tf.test.TestCase): ...@@ -215,9 +217,11 @@ class NASNetTest(tf.test.TestCase):
tf.reset_default_graph() tf.reset_default_graph()
inputs = tf.random_uniform((batch_size, height, width, 3)) inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step() tf.train.create_global_step()
config = nasnet.mobile_imagenet_config()
config.set_hparam('use_aux_head', int(use_aux_head))
with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()): with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
_, end_points = nasnet.build_nasnet_mobile(inputs, num_classes, _, end_points = nasnet.build_nasnet_mobile(inputs, num_classes,
use_aux_head=use_aux_head) config=config)
self.assertEqual('AuxLogits' in end_points, use_aux_head) self.assertEqual('AuxLogits' in end_points, use_aux_head)
def testAllEndPointsShapesLargeModel(self): def testAllEndPointsShapesLargeModel(self):
...@@ -270,9 +274,11 @@ class NASNetTest(tf.test.TestCase): ...@@ -270,9 +274,11 @@ class NASNetTest(tf.test.TestCase):
tf.reset_default_graph() tf.reset_default_graph()
inputs = tf.random_uniform((batch_size, height, width, 3)) inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step() tf.train.create_global_step()
config = nasnet.large_imagenet_config()
config.set_hparam('use_aux_head', int(use_aux_head))
with slim.arg_scope(nasnet.nasnet_large_arg_scope()): with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
_, end_points = nasnet.build_nasnet_large(inputs, num_classes, _, end_points = nasnet.build_nasnet_large(inputs, num_classes,
use_aux_head=use_aux_head) config=config)
self.assertEqual('AuxLogits' in end_points, use_aux_head) self.assertEqual('AuxLogits' in end_points, use_aux_head)
def testVariablesSetDeviceMobileModel(self): def testVariablesSetDeviceMobileModel(self):
...@@ -323,6 +329,48 @@ class NASNetTest(tf.test.TestCase): ...@@ -323,6 +329,48 @@ class NASNetTest(tf.test.TestCase):
output = sess.run(predictions) output = sess.run(predictions)
self.assertEquals(output.shape, (batch_size,)) self.assertEquals(output.shape, (batch_size,))
def testOverrideHParamsCifarModel(self):
batch_size = 5
height, width = 32, 32
num_classes = 10
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
config = nasnet.cifar_config()
config.set_hparam('data_format', 'NCHW')
with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
_, end_points = nasnet.build_nasnet_cifar(
inputs, num_classes, config=config)
self.assertListEqual(
end_points['Stem'].shape.as_list(), [batch_size, 96, 32, 32])
def testOverrideHParamsMobileModel(self):
batch_size = 5
height, width = 224, 224
num_classes = 1000
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
config = nasnet.mobile_imagenet_config()
config.set_hparam('data_format', 'NCHW')
with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
_, end_points = nasnet.build_nasnet_mobile(
inputs, num_classes, config=config)
self.assertListEqual(
end_points['Stem'].shape.as_list(), [batch_size, 88, 28, 28])
def testOverrideHParamsLargeModel(self):
batch_size = 5
height, width = 331, 331
num_classes = 1000
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
config = nasnet.large_imagenet_config()
config.set_hparam('data_format', 'NCHW')
with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
_, end_points = nasnet.build_nasnet_large(
inputs, num_classes, config=config)
self.assertListEqual(
end_points['Stem'].shape.as_list(), [batch_size, 336, 42, 42])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment