Unverified Commit c722117d authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Improve error reporting in the API (#1274)

C++ API is not printing thrown exception string. this improves on it.
parent 6e6cb994
......@@ -39,34 +39,47 @@
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b)
{
disable_exception_catch = b;
}
template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT
{
try
if(disable_exception_catch)
{
f();
}
catch(const migraphx::exception& ex)
else
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
if(ex.error > 0)
return migraphx_status(ex.error);
else
try
{
f();
}
catch(const migraphx::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
if(ex.error > 0)
return migraphx_status(ex.error);
else
return migraphx_status_unknown_error;
}
catch(const std::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
return migraphx_status_unknown_error;
}
catch(const std::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
}
}
return migraphx_status_success;
}
......@@ -305,6 +318,7 @@ void destroy(T* x)
{
delete x; // NOLINT
}
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
......@@ -313,30 +327,35 @@ struct manage_generic_ptr
manage_generic_ptr(std::nullptr_t) {}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
manage_generic_ptr(void* pdata, const char* obj_tname, C pcopier, D pdeleter)
: data(nullptr), obj_typename(obj_tname), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
: data(nullptr), obj_typename(rhs.obj_typename), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
: data(other.data),
obj_typename(other.obj_typename),
copier(other.copier),
deleter(other.deleter)
{
other.data = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
other.data = nullptr;
other.obj_typename = "";
other.copier = nullptr;
other.deleter = nullptr;
}
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(obj_typename, rhs.obj_typename);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
......@@ -348,9 +367,10 @@ struct manage_generic_ptr
deleter(data);
}
void* data = nullptr;
C copier = nullptr;
D deleter = nullptr;
void* data = nullptr;
const char* obj_typename = "";
C copier = nullptr;
D deleter = nullptr;
};
extern "C" struct migraphx_shape;
......@@ -580,8 +600,9 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op(void* p,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* obj_typename,
Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
: object_ptr(p, obj_typename, c, d), xobject(std::forward<Ts>(xs)...)
{
}
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
......@@ -595,13 +616,21 @@ struct migraphx_experimental_custom_op
std::remove_pointer_t<migraphx_argument_t> out;
if(compute_f == nullptr)
throw std::runtime_error("compute function is missing.");
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = compute_f(&out,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_context_t>(&(ctx)),
object_cast<migraphx_shape_t>(&(output)),
object_cast<migraphx_arguments_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute.");
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object;
}
......@@ -611,10 +640,19 @@ struct migraphx_experimental_custom_op
std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing.");
auto api_error_result =
compute_shape_f(&out, object_ptr.data, object_cast<migraphx_shapes_t>(&(inputs)));
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = compute_shape_f(&out,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute_shape.");
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute_shape of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object;
}
};
......@@ -743,6 +781,16 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).standard();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{
auto api_error_result = migraphx::try_([&] { destroy((argument)); });
......@@ -1806,11 +1854,12 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name)
{
auto api_error_result = migraphx::try_([&] {
*experimental_custom_op =
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (name));
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (obj_typename), (name));
});
return api_error_result;
}
......
......@@ -132,12 +132,16 @@ typedef const struct migraphx_experimental_custom_op* const_migraphx_experimenta
typedef migraphx_status (*migraphx_experimental_custom_op_compute)(migraphx_argument_t out,
void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_context_t ctx,
migraphx_shape_t output,
migraphx_arguments_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
......@@ -176,6 +180,8 @@ migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x);
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
......@@ -486,6 +492,7 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name);
migraphx_status
......
......@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <cstring>
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <memory>
......@@ -58,6 +59,42 @@ struct rank<0>
{
};
template <class PrivateMigraphTypeNameProbe>
std::string compute_type_name()
{
std::string name;
#ifdef _MSC_VER
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
name = __PRETTY_FUNCTION__;
auto begin = name.find(parameter_name) + sizeof(parameter_name);
#if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7)
auto length = name.find_last_of(",") - begin;
#else
auto length = name.find_first_of("];", begin) - begin;
#endif
name = name.substr(begin, length);
#endif
return name;
}
template <class T>
const std::string& get_type_name()
{
static const std::string name = compute_type_name<T>();
return name;
}
template <class T>
const std::string& get_type_name(const T&)
{
return get_type_name<T>();
}
template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs)
{
......@@ -310,13 +347,22 @@ struct interface_base : Base
protected:
template <class F>
static migraphx_status try_(F f) // NOLINT
static migraphx_status try_(F f, char* ex_msg = nullptr, size_t ex_msg_size = 0) // NOLINT
{
try
{
f();
return migraphx_status_success;
}
catch(const std::exception& ex)
{
if(ex_msg)
{
std::strncpy(ex_msg, ex.what(), ex_msg_size);
ex_msg[ex_msg_size - 1] = '\0';
}
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
......@@ -349,9 +395,13 @@ struct interface_base : Base
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status {
return try_([&] { call_cast_arg<T>(rank<1>{}, f, xs...); });
});
call(setter,
this->get_handle_ptr(),
[](auto out, void* obj, char* ex_msg, size_t ex_msg_size, auto... xs)
-> migraphx_status {
return try_(
[&] { call_cast_arg<T>(rank<1>{}, f, out, obj, xs...); }, ex_msg, ex_msg_size);
});
}
template <class T, class Setter, class F>
......@@ -524,6 +574,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout;
}
bool standard() const
{
bool result = false;
call(&migraphx_shape_standard, &result, this->get_handle_ptr());
return result;
}
friend bool operator==(const shape& px, const shape& py)
{
bool pout;
......@@ -1206,7 +1263,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
template <class T>
experimental_custom_op(T& obj)
{
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str());
this->make_interface(&migraphx_experimental_custom_op_create,
obj,
get_type_name(obj).c_str(),
obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute);
}
......
......@@ -121,6 +121,7 @@ def shape(h):
invoke='migraphx::equal($@)',
returns='bool',
const=True)
h.method('standard', returns='bool', const=True)
@auto_handle()
......@@ -439,7 +440,8 @@ def context(h):
@api.interface('migraphx_experimental_custom_op',
'migraphx::experimental_custom_op')
def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*'))
h.constructor('create',
api.params(obj_typename='const char*', name='const char*'))
h.virtual('compute',
api.params(ctx='migraphx::context',
output='migraphx::shape',
......
......@@ -23,8 +23,10 @@
*/
#include <algorithm>
#include <cmath>
#include <exception>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include <stdexcept>
#include "test.hpp"
struct sigmoid_custom_op final : migraphx::experimental_custom_op_base
......@@ -43,10 +45,22 @@ struct sigmoid_custom_op final : migraphx::experimental_custom_op_base
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
CHECK(inputs.size() == 2);
CHECK(inputs[0].lengths().size() == 1);
CHECK(inputs[0].type() == migraphx_shape_float_type);
CHECK(bool{inputs[0] == inputs[1]});
if(inputs.size() != 2)
{
throw std::runtime_error("op must have two inputs");
}
if(inputs[0].lengths().size() != 1)
{
throw std::runtime_error("input arg must be a vector or scalar");
}
if(inputs[0].type() != migraphx_shape_float_type)
{
throw std::runtime_error("input arg must be of type float");
}
if(inputs[0] != inputs[1])
{
throw std::runtime_error("input arg and buffer allocation must be of same shape");
}
return inputs.back();
}
};
......@@ -83,4 +97,18 @@ TEST_CASE(run_sigmoid_custom_op)
EXPECT(bool{result == migraphx::argument(s, expected_result.data())});
}
extern "C" void migraphx_test_private_disable_exception_catch(bool b);
TEST_CASE(run_sigmoid_with_incorrect_shape)
{
migraphx::program p;
migraphx::shape s{migraphx_shape_float_type, {12}};
migraphx::module m = p.get_main_module();
auto x = m.add_parameter("x", s);
migraphx_test_private_disable_exception_catch(true);
EXPECT(test::throws<std::exception>(
[&] { m.add_instruction(migraphx::operation("sigmoid_custom_op"), {x}); },
"Error in compute_shape of: sigmoid_custom_op: op must have two inputs"));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -24,6 +24,7 @@
#include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include <stdexcept>
#include "test.hpp"
#define MIGRAPHX_HIP_ASSERT(x) (EXPECT(x == hipSuccess))
......@@ -54,6 +55,14 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
if(!inputs[0].standard())
{
throw std::runtime_error("first arg must be standard shaped");
}
if(inputs.size() != 2)
{
throw std::runtime_error("number of inputs must be 2");
}
return inputs.back();
}
};
......@@ -64,12 +73,17 @@ TEST_CASE(run_simple_custom_op)
migraphx::register_experimental_custom_op(simple_op);
migraphx::program p;
migraphx::shape s{migraphx_shape_int32_type, {4, 3}};
migraphx::shape trans_shape{migraphx_shape_int32_type, {3, 4}};
migraphx::module m = p.get_main_module();
auto x = m.add_parameter("x", s);
auto neg = m.add_instruction(migraphx::operation("neg"), x);
auto alloc = m.add_allocation(s);
auto custom_kernel = m.add_instruction(migraphx::operation("simple_custom_op"), {neg, alloc});
auto relu = m.add_instruction(migraphx::operation("relu"), custom_kernel);
auto alloc = m.add_allocation(trans_shape);
auto neg_trans =
m.add_instruction(migraphx::operation("transpose", "{permutation: [1, 0]}"), {neg});
auto neg_cont = m.add_instruction(migraphx::operation("contiguous"), {neg_trans});
auto custom_kernel =
m.add_instruction(migraphx::operation("simple_custom_op"), {neg_cont, alloc});
auto relu = m.add_instruction(migraphx::operation("relu"), custom_kernel);
m.add_return({relu});
migraphx::compile_options options;
options.set_offload_copy();
......@@ -82,7 +96,7 @@ TEST_CASE(run_simple_custom_op)
auto result_vec = result.as_vector<int>();
std::vector<int> expected_result(12, 0);
std::fill(expected_result.begin() + 6, expected_result.end(), 3);
EXPECT(bool{result == migraphx::argument(s, expected_result.data())});
EXPECT(bool{result == migraphx::argument(trans_shape, expected_result.data())});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -197,7 +197,8 @@ class Parameter:
optional: bool = False,
returns: bool = False,
virtual: bool = False,
this: bool = False) -> None:
this: bool = False,
hidden: bool = False) -> None:
self.name = name
self.type = Type(type)
self.optional = optional
......@@ -211,6 +212,7 @@ class Parameter:
self.returns = returns
self.virtual = virtual
self.this = this
self.hidden = hidden
self.bad_param_check: Optional[BadParam] = None
self.virtual_read: Optional[List[str]] = None
self.virtual_write: Optional[str] = None
......@@ -744,6 +746,8 @@ void destroy(T* x)
{
delete x; // NOLINT
}
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
......@@ -754,23 +758,24 @@ struct manage_generic_ptr
{
}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
manage_generic_ptr(void* pdata, const char* obj_tname, C pcopier, D pdeleter)
: data(nullptr), obj_typename(obj_tname), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
: data(nullptr), obj_typename(rhs.obj_typename), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
: data(other.data), obj_typename(other.obj_typename), copier(other.copier), deleter(other.deleter)
{
other.data = nullptr;
other.obj_typename = "";
other.copier = nullptr;
other.deleter = nullptr;
}
......@@ -778,6 +783,7 @@ struct manage_generic_ptr
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(obj_typename, rhs.obj_typename);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
......@@ -790,6 +796,7 @@ struct manage_generic_ptr
}
void* data = nullptr;
const char* obj_typename = "";
C copier = nullptr;
D deleter = nullptr;
};
......@@ -1042,8 +1049,8 @@ interface_handle_definition = Template('''
extern "C" struct ${ctype};
struct ${ctype} {
template<class... Ts>
${ctype}(void* p, ${copier} c, ${deleter} d, Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
${ctype}(void* p, ${copier} c, ${deleter} d, const char* obj_typename, Ts&&... xs)
: object_ptr(p, obj_typename, c, d), xobject(std::forward<Ts>(xs)...)
{}
manage_generic_ptr<${copier}, ${deleter}> object_ptr = nullptr;
${cpptype} xobject;
......@@ -1057,9 +1064,13 @@ ${return_type} ${name}(${params}) const
${output_decls}
if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing.");
std::array<char, 256> exception_msg;
exception_msg.front() = '\\0';
auto api_error_result = ${fname}(${args});
if (api_error_result != ${success})
throw std::runtime_error("Error in ${name}.");
if (api_error_result != ${success}) {
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in ${name} of: " + std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return ${output};
}
''')
......@@ -1079,7 +1090,9 @@ def generate_virtual_impl(f: Function, fname: str) -> str:
largs += f.returns.virtual_output_args()
output = f.returns.virtual_output()
largs += [arg for p in f.params for arg in p.virtual_arg()]
lparams += [p.virtual_param() for p in f.params if not p.this]
lparams += [
p.virtual_param() for p in f.params if not (p.this or p.hidden)
]
args = ', '.join(largs)
params = ', '.join(lparams)
return c_api_virtual_impl.substitute(locals())
......@@ -1126,8 +1139,15 @@ class Interface(Handle):
# Add this parameter to the function
this = Parameter('obj', 'void*', this=True)
this.virtual_read = ['object_ptr.data']
exception_msg = Parameter('exception_msg', 'char*', hidden=True)
exception_msg.virtual_read = ['${name}.data()']
exception_msg_size = Parameter('exception_msg_size',
'size_t',
hidden=True)
exception_msg_size.virtual_read = ['exception_msg.size()']
f = Function(name,
params=[this] + (params or []),
params=[this, exception_msg, exception_msg_size] +
(params or []),
virtual=True,
**kwargs)
self.ifunctions.append(f)
......
......@@ -39,34 +39,47 @@
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b)
{
disable_exception_catch = b;
}
template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT
{
try
if(disable_exception_catch)
{
f();
}
catch(const migraphx::exception& ex)
else
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
if(ex.error > 0)
return migraphx_status(ex.error);
else
try
{
f();
}
catch(const migraphx::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
if(ex.error > 0)
return migraphx_status(ex.error);
else
return migraphx_status_unknown_error;
}
catch(const std::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
return migraphx_status_unknown_error;
}
catch(const std::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
}
}
return migraphx_status_success;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment