Commit b7a1823c authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into rocblas_api_opt

parents 6af36ea4 40c087bd
...@@ -235,6 +235,8 @@ struct context ...@@ -235,6 +235,8 @@ struct context
this->current_device = std::make_shared<hip_device>(0, n_streams); this->current_device = std::make_shared<hip_device>(0, n_streams);
} }
any_ptr get_queue() { return get_stream().get(); }
private: private:
// TODO: Make this a vector to support multiple devices // TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
......
...@@ -15,8 +15,6 @@ target_link_libraries(migraphx_ref migraphx Threads::Threads) ...@@ -15,8 +15,6 @@ target_link_libraries(migraphx_ref migraphx Threads::Threads)
target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE}) target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS) target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
target_link_libraries(migraphx_all_targets INTERFACE migraphx_ref)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_ref TARGETS migraphx_ref
INCLUDE INCLUDE
......
...@@ -19,7 +19,7 @@ target_compile_options(tf-proto PRIVATE -w) ...@@ -19,7 +19,7 @@ target_compile_options(tf-proto PRIVATE -w)
target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB TF_SRCS *.cpp) file(GLOB TF_SRCS ${CONFIGURE_DEPENDS} *.cpp)
add_library(migraphx_tf ${TF_SRCS}) add_library(migraphx_tf ${TF_SRCS})
target_include_directories(migraphx_tf PRIVATE include) target_include_directories(migraphx_tf PRIVATE include)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf) set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
......
...@@ -90,7 +90,7 @@ function(add_test_executable TEST_NAME) ...@@ -90,7 +90,7 @@ function(add_test_executable TEST_NAME)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable) endfunction(add_test_executable)
file(GLOB TESTS *.cpp) file(GLOB TESTS ${CONFIGURE_DEPENDS} *.cpp)
foreach(TEST ${TESTS}) foreach(TEST ${TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -100,7 +100,7 @@ endforeach() ...@@ -100,7 +100,7 @@ endforeach()
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
# gpu tests # gpu tests
file(GLOB GPU_TESTS gpu/*.cpp) file(GLOB GPU_TESTS ${CONFIGURE_DEPENDS} gpu/*.cpp)
foreach(TEST ${GPU_TESTS}) foreach(TEST ${GPU_TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -120,7 +120,7 @@ file (GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp) ...@@ -120,7 +120,7 @@ file (GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp)
foreach(ONNX_TEST ${ONNX_TESTS}) foreach(ONNX_TEST ${ONNX_TESTS})
get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE) get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE)
set(TEST_NAME test_${BASE_NAME}) set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${TES_ONNX_DIR}/${ONNX_TEST}) add_executable(${TEST_NAME} ${ONNX_TEST})
rocm_clang_tidy_check(${TEST_NAME}) rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref) target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
...@@ -160,7 +160,7 @@ function(test_header NAME HEADER) ...@@ -160,7 +160,7 @@ function(test_header NAME HEADER)
endfunction() endfunction()
function(test_headers PREFIX) function(test_headers PREFIX)
file(GLOB HEADERS ${ARGN}) file(GLOB HEADERS ${CONFIGURE_DEPENDS} ${ARGN})
foreach(HEADER ${HEADERS}) foreach(HEADER ${HEADERS})
file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER})
......
#include <migraphx/any_ptr.hpp>
#include <test.hpp>
TEST_CASE(test_int_id)
{
int i = 1;
migraphx::any_ptr p = &i;
EXPECT(p.get<int*>() == &i);
EXPECT(p.get(migraphx::get_type_name(i)) == &i);
EXPECT(p.unsafe_get() == &i);
EXPECT(test::throws([&] { p.get<float*>(); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(&i)); }));
}
TEST_CASE(test_int_name)
{
int i = 1;
void* vp = &i;
migraphx::any_ptr p{vp, migraphx::get_type_name(i)};
EXPECT(p.get<int*>() == &i);
EXPECT(p.get(migraphx::get_type_name(i)) == &i);
EXPECT(p.unsafe_get() == &i);
EXPECT(test::throws([&] { p.get<float*>(); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(&i)); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(float{})); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -10,7 +10,9 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -10,7 +10,9 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_dependencies(check ${NAME}) add_dependencies(check ${NAME})
endfunction() endfunction()
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(shape_assign)
{
auto s1_cpp = migraphx::shape{migraphx_shape_float_type, {1, 3}};
std::vector<size_t> lens{2, 3};
// handle ptr is const, workaround to construct shape using C API
migraphx_shape_t s2;
migraphx_shape_create(&s2, migraphx_shape_float_type, lens.data(), lens.size());
auto s2_cpp = migraphx::shape(s2, migraphx::own{});
CHECK(bool{s1_cpp != s2_cpp});
// use C++ API for assignment
s1_cpp.assign_to_handle(s2);
CHECK(bool{s1_cpp == s2_cpp});
auto s3_cpp = migraphx::shape{migraphx_shape_float_type, lens};
// use C API for assignment
migraphx_shape_assign_to(s2, s3_cpp.get_handle_ptr());
CHECK(bool{s2_cpp == s3_cpp});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/migraphx.hpp>
#include <migraphx/rank.hpp>
#include "test.hpp"
template <class T>
std::false_type has_handle(migraphx::rank<0>, T)
{
return {};
}
template <class T>
auto has_handle(migraphx::rank<1>, T*) -> decltype(migraphx::as_handle<T>{}, std::true_type{})
{
return {};
}
TEST_CASE(shape)
{
static_assert(std::is_same<migraphx::as_handle<migraphx_shape>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<migraphx_shape_t>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<const_migraphx_shape_t>, migraphx::shape>{},
"Failed");
}
TEST_CASE(non_handle)
{
int i = 0;
EXPECT(bool{has_handle(migraphx::rank<1>{}, migraphx_shape_t{})});
EXPECT(bool{not has_handle(migraphx::rank<1>{}, &i)});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -135,4 +135,52 @@ TEST_CASE(two_transpose_gather) ...@@ -135,4 +135,52 @@ TEST_CASE(two_transpose_gather)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(standard_reshape)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m1.add_instruction(migraphx::make_op("add"), data, data);
auto r = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), add);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
TEST_CASE(dead_instruction)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data);
auto r = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}),
data);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data);
auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}),
data);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(gpu_context) TEST_CASE(gpu_context_serialize)
{ {
migraphx::context ctx = migraphx::gpu::context{0, 3}; migraphx::context ctx = migraphx::gpu::context{0, 3};
...@@ -25,4 +25,10 @@ TEST_CASE(gpu_context) ...@@ -25,4 +25,10 @@ TEST_CASE(gpu_context)
EXPECT(v == v1); EXPECT(v == v1);
} }
TEST_CASE(context_queue)
{
migraphx::context ctx = migraphx::gpu::context{0, 3};
EXPECT(ctx.get_queue().get<hipStream_t>() != nullptr);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1513,6 +1513,55 @@ TEST_CASE(test_unsqueeze_scalar_tensor2) ...@@ -1513,6 +1513,55 @@ TEST_CASE(test_unsqueeze_scalar_tensor2)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s); throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
} }
TEST_CASE(test_unsqueeze_transpose)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 3}, {12, 1, 4}};
migraphx::shape s2{migraphx::shape::float_type, {4, 4, 1, 3}, {12, 1, 1, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_multibroadcast)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 1, 0}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_slice)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 36, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_axis_zero)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 3, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_unsqueeze_axis_last)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-1}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_1)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 3, 4, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, -1}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_2)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape s2{migraphx::shape::float_type, {1, 1, 2, 3, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), s1);
}
TEST_CASE(transpose_shape) TEST_CASE(transpose_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
......
...@@ -57,14 +57,15 @@ TEST_CASE(squeeze_transpose_test) ...@@ -57,14 +57,15 @@ TEST_CASE(squeeze_transpose_test)
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 3, 0, 4}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 3, 0, 4}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_trans); mm->add_instruction(migraphx::make_op("squeeze"), l0_trans);
auto p_uncompiled = p; auto p_uncompiled = p;
// contiguous is required to read the values in standard shaped order
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
// contiguous is required to read the values in standard shaped order
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 3}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 3}});
EXPECT(result == std_expected_result); EXPECT(result == expected_result);
} }
TEST_CASE(squeeze_multibroadcast_test) TEST_CASE(squeeze_multibroadcast_test)
...@@ -77,13 +78,14 @@ TEST_CASE(squeeze_multibroadcast_test) ...@@ -77,13 +78,14 @@ TEST_CASE(squeeze_multibroadcast_test)
migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 3, 4, 3}}}), l0); migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 3, 4, 3}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_brcst); mm->add_instruction(migraphx::make_op("squeeze"), l0_brcst);
auto p_uncompiled = p; auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 3, 4, 3}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 3, 4, 3}});
EXPECT(result == std_expected_result); EXPECT(result == expected_result);
} }
TEST_CASE(squeeze_slice_test) TEST_CASE(squeeze_slice_test)
...@@ -96,13 +98,74 @@ TEST_CASE(squeeze_slice_test) ...@@ -96,13 +98,74 @@ TEST_CASE(squeeze_slice_test)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), l0); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), l0);
mm->add_instruction(migraphx::make_op("squeeze"), l0_slice); mm->add_instruction(migraphx::make_op("squeeze"), l0_slice);
auto p_uncompiled = p; auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
auto tr_op = migraphx::make_op("contiguous");
auto std_expected_result = tr_op.compute(result.get_shape(), {expected_result});
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 3}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 3}});
EXPECT(result == std_expected_result); EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_transpose_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_trans);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 1, 3}});
EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_multibroadcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_brcst =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 3, 3}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_brcst);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 4, 1, 3, 3}});
EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_slice_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4, 4}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {2}}, {"ends", {3}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l0_slice);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {2, 1, 3, 4, 1}});
EXPECT(result == expected_result);
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
file(GLOB VERIFY_TESTS *.cpp) file(GLOB VERIFY_TESTS ${CONFIGURE_DEPENDS} *.cpp)
add_executable(test_verify ${VERIFY_TESTS}) add_executable(test_verify ${VERIFY_TESTS})
add_dependencies(tests test_verify) add_dependencies(tests test_verify)
......
...@@ -15,7 +15,7 @@ c_api_body_preamble: List[str] = [] ...@@ -15,7 +15,7 @@ c_api_body_preamble: List[str] = []
cpp_header_preamble: List[str] = [] cpp_header_preamble: List[str] = []
def bad_param_error(msg): def bad_param_error(msg: str):
return 'throw std::runtime_error("{}")'.format(msg) return 'throw std::runtime_error("{}")'.format(msg)
...@@ -89,7 +89,7 @@ class Type: ...@@ -89,7 +89,7 @@ class Type:
else: else:
return t.remove_const() return t.remove_const()
def const_compatible(self, t): def const_compatible(self, t: 'Type'):
if t.is_const(): if t.is_const():
return self.add_const() return self.add_const()
return self return self
...@@ -720,6 +720,7 @@ def add_handle(name: str, ...@@ -720,6 +720,7 @@ def add_handle(name: str,
destroy: Optional[str] = None, destroy: Optional[str] = None,
ref: Optional[bool] = None) -> None: ref: Optional[bool] = None) -> None:
opaque_type = ctype + '_t' opaque_type = ctype + '_t'
const_opaque_type = 'const_' + opaque_type
def handle_wrap(p): def handle_wrap(p):
t = Type(opaque_type) t = Type(opaque_type)
...@@ -747,6 +748,9 @@ def add_handle(name: str, ...@@ -747,6 +748,9 @@ def add_handle(name: str,
add_function(destroy or ctype + '_' + 'destroy', add_function(destroy or ctype + '_' + 'destroy',
params({name: opaque_type}), params({name: opaque_type}),
fname='destroy') fname='destroy')
add_function(ctype + '_' + 'assign_to',
params(output=opaque_type, input=const_opaque_type),
invoke='*output = *input')
add_handle_preamble() add_handle_preamble()
c_header_preamble.append(handle_typedef.substitute(locals())) c_header_preamble.append(handle_typedef.substitute(locals()))
c_api_body_preamble.append(handle_definition.substitute(locals())) c_api_body_preamble.append(handle_definition.substitute(locals()))
......
...@@ -14,4 +14,6 @@ function api { ...@@ -14,4 +14,6 @@ function api {
} }
api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h
echo "Finished generating header migraphx.h"
api $DIR/api/api.cpp $SRC_DIR/api/api.cpp api $DIR/api/api.cpp $SRC_DIR/api/api.cpp
echo "Finished generating source api.cpp "
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -33,12 +34,21 @@ value to_value_context(const T&) ...@@ -33,12 +34,21 @@ value to_value_context(const T&)
} }
template <class T> template <class T>
void from_value_context(T&, const value&){} void from_value_context(T&, const value&)
{
}
template <class T>
any_ptr get_queue_context(T&)
{
return {};
}
<% <%
interface('context', interface('context',
virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), virtual('to_value', returns = 'value', const = True, default = 'to_value_context'),
virtual('from_value', v = 'const value&', default = 'from_value_context'), virtual('from_value', v = 'const value&', default = 'from_value_context'),
virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'),
virtual('finish', returns = 'void', const = True)) %> virtual('finish', returns = 'void', const = True)) %>
inline void migraphx_to_value(value& v, const context& ctx) inline void migraphx_to_value(value& v, const context& ctx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment