Unverified Commit 233d4303 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Python_scalar_argument_support (#463)



* python api support scalar argument

* clang format

* add unit test for the scalar type

* clang format

* update to include numpy

* fix package issue

* fix package issue

* fixed review comments

* resolve review comments

* remove unnecessary changes

* refine a unit test for beter coverage

* clang format

* refine unit tests to cover code change

* clang format

* change unit test

* refine tests

* refine unit test

* clang format

* refine unit test

* fix a dockerfile error

* fix bug

* scalar shape support in c++ api

* clang format

* fix cppcheck error

* fix possible errors

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 9879574a
...@@ -94,9 +94,10 @@ void set_default_dim_value(onnx_options& options, size_t value) ...@@ -94,9 +94,10 @@ void set_default_dim_value(onnx_options& options, size_t value)
void set_input_parameter_shape(onnx_options& options, void set_input_parameter_shape(onnx_options& options,
const char* name, const char* name,
const std::vector<std::size_t>& dims) const size_t* dims,
const size_t dim_num)
{ {
options.map_input_dims[std::string(name)] = dims; options.map_input_dims[std::string(name)] = std::vector<std::size_t>(dims, dims + dim_num);
} }
template <class Value> template <class Value>
...@@ -257,6 +258,15 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape, ...@@ -257,6 +258,15 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
}); });
} }
extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type)
{
return migraphx::try_([&] {
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type))));
});
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape) migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{ {
...@@ -601,16 +611,16 @@ extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* ...@@ -601,16 +611,16 @@ extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t*
}); });
} }
extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape( extern "C" migraphx_status
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size) migraphx_onnx_options_set_input_parameter_shape(migraphx_onnx_options_t onnx_options,
const char* name,
const size_t* dims,
const size_t dim_num)
{ {
return migraphx::try_([&] { return migraphx::try_([&] {
if(onnx_options == nullptr) if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr) migraphx::set_input_parameter_shape((onnx_options->object), (name), (dims), (dim_num));
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
migraphx::set_input_parameter_shape(
(onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
}); });
} }
......
...@@ -77,6 +77,9 @@ migraphx_status migraphx_shape_create(migraphx_shape_t* shape, ...@@ -77,6 +77,9 @@ migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
size_t* lengths, size_t* lengths,
size_t lengths_size); size_t lengths_size);
migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type);
migraphx_status migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape); migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape);
...@@ -173,8 +176,11 @@ migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_optio ...@@ -173,8 +176,11 @@ migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_optio
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options); migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape( migraphx_status
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size); migraphx_onnx_options_set_input_parameter_shape(migraphx_onnx_options_t onnx_options,
const char* name,
const size_t* dims,
const size_t dim_num);
migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options,
size_t value); size_t value);
......
...@@ -216,6 +216,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -216,6 +216,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size()); this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size());
} }
shape(migraphx_shape_datatype_t type)
{
this->make_handle(&migraphx_shape_create_scalar, type);
}
std::vector<size_t> lengths() const std::vector<size_t> lengths() const
{ {
const size_t* pout; const size_t* pout;
......
...@@ -54,6 +54,7 @@ def shape(h): ...@@ -54,6 +54,7 @@ def shape(h):
'create', 'create',
api.params(type='migraphx::shape::type_t', api.params(type='migraphx::shape::type_t',
lengths='std::vector<size_t>')) lengths='std::vector<size_t>'))
h.constructor('create_scalar', api.params(type='migraphx::shape::type_t'))
h.method('lengths', h.method('lengths',
fname='lens', fname='lens',
returns='const std::vector<size_t>&', returns='const std::vector<size_t>&',
...@@ -175,7 +176,9 @@ def onnx_options(h): ...@@ -175,7 +176,9 @@ def onnx_options(h):
h.constructor('create') h.constructor('create')
h.method( h.method(
'set_input_parameter_shape', 'set_input_parameter_shape',
api.params(name='const char*', dims='std::vector<size_t>'), api.params(name='const char*',
dims='const size_t *',
dim_num='const size_t'),
invoke='migraphx::set_input_parameter_shape($@)', invoke='migraphx::set_input_parameter_shape($@)',
) )
h.method( h.method(
......
...@@ -2160,7 +2160,7 @@ struct onnx_parser ...@@ -2160,7 +2160,7 @@ struct onnx_parser
case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break; case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break; case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break; case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
case onnx::TensorProto::UINT8: case onnx::TensorProto::UINT8: shape_type = shape::uint8_type; break;
case onnx::TensorProto::STRING: case onnx::TensorProto::STRING:
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
case onnx::TensorProto::UNDEFINED: case onnx::TensorProto::UNDEFINED:
......
...@@ -113,7 +113,16 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -113,7 +113,16 @@ migraphx::shape to_shape(const py::buffer_info& info)
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t { std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
return n > 0 ? i / n : 0; return n > 0 ? i / n : 0;
}); });
// scalar support
if(info.shape.empty())
{
return migraphx::shape{t};
}
else
{
return migraphx::shape{t, info.shape, strides}; return migraphx::shape{t, info.shape, strides};
}
} }
PYBIND11_MODULE(migraphx, m) PYBIND11_MODULE(migraphx, m)
......
...@@ -64,4 +64,21 @@ TEST_CASE(zero_parameter) ...@@ -64,4 +64,21 @@ TEST_CASE(zero_parameter)
CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
} }
TEST_CASE(set_scalar_parameter)
{
auto p1 = migraphx::parse_onnx("add_bcast_test.onnx");
migraphx::shape s1(migraphx_shape_float_type, {3, 4});
auto param_shapes = p1.get_parameter_shapes();
auto s1_orig = param_shapes["1"];
CHECK(bool{s1 == s1_orig});
migraphx::onnx_options option;
option.set_input_parameter_shape("1", {});
auto p2 = migraphx::parse_onnx("add_bcast_test.onnx", option);
migraphx::shape s_scalar(migraphx_shape_float_type);
auto param_shapes_1 = p2.get_parameter_shapes();
auto s_scalar_after = param_shapes_1["1"];
CHECK(bool{s_scalar == s_scalar_after});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -92,14 +92,13 @@ def add_fp16_test(): ...@@ -92,14 +92,13 @@ def add_fp16_test():
@onnx_test @onnx_test
def add_scalar_test(): def add_scalar_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) x = helper.make_tensor_value_info('0', TensorProto.UINT8, [2, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, []) y = helper.make_tensor_value_info('1', TensorProto.UINT8, [])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5]) z = helper.make_tensor_value_info('2', TensorProto.UINT8, [2, 3, 4, 5])
node = onnx.helper.make_node('Add', inputs=['0', '1'], outputs=['2']) node = onnx.helper.make_node('Add', inputs=['0', '1'], outputs=['2'])
return ([node], [x, y], [z], return ([node], [x, y], [z])
[helper.make_tensor('1', TensorProto.FLOAT, [], [1])])
@onnx_test @onnx_test
...@@ -2087,7 +2086,7 @@ def sub_scalar_test(): ...@@ -2087,7 +2086,7 @@ def sub_scalar_test():
values_tensor = helper.make_tensor(name='const', values_tensor = helper.make_tensor(name='const',
data_type=TensorProto.FLOAT, data_type=TensorProto.FLOAT,
dims=values.shape, dims=values.reshape(()).shape,
vals=values.flatten().astype(float)) vals=values.flatten().astype(float))
arg_const = onnx.helper.make_node( arg_const = onnx.helper.make_node(
......
...@@ -78,11 +78,12 @@ TEST_CASE(add_fp16_test) ...@@ -78,11 +78,12 @@ TEST_CASE(add_fp16_test)
TEST_CASE(add_scalar_test) TEST_CASE(add_scalar_test)
{ {
migraphx::program p; migraphx::program p;
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::uint8_type, {2, 3, 4, 5}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::uint8_type});
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l0, m1); auto r = p.add_instruction(migraphx::op::add{}, l0, m1);
auto prog = optimize_onnx("add_scalar_test.onnx"); p.add_return({r});
auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1585,8 +1586,7 @@ TEST_CASE(sub_scalar_test) ...@@ -1585,8 +1586,7 @@ TEST_CASE(sub_scalar_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l0, m1); p.add_instruction(migraphx::op::sub{}, l0, m1);
auto prog = optimize_onnx("sub_scalar_test.onnx"); auto prog = optimize_onnx("sub_scalar_test.onnx");
......
import migraphx import migraphx, array, sys
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print(p) def test_conv_relu():
s1 = p.get_output_shapes()[-1] p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print("Compiling ...") print(p)
p.compile(migraphx.get_target("cpu")) s1 = p.get_output_shapes()[-1]
print(p) print("Compiling ...")
s2 = p.get_output_shapes()[-1] p.compile(migraphx.get_target("cpu"))
assert s1 == s2 print(p)
params = {} s2 = p.get_output_shapes()[-1]
assert s1 == s2
for key, value in p.get_parameter_shapes().items(): params = {}
for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value)) print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.generate_argument(value) params[key] = migraphx.generate_argument(value)
r = p.run(params)[-1] r = p.run(params)[-1]
print(r) print(r)
def create_buffer(t, data, shape):
a = array.array(t, data)
if sys.version_info >= (3, 0):
m = memoryview(a.tobytes())
return m.cast(t, shape)
else:
m = memoryview(a.tostring())
return m
def test_add_scalar():
p = migraphx.parse_onnx("add_scalar_test.onnx")
print(p)
s1 = p.get_output_shapes()[-1]
print("Compiling ...")
p.compile(migraphx.get_target("cpu"))
print(p)
s2 = p.get_output_shapes()[-1]
assert s1 == s2
d0 = list(range(120))
arg0 = create_buffer("B", d0, [2, 3, 4, 5])
d1 = [1]
arg1 = create_buffer("B", d1, ())
params = {}
params["0"] = migraphx.argument(arg0)
params["1"] = migraphx.argument(arg1)
r = p.run(params)[-1]
print(r)
test_conv_relu()
if sys.version_info >= (3, 0):
test_add_scalar()
...@@ -94,9 +94,10 @@ void set_default_dim_value(onnx_options& options, size_t value) ...@@ -94,9 +94,10 @@ void set_default_dim_value(onnx_options& options, size_t value)
void set_input_parameter_shape(onnx_options& options, void set_input_parameter_shape(onnx_options& options,
const char* name, const char* name,
const std::vector<std::size_t>& dims) const size_t* dims,
const size_t dim_num)
{ {
options.map_input_dims[std::string(name)] = dims; options.map_input_dims[std::string(name)] = std::vector<std::size_t>(dims, dims + dim_num);
} }
template <class Value> template <class Value>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment