Unverified Commit 5e5ed37a authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Expose `add_literal` in C and Python API (#1173)

Expose add_literal method in C/C++ api
parent ddbbe54b
...@@ -146,6 +146,13 @@ module ...@@ -146,6 +146,13 @@ module
:param list[module] mod_args: optional list of module arguments to the operator. :param list[module] mod_args: optional list of module arguments to the operator.
:rtype instruction :rtype instruction
.. py:method:: add_literal(data)
Adds constant or literal data of provided shape into the module from python buffer which includes numpy array.
:param py::buffer data: Python buffer or numpy array
:rtype instruction
.. py:method:: add_parameter(name, shape) .. py:method:: add_parameter(name, shape)
Adds a parameter to the module with provided name and shape. Adds a parameter to the module with provided name and shape.
......
...@@ -1072,6 +1072,22 @@ migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out, ...@@ -1072,6 +1072,22 @@ migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t shape,
const char* buffer)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_literal((shape->object), (buffer)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out, extern "C" migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
const char* name, const char* name,
......
...@@ -258,6 +258,11 @@ migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instructi ...@@ -258,6 +258,11 @@ migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instructi
migraphx_instructions_t args, migraphx_instructions_t args,
migraphx_modules_t module_refs); migraphx_modules_t module_refs);
migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t shape,
const char* buffer);
migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out, migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
const char* name, const char* name,
......
...@@ -762,6 +762,15 @@ struct module ...@@ -762,6 +762,15 @@ struct module
return instruction(op_ins, own{}); return instruction(op_ins, own{});
} }
template <typename T>
instruction add_literal(const migraphx::shape& s, T* buffer)
{
migraphx_instruction_t literal_ins;
const auto* buffer_ptr = reinterpret_cast<const char*>(buffer);
call(&migraphx_module_add_literal, &literal_ins, mm.get(), s.get_handle_ptr(), buffer_ptr);
return instruction(literal_ins, own{});
}
instruction add_parameter(const std::string& name, shape s) instruction add_parameter(const std::string& name, shape s)
{ {
migraphx_instruction_t param_ins; migraphx_instruction_t param_ins;
......
...@@ -212,6 +212,9 @@ def module(h): ...@@ -212,6 +212,9 @@ def module(h):
module_refs='std::vector<migraphx::module*>'), module_refs='std::vector<migraphx::module*>'),
fname='add_instruction', fname='add_instruction',
returns='migraphx::instruction_ref') returns='migraphx::instruction_ref')
h.method('add_literal',
api.params(shape='const migraphx::shape&', buffer='const char*'),
returns='migraphx::instruction_ref')
h.method('add_parameter', h.method('add_parameter',
api.params(name='const char*', shape='const migraphx::shape&'), api.params(name='const char*', shape='const migraphx::shape&'),
returns='migraphx::instruction_ref') returns='migraphx::instruction_ref')
......
...@@ -273,6 +273,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -273,6 +273,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("op"), py::arg("op"),
py::arg("args"), py::arg("args"),
py::arg("mod_args") = std::vector<migraphx::module*>{}) py::arg("mod_args") = std::vector<migraphx::module*>{})
.def(
"add_literal",
[](migraphx::module& mm, py::buffer data) {
py::buffer_info info = data.request();
auto literal_shape = to_shape(info);
return mm.add_literal(literal_shape, reinterpret_cast<char*>(info.ptr));
},
py::arg("data"))
.def( .def(
"add_parameter", "add_parameter",
[](migraphx::module& mm, const std::string& name, const migraphx::shape shape) { [](migraphx::module& mm, const std::string& name, const migraphx::shape shape) {
......
...@@ -3,23 +3,21 @@ ...@@ -3,23 +3,21 @@
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(add_op) TEST_CASE(add_literals)
{ {
migraphx::program p; migraphx::program p;
migraphx::module m = p.get_main_module(); migraphx::module m = p.get_main_module();
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}}; migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
auto x = m.add_parameter("x", param_shape); std::vector<float> x_values(9, 1);
auto y = m.add_parameter("y", param_shape); auto x = m.add_literal(param_shape, x_values.data());
std::vector<float> y_values(9, -1);
auto y = m.add_literal(param_shape, y_values.data());
auto add_op = migraphx::operation("add"); auto add_op = migraphx::operation("add");
auto r = m.add_instruction(add_op, {x, y}); auto r = m.add_instruction(add_op, {x, y});
m.add_return({r}); m.add_return({r});
// run on ref target // run on ref target
p.compile(migraphx::target("ref")); p.compile(migraphx::target("ref"));
migraphx::program_parameters pp; migraphx::program_parameters pp;
std::vector<float> x_data(9, 1);
std::vector<float> y_data(9, -1);
pp.add("x", migraphx::argument(param_shape, x_data.data()));
pp.add("y", migraphx::argument(param_shape, y_data.data()));
auto outputs = p.eval(pp); auto outputs = p.eval(pp);
auto output = outputs[0]; auto output = outputs[0];
std::vector<float> expected(9, 0); std::vector<float> expected(9, 0);
......
import migraphx import migraphx, array, sys
def create_buffer(t, data, shape):
a = array.array(t, data)
m = memoryview(a.tobytes())
return m.cast(t, shape)
def test_add_op(): def test_add_op():
p = migraphx.program() p = migraphx.program()
mm = p.get_main_module() mm = p.get_main_module()
param_shape = migraphx.shape(lens=[3, 3], type="float") x = mm.add_literal(create_buffer('f', [1.0] * 9, (3, 3)))
x = mm.add_parameter("x", param_shape) y = mm.add_literal(create_buffer('f', [2.0] * 9, (3, 3)))
y = mm.add_parameter("y", param_shape)
add_op = mm.add_instruction(migraphx.op("add"), [x, y]) add_op = mm.add_instruction(migraphx.op("add"), [x, y])
mm.add_return([add_op]) mm.add_return([add_op])
p.compile(migraphx.get_target("ref")) p.compile(migraphx.get_target("ref"))
params = {} params = {}
params["x"] = migraphx.generate_argument(param_shape)
params["y"] = migraphx.generate_argument(param_shape)
output = p.run(params)[-1].tolist() output = p.run(params)[-1].tolist()
assert output == [ assert output == list([3.0] * 9)
a + b for a, b in zip(params["x"].tolist(), params["y"].tolist())
]
def test_if_then_else(): def test_if_then_else():
...@@ -60,5 +61,6 @@ def test_if_then_else(): ...@@ -60,5 +61,6 @@ def test_if_then_else():
if __name__ == "__main__": if __name__ == "__main__":
test_add_op() if sys.version_info >= (3, 0):
test_add_op()
test_if_then_else() test_if_then_else()
import migraphx, sys
try:
import numpy as np
except:
sys.exit()
def test_add_op():
p = migraphx.program()
mm = p.get_main_module()
x = mm.add_literal(np.ones((3, 3), dtype='float32'))
y = mm.add_literal(2 * np.ones((3, 3), dtype='float32'))
add_op = mm.add_instruction(migraphx.op("add"), [x, y])
mm.add_return([add_op])
p.compile(migraphx.get_target("ref"))
params = {}
output = p.run(params)[-1].tolist()
assert output == list(3 * np.ones((9), dtype='float32'))
if __name__ == "__main__":
test_add_op()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment