"src/targets/vscode:/vscode.git/clone" did not exist on "6f99b12f8a293901c8bf705544e1813a537e9dbc"
Unverified Commit 36656030 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Expose get_queue method for context in API (#1161)

* expose get_queue method
parent 31906785
......@@ -401,7 +401,8 @@ extern "C" struct migraphx_instruction;
struct migraphx_instruction
{
template <class... Ts>
migraphx_instruction(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_instruction(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::instruction_ref object;
......@@ -411,7 +412,8 @@ extern "C" struct migraphx_instructions;
struct migraphx_instructions
{
template <class... Ts>
migraphx_instructions(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_instructions(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::instruction_ref> object;
......@@ -421,7 +423,8 @@ extern "C" struct migraphx_modules;
struct migraphx_modules
{
template <class... Ts>
migraphx_modules(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_modules(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::module*> object;
......@@ -1691,6 +1694,16 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont
return api_error_result;
}
extern "C" migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context)
{
auto api_error_result = migraphx::try_([&] {
if(context == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter context: Null pointer");
*out = (context->object).get_queue().unsafe_get();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op)
{
......
......@@ -433,6 +433,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_status migraphx_context_finish(const_migraphx_context_t context);
migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context);
migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
......
......@@ -777,6 +777,15 @@ struct context
migraphx_context_t ctx;
void finish() const { call(&migraphx_context_finish, ctx); }
template <class T>
T get_queue()
{
void* out;
call(&migraphx_context_get_queue, &out, ctx);
// TODO: check type here
return reinterpret_cast<T>(out);
}
};
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
......
......@@ -403,6 +403,7 @@ api.add_function('migraphx_quantize_int8',
@auto_handle(ref=True)
def context(h):
h.method('finish', const=True)
h.method('get_queue', returns='void*', fname='get_queue().unsafe_get')
@api.interface('migraphx_experimental_custom_op',
......
function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
......@@ -19,7 +18,8 @@ add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR})
# GPU-based tests
if(MIGRAPHX_ENABLE_GPU)
add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
# GPU-based tests
target_link_libraries(test_api_gpu migraphx_gpu)
endif()
#include <numeric>
#include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
......@@ -38,6 +39,7 @@ TEST_CASE(load_and_run_ctx)
pp.add(name, migraphx::argument::generate(param_shapes[name]));
}
auto ctx = p.experimental_get_context();
EXPECT(ctx.get_queue<hipStream_t>() != nullptr);
p.eval(pp);
ctx.finish();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment