# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Provides utilities to Cifar-10 dataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os from absl import logging import tensorflow as tf from official.legacy.image_classification.resnet import imagenet_preprocessing HEIGHT = 32 WIDTH = 32 NUM_CHANNELS = 3 _DEFAULT_IMAGE_BYTES = HEIGHT * WIDTH * NUM_CHANNELS # The record is the image plus a one-byte label _RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1 # TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits. NUM_IMAGES = { 'train': 50000, 'validation': 10000, } _NUM_DATA_FILES = 5 NUM_CLASSES = 10 def parse_record(raw_record, is_training, dtype): """Parses a record containing a training example of an image. The input record is parsed into a label and image, and the image is passed through preprocessing steps (cropping, flipping, and so on). This method converts the label to one hot to fit the loss function. Args: raw_record: scalar Tensor tf.string containing a serialized Example protocol buffer. is_training: A boolean denoting whether the input is for training. dtype: Data type to use for input images. Returns: Tuple with processed image tensor and one-hot-encoded label tensor. """ # Convert bytes to a vector of uint8 that is record_bytes long. record_vector = tf.io.decode_raw(raw_record, tf.uint8) # The first byte represents the label, which we convert from uint8 to int32 # and then to one-hot. label = tf.cast(record_vector[0], tf.int32) # The remaining bytes after the label represent the image, which we reshape # from [depth * height * width] to [depth, height, width]. depth_major = tf.reshape(record_vector[1:_RECORD_BYTES], [NUM_CHANNELS, HEIGHT, WIDTH]) # Convert from [depth, height, width] to [height, width, depth], and cast as # float32. image = tf.cast(tf.transpose(a=depth_major, perm=[1, 2, 0]), tf.float32) image = preprocess_image(image, is_training) image = tf.cast(image, dtype) return image, label def preprocess_image(image, is_training): """Preprocess a single image of layout [height, width, depth].""" if is_training: # Resize the image to add four extra pixels on each side. image = tf.image.resize_with_crop_or_pad(image, HEIGHT + 8, WIDTH + 8) # Randomly crop a [HEIGHT, WIDTH] section of the image. image = tf.image.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS]) # Randomly flip the image horizontally. image = tf.image.random_flip_left_right(image) # Subtract off the mean and divide by the variance of the pixels. image = tf.image.per_image_standardization(image) return image def get_filenames(is_training, data_dir): """Returns a list of filenames.""" assert tf.io.gfile.exists(data_dir), ( 'Run cifar10_download_and_extract.py first to download and extract the ' 'CIFAR-10 data.') if is_training: return [ os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, _NUM_DATA_FILES + 1) ] else: return [os.path.join(data_dir, 'test_batch.bin')] def input_fn(is_training, data_dir, batch_size, dtype=tf.float32, datasets_num_private_threads=None, parse_record_fn=parse_record, input_context=None, drop_remainder=False): """Input function which provides batches for train or eval. Args: is_training: A boolean denoting whether the input is for training. data_dir: The directory containing the input data. batch_size: The number of samples per batch. dtype: Data type to use for images/features datasets_num_private_threads: Number of private threads for tf.data. parse_record_fn: Function to use for parsing the records. input_context: A `tf.distribute.InputContext` object passed in by `tf.distribute.Strategy`. drop_remainder: A boolean indicates whether to drop the remainder of the batches. If True, the batch dimension will be static. Returns: A dataset that can be used for iteration. """ filenames = get_filenames(is_training, data_dir) dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES) if input_context: logging.info( 'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d', input_context.input_pipeline_id, input_context.num_input_pipelines) dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) return imagenet_preprocessing.process_record_dataset( dataset=dataset, is_training=is_training, batch_size=batch_size, shuffle_buffer=NUM_IMAGES['train'], parse_record_fn=parse_record_fn, dtype=dtype, datasets_num_private_threads=datasets_num_private_threads, drop_remainder=drop_remainder)