Unverified Commit 233d4303 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Python_scalar_argument_support (#463)



* python api support scalar argument

* clang format

* add unit test for the scalar type

* clang format

* update to include numpy

* fix package issue

* fix package issue

* fixed review comments

* resolve review comments

* remove unnecessary changes

* refine a unit test for beter coverage

* clang format

* refine unit tests to cover code change

* clang format

* change unit test

* refine tests

* refine unit test

* clang format

* refine unit test

* fix a dockerfile error

* fix bug

* scalar shape support in c++ api

* clang format

* fix cppcheck error

* fix possible errors

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 9879574a
......@@ -94,9 +94,10 @@ void set_default_dim_value(onnx_options& options, size_t value)
void set_input_parameter_shape(onnx_options& options,
const char* name,
const std::vector<std::size_t>& dims)
const size_t* dims,
const size_t dim_num)
{
options.map_input_dims[std::string(name)] = dims;
options.map_input_dims[std::string(name)] = std::vector<std::size_t>(dims, dims + dim_num);
}
template <class Value>
......@@ -257,6 +258,15 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
});
}
extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type)
{
return migraphx::try_([&] {
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type))));
});
}
extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{
......@@ -601,16 +611,16 @@ extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t*
});
}
extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size)
extern "C" migraphx_status
migraphx_onnx_options_set_input_parameter_shape(migraphx_onnx_options_t onnx_options,
const char* name,
const size_t* dims,
const size_t dim_num)
{
return migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
migraphx::set_input_parameter_shape(
(onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
migraphx::set_input_parameter_shape((onnx_options->object), (name), (dims), (dim_num));
});
}
......
......@@ -77,6 +77,9 @@ migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
size_t* lengths,
size_t lengths_size);
migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type);
migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape);
......@@ -173,8 +176,11 @@ migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_optio
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size);
migraphx_status
migraphx_onnx_options_set_input_parameter_shape(migraphx_onnx_options_t onnx_options,
const char* name,
const size_t* dims,
const size_t dim_num);
migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options,
size_t value);
......
......@@ -216,6 +216,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size());
}
shape(migraphx_shape_datatype_t type)
{
this->make_handle(&migraphx_shape_create_scalar, type);
}
std::vector<size_t> lengths() const
{
const size_t* pout;
......
......@@ -54,6 +54,7 @@ def shape(h):
'create',
api.params(type='migraphx::shape::type_t',
lengths='std::vector<size_t>'))
h.constructor('create_scalar', api.params(type='migraphx::shape::type_t'))
h.method('lengths',
fname='lens',
returns='const std::vector<size_t>&',
......@@ -175,7 +176,9 @@ def onnx_options(h):
h.constructor('create')
h.method(
'set_input_parameter_shape',
api.params(name='const char*', dims='std::vector<size_t>'),
api.params(name='const char*',
dims='const size_t *',
dim_num='const size_t'),
invoke='migraphx::set_input_parameter_shape($@)',
)
h.method(
......
......@@ -2160,7 +2160,7 @@ struct onnx_parser
case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
case onnx::TensorProto::UINT8:
case onnx::TensorProto::UINT8: shape_type = shape::uint8_type; break;
case onnx::TensorProto::STRING:
case onnx::TensorProto::BOOL:
case onnx::TensorProto::UNDEFINED:
......
......@@ -113,7 +113,16 @@ migraphx::shape to_shape(const py::buffer_info& info)
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
return n > 0 ? i / n : 0;
});
return migraphx::shape{t, info.shape, strides};
// scalar support
if(info.shape.empty())
{
return migraphx::shape{t};
}
else
{
return migraphx::shape{t, info.shape, strides};
}
}
PYBIND11_MODULE(migraphx, m)
......
......@@ -64,4 +64,21 @@ TEST_CASE(zero_parameter)
CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
}
TEST_CASE(set_scalar_parameter)
{
auto p1 = migraphx::parse_onnx("add_bcast_test.onnx");
migraphx::shape s1(migraphx_shape_float_type, {3, 4});
auto param_shapes = p1.get_parameter_shapes();
auto s1_orig = param_shapes["1"];
CHECK(bool{s1 == s1_orig});
migraphx::onnx_options option;
option.set_input_parameter_shape("1", {});
auto p2 = migraphx::parse_onnx("add_bcast_test.onnx", option);
migraphx::shape s_scalar(migraphx_shape_float_type);
auto param_shapes_1 = p2.get_parameter_shapes();
auto s_scalar_after = param_shapes_1["1"];
CHECK(bool{s_scalar == s_scalar_after});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -92,14 +92,13 @@ def add_fp16_test():
@onnx_test
def add_scalar_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5])
x = helper.make_tensor_value_info('0', TensorProto.UINT8, [2, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.UINT8, [])
z = helper.make_tensor_value_info('2', TensorProto.UINT8, [2, 3, 4, 5])
node = onnx.helper.make_node('Add', inputs=['0', '1'], outputs=['2'])
return ([node], [x, y], [z],
[helper.make_tensor('1', TensorProto.FLOAT, [], [1])])
return ([node], [x, y], [z])
@onnx_test
......@@ -2087,7 +2086,7 @@ def sub_scalar_test():
values_tensor = helper.make_tensor(name='const',
data_type=TensorProto.FLOAT,
dims=values.shape,
dims=values.reshape(()).shape,
vals=values.flatten().astype(float))
arg_const = onnx.helper.make_node(
......
......@@ -78,11 +78,12 @@ TEST_CASE(add_fp16_test)
TEST_CASE(add_scalar_test)
{
migraphx::program p;
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::uint8_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::uint8_type});
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l0, m1);
auto prog = optimize_onnx("add_scalar_test.onnx");
auto r = p.add_instruction(migraphx::op::add{}, l0, m1);
p.add_return({r});
auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
EXPECT(p == prog);
}
......@@ -1585,8 +1586,7 @@ TEST_CASE(sub_scalar_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l0, m1);
auto prog = optimize_onnx("sub_scalar_test.onnx");
......
import migraphx
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print(p)
s1 = p.get_output_shapes()[-1]
print("Compiling ...")
p.compile(migraphx.get_target("cpu"))
print(p)
s2 = p.get_output_shapes()[-1]
assert s1 == s2
params = {}
for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.generate_argument(value)
r = p.run(params)[-1]
print(r)
import migraphx, array, sys
def test_conv_relu():
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print(p)
s1 = p.get_output_shapes()[-1]
print("Compiling ...")
p.compile(migraphx.get_target("cpu"))
print(p)
s2 = p.get_output_shapes()[-1]
assert s1 == s2
params = {}
for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.generate_argument(value)
r = p.run(params)[-1]
print(r)
def create_buffer(t, data, shape):
a = array.array(t, data)
if sys.version_info >= (3, 0):
m = memoryview(a.tobytes())
return m.cast(t, shape)
else:
m = memoryview(a.tostring())
return m
def test_add_scalar():
p = migraphx.parse_onnx("add_scalar_test.onnx")
print(p)
s1 = p.get_output_shapes()[-1]
print("Compiling ...")
p.compile(migraphx.get_target("cpu"))
print(p)
s2 = p.get_output_shapes()[-1]
assert s1 == s2
d0 = list(range(120))
arg0 = create_buffer("B", d0, [2, 3, 4, 5])
d1 = [1]
arg1 = create_buffer("B", d1, ())
params = {}
params["0"] = migraphx.argument(arg0)
params["1"] = migraphx.argument(arg1)
r = p.run(params)[-1]
print(r)
test_conv_relu()
if sys.version_info >= (3, 0):
test_add_scalar()
......@@ -94,9 +94,10 @@ void set_default_dim_value(onnx_options& options, size_t value)
void set_input_parameter_shape(onnx_options& options,
const char* name,
const std::vector<std::size_t>& dims)
const size_t* dims,
const size_t dim_num)
{
options.map_input_dims[std::string(name)] = dims;
options.map_input_dims[std::string(name)] = std::vector<std::size_t>(dims, dims + dim_num);
}
template <class Value>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment