Commit d71cbd0c authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #1867 from tensorflow/fix-reversions

Fix reversions plus a few improvements
parents 8cedd479 00b801b0
...@@ -213,7 +213,7 @@ Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy | ...@@ -213,7 +213,7 @@ Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy |
^ ResNet V2 models use Inception pre-processing and input image size of 299 (use ^ ResNet V2 models use Inception pre-processing and input image size of 299 (use
`--preprocessing_name inception --eval_image_size 299` when using `--preprocessing_name inception --eval_image_size 299` when using
`eval_image_classifier.py`). Performance numbers for ResNet V2 models are `eval_image_classifier.py`). Performance numbers for ResNet V2 models are
reported on ImageNet valdiation set. reported on the ImageNet validation set.
All 16 MobileNet Models reported in the [MobileNet Paper](https://arxiv.org/abs/1704.04861) can be found [here](https://github.com/tensorflow/models/tree/master/slim/nets/mobilenet_v1.md). All 16 MobileNet Models reported in the [MobileNet Paper](https://arxiv.org/abs/1704.04861) can be found [here](https://github.com/tensorflow/models/tree/master/slim/nets/mobilenet_v1.md).
......
...@@ -100,10 +100,10 @@ def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): ...@@ -100,10 +100,10 @@ def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
""" """
if name not in networks_map: if name not in networks_map:
raise ValueError('Name of network unknown %s' % name) raise ValueError('Name of network unknown %s' % name)
arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
func = networks_map[name] func = networks_map[name]
@functools.wraps(func) @functools.wraps(func)
def network_fn(images): def network_fn(images):
arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
with slim.arg_scope(arg_scope): with slim.arg_scope(arg_scope):
return func(images, num_classes, is_training=is_training) return func(images, num_classes, is_training=is_training)
if hasattr(func, 'default_image_size'): if hasattr(func, 'default_image_size'):
......
...@@ -57,6 +57,11 @@ def get_preprocessing(name, is_training=False): ...@@ -57,6 +57,11 @@ def get_preprocessing(name, is_training=False):
'resnet_v1_50': vgg_preprocessing, 'resnet_v1_50': vgg_preprocessing,
'resnet_v1_101': vgg_preprocessing, 'resnet_v1_101': vgg_preprocessing,
'resnet_v1_152': vgg_preprocessing, 'resnet_v1_152': vgg_preprocessing,
'resnet_v1_200': vgg_preprocessing,
'resnet_v2_50': vgg_preprocessing,
'resnet_v2_101': vgg_preprocessing,
'resnet_v2_152': vgg_preprocessing,
'resnet_v2_200': vgg_preprocessing,
'vgg': vgg_preprocessing, 'vgg': vgg_preprocessing,
'vgg_a': vgg_preprocessing, 'vgg_a': vgg_preprocessing,
'vgg_16': vgg_preprocessing, 'vgg_16': vgg_preprocessing,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment