test_save_load.cpp 1.23 KB
Newer Older
1
2
3
4
5
6
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"

TEST_CASE(load_save_default)
{
Charlie Lin's avatar
Charlie Lin committed
7
    std::string filename = "migraphx_api_load_save.mxr";
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
    auto p1              = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
    auto s1              = p1.get_output_shapes();

    migraphx::save(p1, filename.c_str());
    auto p2 = migraphx::load(filename.c_str());
    auto s2 = p2.get_output_shapes();
    EXPECT(s1.size() == s2.size());
    EXPECT(bool{s1.front() == s2.front()});
    EXPECT(bool{p1.sort() == p2.sort()});
    std::remove(filename.c_str());
}

TEST_CASE(load_save_json)
{
    std::string filename = "migraphx_api_load_save.json";
    auto p1              = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
    auto s1              = p1.get_output_shapes();
25
26
    migraphx::file_options options;
    options.set_file_format("json");
27
28
29
30
31
32
33
34
35
36
37

    migraphx::save(p1, filename.c_str(), options);
    auto p2 = migraphx::load(filename.c_str(), options);
    auto s2 = p2.get_output_shapes();
    EXPECT(s1.size() == s2.size());
    EXPECT(bool{s1.front() == s2.front()});
    EXPECT(bool{p1.sort() == p2.sort()});
    std::remove(filename.c_str());
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }