Unverified Commit c722117d authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Improve error reporting in the API (#1274)

C++ API is not printing thrown exception string. this improves on it.
parent 6e6cb994
...@@ -39,34 +39,47 @@ ...@@ -39,34 +39,47 @@
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
#include <cstdarg> #include <cstdarg>
namespace migraphx { namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b)
{
disable_exception_catch = b;
}
template <class F> template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT migraphx_status try_(F f, bool output = true) // NOLINT
{ {
try if(disable_exception_catch)
{ {
f(); f();
} }
catch(const migraphx::exception& ex) else
{ {
if(output) try
std::cerr << "MIGraphX Error: " << ex.what() << std::endl; {
if(ex.error > 0) f();
return migraphx_status(ex.error); }
else catch(const migraphx::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
if(ex.error > 0)
return migraphx_status(ex.error);
else
return migraphx_status_unknown_error;
}
catch(const std::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
return migraphx_status_unknown_error; return migraphx_status_unknown_error;
} }
catch(const std::exception& ex) catch(...)
{ {
if(output) return migraphx_status_unknown_error;
std::cerr << "MIGraphX Error: " << ex.what() << std::endl; }
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
} }
return migraphx_status_success; return migraphx_status_success;
} }
...@@ -305,6 +318,7 @@ void destroy(T* x) ...@@ -305,6 +318,7 @@ void destroy(T* x)
{ {
delete x; // NOLINT delete x; // NOLINT
} }
// TODO: Move to interface preamble // TODO: Move to interface preamble
template <class C, class D> template <class C, class D>
struct manage_generic_ptr struct manage_generic_ptr
...@@ -313,30 +327,35 @@ struct manage_generic_ptr ...@@ -313,30 +327,35 @@ struct manage_generic_ptr
manage_generic_ptr(std::nullptr_t) {} manage_generic_ptr(std::nullptr_t) {}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter) manage_generic_ptr(void* pdata, const char* obj_tname, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter) : data(nullptr), obj_typename(obj_tname), copier(pcopier), deleter(pdeleter)
{ {
copier(&data, pdata); copier(&data, pdata);
} }
manage_generic_ptr(const manage_generic_ptr& rhs) manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter) : data(nullptr), obj_typename(rhs.obj_typename), copier(rhs.copier), deleter(rhs.deleter)
{ {
if(copier) if(copier)
copier(&data, rhs.data); copier(&data, rhs.data);
} }
manage_generic_ptr(manage_generic_ptr&& other) noexcept manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter) : data(other.data),
obj_typename(other.obj_typename),
copier(other.copier),
deleter(other.deleter)
{ {
other.data = nullptr; other.data = nullptr;
other.copier = nullptr; other.obj_typename = "";
other.deleter = nullptr; other.copier = nullptr;
other.deleter = nullptr;
} }
manage_generic_ptr& operator=(manage_generic_ptr rhs) manage_generic_ptr& operator=(manage_generic_ptr rhs)
{ {
std::swap(data, rhs.data); std::swap(data, rhs.data);
std::swap(obj_typename, rhs.obj_typename);
std::swap(copier, rhs.copier); std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter); std::swap(deleter, rhs.deleter);
return *this; return *this;
...@@ -348,9 +367,10 @@ struct manage_generic_ptr ...@@ -348,9 +367,10 @@ struct manage_generic_ptr
deleter(data); deleter(data);
} }
void* data = nullptr; void* data = nullptr;
C copier = nullptr; const char* obj_typename = "";
D deleter = nullptr; C copier = nullptr;
D deleter = nullptr;
}; };
extern "C" struct migraphx_shape; extern "C" struct migraphx_shape;
...@@ -580,8 +600,9 @@ struct migraphx_experimental_custom_op ...@@ -580,8 +600,9 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op(void* p, migraphx_experimental_custom_op(void* p,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d, migraphx_experimental_custom_op_delete d,
const char* obj_typename,
Ts&&... xs) Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...) : object_ptr(p, obj_typename, c, d), xobject(std::forward<Ts>(xs)...)
{ {
} }
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete> manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
...@@ -595,13 +616,21 @@ struct migraphx_experimental_custom_op ...@@ -595,13 +616,21 @@ struct migraphx_experimental_custom_op
std::remove_pointer_t<migraphx_argument_t> out; std::remove_pointer_t<migraphx_argument_t> out;
if(compute_f == nullptr) if(compute_f == nullptr)
throw std::runtime_error("compute function is missing."); throw std::runtime_error("compute function is missing.");
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = compute_f(&out, auto api_error_result = compute_f(&out,
object_ptr.data, object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_context_t>(&(ctx)), object_cast<migraphx_context_t>(&(ctx)),
object_cast<migraphx_shape_t>(&(output)), object_cast<migraphx_shape_t>(&(output)),
object_cast<migraphx_arguments_t>(&(inputs))); object_cast<migraphx_arguments_t>(&(inputs)));
if(api_error_result != migraphx_status_success) if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute."); {
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object; return (&out)->object;
} }
...@@ -611,10 +640,19 @@ struct migraphx_experimental_custom_op ...@@ -611,10 +640,19 @@ struct migraphx_experimental_custom_op
std::remove_pointer_t<migraphx_shape_t> out; std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr) if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing."); throw std::runtime_error("compute_shape function is missing.");
auto api_error_result = std::array<char, 256> exception_msg;
compute_shape_f(&out, object_ptr.data, object_cast<migraphx_shapes_t>(&(inputs))); exception_msg.front() = '\0';
auto api_error_result = compute_shape_f(&out,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success) if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute_shape."); {
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute_shape of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object; return (&out)->object;
} }
}; };
...@@ -743,6 +781,16 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha ...@@ -743,6 +781,16 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).standard();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument) extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{ {
auto api_error_result = migraphx::try_([&] { destroy((argument)); }); auto api_error_result = migraphx::try_([&] { destroy((argument)); });
...@@ -1806,11 +1854,12 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -1806,11 +1854,12 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj, void* obj,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d, migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name) const char* name)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*experimental_custom_op = *experimental_custom_op =
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (name)); allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (obj_typename), (name));
}); });
return api_error_result; return api_error_result;
} }
......
...@@ -132,12 +132,16 @@ typedef const struct migraphx_experimental_custom_op* const_migraphx_experimenta ...@@ -132,12 +132,16 @@ typedef const struct migraphx_experimental_custom_op* const_migraphx_experimenta
typedef migraphx_status (*migraphx_experimental_custom_op_compute)(migraphx_argument_t out, typedef migraphx_status (*migraphx_experimental_custom_op_compute)(migraphx_argument_t out,
void* obj, void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_context_t ctx, migraphx_context_t ctx,
migraphx_shape_t output, migraphx_shape_t output,
migraphx_arguments_t inputs); migraphx_arguments_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out, typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj, void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_shapes_t inputs); migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input); typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
...@@ -176,6 +180,8 @@ migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape); ...@@ -176,6 +180,8 @@ migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x); migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x);
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
...@@ -486,6 +492,7 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -486,6 +492,7 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj, void* obj,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d, migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name); const char* name);
migraphx_status migraphx_status
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP #define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h" #include "migraphx.h"
#include <cstring>
#include <initializer_list> #include <initializer_list>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <memory> #include <memory>
...@@ -58,6 +59,42 @@ struct rank<0> ...@@ -58,6 +59,42 @@ struct rank<0>
{ {
}; };
template <class PrivateMigraphTypeNameProbe>
std::string compute_type_name()
{
std::string name;
#ifdef _MSC_VER
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
name = __PRETTY_FUNCTION__;
auto begin = name.find(parameter_name) + sizeof(parameter_name);
#if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7)
auto length = name.find_last_of(",") - begin;
#else
auto length = name.find_first_of("];", begin) - begin;
#endif
name = name.substr(begin, length);
#endif
return name;
}
template <class T>
const std::string& get_type_name()
{
static const std::string name = compute_type_name<T>();
return name;
}
template <class T>
const std::string& get_type_name(const T&)
{
return get_type_name<T>();
}
template <class T, class F, class... Ts> template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs) T* make(F f, Ts&&... xs)
{ {
...@@ -310,13 +347,22 @@ struct interface_base : Base ...@@ -310,13 +347,22 @@ struct interface_base : Base
protected: protected:
template <class F> template <class F>
static migraphx_status try_(F f) // NOLINT static migraphx_status try_(F f, char* ex_msg = nullptr, size_t ex_msg_size = 0) // NOLINT
{ {
try try
{ {
f(); f();
return migraphx_status_success; return migraphx_status_success;
} }
catch(const std::exception& ex)
{
if(ex_msg)
{
std::strncpy(ex_msg, ex.what(), ex_msg_size);
ex_msg[ex_msg_size - 1] = '\0';
}
return migraphx_status_unknown_error;
}
catch(...) catch(...)
{ {
return migraphx_status_unknown_error; return migraphx_status_unknown_error;
...@@ -349,9 +395,13 @@ struct interface_base : Base ...@@ -349,9 +395,13 @@ struct interface_base : Base
{ {
static F f = pf; static F f = pf;
(void)f; // avoid warning on gcc (void)f; // avoid warning on gcc
call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status { call(setter,
return try_([&] { call_cast_arg<T>(rank<1>{}, f, xs...); }); this->get_handle_ptr(),
}); [](auto out, void* obj, char* ex_msg, size_t ex_msg_size, auto... xs)
-> migraphx_status {
return try_(
[&] { call_cast_arg<T>(rank<1>{}, f, out, obj, xs...); }, ex_msg, ex_msg_size);
});
} }
template <class T, class Setter, class F> template <class T, class Setter, class F>
...@@ -524,6 +574,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -524,6 +574,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout; return pout;
} }
bool standard() const
{
bool result = false;
call(&migraphx_shape_standard, &result, this->get_handle_ptr());
return result;
}
friend bool operator==(const shape& px, const shape& py) friend bool operator==(const shape& px, const shape& py)
{ {
bool pout; bool pout;
...@@ -1206,7 +1263,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental ...@@ -1206,7 +1263,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
template <class T> template <class T>
experimental_custom_op(T& obj) experimental_custom_op(T& obj)
{ {
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str()); this->make_interface(&migraphx_experimental_custom_op_create,
obj,
get_type_name(obj).c_str(),
obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape); MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute); MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute);
} }
......
...@@ -121,6 +121,7 @@ def shape(h): ...@@ -121,6 +121,7 @@ def shape(h):
invoke='migraphx::equal($@)', invoke='migraphx::equal($@)',
returns='bool', returns='bool',
const=True) const=True)
h.method('standard', returns='bool', const=True)
@auto_handle() @auto_handle()
...@@ -439,7 +440,8 @@ def context(h): ...@@ -439,7 +440,8 @@ def context(h):
@api.interface('migraphx_experimental_custom_op', @api.interface('migraphx_experimental_custom_op',
'migraphx::experimental_custom_op') 'migraphx::experimental_custom_op')
def experimental_custom_op(h): def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*')) h.constructor('create',
api.params(obj_typename='const char*', name='const char*'))
h.virtual('compute', h.virtual('compute',
api.params(ctx='migraphx::context', api.params(ctx='migraphx::context',
output='migraphx::shape', output='migraphx::shape',
......
...@@ -23,8 +23,10 @@ ...@@ -23,8 +23,10 @@
*/ */
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <exception>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include <stdexcept>
#include "test.hpp" #include "test.hpp"
struct sigmoid_custom_op final : migraphx::experimental_custom_op_base struct sigmoid_custom_op final : migraphx::experimental_custom_op_base
...@@ -43,10 +45,22 @@ struct sigmoid_custom_op final : migraphx::experimental_custom_op_base ...@@ -43,10 +45,22 @@ struct sigmoid_custom_op final : migraphx::experimental_custom_op_base
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{ {
CHECK(inputs.size() == 2); if(inputs.size() != 2)
CHECK(inputs[0].lengths().size() == 1); {
CHECK(inputs[0].type() == migraphx_shape_float_type); throw std::runtime_error("op must have two inputs");
CHECK(bool{inputs[0] == inputs[1]}); }
if(inputs[0].lengths().size() != 1)
{
throw std::runtime_error("input arg must be a vector or scalar");
}
if(inputs[0].type() != migraphx_shape_float_type)
{
throw std::runtime_error("input arg must be of type float");
}
if(inputs[0] != inputs[1])
{
throw std::runtime_error("input arg and buffer allocation must be of same shape");
}
return inputs.back(); return inputs.back();
} }
}; };
...@@ -83,4 +97,18 @@ TEST_CASE(run_sigmoid_custom_op) ...@@ -83,4 +97,18 @@ TEST_CASE(run_sigmoid_custom_op)
EXPECT(bool{result == migraphx::argument(s, expected_result.data())}); EXPECT(bool{result == migraphx::argument(s, expected_result.data())});
} }
extern "C" void migraphx_test_private_disable_exception_catch(bool b);
TEST_CASE(run_sigmoid_with_incorrect_shape)
{
migraphx::program p;
migraphx::shape s{migraphx_shape_float_type, {12}};
migraphx::module m = p.get_main_module();
auto x = m.add_parameter("x", s);
migraphx_test_private_disable_exception_catch(true);
EXPECT(test::throws<std::exception>(
[&] { m.add_instruction(migraphx::operation("sigmoid_custom_op"), {x}); },
"Error in compute_shape of: sigmoid_custom_op: op must have two inputs"));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include <stdexcept>
#include "test.hpp" #include "test.hpp"
#define MIGRAPHX_HIP_ASSERT(x) (EXPECT(x == hipSuccess)) #define MIGRAPHX_HIP_ASSERT(x) (EXPECT(x == hipSuccess))
...@@ -54,6 +55,14 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base ...@@ -54,6 +55,14 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{ {
if(!inputs[0].standard())
{
throw std::runtime_error("first arg must be standard shaped");
}
if(inputs.size() != 2)
{
throw std::runtime_error("number of inputs must be 2");
}
return inputs.back(); return inputs.back();
} }
}; };
...@@ -64,12 +73,17 @@ TEST_CASE(run_simple_custom_op) ...@@ -64,12 +73,17 @@ TEST_CASE(run_simple_custom_op)
migraphx::register_experimental_custom_op(simple_op); migraphx::register_experimental_custom_op(simple_op);
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx_shape_int32_type, {4, 3}}; migraphx::shape s{migraphx_shape_int32_type, {4, 3}};
migraphx::shape trans_shape{migraphx_shape_int32_type, {3, 4}};
migraphx::module m = p.get_main_module(); migraphx::module m = p.get_main_module();
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto neg = m.add_instruction(migraphx::operation("neg"), x); auto neg = m.add_instruction(migraphx::operation("neg"), x);
auto alloc = m.add_allocation(s); auto alloc = m.add_allocation(trans_shape);
auto custom_kernel = m.add_instruction(migraphx::operation("simple_custom_op"), {neg, alloc}); auto neg_trans =
auto relu = m.add_instruction(migraphx::operation("relu"), custom_kernel); m.add_instruction(migraphx::operation("transpose", "{permutation: [1, 0]}"), {neg});
auto neg_cont = m.add_instruction(migraphx::operation("contiguous"), {neg_trans});
auto custom_kernel =
m.add_instruction(migraphx::operation("simple_custom_op"), {neg_cont, alloc});
auto relu = m.add_instruction(migraphx::operation("relu"), custom_kernel);
m.add_return({relu}); m.add_return({relu});
migraphx::compile_options options; migraphx::compile_options options;
options.set_offload_copy(); options.set_offload_copy();
...@@ -82,7 +96,7 @@ TEST_CASE(run_simple_custom_op) ...@@ -82,7 +96,7 @@ TEST_CASE(run_simple_custom_op)
auto result_vec = result.as_vector<int>(); auto result_vec = result.as_vector<int>();
std::vector<int> expected_result(12, 0); std::vector<int> expected_result(12, 0);
std::fill(expected_result.begin() + 6, expected_result.end(), 3); std::fill(expected_result.begin() + 6, expected_result.end(), 3);
EXPECT(bool{result == migraphx::argument(s, expected_result.data())}); EXPECT(bool{result == migraphx::argument(trans_shape, expected_result.data())});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -197,7 +197,8 @@ class Parameter: ...@@ -197,7 +197,8 @@ class Parameter:
optional: bool = False, optional: bool = False,
returns: bool = False, returns: bool = False,
virtual: bool = False, virtual: bool = False,
this: bool = False) -> None: this: bool = False,
hidden: bool = False) -> None:
self.name = name self.name = name
self.type = Type(type) self.type = Type(type)
self.optional = optional self.optional = optional
...@@ -211,6 +212,7 @@ class Parameter: ...@@ -211,6 +212,7 @@ class Parameter:
self.returns = returns self.returns = returns
self.virtual = virtual self.virtual = virtual
self.this = this self.this = this
self.hidden = hidden
self.bad_param_check: Optional[BadParam] = None self.bad_param_check: Optional[BadParam] = None
self.virtual_read: Optional[List[str]] = None self.virtual_read: Optional[List[str]] = None
self.virtual_write: Optional[str] = None self.virtual_write: Optional[str] = None
...@@ -744,6 +746,8 @@ void destroy(T* x) ...@@ -744,6 +746,8 @@ void destroy(T* x)
{ {
delete x; // NOLINT delete x; // NOLINT
} }
// TODO: Move to interface preamble // TODO: Move to interface preamble
template <class C, class D> template <class C, class D>
struct manage_generic_ptr struct manage_generic_ptr
...@@ -754,23 +758,24 @@ struct manage_generic_ptr ...@@ -754,23 +758,24 @@ struct manage_generic_ptr
{ {
} }
manage_generic_ptr(void* pdata, C pcopier, D pdeleter) manage_generic_ptr(void* pdata, const char* obj_tname, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter) : data(nullptr), obj_typename(obj_tname), copier(pcopier), deleter(pdeleter)
{ {
copier(&data, pdata); copier(&data, pdata);
} }
manage_generic_ptr(const manage_generic_ptr& rhs) manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter) : data(nullptr), obj_typename(rhs.obj_typename), copier(rhs.copier), deleter(rhs.deleter)
{ {
if(copier) if(copier)
copier(&data, rhs.data); copier(&data, rhs.data);
} }
manage_generic_ptr(manage_generic_ptr&& other) noexcept manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter) : data(other.data), obj_typename(other.obj_typename), copier(other.copier), deleter(other.deleter)
{ {
other.data = nullptr; other.data = nullptr;
other.obj_typename = "";
other.copier = nullptr; other.copier = nullptr;
other.deleter = nullptr; other.deleter = nullptr;
} }
...@@ -778,6 +783,7 @@ struct manage_generic_ptr ...@@ -778,6 +783,7 @@ struct manage_generic_ptr
manage_generic_ptr& operator=(manage_generic_ptr rhs) manage_generic_ptr& operator=(manage_generic_ptr rhs)
{ {
std::swap(data, rhs.data); std::swap(data, rhs.data);
std::swap(obj_typename, rhs.obj_typename);
std::swap(copier, rhs.copier); std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter); std::swap(deleter, rhs.deleter);
return *this; return *this;
...@@ -790,6 +796,7 @@ struct manage_generic_ptr ...@@ -790,6 +796,7 @@ struct manage_generic_ptr
} }
void* data = nullptr; void* data = nullptr;
const char* obj_typename = "";
C copier = nullptr; C copier = nullptr;
D deleter = nullptr; D deleter = nullptr;
}; };
...@@ -1042,8 +1049,8 @@ interface_handle_definition = Template(''' ...@@ -1042,8 +1049,8 @@ interface_handle_definition = Template('''
extern "C" struct ${ctype}; extern "C" struct ${ctype};
struct ${ctype} { struct ${ctype} {
template<class... Ts> template<class... Ts>
${ctype}(void* p, ${copier} c, ${deleter} d, Ts&&... xs) ${ctype}(void* p, ${copier} c, ${deleter} d, const char* obj_typename, Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...) : object_ptr(p, obj_typename, c, d), xobject(std::forward<Ts>(xs)...)
{} {}
manage_generic_ptr<${copier}, ${deleter}> object_ptr = nullptr; manage_generic_ptr<${copier}, ${deleter}> object_ptr = nullptr;
${cpptype} xobject; ${cpptype} xobject;
...@@ -1057,9 +1064,13 @@ ${return_type} ${name}(${params}) const ...@@ -1057,9 +1064,13 @@ ${return_type} ${name}(${params}) const
${output_decls} ${output_decls}
if (${fname} == nullptr) if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing."); throw std::runtime_error("${name} function is missing.");
std::array<char, 256> exception_msg;
exception_msg.front() = '\\0';
auto api_error_result = ${fname}(${args}); auto api_error_result = ${fname}(${args});
if (api_error_result != ${success}) if (api_error_result != ${success}) {
throw std::runtime_error("Error in ${name}."); const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in ${name} of: " + std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return ${output}; return ${output};
} }
''') ''')
...@@ -1079,7 +1090,9 @@ def generate_virtual_impl(f: Function, fname: str) -> str: ...@@ -1079,7 +1090,9 @@ def generate_virtual_impl(f: Function, fname: str) -> str:
largs += f.returns.virtual_output_args() largs += f.returns.virtual_output_args()
output = f.returns.virtual_output() output = f.returns.virtual_output()
largs += [arg for p in f.params for arg in p.virtual_arg()] largs += [arg for p in f.params for arg in p.virtual_arg()]
lparams += [p.virtual_param() for p in f.params if not p.this] lparams += [
p.virtual_param() for p in f.params if not (p.this or p.hidden)
]
args = ', '.join(largs) args = ', '.join(largs)
params = ', '.join(lparams) params = ', '.join(lparams)
return c_api_virtual_impl.substitute(locals()) return c_api_virtual_impl.substitute(locals())
...@@ -1126,8 +1139,15 @@ class Interface(Handle): ...@@ -1126,8 +1139,15 @@ class Interface(Handle):
# Add this parameter to the function # Add this parameter to the function
this = Parameter('obj', 'void*', this=True) this = Parameter('obj', 'void*', this=True)
this.virtual_read = ['object_ptr.data'] this.virtual_read = ['object_ptr.data']
exception_msg = Parameter('exception_msg', 'char*', hidden=True)
exception_msg.virtual_read = ['${name}.data()']
exception_msg_size = Parameter('exception_msg_size',
'size_t',
hidden=True)
exception_msg_size.virtual_read = ['exception_msg.size()']
f = Function(name, f = Function(name,
params=[this] + (params or []), params=[this, exception_msg, exception_msg_size] +
(params or []),
virtual=True, virtual=True,
**kwargs) **kwargs)
self.ifunctions.append(f) self.ifunctions.append(f)
......
...@@ -39,34 +39,47 @@ ...@@ -39,34 +39,47 @@
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
#include <cstdarg> #include <cstdarg>
namespace migraphx { namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b)
{
disable_exception_catch = b;
}
template <class F> template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT migraphx_status try_(F f, bool output = true) // NOLINT
{ {
try if(disable_exception_catch)
{ {
f(); f();
} }
catch(const migraphx::exception& ex) else
{ {
if(output) try
std::cerr << "MIGraphX Error: " << ex.what() << std::endl; {
if(ex.error > 0) f();
return migraphx_status(ex.error); }
else catch(const migraphx::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
if(ex.error > 0)
return migraphx_status(ex.error);
else
return migraphx_status_unknown_error;
}
catch(const std::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
return migraphx_status_unknown_error; return migraphx_status_unknown_error;
} }
catch(const std::exception& ex) catch(...)
{ {
if(output) return migraphx_status_unknown_error;
std::cerr << "MIGraphX Error: " << ex.what() << std::endl; }
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
} }
return migraphx_status_success; return migraphx_status_success;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment