Commit bff0223b authored by Scott Thornton's avatar Scott Thornton
Browse files

Added test for MNIST

parent 79fe7d41
...@@ -13,3 +13,7 @@ target_link_libraries(rtg_onnx onnx-proto rtg) ...@@ -13,3 +13,7 @@ target_link_libraries(rtg_onnx onnx-proto rtg)
add_executable(read_onnx read_onnx.cpp) add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx rtg_onnx rtg_cpu) target_link_libraries(read_onnx rtg_onnx rtg_cpu)
add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist)
target_link_libraries(mnist rtg_onnx rtg_cpu)
#include <cstdio>
#include <string>
#include <fstream>
#include <stdexcept>
#include <rtg/onnx.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include <rtg/generate.hpp>
std::vector<float> read_mnist_images(std::string full_path, int& number_of_images, int& image_size)
{
auto reverseInt = [](int i) {
unsigned char c1, c2, c3, c4;
c1 = i & 255;
c2 = (i >> 8) & 255;
c3 = (i >> 16) & 255;
c4 = (i >> 24) & 255;
return (static_cast<int>(c1) << 24) + (static_cast<int>(c2) << 16) +
(static_cast<int>(c3) << 8) + c4;
};
typedef unsigned char uchar;
std::ifstream file(full_path, std::ios::binary);
if(file.is_open())
{
int magic_number = 0, n_rows = 0, n_cols = 0;
file.read((char*)&magic_number, sizeof(magic_number));
magic_number = reverseInt(magic_number);
if(magic_number != 2051)
throw std::runtime_error("Invalid MNIST image file!");
file.read((char*)&number_of_images, sizeof(number_of_images)),
number_of_images = reverseInt(number_of_images);
file.read((char*)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows);
file.read((char*)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols);
image_size = n_rows * n_cols;
printf("n_rows: %d n_cols: %d image_size: %d\n\n", n_rows, n_cols, image_size);
// uchar** _dataset = new uchar*[number_of_images];
// for(int i = 0; i < number_of_images; i++) {
// _dataset[i] = new uchar[image_size];
// file.read((char *)_dataset[i], image_size);
// }
std::vector<float> result(number_of_images * image_size);
for(int i = 0; i < number_of_images; i++)
{
for(int j = 0; j < image_size; j++)
{
uchar tmp;
file.read((char*)&tmp, 1);
result[i * image_size + j] = tmp / 255.0;
}
}
return result;
}
else
{
throw std::runtime_error("Cannot open file `" + full_path + "`!");
}
}
std::vector<int32_t> read_mnist_labels(std::string full_path, int& number_of_labels)
{
auto reverseInt = [](int i) {
unsigned char c1, c2, c3, c4;
c1 = i & 255;
c2 = (i >> 8) & 255;
c3 = (i >> 16) & 255;
c4 = (i >> 24) & 255;
return (static_cast<int>(c1) << 24) + (static_cast<int>(c2) << 16) +
(static_cast<int>(c3) << 8) + c4;
};
typedef unsigned char uchar;
std::ifstream file(full_path, std::ios::binary);
if(file.is_open())
{
int magic_number = 0;
file.read((char*)&magic_number, sizeof(magic_number));
magic_number = reverseInt(magic_number);
if(magic_number != 2049)
throw std::runtime_error("Invalid MNIST label file!");
file.read((char*)&number_of_labels, sizeof(number_of_labels)),
number_of_labels = reverseInt(number_of_labels);
std::vector<int32_t> result(number_of_labels);
for(int i = 0; i < number_of_labels; i++)
{
uchar tmp;
file.read((char*)&tmp, 1);
result[i] = tmp;
}
return result;
}
else
{
throw std::runtime_error("Unable to open file `" + full_path + "`!");
}
}
int main(int argc, char const* argv[])
{
if(argc > 1)
{
std::string datafile = argv[2];
std::string labelfile = argv[3];
int nimages = -1;
int image_size = -1;
int nlabels = -1;
std::vector<float> input = read_mnist_images(datafile, nimages, image_size);
std::vector<int32_t> labels = read_mnist_labels(labelfile, nlabels);
printf("label: %d\n\n", labels[0]);
for(int i = 7; i < 9; i++)
{
for(int j = 0; j < 28; j++)
{
printf("%8.5f ", input[i * 28 + j]);
}
printf("\n");
}
std::string file = argv[1];
auto prog = rtg::parse_onnx(file);
prog.compile(rtg::cpu::cpu_target{});
auto s = prog.get_parameter_shape("Input3");
std::cout << s << std::endl;
auto input3 = rtg::argument{s, input.data()};
auto out = prog.eval({{"Input3", input3}});
std::cout << out << std::endl;
std::cout << prog << std::endl;
}
}
...@@ -141,6 +141,10 @@ argument program::eval(std::unordered_map<std::string, argument> params) const ...@@ -141,6 +141,10 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
values.begin(), values.begin(),
[&](instruction_ref i) { return results.at(std::addressof(*i)); }); [&](instruction_ref i) { return results.at(std::addressof(*i)); });
result = ins.op.compute(ins.result, values); result = ins.op.compute(ins.result, values);
std::cout << "Debug: " << ins.op.name() << "\n";
if(result.get_shape().elements() > 0 and result.get_shape().packed() and
std::isnan(result.at<float>()))
std::cout << "Nan: " << ins.op.name() << std::endl;
} }
results.emplace(std::addressof(ins), result); results.emplace(std::addressof(ins), result);
} }
......
...@@ -60,7 +60,11 @@ struct max_pool ...@@ -60,7 +60,11 @@ struct max_pool
static std::string name() { return "max"; } static std::string name() { return "max"; }
static double start() { return std::numeric_limits<double>::lowest(); } static double start() { return std::numeric_limits<double>::lowest(); }
static double apply(double x, double y) { return x + y; } static double apply(double x, double y)
{
double m = std::max(x, y);
return (m);
}
static double final(double x, double) { return (x); } static double final(double x, double) { return (x); }
}; };
...@@ -70,11 +74,7 @@ struct avg_pool ...@@ -70,11 +74,7 @@ struct avg_pool
static std::string name() { return "average"; } static std::string name() { return "average"; }
static double start() { return 0.0; } static double start() { return 0.0; }
static double apply(double x, double y) static double apply(double x, double y) { return x + y; }
{
double m = std::max(x, y);
return (m);
}
static double final(double x, double y) { return x / y; } static double final(double x, double y) { return x / y; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment