Unverified Commit 40c087bd authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add reverse lookup of c++ class to c class (#1099)

Needed for custom_op so we can generically convert the C type back to the C++ type in the function pointer.
parent e5242676
...@@ -152,6 +152,35 @@ struct array_base ...@@ -152,6 +152,35 @@ struct array_base
} }
}; };
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wnon-template-friend"
#endif
template <class T>
struct holder
{
// Friend injection
friend auto migraphx_adl_handle_lookup(holder<T>);
// Function left unimplemented since its only used in non-evaluated
// context
T get() const;
};
template <class C, class T>
struct handle_lookup
{
friend auto migraphx_adl_handle_lookup(holder<T>) { return holder<C>{}; }
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
template <class T>
using as_handle = decltype(
migraphx_adl_handle_lookup(holder<std::remove_cv_t<std::remove_pointer_t<T>>>{}).get());
struct own struct own
{ {
}; };
...@@ -159,8 +188,8 @@ struct borrow ...@@ -159,8 +188,8 @@ struct borrow
{ {
}; };
template <class T, class D, D Deleter, class A, A Assigner> template <class Derived, class T, class D, D Deleter, class A, A Assigner>
struct handle_base struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
{ {
handle_base() : m_handle(nullptr) {} handle_base() : m_handle(nullptr) {}
template <class F, class... Ts> template <class F, class... Ts>
...@@ -204,7 +233,8 @@ struct handle_base ...@@ -204,7 +233,8 @@ struct handle_base
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<> #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else #else
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \ #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<const_ migraphx_##name, \ handle_base<name, \
const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \ decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy, \ migraphx_##name##_destroy, \
decltype(&migraphx_##name##_assign_to), \ decltype(&migraphx_##name##_assign_to), \
......
...@@ -12,6 +12,7 @@ endfunction() ...@@ -12,6 +12,7 @@ endfunction()
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.hpp>
#include <migraphx/rank.hpp>
#include "test.hpp"
template <class T>
std::false_type has_handle(migraphx::rank<0>, T)
{
return {};
}
template <class T>
auto has_handle(migraphx::rank<1>, T*) -> decltype(migraphx::as_handle<T>{}, std::true_type{})
{
return {};
}
TEST_CASE(shape)
{
static_assert(std::is_same<migraphx::as_handle<migraphx_shape>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<migraphx_shape_t>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<const_migraphx_shape_t>, migraphx::shape>{},
"Failed");
}
TEST_CASE(non_handle)
{
int i = 0;
EXPECT(bool{has_handle(migraphx::rank<1>{}, migraphx_shape_t{})});
EXPECT(bool{not has_handle(migraphx::rank<1>{}, &i)});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment