Unverified Commit 77164f3c authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Allow constructing an operation with a format string (#976)

Designed to allow a user to format the values needed for the json_string: migraphx::operation("reduce_mean", "{axes : [%i, %i, %i, %i]}", axes[0], axes[1], axes[2], axes[3]) instead of needing to use string concat or stringstream
parent a05113aa
...@@ -190,6 +190,7 @@ rocm_enable_cppcheck( ...@@ -190,6 +190,7 @@ rocm_enable_cppcheck(
shadowVariable shadowVariable
unsafeClassDivZero unsafeClassDivZero
definePrefix:*test/include/test.hpp definePrefix:*test/include/test.hpp
ctuOneDefinitionRuleViolation:*test/*
useSmartPointer:*src/api/api.cpp useSmartPointer:*src/api/api.cpp
useSmartPointer:*make_shared_array.hpp useSmartPointer:*make_shared_array.hpp
constParameter:*src/targets/gpu/*.cpp constParameter:*src/targets/gpu/*.cpp
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
#include <cstdarg>
namespace migraphx { namespace migraphx {
...@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o ...@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names); migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
} }
operation create_op(const char* name, const char* attributes) #ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
operation create_op(const char* name, const char* attributes, va_list vlist)
{ {
std::string sattributes = attributes == nullptr ? "" : attributes;
std::vector<char> buffer(sattributes.size() * 2);
std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist);
value v = value::object{}; value v = value::object{};
if(attributes != nullptr) if(attributes != nullptr)
{ {
v = from_json_string(convert_to_json(std::string(attributes))); v = from_json_string(convert_to_json(std::string(buffer.data())));
} }
auto op = make_op(name, v); auto op = make_op(name, v);
return op; return op;
} }
#ifdef __clang__
#pragma clang diagnostic pop
#endif
template <class T> template <class T>
bool equal(const T& x, const T& y) bool equal(const T& x, const T& y)
{ {
...@@ -368,7 +381,8 @@ struct migraphx_quantize_int8_options ...@@ -368,7 +381,8 @@ struct migraphx_quantize_int8_options
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{ {
return migraphx::try_([&] { destroy((shape)); }); auto api_error_result = migraphx::try_([&] { destroy((shape)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape, extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
...@@ -376,13 +390,14 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape, ...@@ -376,13 +390,14 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
size_t* lengths, size_t* lengths,
size_t lengths_size) size_t lengths_size)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0) if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
*shape = object_cast<migraphx_shape_t>( *shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)), allocate<migraphx::shape>((migraphx::to_shape_type(type)),
(std::vector<size_t>(lengths, lengths + lengths_size)))); (std::vector<size_t>(lengths, lengths + lengths_size))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape, extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
...@@ -392,7 +407,7 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* ...@@ -392,7 +407,7 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
size_t* strides, size_t* strides,
size_t strides_size) size_t strides_size)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0) if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
if(strides == nullptr and strides_size != 0) if(strides == nullptr and strides_size != 0)
...@@ -402,21 +417,23 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* ...@@ -402,21 +417,23 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
(std::vector<size_t>(lengths, lengths + lengths_size)), (std::vector<size_t>(lengths, lengths + lengths_size)),
(std::vector<size_t>(strides, strides + strides_size)))); (std::vector<size_t>(strides, strides + strides_size))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type) migraphx_shape_datatype_t type)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*shape = object_cast<migraphx_shape_t>( *shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)))); allocate<migraphx::shape>((migraphx::to_shape_type(type))));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape) migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr) if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr) if(shape == nullptr)
...@@ -425,12 +442,13 @@ migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shap ...@@ -425,12 +442,13 @@ migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data(); *out = api_result.data();
*out_size = api_result.size(); *out_size = api_result.size();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape) migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr) if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr) if(shape == nullptr)
...@@ -439,127 +457,141 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap ...@@ -439,127 +457,141 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data(); *out = api_result.data();
*out_size = api_result.size(); *out_size = api_result.size();
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape) const_migraphx_shape_t shape)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(out == nullptr) if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr) if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = migraphx::to_shape_type((shape->object).type()); *out = migraphx::to_shape_type((shape->object).type());
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(shape == nullptr) if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).bytes(); *out = (shape->object).bytes();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x) migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(shape == nullptr) if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
if(x == nullptr) if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((shape->object), (x->object)); *out = migraphx::equal((shape->object), (x->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument) extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{ {
return migraphx::try_([&] { destroy((argument)); }); auto api_error_result = migraphx::try_([&] { destroy((argument)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer) migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(shape == nullptr) if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*argument = object_cast<migraphx_argument_t>( *argument = object_cast<migraphx_argument_t>(
allocate<migraphx::argument>((shape->object), (buffer))); allocate<migraphx::argument>((shape->object), (buffer)));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out, extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument) const_migraphx_argument_t argument)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(argument == nullptr) if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((argument->object).get_shape())); *out = object_cast<const_migraphx_shape_t>(&((argument->object).get_shape()));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument) extern "C" migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(argument == nullptr) if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = (argument->object).data(); *out = (argument->object).data();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x) migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(argument == nullptr) if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
if(x == nullptr) if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((argument->object), (x->object)); *out = migraphx::equal((argument->object), (x->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed) migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(s == nullptr) if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_argument_t>(migraphx::generate_argument((s->object), (seed))); *out = allocate<migraphx_argument_t>(migraphx::generate_argument((s->object), (seed)));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target) extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target)
{ {
return migraphx::try_([&] { destroy((target)); }); auto api_error_result = migraphx::try_([&] { destroy((target)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name) extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*target = object_cast<migraphx_target_t>( *target = object_cast<migraphx_target_t>(
allocate<migraphx::target>(migraphx::get_target((name)))); allocate<migraphx::target>(migraphx::get_target((name))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_parameter_shapes_destroy( extern "C" migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes) migraphx_program_parameter_shapes_t program_parameter_shapes)
{ {
return migraphx::try_([&] { destroy((program_parameter_shapes)); }); auto api_error_result = migraphx::try_([&] { destroy((program_parameter_shapes)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameter_shapes_size(size_t* out, migraphx_program_parameter_shapes_size(size_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes) migraphx_program_parameter_shapes_t program_parameter_shapes)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr) if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer"); "Bad parameter program_parameter_shapes: Null pointer");
*out = (program_parameter_shapes->object).size(); *out = (program_parameter_shapes->object).size();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
...@@ -567,19 +599,20 @@ migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out, ...@@ -567,19 +599,20 @@ migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes, migraphx_program_parameter_shapes_t program_parameter_shapes,
const char* name) const char* name)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr) if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer"); "Bad parameter program_parameter_shapes: Null pointer");
*out = *out =
object_cast<const_migraphx_shape_t>(&((program_parameter_shapes->object).at((name)))); object_cast<const_migraphx_shape_t>(&((program_parameter_shapes->object).at((name))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_parameter_shapes_names( extern "C" migraphx_status migraphx_program_parameter_shapes_names(
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes) const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(out == nullptr) if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(program_parameter_shapes == nullptr) if(program_parameter_shapes == nullptr)
...@@ -588,21 +621,24 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_names( ...@@ -588,21 +621,24 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_names(
auto&& api_result = migraphx::get_names((program_parameter_shapes->object)); auto&& api_result = migraphx::get_names((program_parameter_shapes->object));
std::copy(api_result.begin(), api_result.end(), out); std::copy(api_result.begin(), api_result.end(), out);
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters) migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters)
{ {
return migraphx::try_([&] { destroy((program_parameters)); }); auto api_error_result = migraphx::try_([&] { destroy((program_parameters)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters) migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*program_parameters = object_cast<migraphx_program_parameters_t>( *program_parameters = object_cast<migraphx_program_parameters_t>(
allocate<std::unordered_map<std::string, migraphx::argument>>()); allocate<std::unordered_map<std::string, migraphx::argument>>());
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
...@@ -610,7 +646,7 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters ...@@ -610,7 +646,7 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
const char* name, const char* name,
const_migraphx_argument_t argument) const_migraphx_argument_t argument)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program_parameters == nullptr) if(program_parameters == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameters: Null pointer"); "Bad parameter program_parameters: Null pointer");
...@@ -618,85 +654,95 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters ...@@ -618,85 +654,95 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
(program_parameters->object)[(name)] = (argument->object); (program_parameters->object)[(name)] = (argument->object);
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments) extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments)
{ {
return migraphx::try_([&] { destroy((arguments)); }); auto api_error_result = migraphx::try_([&] { destroy((arguments)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments) extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr) if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = (arguments->object).size(); *out = (arguments->object).size();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx) migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr) if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = object_cast<const_migraphx_argument_t>(&((arguments->object).at((idx)))); *out = object_cast<const_migraphx_argument_t>(&((arguments->object).at((idx))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes) extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes)
{ {
return migraphx::try_([&] { destroy((shapes)); }); auto api_error_result = migraphx::try_([&] { destroy((shapes)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes) extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr) if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = (shapes->object).size(); *out = (shapes->object).size();
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx) migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr) if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((shapes->object).at((idx)))); *out = object_cast<const_migraphx_shape_t>(&((shapes->object).at((idx))));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module) extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(module == nullptr) if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
migraphx::print_module((module->object)); migraphx::print_module((module->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{ {
return migraphx::try_([&] { destroy((program)); }); auto api_error_result = migraphx::try_([&] { destroy((program)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program) migraphx_program_t program)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).get_main_module()); *out = object_cast<migraphx_module_t>((program->object).get_main_module());
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program, extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target, migraphx_target_t target,
migraphx_compile_options_t options) migraphx_compile_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(target == nullptr) if(target == nullptr)
...@@ -705,91 +751,105 @@ extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program, ...@@ -705,91 +751,105 @@ extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
(program->object).compile((target->object), (options->object)); (program->object).compile((target->object), (options->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out, migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out,
migraphx_program_t program) migraphx_program_t program)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = *out =
allocate<migraphx_program_parameter_shapes_t>((program->object).get_parameter_shapes()); allocate<migraphx_program_parameter_shapes_t>((program->object).get_parameter_shapes());
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out, extern "C" migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_program_t program) migraphx_program_t program)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = allocate<migraphx_shapes_t>(migraphx::get_output_shapes((program->object))); *out = allocate<migraphx_shapes_t>(migraphx::get_output_shapes((program->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t program) extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t program)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
migraphx::print_program((program->object)); migraphx::print_program((program->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program) extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
(program->object).sort(); (program->object).sort();
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out, extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params) migraphx_program_parameters_t params)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(params == nullptr) if(params == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer");
*out = allocate<migraphx_arguments_t>(migraphx::run((program->object), (params->object))); *out = allocate<migraphx_arguments_t>(migraphx::run((program->object), (params->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x) migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(x == nullptr) if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((program->object), (x->object)); *out = migraphx::equal((program->object), (x->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation) extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation)
{ {
return migraphx::try_([&] { destroy((operation)); }); auto api_error_result = migraphx::try_([&] { destroy((operation)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
migraphx_operation_create(migraphx_operation_t* operation, const char* name, const char* attributes) const char* name,
const char* attributes,
...)
{ {
return migraphx::try_([&] { va_list vlist;
va_start(vlist, attributes);
auto api_error_result = migraphx::try_([&] {
*operation = object_cast<migraphx_operation_t>( *operation = object_cast<migraphx_operation_t>(
allocate<migraphx::operation>(migraphx::create_op((name), (attributes)))); allocate<migraphx::operation>(migraphx::create_op((name), (attributes), (vlist))));
}); });
va_end(vlist);
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation) migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(out == nullptr) if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(operation == nullptr) if(operation == nullptr)
...@@ -798,46 +858,51 @@ migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operati ...@@ -798,46 +858,51 @@ migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operati
auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out); auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out);
*it = '\0'; *it = '\0';
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options) migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(options == nullptr) if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::load((name), (options->object))); *out = allocate<migraphx_program_t>(migraphx::load((name), (options->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options) migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(p == nullptr) if(p == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer");
if(options == nullptr) if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::save((p->object), (name), (options->object)); migraphx::save((p->object), (name), (options->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options) extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options)
{ {
return migraphx::try_([&] { destroy((onnx_options)); }); auto api_error_result = migraphx::try_([&] { destroy((onnx_options)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options) extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*onnx_options = object_cast<migraphx_onnx_options_t>(allocate<migraphx::onnx_options>()); *onnx_options = object_cast<migraphx_onnx_options_t>(allocate<migraphx::onnx_options>());
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape( extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size) migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr) if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr and dims_size != 0) if(dims == nullptr and dims_size != 0)
...@@ -845,96 +910,107 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape( ...@@ -845,96 +910,107 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape( migraphx::set_input_parameter_shape(
(onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size))); (onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value) migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr) if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_dim_value((onnx_options->object), (value)); migraphx::set_default_dim_value((onnx_options->object), (value));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value) int64_t value)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr) if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_loop_iterations((onnx_options->object), (value)); migraphx::set_default_loop_iterations((onnx_options->object), (value));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options) extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options)
{ {
return migraphx::try_([&] { destroy((file_options)); }); auto api_error_result = migraphx::try_([&] { destroy((file_options)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options) extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*file_options = object_cast<migraphx_file_options_t>(allocate<migraphx::file_options>()); *file_options = object_cast<migraphx_file_options_t>(allocate<migraphx::file_options>());
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format) migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(file_options == nullptr) if(file_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter file_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter file_options: Null pointer");
migraphx::set_file_format((file_options->object), (format)); migraphx::set_file_format((file_options->object), (format));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options) migraphx_compile_options_destroy(migraphx_compile_options_t compile_options)
{ {
return migraphx::try_([&] { destroy((compile_options)); }); auto api_error_result = migraphx::try_([&] { destroy((compile_options)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options) migraphx_compile_options_create(migraphx_compile_options_t* compile_options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*compile_options = *compile_options =
object_cast<migraphx_compile_options_t>(allocate<migraphx::compile_options>()); object_cast<migraphx_compile_options_t>(allocate<migraphx::compile_options>());
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value) migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr) if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer"); "Bad parameter compile_options: Null pointer");
migraphx::set_offload_copy((compile_options->object), (value)); migraphx::set_offload_copy((compile_options->object), (value));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value) migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr) if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer"); "Bad parameter compile_options: Null pointer");
migraphx::set_fast_math((compile_options->object), (value)); migraphx::set_fast_math((compile_options->object), (value));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options) migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(options == nullptr) if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_onnx((name), (options->object))); *out = allocate<migraphx_program_t>(migraphx::parse_onnx((name), (options->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
...@@ -942,40 +1018,44 @@ extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, ...@@ -942,40 +1018,44 @@ extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
size_t size, size_t size,
migraphx_onnx_options_t options) migraphx_onnx_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(options == nullptr) if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>( *out = allocate<migraphx_program_t>(
migraphx::parse_onnx_buffer((data), (size), (options->object))); migraphx::parse_onnx_buffer((data), (size), (options->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options) extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options)
{ {
return migraphx::try_([&] { destroy((tf_options)); }); auto api_error_result = migraphx::try_([&] { destroy((tf_options)); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options) extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*tf_options = object_cast<migraphx_tf_options_t>(allocate<migraphx::tf_options>()); *tf_options = object_cast<migraphx_tf_options_t>(allocate<migraphx::tf_options>());
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, extern "C" migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
bool is_nhwc) bool is_nhwc)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr) if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_nhwc((tf_options->object), (is_nhwc)); migraphx::set_nhwc((tf_options->object), (is_nhwc));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape( extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size) migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr) if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(dims == nullptr and dims_size != 0) if(dims == nullptr and dims_size != 0)
...@@ -983,23 +1063,25 @@ extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape( ...@@ -983,23 +1063,25 @@ extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape( migraphx::set_input_parameter_shape(
(tf_options->object), (name), (std::vector<size_t>(dims, dims + dims_size))); (tf_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value) migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr) if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_default_dim_value((tf_options->object), (value)); migraphx::set_default_dim_value((tf_options->object), (value));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options, extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options,
const char** names, const char** names,
size_t names_size) size_t names_size)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr) if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(names == nullptr and names_size != 0) if(names == nullptr and names_size != 0)
...@@ -1007,96 +1089,106 @@ extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_opti ...@@ -1007,96 +1089,106 @@ extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_opti
migraphx::set_output_names((tf_options->object), migraphx::set_output_names((tf_options->object),
(std::vector<const char*>(names, names + names_size))); (std::vector<const char*>(names, names + names_size)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options) migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(options == nullptr) if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_tf((name), (options->object))); *out = allocate<migraphx_program_t>(migraphx::parse_tf((name), (options->object)));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names) migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names)
{ {
return migraphx::try_([&] { destroy((quantize_op_names)); }); auto api_error_result = migraphx::try_([&] { destroy((quantize_op_names)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names) migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*quantize_op_names = *quantize_op_names =
object_cast<migraphx_quantize_op_names_t>(allocate<std::vector<std::string>>()); object_cast<migraphx_quantize_op_names_t>(allocate<std::vector<std::string>>());
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name) migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(quantize_op_names == nullptr) if(quantize_op_names == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_op_names: Null pointer"); "Bad parameter quantize_op_names: Null pointer");
(quantize_op_names->object).push_back((name)); (quantize_op_names->object).push_back((name));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, extern "C" migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog,
migraphx_quantize_op_names_t name) migraphx_quantize_op_names_t name)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(prog == nullptr) if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(name == nullptr) if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
migraphx::quantize_fp16_with_op_names((prog->object), (name->object)); migraphx::quantize_fp16_with_op_names((prog->object), (name->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog) extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(prog == nullptr) if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
migraphx::quantize_fp16((prog->object)); migraphx::quantize_fp16((prog->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options) migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options)
{ {
return migraphx::try_([&] { destroy((quantize_int8_options)); }); auto api_error_result = migraphx::try_([&] { destroy((quantize_int8_options)); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options) migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*quantize_int8_options = object_cast<migraphx_quantize_int8_options_t>( *quantize_int8_options = object_cast<migraphx_quantize_int8_options_t>(
allocate<migraphx::quantize_int8_options>()); allocate<migraphx::quantize_int8_options>());
}); });
return api_error_result;
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options, migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options,
const char* name) const char* name)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr) if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer"); "Bad parameter quantize_int8_options: Null pointer");
migraphx::add_op_name((quantize_int8_options->object), (name)); migraphx::add_op_name((quantize_int8_options->object), (name));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data( extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data) migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr) if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer"); "Bad parameter quantize_int8_options: Null pointer");
...@@ -1104,13 +1196,14 @@ extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data( ...@@ -1104,13 +1196,14 @@ extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer");
migraphx::add_calibration_data((quantize_int8_options->object), (data->object)); migraphx::add_calibration_data((quantize_int8_options->object), (data->object));
}); });
return api_error_result;
} }
extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog, extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target, migraphx_target_t target,
migraphx_quantize_int8_options_t options) migraphx_quantize_int8_options_t options)
{ {
return migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(prog == nullptr) if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(target == nullptr) if(target == nullptr)
...@@ -1119,4 +1212,5 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog, ...@@ -1119,4 +1212,5 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::quantize_int8_wrap((prog->object), (target->object), (options->object)); migraphx::quantize_int8_wrap((prog->object), (target->object), (options->object));
}); });
return api_error_result;
} }
...@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); ...@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation, migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes); const char* attributes,
...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation); migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
......
...@@ -599,9 +599,10 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -599,9 +599,10 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); } operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
operation(const char* name, const char* attributes = nullptr) template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{ {
this->make_handle(&migraphx_operation_create, name, attributes); this->make_handle(&migraphx_operation_create, name, attributes, xs...);
} }
std::string name() std::string name()
......
...@@ -212,7 +212,9 @@ def program(h): ...@@ -212,7 +212,9 @@ def program(h):
@auto_handle() @auto_handle()
def operation(h): def operation(h):
h.constructor('create', h.constructor('create',
api.params(name='const char*', attributes='const char*'), api.params(name='const char*',
attributes='const char*',
vlist='...'),
fname='migraphx::create_op') fname='migraphx::create_op')
h.method('name', returns='std::string') h.method('name', returns='std::string')
......
...@@ -8,16 +8,22 @@ TEST_CASE(add_op) ...@@ -8,16 +8,22 @@ TEST_CASE(add_op)
EXPECT(add_op.name() == "add"); EXPECT(add_op.name() == "add");
} }
TEST_CASE(reduce_mean) TEST_CASE(reduce_mean_without_quotes)
{ {
auto rm = migraphx::operation("reduce_mean", "{axes : [1, 2, 3, 4]}"); auto rm = migraphx::operation("reduce_mean", "{axes : [1, 2, 3, 4]}");
EXPECT(rm.name() == "reduce_mean"); EXPECT(rm.name() == "reduce_mean");
} }
TEST_CASE(reduce_mean1) TEST_CASE(reduce_mean)
{ {
auto rm = migraphx::operation("reduce_mean", "{\"axes\" : [1, 2, 3, 4]}"); auto rm = migraphx::operation("reduce_mean", "{\"axes\" : [1, 2, 3, 4]}");
EXPECT(rm.name() == "reduce_mean"); EXPECT(rm.name() == "reduce_mean");
} }
TEST_CASE(reduce_mean_with_format)
{
auto rm = migraphx::operation("reduce_mean", "{axes : [%i, %i, %i, %i]}", 1, 2, 3, 4);
EXPECT(rm.name() == "reduce_mean");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -35,6 +35,9 @@ class Type: ...@@ -35,6 +35,9 @@ class Type:
def is_const(self): def is_const(self):
return self.name.startswith('const ') return self.name.startswith('const ')
def is_variadic(self):
return self.name.startswith('...')
def add_pointer(self): def add_pointer(self):
return Type(self.name + '*') return Type(self.name + '*')
...@@ -101,9 +104,10 @@ ${error_type} ${name}(${params}); ...@@ -101,9 +104,10 @@ ${error_type} ${name}(${params});
c_api_impl = Template(''' c_api_impl = Template('''
extern "C" ${error_type} ${name}(${params}) extern "C" ${error_type} ${name}(${params})
{ {
return ${try_wrap}([&] { ${va_start}auto api_error_result = ${try_wrap}([&] {
${body}; ${body};
}); });
${va_end}return api_error_result;
} }
''') ''')
...@@ -113,6 +117,8 @@ class CFunction: ...@@ -113,6 +117,8 @@ class CFunction:
self.name = name self.name = name
self.params = [] self.params = []
self.body = [] self.body = []
self.va_start = []
self.va_end = []
def add_param(self, type, pname): def add_param(self, type, pname):
self.params.append('{} {}'.format(type, pname)) self.params.append('{} {}'.format(type, pname))
...@@ -120,12 +126,23 @@ class CFunction: ...@@ -120,12 +126,23 @@ class CFunction:
def add_statement(self, stmt): def add_statement(self, stmt):
self.body.append(stmt) self.body.append(stmt)
def add_vlist(self, name):
last_param = self.params[-1].split()[-1]
self.va_start = [
'va_list {};'.format(name),
'va_start({}, {});'.format(name, last_param)
]
self.va_end = ['va_end({});'.format(name)]
self.add_param('...', '')
def substitute(self, form): def substitute(self, form):
return form.substitute(error_type=error_type, return form.substitute(error_type=error_type,
try_wrap=try_wrap, try_wrap=try_wrap,
name=self.name, name=self.name,
params=', '.join(self.params), params=', '.join(self.params),
body=";\n ".join(self.body)) body=";\n ".join(self.body),
va_start="\n ".join(self.va_start),
va_end="\n ".join(self.va_end))
def generate_header(self): def generate_header(self):
return self.substitute(header_function) return self.substitute(header_function)
...@@ -256,7 +273,10 @@ class Parameter: ...@@ -256,7 +273,10 @@ class Parameter:
def add_to_cfunction(self, cfunction): def add_to_cfunction(self, cfunction):
for t, name in self.cparams: for t, name in self.cparams:
cfunction.add_param(self.substitute(t), self.substitute(name)) if t.startswith('...'):
cfunction.add_vlist(name)
else:
cfunction.add_param(self.substitute(t), self.substitute(name))
if self.bad_param_check: if self.bad_param_check:
msg = 'Bad parameter {name}: {msg}'.format( msg = 'Bad parameter {name}: {msg}'.format(
name=self.name, msg=self.bad_param_check.msg) name=self.name, msg=self.bad_param_check.msg)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
#include <cstdarg>
namespace migraphx { namespace migraphx {
...@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o ...@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names); migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
} }
operation create_op(const char* name, const char* attributes) #ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
operation create_op(const char* name, const char* attributes, va_list vlist)
{ {
std::string sattributes = attributes == nullptr ? "" : attributes;
std::vector<char> buffer(sattributes.size() * 2);
std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist);
value v = value::object{}; value v = value::object{};
if(attributes != nullptr) if(attributes != nullptr)
{ {
v = from_json_string(convert_to_json(std::string(attributes))); v = from_json_string(convert_to_json(std::string(buffer.data())));
} }
auto op = make_op(name, v); auto op = make_op(name, v);
return op; return op;
} }
#ifdef __clang__
#pragma clang diagnostic pop
#endif
template <class T> template <class T>
bool equal(const T& x, const T& y) bool equal(const T& x, const T& y)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment