Unverified Commit d3b02131 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add CIFAR-10 dataset loader (#2245)

* fix typos

* add cifar-10

* open files in binary mode

* print messages with file name only, like mnist loader

* some fixes

* add mnist.cpp to CMakeLists.txt

* fix test index

* do not use iterator in cast

* add cifar.cpp to all

* Add Davis' suggestions

* no need to use namespace std and clean up empty lines
parent d9e58d66
...@@ -315,6 +315,7 @@ if (NOT TARGET dlib) ...@@ -315,6 +315,7 @@ if (NOT TARGET dlib)
cuda/tensor_tools.cpp cuda/tensor_tools.cpp
data_io/image_dataset_metadata.cpp data_io/image_dataset_metadata.cpp
data_io/mnist.cpp data_io/mnist.cpp
data_io/cifar.cpp
global_optimization/global_function_search.cpp global_optimization/global_function_search.cpp
filtering/kalman_filter.cpp filtering/kalman_filter.cpp
svm/auto.cpp svm/auto.cpp
......
...@@ -83,6 +83,7 @@ ...@@ -83,6 +83,7 @@
#include "../cuda/tensor_tools.cpp" #include "../cuda/tensor_tools.cpp"
#include "../data_io/image_dataset_metadata.cpp" #include "../data_io/image_dataset_metadata.cpp"
#include "../data_io/mnist.cpp" #include "../data_io/mnist.cpp"
#include "../data_io/cifar.cpp"
#include "../svm/auto.cpp" #include "../svm/auto.cpp"
#include "../global_optimization/global_function_search.cpp" #include "../global_optimization/global_function_search.cpp"
#include "../filtering/kalman_filter.cpp" #include "../filtering/kalman_filter.cpp"
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "data_io/libsvm_io.h" #include "data_io/libsvm_io.h"
#include "data_io/image_dataset_metadata.h" #include "data_io/image_dataset_metadata.h"
#include "data_io/mnist.h" #include "data_io/mnist.h"
#include "data_io/cifar.h"
#ifndef DLIB_ISO_CPP_ONLY #ifndef DLIB_ISO_CPP_ONLY
#include "data_io/load_image_dataset.h" #include "data_io/load_image_dataset.h"
......
// Copyright (C) 2020 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CIFAR_CPp_
#define DLIB_CIFAR_CPp_
#include "cifar.h"
#include <fstream>
// ----------------------------------------------------------------------------------------
namespace dlib
{
namespace impl
{
void load_cifar_10_batch (
const std::string& folder_name,
const std::string& batch_name,
const size_t first_idx,
const size_t images_per_batch,
std::vector<matrix<rgb_pixel>>& images,
std::vector<unsigned long>& labels
)
{
std::ifstream fin((folder_name + "/" + batch_name).c_str(), std::ios::binary);
if (!fin) throw error("Unable to open file " + batch_name);
const long nr = 32;
const long nc = 32;
const long plane_size = nr * nc;
const long image_size = 3 * plane_size;
for (size_t i = 0; i < images_per_batch; ++i)
{
char l;
fin.read(&l, 1);
labels[first_idx + i] = l;
images[first_idx + i].set_size(nr, nc);
std::array<unsigned char, image_size> buffer;
fin.read((char*)(buffer.data()), buffer.size());
for (long k = 0; k < plane_size; ++k)
{
char r = buffer[0 * plane_size + k];
char g = buffer[1 * plane_size + k];
char b = buffer[2 * plane_size + k];
const long row = k / nr;
const long col = k % nr;
images[first_idx + i](row, col) = rgb_pixel(r, g, b);
}
}
if (!fin) throw error("Unable to read file " + batch_name);
if (fin.get() != EOF) throw error("Unexpected bytes at end of " + batch_name);
}
}
void load_cifar_10_dataset (
const std::string& folder_name,
std::vector<matrix<rgb_pixel>>& training_images,
std::vector<unsigned long>& training_labels,
std::vector<matrix<rgb_pixel>>& testing_images,
std::vector<unsigned long>& testing_labels
)
{
using namespace std;
const size_t images_per_batch = 10000;
const size_t num_training_batches = 5;
const size_t num_testing_batches = 1;
training_images.resize(images_per_batch * num_training_batches);
training_labels.resize(images_per_batch * num_training_batches);
testing_images.resize(images_per_batch * num_testing_batches);
testing_labels.resize(images_per_batch * num_testing_batches);
std::vector<string> training_batches_names{
"data_batch_1.bin",
"data_batch_2.bin",
"data_batch_3.bin",
"data_batch_4.bin",
"data_batch_5.bin",
};
for (size_t i = 0; i < num_training_batches; ++i)
{
impl::load_cifar_10_batch(
folder_name,
training_batches_names[i],
i * images_per_batch,
images_per_batch,
training_images,
training_labels);
}
impl::load_cifar_10_batch(
folder_name,
"test_batch.bin",
0,
images_per_batch,
testing_images,
testing_labels);
}
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_CIFAR_CPp_
// Copyright (C) 2020 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CIFAR_Hh_
#define DLIB_CIFAR_Hh_
#include "cifar_abstract.h"
#include <string>
#include <vector>
#include "../matrix.h"
#include "../pixel.h"
// ----------------------------------------------------------------------------------------
namespace dlib
{
void load_cifar_10_dataset (
const std::string& folder_name,
std::vector<matrix<rgb_pixel>>& training_images,
std::vector<unsigned long>& training_labels,
std::vector<matrix<rgb_pixel>>& testing_images,
std::vector<unsigned long>& testing_labels
);
}
// ----------------------------------------------------------------------------------------
#ifdef NO_MAKEFILE
#include "cifar.cpp"
#endif
#endif // DLIB_CIFAR_Hh_
// Copyright (C) 2020 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_CIFAR_ABSTRACT_Hh_
#ifdef DLIB_CIFAR_ABSTRACT_Hh_
#include <string>
#include <vector>
#include "../matrix.h"
#include "../pixel.h"
// ----------------------------------------------------------------------------------------
namespace dlib
{
void load_cifar_10_dataset (
const std::string& folder_name,
std::vector<matrix<rgb_pixel>>& training_images,
std::vector<unsigned long>& training_labels,
std::vector<matrix<rgb_pixel>>& testing_images,
std::vector<unsigned long>& testing_labels
);
/*!
ensures
- Attempts to load the CIFAR-10 dataset from the hard drive. The CIFAR-10
dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images
per class. There are 50000 training images and 10000 test images. It is
available from https://www.cs.toronto.edu/~kriz/cifar.html. In particular,
the 6 files comprising the CIFAR-10 dataset should be present in the folder
indicated by folder_name. These six files are:
- data_batch_1.bin
- data_batch_2.bin
- data_batch_3.bin
- data_batch_4.bin
- data_batch_5.bin
- test_batch.bin
- #training_images == The 50,000 training images from the dataset.
- #training_labels == The labels for the contents of #training_images.
I.e. #training_labels[i] is the label of #training_images[i].
- #testing_images == The 10,000 testing images from the dataset.
- #testing_labels == The labels for the contents of #testing_images.
I.e. #testing_labels[i] is the label of #testing_images[i].
throws
- dlib::error if some problem prevents us from loading the data or the files
can't be found.
!*/
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_CIFAR_ABSTRACT_Hh_
...@@ -59,24 +59,24 @@ namespace dlib ...@@ -59,24 +59,24 @@ namespace dlib
fin1.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr); fin1.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr);
fin1.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc); fin1.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc);
if (magic != 2051 || num != 60000 || nr != 28 || nc != 28) if (magic != 2051 || num != 60000 || nr != 28 || nc != 28)
throw error("mndist dat files are corrupted."); throw error("mnist dat files are corrupted.");
fin2.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); fin2.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic);
fin2.read((char*)&num2, sizeof(num2)); bo.big_to_host(num2); fin2.read((char*)&num2, sizeof(num2)); bo.big_to_host(num2);
if (magic != 2049 || num2 != 60000) if (magic != 2049 || num2 != 60000)
throw error("mndist dat files are corrupted."); throw error("mnist dat files are corrupted.");
fin3.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); fin3.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic);
fin3.read((char*)&num3, sizeof(num3)); bo.big_to_host(num3); fin3.read((char*)&num3, sizeof(num3)); bo.big_to_host(num3);
fin3.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr); fin3.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr);
fin3.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc); fin3.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc);
if (magic != 2051 || num3 != 10000 || nr != 28 || nc != 28) if (magic != 2051 || num3 != 10000 || nr != 28 || nc != 28)
throw error("mndist dat files are corrupted."); throw error("mnist dat files are corrupted.");
fin4.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); fin4.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic);
fin4.read((char*)&num4, sizeof(num4)); bo.big_to_host(num4); fin4.read((char*)&num4, sizeof(num4)); bo.big_to_host(num4);
if (magic != 2049 || num4 != 10000) if (magic != 2049 || num4 != 10000)
throw error("mndist dat files are corrupted."); throw error("mnist dat files are corrupted.");
if (!fin1) throw error("Unable to read train-images-idx3-ubyte"); if (!fin1) throw error("Unable to read train-images-idx3-ubyte");
if (!fin2) throw error("Unable to read train-labels-idx1-ubyte"); if (!fin2) throw error("Unable to read train-labels-idx1-ubyte");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment