cifar10.cpp 3.64 KB
Newer Older
1
2
3
4
5
6
#include <cstdio>
#include <string>
#include <fstream>
#include <numeric>
#include <stdexcept>

Paul's avatar
Paul committed
7
#include <migraphx/onnx.hpp>
8

Paul's avatar
Paul committed
9
10
11
12
#include <migraphx/cpu/target.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
13

wsttiger's avatar
wsttiger committed
14
#include "softmax.hpp"
15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
auto read_cifar10_images(const std::string& full_path)
{
    std::ifstream file(full_path, std::ios::binary);

    const size_t nimages          = 10;
    const size_t nbytes_per_image = 3072;
    std::vector<uint8_t> raw_data(nimages * (nbytes_per_image + 1));
    std::vector<uint8_t> labels(nimages);
    std::vector<float> data(nimages * nbytes_per_image);
    if(file.is_open())
    {
        file.read(reinterpret_cast<char*>(raw_data.data()),
                  (nbytes_per_image + 1) * nimages * sizeof(uint8_t));
        uint8_t* pimage = raw_data.data();
        for(size_t i = 0; i < nimages; i++, pimage += nbytes_per_image)
        {
            labels[i] = *pimage++;
            for(size_t j = 0; j < nbytes_per_image; j++)
            {
                float v                        = *(pimage + j) / 255.0f;
                data[i * nbytes_per_image + j] = v;
            }
        }
        return std::make_pair(labels, data);
    }
    else
    {
        throw std::runtime_error("Cannot open file `" + full_path + "`!");
    }
}

int main(int argc, char const* argv[])
{
wsttiger's avatar
wsttiger committed
49
    if(argc < 4)
50
51
52
53
54
55
    {
        throw std::runtime_error("Usage:  cifar10 [gpu | cpu] <onnx file> <cifar10 data file>");
    }
    std::string gpu_cpu  = argv[1];
    std::string file     = argv[2];
    std::string datafile = argv[3];
Paul's avatar
Paul committed
56
    auto prog            = migraphx::parse_onnx(file);
57
58
59
    std::cout << prog << std::endl;
    auto imageset = read_cifar10_images(datafile);

wsttiger's avatar
wsttiger committed
60
    if(gpu_cpu == "gpu")
61
62
    {
        // GPU target
Paul's avatar
Paul committed
63
64
65
        prog.compile(migraphx::gpu::target{});
        migraphx::program::parameter_map m;
        auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
66
67
        for(auto&& x : prog.get_parameter_shapes())
        {
Paul's avatar
Paul committed
68
            m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
69
70
71
72
73
74
75
        }
        auto labels = imageset.first;
        auto input  = imageset.second;
        auto ptr    = input.data();
        for(int i = 0; i < 10; i++)
        {
            std::cout << "label: " << static_cast<uint32_t>(labels[i]) << "  ---->  ";
Paul's avatar
Paul committed
76
77
            m["0"]      = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[3072 * i]});
            auto result = migraphx::gpu::from_gpu(prog.eval(m));
78
79
            std::vector<float> logits;
            result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
80
            std::vector<float> probs = softmax<float>(logits);
81
82
83
84
85
86
87
88
            for(auto x : probs)
                std::cout << x << "    ";
            std::cout << std::endl << std::endl;
        }
    }
    else
    {
        // CPU target
Paul's avatar
Paul committed
89
90
        prog.compile(migraphx::cpu::target{});
        auto s      = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
91
92
93
94
95
96
        auto labels = imageset.first;
        auto input  = imageset.second;
        auto ptr    = input.data();
        for(int i = 0; i < 10; i++)
        {
            std::cout << "label: " << static_cast<uint32_t>(labels[i]) << "  ---->  ";
Paul's avatar
Paul committed
97
            auto input3 = migraphx::argument{s, &ptr[3072 * i]};
98
99
100
            auto result = prog.eval({{"0", input3}});
            std::vector<float> logits;
            result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
101
            std::vector<float> probs = softmax<float>(logits);
102
103
104
105
106
107
            for(auto x : probs)
                std::cout << x << "    ";
            std::cout << std::endl;
        }
    }
}