Unverified Commit b83cd632 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improve API for program parameters (#635)



* Take numpy array directly in python API

* Formatting

* Intialize program parameters from initializer list

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 16a03b39
...@@ -379,6 +379,13 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) ...@@ -379,6 +379,13 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
program_parameters() { this->make_handle(&migraphx_program_parameters_create); } program_parameters() { this->make_handle(&migraphx_program_parameters_create); }
program_parameters(std::initializer_list<std::pair<std::string, argument>> l)
{
this->make_handle(&migraphx_program_parameters_create);
for(auto&& p : l)
this->add(p.first.c_str(), p.second);
}
void add(const char* pname, const argument& pargument) const void add(const char* pname, const argument& pargument) const
{ {
call(&migraphx_program_parameters_add, call(&migraphx_program_parameters_add,
......
...@@ -180,7 +180,18 @@ PYBIND11_MODULE(migraphx, m) ...@@ -180,7 +180,18 @@ PYBIND11_MODULE(migraphx, m)
}, },
py::arg("t"), py::arg("t"),
py::arg("offload_copy") = true) py::arg("offload_copy") = true)
.def("run", &migraphx::program::eval) .def("run",
[](migraphx::program& p, py::dict params) {
migraphx::program::parameter_map pm;
for(auto x : params)
{
std::string key = x.first.cast<std::string>();
py::buffer b = x.second.cast<py::buffer>();
py::buffer_info info = b.request();
pm[key] = migraphx::argument(to_shape(info), info.ptr);
}
return p.eval(pm);
})
.def("sort", &migraphx::program::sort) .def("sort", &migraphx::program::sort)
.def("__eq__", std::equal_to<migraphx::program>{}) .def("__eq__", std::equal_to<migraphx::program>{})
.def("__ne__", std::not_equal_to<migraphx::program>{}) .def("__ne__", std::not_equal_to<migraphx::program>{})
......
...@@ -22,6 +22,25 @@ TEST_CASE(load_and_run) ...@@ -22,6 +22,25 @@ TEST_CASE(load_and_run)
CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
} }
TEST_CASE(load_and_run_init_list)
{
auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
auto shapes_before = p.get_output_shapes();
p.compile(migraphx::target("cpu"));
auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 1);
CHECK(shapes_before.size() == shapes_after.size());
CHECK(bool{shapes_before.front() == shapes_after.front()});
auto param_shapes = p.get_parameter_shapes();
EXPECT(param_shapes.size() == 3);
auto names = param_shapes.names();
auto outputs = p.eval({{names[0], migraphx::argument::generate(param_shapes[names[0]])},
{names[1], migraphx::argument::generate(param_shapes[names[1]])},
{names[2], migraphx::argument::generate(param_shapes[names[2]])}});
CHECK(shapes_before.size() == outputs.size());
CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
}
TEST_CASE(quantize_fp16) TEST_CASE(quantize_fp16)
{ {
auto p1 = migraphx::parse_onnx("gemm_ex_test.onnx"); auto p1 = migraphx::parse_onnx("gemm_ex_test.onnx");
......
...@@ -31,11 +31,9 @@ def test_sub_uint64(): ...@@ -31,11 +31,9 @@ def test_sub_uint64():
params = {} params = {}
shapes = p.get_parameter_shapes() shapes = p.get_parameter_shapes()
arg0 = np.arange(120).reshape(shapes["0"].lens()).astype(np.uint64) params["0"] = np.arange(120).reshape(shapes["0"].lens()).astype(np.uint64)
arg1 = np.arange(20).reshape(shapes["1"].lens()).astype(np.uint64) params["1"] = np.arange(20).reshape(shapes["1"].lens()).astype(np.uint64)
params["0"] = migraphx.argument(arg0)
params["1"] = migraphx.argument(arg1)
r = p.run(params) r = p.run(params)
print(r) print(r)
...@@ -49,9 +47,8 @@ def test_neg_int64(): ...@@ -49,9 +47,8 @@ def test_neg_int64():
params = {} params = {}
shapes = p.get_parameter_shapes() shapes = p.get_parameter_shapes()
arg0 = np.arange(6).reshape(shapes["0"].lens()).astype(np.int64) params["0"] = np.arange(6).reshape(shapes["0"].lens()).astype(np.int64)
params["0"] = migraphx.argument(arg0)
r = p.run(params) r = p.run(params)
print(r) print(r)
...@@ -68,8 +65,8 @@ def test_fp16_imagescaler(): ...@@ -68,8 +65,8 @@ def test_fp16_imagescaler():
params = {} params = {}
shapes = p.get_parameter_shapes() shapes = p.get_parameter_shapes()
arg0 = np.random.randn(768).reshape(shapes["0"].lens()).astype(np.float16) params["0"] = np.random.randn(768).reshape(shapes["0"].lens()).astype(
params["0"] = migraphx.argument(arg0) np.float16)
r = p.run(params)[-1] r = p.run(params)[-1]
print(r) print(r)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment