"vscode:/vscode.git/clone" did not exist on "a73eb8cd200d3f3808e16a727c06b9fef4c09828"
Unverified Commit 64e79a94 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Expose APIs for the MIGraphX program (#1093)

API includes following
create_module,
get_main_module
add_instruction without module args
add_instruction with module args
add_parameter
add_return
parent 31e63991
......@@ -4,6 +4,7 @@
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
......@@ -72,6 +73,23 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
}
template <class T>
auto to_obj_vector(const T* x, std::size_t n)
{
std::vector<decltype((*x)->object)> result;
std::transform(x, x + n, std::back_inserter(result), [&](auto&& y) { return y->object; });
return result;
}
template <class T, class U>
auto to_objptr_vector(const U* x, std::size_t n)
{
std::vector<T> result;
std::transform(
x, x + n, std::back_inserter(result), [&](auto&& y) { return std::addressof(y->object); });
return result;
}
target get_target(const std::string& name) { return make_target(name); }
void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; }
......@@ -291,6 +309,36 @@ struct migraphx_shapes
std::vector<migraphx::shape> object;
};
extern "C" struct migraphx_instruction;
struct migraphx_instruction
{
template <class... Ts>
migraphx_instruction(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::instruction_ref object;
};
extern "C" struct migraphx_instructions;
struct migraphx_instructions
{
template <class... Ts>
migraphx_instructions(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
std::vector<migraphx::instruction_ref> object;
};
extern "C" struct migraphx_modules;
struct migraphx_modules
{
template <class... Ts>
migraphx_modules(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
std::vector<migraphx::module*> object;
};
extern "C" struct migraphx_module;
struct migraphx_module
{
......@@ -774,6 +822,77 @@ migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_
return api_error_result;
}
extern "C" migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction)
{
auto api_error_result = migraphx::try_([&] { destroy((instruction)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output,
const_migraphx_instruction_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions)
{
auto api_error_result = migraphx::try_([&] { destroy((instructions)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output,
const_migraphx_instructions_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr,
size_t size)
{
auto api_error_result = migraphx::try_([&] {
*instructions =
object_cast<migraphx_instructions_t>(allocate<std::vector<migraphx::instruction_ref>>(
migraphx::to_obj_vector<const_migraphx_instruction_t>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_modules_destroy(migraphx_modules_t modules)
{
auto api_error_result = migraphx::try_([&] { destroy((modules)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size)
{
auto api_error_result = migraphx::try_([&] {
*modules = object_cast<migraphx_modules_t>(allocate<std::vector<migraphx::module*>>(
migraphx::to_objptr_vector<migraphx::module*>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_create(migraphx_module_t* module, char* name)
{
auto api_error_result = migraphx::try_([&] {
if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
*module = object_cast<migraphx_module_t>(allocate<migraphx::module>((std::string(name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{
auto api_error_result = migraphx::try_([&] {
......@@ -784,6 +903,76 @@ extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter op: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_instruction((op->object), (args->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args,
migraphx_modules_t module_refs)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter op: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
if(module_refs == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module_refs: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_instruction((op->object), (args->object), (module_refs->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module,
const char* name,
const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_parameter((name), (shape->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_instructions_t args)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
*out = allocate<migraphx_instruction_t>((module->object).add_return((args->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] { destroy((program)); });
......@@ -797,6 +986,13 @@ extern "C" migraphx_status migraphx_program_assign_to(migraphx_program_t output,
return api_error_result;
}
extern "C" migraphx_status migraphx_program_create(migraphx_program_t* program)
{
auto api_error_result = migraphx::try_(
[&] { *program = object_cast<migraphx_program_t>(allocate<migraphx::program>()); });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program)
{
......@@ -808,6 +1004,17 @@ extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* o
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_create_module(migraphx_module_t* out, migraphx_program_t program, const char* name)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).create_module((name)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options_t options)
......
......@@ -64,6 +64,15 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t;
typedef struct migraphx_shapes* migraphx_shapes_t;
typedef const struct migraphx_shapes* const_migraphx_shapes_t;
typedef struct migraphx_instruction* migraphx_instruction_t;
typedef const struct migraphx_instruction* const_migraphx_instruction_t;
typedef struct migraphx_instructions* migraphx_instructions_t;
typedef const struct migraphx_instructions* const_migraphx_instructions_t;
typedef struct migraphx_modules* migraphx_modules_t;
typedef const struct migraphx_modules* const_migraphx_modules_t;
typedef struct migraphx_module* migraphx_module_t;
typedef const struct migraphx_module* const_migraphx_module_t;
......@@ -201,16 +210,66 @@ migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx);
migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction);
migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output,
const_migraphx_instruction_t input);
migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions);
migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output,
const_migraphx_instructions_t input);
migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr,
size_t size);
migraphx_status migraphx_modules_destroy(migraphx_modules_t modules);
migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input);
migraphx_status
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size);
migraphx_status migraphx_module_create(migraphx_module_t* module, char* name);
migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args);
migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args,
migraphx_modules_t module_refs);
migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module,
const char* name,
const_migraphx_shape_t shape);
migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_instructions_t args);
migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input);
migraphx_status migraphx_program_create(migraphx_program_t* program);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program);
migraphx_status migraphx_program_create_module(migraphx_module_t* out,
migraphx_program_t program,
const char* name);
migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options_t options);
......
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <memory>
#include <exception>
......@@ -523,12 +525,109 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
};
};
struct operation : MIGRAPHX_HANDLE_BASE(operation)
{
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); }
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{
this->make_handle(&migraphx_operation_create, name, attributes, xs...);
}
std::string name()
{
std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return {out_name.data()};
}
};
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{
instruction(migraphx_instruction* p, own) { this->set_handle(p, own{}); }
};
struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{
instructions(migraphx_instructions* p, own) { this->set_handle(p, own{}); }
instructions(migraphx_instructions* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
instructions(Ts... xs)
{
std::array<const_migraphx_instruction_t, sizeof...(Ts)> a{xs.get_handle_ptr()...};
this->make_handle(&migraphx_instructions_create, a.data(), a.size());
}
};
struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules)
{
modules(migraphx_modules* p, own) { this->set_handle(p, own{}); }
modules(migraphx_modules* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
modules(Ts... xs)
{
std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.mm...};
this->make_handle(&migraphx_modules_create, a.data(), a.size());
}
};
struct module
{
migraphx_module_t mm;
module(const migraphx_module_t& m) : mm(m) {}
void print() const { call(&migraphx_module_print, mm); }
instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args)
{
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction,
&op_ins,
mm,
op.get_handle_ptr(),
args.get_handle_ptr());
return instruction(op_ins, own{});
}
instruction add_instruction(const migraphx::operation& op,
const migraphx::instructions& args,
const migraphx::modules& module_args)
{
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction_with_mod_args,
&op_ins,
mm,
op.get_handle_ptr(),
args.get_handle_ptr(),
module_args.get_handle_ptr());
return instruction(op_ins, own{});
}
instruction add_parameter(const std::string& name, shape s)
{
migraphx_instruction_t param_ins;
call(&migraphx_module_add_parameter, &param_ins, mm, name.c_str(), s.get_handle_ptr());
return instruction(param_ins, own{});
}
instruction add_return(const migraphx::instructions& args)
{
migraphx_instruction_t ret_ins;
call(&migraphx_module_add_return, &ret_ins, mm, args.get_handle_ptr());
return instruction(ret_ins, own{});
}
};
struct context
......@@ -564,7 +663,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
/// A program represents the all computation graphs to be compiled and executed
struct program : MIGRAPHX_HANDLE_BASE(program)
{
program() {}
program() { this->make_handle(&migraphx_program_create); }
program(migraphx_program* p, own) { this->set_handle(p, own{}); }
......@@ -641,27 +740,14 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return context{ctx};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
};
struct operation : MIGRAPHX_HANDLE_BASE(operation)
{
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); }
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
module create_module(const std::string& name)
{
this->make_handle(&migraphx_operation_create, name, attributes, xs...);
migraphx_module_t p_modu;
call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data());
return module{p_modu};
}
std::string name()
{
std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return {out_name.data()};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
};
// options for migraphx file format options
......
......@@ -178,14 +178,55 @@ def shapes(h):
returns='const migraphx::shape&')
@api.handle('migraphx_instruction', 'migraphx::instruction_ref')
def instruction(h):
pass
@api.handle('migraphx_instructions', 'std::vector<migraphx::instruction_ref>')
def instructions(h):
h.constructor(
'create',
api.params(ptr='const_migraphx_instruction_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_instruction_t>')
@api.handle('migraphx_modules', 'std::vector<migraphx::module*>')
def modules(h):
h.constructor('create',
api.params(ptr='migraphx_module_t*', size='size_t'),
fname='migraphx::to_objptr_vector<migraphx::module*>')
@auto_handle(ref=True)
def module(h):
h.constructor('create', api.params(name='std::string'))
h.method('print', invoke='migraphx::print_module($@)', const=True)
h.method('add_instruction',
api.params(op='migraphx::operation',
args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
h.method('add_instruction_with_mod_args',
api.params(op='migraphx::operation',
args='std::vector<migraphx::instruction_ref>',
module_refs='std::vector<migraphx::module*>'),
fname='add_instruction',
returns='migraphx::instruction_ref')
h.method('add_parameter',
api.params(name='const char*', shape='const migraphx::shape&'),
returns='migraphx::instruction_ref')
h.method('add_return',
api.params(args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
@auto_handle()
def program(h):
h.constructor('create')
h.method('get_main_module', returns='migraphx::module*')
h.method('create_module',
api.params(name='const char*'),
returns='migraphx::module*')
h.method(
'compile',
api.params(target='migraphx::target',
......
......@@ -13,6 +13,7 @@ endfunction()
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
#include <numeric>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(add_op)
{
migraphx::program p;
migraphx::module m = p.get_main_module();
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
auto x = m.add_parameter("x", param_shape);
auto y = m.add_parameter("y", param_shape);
auto add_op = migraphx::operation("add");
auto r = m.add_instruction(add_op, {x, y});
m.add_return({r});
// run on ref target
p.compile(migraphx::target("ref"));
migraphx::program_parameters pp;
std::vector<float> x_data(9, 1);
std::vector<float> y_data(9, -1);
pp.add("x", migraphx::argument(param_shape, x_data.data()));
pp.add("y", migraphx::argument(param_shape, y_data.data()));
auto outputs = p.eval(pp);
auto output = outputs[0];
std::vector<float> expected(9, 0);
CHECK(bool(output == migraphx::argument(param_shape, expected.data())));
}
TEST_CASE(if_then_else_op)
{
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
migraphx::shape cond_s{migraphx_shape_bool_type};
auto create_program = [&]() {
migraphx::program p;
auto mm = p.get_main_module();
auto cond = mm.add_parameter("cond", cond_s);
auto x = mm.add_parameter("x", param_shape);
auto y = mm.add_parameter("y", param_shape);
auto then_mod = p.create_module("If_0_if");
auto x_identity = then_mod.add_instruction(migraphx::operation("identity"), {x});
then_mod.add_return({x_identity});
auto else_mod = p.create_module("If_0_else");
auto y_identity = else_mod.add_instruction(migraphx::operation("identity"), {y});
else_mod.add_return({y_identity});
auto if_ins = mm.add_instruction(migraphx::operation("if"), {cond}, {then_mod, else_mod});
auto get_tuple_op = migraphx::operation("get_tuple_elem", "{index: 0}");
auto ret = mm.add_instruction(get_tuple_op, {if_ins});
mm.add_return({ret});
return p;
};
std::vector<float> x_data(9, 1);
std::vector<float> y_data(9, -1);
auto x_arg = migraphx::argument(param_shape, x_data.data());
auto y_arg = migraphx::argument(param_shape, y_data.data());
auto run_prog = [&](bool cond) {
auto p = create_program();
p.compile(migraphx::target("ref"));
auto outputs =
p.eval({{"cond", migraphx::argument(cond_s, &cond)}, {"x", x_arg}, {"y", y_arg}});
return outputs;
};
// then branch
auto then_res = run_prog(true);
CHECK(bool{then_res[0] == x_arg});
// else branch
auto else_res = run_prog(false);
CHECK(bool{else_res[0] == y_arg});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -4,6 +4,7 @@
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
......@@ -72,6 +73,23 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
}
template <class T>
auto to_obj_vector(const T* x, std::size_t n)
{
std::vector<decltype((*x)->object)> result;
std::transform(x, x + n, std::back_inserter(result), [&](auto&& y) { return y->object; });
return result;
}
template <class T, class U>
auto to_objptr_vector(const U* x, std::size_t n)
{
std::vector<T> result;
std::transform(
x, x + n, std::back_inserter(result), [&](auto&& y) { return std::addressof(y->object); });
return result;
}
target get_target(const std::string& name) { return make_target(name); }
void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment