"tests/vscode:/vscode.git/clone" did not exist on "131312caba0af97da98fc498dfdca335c9692f8c"
Unverified Commit ee9e3636 authored by Tian Lin's avatar Tian Lin Committed by GitHub
Browse files

Add CIFAR-10 support for keras application models. (#5892)

* Add CIFAR-10 support for keras application models.
parent 91a59c78
...@@ -62,7 +62,6 @@ def run_keras_model_benchmark(_): ...@@ -62,7 +62,6 @@ def run_keras_model_benchmark(_):
# Load the model # Load the model
tf.logging.info("Benchmark on {} model...".format(FLAGS.model)) tf.logging.info("Benchmark on {} model...".format(FLAGS.model))
keras_model = MODELS[FLAGS.model] keras_model = MODELS[FLAGS.model]
model = keras_model(weights=None)
# Get dataset # Get dataset
dataset_name = "ImageNet" dataset_name = "ImageNet"
...@@ -73,8 +72,15 @@ def run_keras_model_benchmark(_): ...@@ -73,8 +72,15 @@ def run_keras_model_benchmark(_):
FLAGS.model, FLAGS.batch_size) FLAGS.model, FLAGS.batch_size)
val_dataset = dataset.generate_synthetic_input_dataset( val_dataset = dataset.generate_synthetic_input_dataset(
FLAGS.model, FLAGS.batch_size) FLAGS.model, FLAGS.batch_size)
model = keras_model(weights=None)
else: else:
raise ValueError("Only synthetic dataset is supported!") tf.logging.info("Using CIFAR-10 dataset...")
dataset_name = "CIFAR-10"
ds = dataset.Cifar10Dataset(FLAGS.batch_size)
train_dataset = ds.train_dataset
val_dataset = ds.test_dataset
model = keras_model(
weights=None, input_shape=ds.input_shape, classes=ds.num_classes)
num_gpus = flags_core.get_num_gpus(FLAGS) num_gpus = flags_core.get_num_gpus(FLAGS)
......
...@@ -17,6 +17,7 @@ from __future__ import absolute_import ...@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
import tensorflow as tf import tensorflow as tf
from official.utils.misc import model_helpers # pylint: disable=g-bad-import-order from official.utils.misc import model_helpers # pylint: disable=g-bad-import-order
...@@ -46,3 +47,28 @@ def generate_synthetic_input_dataset(model, batch_size): ...@@ -46,3 +47,28 @@ def generate_synthetic_input_dataset(model, batch_size):
label_shape=tf.TensorShape(label_shape), label_shape=tf.TensorShape(label_shape),
) )
return dataset return dataset
class Cifar10Dataset(object):
"""CIFAR10 dataset, including train and test set.
Each sample consists of a 32x32 color image, and label is from 10 classes.
"""
def __init__(self, batch_size):
"""Initializes train/test datasets.
Args:
batch_size: int, the number of batch size.
"""
self.input_shape = (32, 32, 3)
self.num_classes = 10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)
y_train = tf.keras.utils.to_categorical(y_train, self.num_classes)
y_test = tf.keras.utils.to_categorical(y_test, self.num_classes)
self.train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(2000).batch(batch_size)
self.test_dataset = tf.data.Dataset.from_tensor_slices(
(x_test, y_test)).shuffle(2000).batch(batch_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment