cifar10.cpp 3.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
#include <cstdio>
#include <string>
#include <fstream>
#include <numeric>
#include <stdexcept>

#include <migraph/onnx.hpp>

#include <migraph/cpu/cpu_target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>

wsttiger's avatar
wsttiger committed
14
#include "softmax.hpp"
15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
auto read_cifar10_images(const std::string& full_path)
{
    std::ifstream file(full_path, std::ios::binary);

    const size_t nimages          = 10;
    const size_t nbytes_per_image = 3072;
    std::vector<uint8_t> raw_data(nimages * (nbytes_per_image + 1));
    std::vector<uint8_t> labels(nimages);
    std::vector<float> data(nimages * nbytes_per_image);
    if(file.is_open())
    {
        file.read(reinterpret_cast<char*>(raw_data.data()),
                  (nbytes_per_image + 1) * nimages * sizeof(uint8_t));
        uint8_t* pimage = raw_data.data();
        for(size_t i = 0; i < nimages; i++, pimage += nbytes_per_image)
        {
            labels[i] = *pimage++;
            for(size_t j = 0; j < nbytes_per_image; j++)
            {
                float v                        = *(pimage + j) / 255.0f;
                data[i * nbytes_per_image + j] = v;
            }
        }
        return std::make_pair(labels, data);
    }
    else
    {
        throw std::runtime_error("Cannot open file `" + full_path + "`!");
    }
}

int main(int argc, char const* argv[])
{
wsttiger's avatar
wsttiger committed
49
    if(argc < 4)
50
51
52
53
54
55
56
57
58
59
    {
        throw std::runtime_error("Usage:  cifar10 [gpu | cpu] <onnx file> <cifar10 data file>");
    }
    std::string gpu_cpu  = argv[1];
    std::string file     = argv[2];
    std::string datafile = argv[3];
    auto prog            = migraph::parse_onnx(file);
    std::cout << prog << std::endl;
    auto imageset = read_cifar10_images(datafile);

wsttiger's avatar
wsttiger committed
60
    if(gpu_cpu == "gpu")
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    {
        // GPU target
        prog.compile(migraph::gpu::target{});
        migraph::program::parameter_map m;
        auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
        for(auto&& x : prog.get_parameter_shapes())
        {
            m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
        }
        auto labels = imageset.first;
        auto input  = imageset.second;
        auto ptr    = input.data();
        for(int i = 0; i < 10; i++)
        {
            std::cout << "label: " << static_cast<uint32_t>(labels[i]) << "  ---->  ";
            m["0"]      = migraph::gpu::to_gpu(migraph::argument{s, &ptr[3072 * i]});
            auto result = migraph::gpu::from_gpu(prog.eval(m));
            std::vector<float> logits;
            result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
80
            std::vector<float> probs = softmax<float>(logits);
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            for(auto x : probs)
                std::cout << x << "    ";
            std::cout << std::endl << std::endl;
        }
    }
    else
    {
        // CPU target
        prog.compile(migraph::cpu::cpu_target{});
        auto s      = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
        auto labels = imageset.first;
        auto input  = imageset.second;
        auto ptr    = input.data();
        for(int i = 0; i < 10; i++)
        {
            std::cout << "label: " << static_cast<uint32_t>(labels[i]) << "  ---->  ";
            auto input3 = migraph::argument{s, &ptr[3072 * i]};
            auto result = prog.eval({{"0", input3}});
            std::vector<float> logits;
            result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
101
            std::vector<float> probs = softmax<float>(logits);
102
103
104
105
106
107
            for(auto x : probs)
                std::cout << x << "    ";
            std::cout << std::endl;
        }
    }
}