test_cpu.cpp 2.51 KB
Newer Older
Paul Fultz II's avatar
Paul Fultz II committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"

TEST_CASE(load_and_run)
{
    auto p             = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
    auto shapes_before = p.get_output_shapes();
    p.compile(migraphx::target("cpu"));
    auto shapes_after = p.get_output_shapes();
    CHECK(shapes_before.size() == 1);
    CHECK(shapes_before.size() == shapes_after.size());
    CHECK(bool{shapes_before.front() == shapes_after.front()});
    migraphx::program_parameters pp;
    auto param_shapes = p.get_parameter_shapes();
    for(auto&& name : param_shapes.names())
    {
        pp.add(name, migraphx::argument::generate(param_shapes[name]));
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    }
    auto outputs = p.eval(pp);
    CHECK(shapes_before.size() == outputs.size());
    CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
}

TEST_CASE(load_and_run_user_input_shape)
{
    migraphx::onnx_options options;
    options.set_input_parameter_shape("0", {2, 3, 64, 64});
    auto p             = migraphx::parse_onnx("conv_relu_maxpool_test.onnx", options);
    auto shapes_before = p.get_output_shapes();
    p.compile(migraphx::target("cpu"));
    auto shapes_after = p.get_output_shapes();
    CHECK(shapes_before.size() == 1);
    CHECK(shapes_before.size() == shapes_after.size());
    CHECK(bool{shapes_before.front() == shapes_after.front()});
    migraphx::program_parameters pp;
    auto param_shapes = p.get_parameter_shapes();
    for(auto&& name : param_shapes.names())
    {
        pp.add(name, migraphx::argument::generate(param_shapes[name]));
Paul Fultz II's avatar
Paul Fultz II committed
41
42
43
44
45
46
    }
    auto outputs = p.eval(pp);
    CHECK(shapes_before.size() == outputs.size());
    CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
}

Shucai Xiao's avatar
Shucai Xiao committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
TEST_CASE(zero_parameter)
{
    auto p             = migraphx::parse_onnx("constant_fill_test.onnx");
    auto shapes_before = p.get_output_shapes();
    p.compile(migraphx::target("cpu"));
    auto shapes_after = p.get_output_shapes();
    CHECK(shapes_before.size() == 1);
    CHECK(shapes_before.size() == shapes_after.size());
    CHECK(bool{shapes_before.front() == shapes_after.front()});
    migraphx::program_parameters pp;
    auto param_shapes = p.get_parameter_shapes();
    for(auto&& name : param_shapes.names())
    {
        pp.add(name, migraphx::argument::generate(param_shapes[name]));
    }
    auto outputs = p.eval(pp);
    CHECK(shapes_before.size() == outputs.size());
    CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
}

Paul Fultz II's avatar
Paul Fultz II committed
67
int main(int argc, const char* argv[]) { test::run(argc, argv); }