Unverified Commit 2472278c authored by Yanhui Liang's avatar Yanhui Liang Committed by GitHub
Browse files

Fix typos of model name (#5063)

parent 83a9a239
...@@ -18,7 +18,6 @@ from __future__ import division ...@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.utils.misc import model_helpers # pylint: disable=g-bad-import-order from official.utils.misc import model_helpers # pylint: disable=g-bad-import-order
# Default values for dataset. # Default values for dataset.
...@@ -29,7 +28,7 @@ _NUM_CLASSES = 1000 ...@@ -29,7 +28,7 @@ _NUM_CLASSES = 1000
def _get_default_image_size(model): def _get_default_image_size(model):
"""Provide default image size for each model.""" """Provide default image size for each model."""
image_size = (224, 224) image_size = (224, 224)
if model in ["inception", "xception", "inceptionresnet"]: if model in ["inceptionv3", "xception", "inceptionresnetv2"]:
image_size = (299, 299) image_size = (299, 299)
elif model in ["nasnetlarge"]: elif model in ["nasnetlarge"]:
image_size = (331, 331) image_size = (331, 331)
...@@ -42,8 +41,8 @@ def generate_synthetic_input_dataset(model, batch_size): ...@@ -42,8 +41,8 @@ def generate_synthetic_input_dataset(model, batch_size):
image_shape = (batch_size,) + image_size + (_NUM_CHANNELS,) image_shape = (batch_size,) + image_size + (_NUM_CHANNELS,)
label_shape = (batch_size, _NUM_CLASSES) label_shape = (batch_size, _NUM_CLASSES)
return model_helpers.generate_synthetic_data( dataset = model_helpers.generate_synthetic_data(
input_shape=tf.TensorShape(image_shape), input_shape=tf.TensorShape(image_shape),
input_dtype=tf.float32,
label_shape=tf.TensorShape(label_shape), label_shape=tf.TensorShape(label_shape),
label_dtype=tf.float32) )
return dataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment