"src/targets/vscode:/vscode.git/clone" did not exist on "6155c7822a57a96b95ed683887859b53f2a30e73"
mnist.cpp 4.87 KB
Newer Older
Scott Thornton's avatar
Scott Thornton committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <cstdio>
#include <string>
#include <fstream>
#include <stdexcept>

#include <rtg/onnx.hpp>

#include <rtg/cpu/cpu_target.hpp>
#include <rtg/generate.hpp>

std::vector<float> read_mnist_images(std::string full_path, int& number_of_images, int& image_size)
{
    auto reverseInt = [](int i) {
        unsigned char c1, c2, c3, c4;
        c1 = i & 255;
        c2 = (i >> 8) & 255;
        c3 = (i >> 16) & 255;
        c4 = (i >> 24) & 255;
        return (static_cast<int>(c1) << 24) + (static_cast<int>(c2) << 16) +
               (static_cast<int>(c3) << 8) + c4;
    };

    typedef unsigned char uchar;

    std::ifstream file(full_path, std::ios::binary);

    if(file.is_open())
    {
        int magic_number = 0, n_rows = 0, n_cols = 0;

        file.read((char*)&magic_number, sizeof(magic_number));
        magic_number = reverseInt(magic_number);

        if(magic_number != 2051)
            throw std::runtime_error("Invalid MNIST image file!");

        file.read((char*)&number_of_images, sizeof(number_of_images)),
            number_of_images = reverseInt(number_of_images);
        file.read((char*)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows);
        file.read((char*)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols);

        image_size = n_rows * n_cols;

        printf("n_rows: %d    n_cols: %d    image_size: %d\n\n", n_rows, n_cols, image_size);

        // uchar** _dataset = new uchar*[number_of_images];
        // for(int i = 0; i < number_of_images; i++) {
        //     _dataset[i] = new uchar[image_size];
        //     file.read((char *)_dataset[i], image_size);
        // }

        std::vector<float> result(number_of_images * image_size);
        for(int i = 0; i < number_of_images; i++)
        {
            for(int j = 0; j < image_size; j++)
            {
                uchar tmp;
                file.read((char*)&tmp, 1);
                result[i * image_size + j] = tmp / 255.0;
            }
        }
        return result;
    }
    else
    {
        throw std::runtime_error("Cannot open file `" + full_path + "`!");
    }
}

std::vector<int32_t> read_mnist_labels(std::string full_path, int& number_of_labels)
{
    auto reverseInt = [](int i) {
        unsigned char c1, c2, c3, c4;
        c1 = i & 255;
        c2 = (i >> 8) & 255;
        c3 = (i >> 16) & 255;
        c4 = (i >> 24) & 255;
        return (static_cast<int>(c1) << 24) + (static_cast<int>(c2) << 16) +
               (static_cast<int>(c3) << 8) + c4;
    };

    typedef unsigned char uchar;

    std::ifstream file(full_path, std::ios::binary);

    if(file.is_open())
    {
        int magic_number = 0;
        file.read((char*)&magic_number, sizeof(magic_number));
        magic_number = reverseInt(magic_number);

        if(magic_number != 2049)
            throw std::runtime_error("Invalid MNIST label file!");

        file.read((char*)&number_of_labels, sizeof(number_of_labels)),
            number_of_labels = reverseInt(number_of_labels);

        std::vector<int32_t> result(number_of_labels);
        for(int i = 0; i < number_of_labels; i++)
        {
            uchar tmp;
            file.read((char*)&tmp, 1);
            result[i] = tmp;
        }
        return result;
    }
    else
    {
        throw std::runtime_error("Unable to open file `" + full_path + "`!");
    }
}

Scott Thornton's avatar
Scott Thornton committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
std::vector<float> softmax(std::vector<float> p) {
    size_t n = p.size();
    std::vector<float> result(n);
    float s = 0.0f;
    for (size_t i = 0; i < n; i++) {
        result[i] = std::exp(p[i]);
        s += result[i];
    }
    for (size_t i = 0; i < n; i++) {
        result[i] = result[i]/s;
    }
    return result;
}

Scott Thornton's avatar
Scott Thornton committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
int main(int argc, char const* argv[])
{
    if(argc > 1)
    {
        std::string datafile        = argv[2];
        std::string labelfile       = argv[3];
        int nimages                 = -1;
        int image_size              = -1;
        int nlabels                 = -1;
        std::vector<float> input    = read_mnist_images(datafile, nimages, image_size);
        std::vector<int32_t> labels = read_mnist_labels(labelfile, nlabels);

        std::string file = argv[1];
        auto prog        = rtg::parse_onnx(file);
        prog.compile(rtg::cpu::cpu_target{});
Scott Thornton's avatar
Scott Thornton committed
142
143
        // auto s = prog.get_parameter_shape("Input3");
        auto s = rtg::shape{rtg::shape::float_type, {1, 1, 28, 28}};
Scott Thornton's avatar
Scott Thornton committed
144
        std::cout << s << std::endl;
Scott Thornton's avatar
Scott Thornton committed
145
146
147
148
149
150
151
152
153
154
155
156
157
        auto ptr = input.data();
        for (int i = 0; i < 20; i++)
        {
            printf("label: %d  ---->  ", labels[i]);
            auto input3 = rtg::argument{s, &ptr[784*i]};
            auto result = prog.eval({{"Input3", input3}});
            std::vector<float> logits;
            result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
            std::vector<float> probs = softmax(logits);
            for (auto x : probs) printf("%8.4f    ", x);
            printf("\n");
        }
        printf("\n");
Scott Thornton's avatar
Scott Thornton committed
158
159
    }
}