Unverified Commit 557618ba authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add more constructor to C api for shape (#520)



* Add more constructor to C api for shape

* Formatting

* Fix template file

* Add support for passing null pointers for empy vectors

* Move variable
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 6207ea00
...@@ -94,10 +94,9 @@ void set_default_dim_value(onnx_options& options, size_t value) ...@@ -94,10 +94,9 @@ void set_default_dim_value(onnx_options& options, size_t value)
void set_input_parameter_shape(onnx_options& options, void set_input_parameter_shape(onnx_options& options,
const char* name, const char* name,
const size_t* dims, std::vector<std::size_t> dims)
const size_t dim_num)
{ {
options.map_input_dims[std::string(name)] = std::vector<std::size_t>(dims, dims + dim_num); options.map_input_dims[std::string(name)] = std::move(dims);
} }
template <class Value> template <class Value>
...@@ -250,7 +249,7 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape, ...@@ -250,7 +249,7 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
size_t lengths_size) size_t lengths_size)
{ {
return migraphx::try_([&] { return migraphx::try_([&] {
if(lengths == nullptr) if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
*shape = object_cast<migraphx_shape_t>( *shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)), allocate<migraphx::shape>((migraphx::to_shape_type(type)),
...@@ -258,6 +257,25 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape, ...@@ -258,6 +257,25 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
}); });
} }
extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size,
size_t* strides,
size_t strides_size)
{
return migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
if(strides == nullptr and strides_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter strides: Null pointer");
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)),
(std::vector<size_t>(lengths, lengths + lengths_size)),
(std::vector<size_t>(strides, strides + strides_size))));
});
}
extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type) migraphx_shape_datatype_t type)
{ {
...@@ -611,16 +629,16 @@ extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* ...@@ -611,16 +629,16 @@ extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t*
}); });
} }
extern "C" migraphx_status extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_set_input_parameter_shape(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size)
const char* name,
const size_t* dims,
const size_t dim_num)
{ {
return migraphx::try_([&] { return migraphx::try_([&] {
if(onnx_options == nullptr) if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_input_parameter_shape((onnx_options->object), (name), (dims), (dim_num)); if(dims == nullptr and dims_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
migraphx::set_input_parameter_shape(
(onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
}); });
} }
......
...@@ -77,6 +77,13 @@ migraphx_status migraphx_shape_create(migraphx_shape_t* shape, ...@@ -77,6 +77,13 @@ migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
size_t* lengths, size_t* lengths,
size_t lengths_size); size_t lengths_size);
migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size,
size_t* strides,
size_t strides_size);
migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type); migraphx_shape_datatype_t type);
...@@ -176,11 +183,8 @@ migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_optio ...@@ -176,11 +183,8 @@ migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_optio
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options); migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_set_input_parameter_shape(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size);
const char* name,
const size_t* dims,
const size_t dim_num);
migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options,
size_t value); size_t value);
......
...@@ -211,14 +211,26 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -211,14 +211,26 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
shape(migraphx_shape* p, borrow) { this->set_handle(p, borrow{}); } shape(migraphx_shape* p, borrow) { this->set_handle(p, borrow{}); }
shape(migraphx_shape_datatype_t type)
{
this->make_handle(&migraphx_shape_create_scalar, type);
}
shape(migraphx_shape_datatype_t type, std::vector<size_t> plengths) shape(migraphx_shape_datatype_t type, std::vector<size_t> plengths)
{ {
this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size()); this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size());
} }
shape(migraphx_shape_datatype_t type) shape(migraphx_shape_datatype_t type,
std::vector<size_t> plengths,
std::vector<size_t> pstrides)
{ {
this->make_handle(&migraphx_shape_create_scalar, type); this->make_handle(&migraphx_shape_create_with_strides,
type,
plengths.data(),
plengths.size(),
pstrides.data(),
pstrides.size());
} }
std::vector<size_t> lengths() const std::vector<size_t> lengths() const
......
...@@ -54,6 +54,11 @@ def shape(h): ...@@ -54,6 +54,11 @@ def shape(h):
'create', 'create',
api.params(type='migraphx::shape::type_t', api.params(type='migraphx::shape::type_t',
lengths='std::vector<size_t>')) lengths='std::vector<size_t>'))
h.constructor(
'create_with_strides',
api.params(type='migraphx::shape::type_t',
lengths='std::vector<size_t>',
strides='std::vector<size_t>'))
h.constructor('create_scalar', api.params(type='migraphx::shape::type_t')) h.constructor('create_scalar', api.params(type='migraphx::shape::type_t'))
h.method('lengths', h.method('lengths',
fname='lens', fname='lens',
...@@ -176,9 +181,7 @@ def onnx_options(h): ...@@ -176,9 +181,7 @@ def onnx_options(h):
h.constructor('create') h.constructor('create')
h.method( h.method(
'set_input_parameter_shape', 'set_input_parameter_shape',
api.params(name='const char*', api.params(name='const char*', dims='std::vector<size_t>'),
dims='const size_t *',
dim_num='const size_t'),
invoke='migraphx::set_input_parameter_shape($@)', invoke='migraphx::set_input_parameter_shape($@)',
) )
h.method( h.method(
......
...@@ -81,4 +81,22 @@ TEST_CASE(set_scalar_parameter) ...@@ -81,4 +81,22 @@ TEST_CASE(set_scalar_parameter)
CHECK(bool{s_scalar == s_scalar_after}); CHECK(bool{s_scalar == s_scalar_after});
} }
TEST_CASE(scalar_shape)
{
auto s = migraphx::shape(migraphx_shape_float_type);
EXPECT(s.lengths().size() == 1);
EXPECT(s.strides().size() == 1);
EXPECT(s.lengths().front() == 1);
EXPECT(s.strides().front() == 0);
}
TEST_CASE(strided_shape)
{
std::vector<std::size_t> lens = {2, 2};
std::vector<std::size_t> strides = {1, 2};
auto s = migraphx::shape(migraphx_shape_float_type, lens, strides);
EXPECT(s.lengths() == lens);
EXPECT(s.strides() == strides);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -721,7 +721,7 @@ def vector_c_wrap(p): ...@@ -721,7 +721,7 @@ def vector_c_wrap(p):
else: else:
p.add_param(t) p.add_param(t)
p.add_size_param() p.add_size_param()
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr and ${size} != 0', 'Null pointer')
p.read = '${type}(${name}, ${name}+${size})' p.read = '${type}(${name}, ${name}+${size})'
......
...@@ -94,10 +94,9 @@ void set_default_dim_value(onnx_options& options, size_t value) ...@@ -94,10 +94,9 @@ void set_default_dim_value(onnx_options& options, size_t value)
void set_input_parameter_shape(onnx_options& options, void set_input_parameter_shape(onnx_options& options,
const char* name, const char* name,
const size_t* dims, std::vector<std::size_t> dims)
const size_t dim_num)
{ {
options.map_input_dims[std::string(name)] = std::vector<std::size_t>(dims, dims + dim_num); options.map_input_dims[std::string(name)] = std::move(dims);
} }
template <class Value> template <class Value>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment