Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
......@@ -14,31 +14,31 @@ bool happens_before(const std::vector<std::size_t>& e1, const std::vector<std::s
not std::equal(e1.begin(), e1.end(), e2.begin(), e2.end(), std::greater_equal<>{});
}
std::vector<stream_race> analyze_streams(const module& p, const stream_model& m)
std::vector<stream_race> analyze_streams(const module& m, const stream_model& strmm)
{
using vector_clock = std::vector<std::size_t>;
std::vector<stream_race> races;
auto nstream = m.get_nstream();
auto nstream = strmm.get_nstream();
std::vector<vector_clock> vclock(nstream, vector_clock(nstream));
std::unordered_map<instruction_ref, vector_clock> timestamp;
std::unordered_map<std::size_t, vector_clock> events;
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
if(not m.has_stream(ins))
if(not strmm.has_stream(ins))
continue;
std::size_t s = m.get_stream(ins);
std::size_t s = strmm.get_stream(ins);
assert(s < nstream);
assert(vclock.size() == nstream);
assert(vclock[s].size() == nstream);
if(m.is_record(ins))
if(strmm.is_record(ins))
{
vclock[s][s]++;
auto event = m.get_event_id(ins);
auto event = strmm.get_event_id(ins);
events[event] = vclock[s];
}
else if(m.is_wait(ins))
else if(strmm.is_wait(ins))
{
auto event = m.get_event_id(ins);
auto event = strmm.get_event_id(ins);
if(not contains(events, event))
MIGRAPHX_THROW("Event is waited on before being recorded: " +
std::to_string(event));
......@@ -57,21 +57,21 @@ std::vector<stream_race> analyze_streams(const module& p, const stream_model& m)
}
timestamp[ins] = vclock[s];
}
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
if(not m.has_stream(ins))
if(not strmm.has_stream(ins))
continue;
if(ins->inputs().empty())
continue;
std::size_t s = m.get_stream(ins);
std::size_t s = strmm.get_stream(ins);
// Find inputs from different streams
std::vector<instruction_ref> inputs;
fix([&](auto self, auto start) {
for(auto input : start->inputs())
{
if(not m.has_stream(input))
if(not strmm.has_stream(input))
self(input);
else if(m.get_stream(input) != s)
else if(strmm.get_stream(input) != s)
inputs.push_back(input);
}
})(ins);
......
......@@ -3,7 +3,7 @@ add_library(migraphx_c
api.cpp
)
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 2.0)
rocm_set_soversion(migraphx_c 3.0)
rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
......
......@@ -4,15 +4,18 @@
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
......@@ -71,28 +74,41 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
}
target get_target(const std::string& name) { return make_target(name); }
migraphx::compile_options to_compile_options(const migraphx_compile_options& options)
template <class T>
auto to_obj_vector(const T* x, std::size_t n)
{
migraphx::compile_options result{};
result.offload_copy = options.offload_copy;
result.fast_math = options.fast_math;
std::vector<decltype((*x)->object)> result;
std::transform(x, x + n, std::back_inserter(result), [&](auto&& y) { return y->object; });
return result;
}
migraphx::file_options to_file_options(const migraphx_file_options& options)
template <class T, class U>
auto to_objptr_vector(const U* x, std::size_t n)
{
migraphx::file_options result{};
result.format = options.format;
std::vector<T> result;
std::transform(
x, x + n, std::back_inserter(result), [&](auto&& y) { return std::addressof(y->object); });
return result;
}
target get_target(const std::string& name) { return make_target(name); }
void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; }
void set_fast_math(compile_options& options, bool value) { options.fast_math = value; }
void set_file_format(file_options& options, const char* format) { options.format = format; }
void set_default_dim_value(onnx_options& options, size_t value)
{
options.default_dim_value = value;
}
void set_default_loop_iterations(onnx_options& options, int64_t value)
{
options.max_loop_iterations = value;
}
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
......@@ -159,18 +175,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
}
operation create_op(const char* name, const char* attributes)
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
operation create_op(const char* name, const char* attributes, va_list vlist)
{
std::string sattributes = attributes == nullptr ? "" : attributes;
std::vector<char> buffer(sattributes.size() * 2);
std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist);
value v = value::object{};
if(attributes != nullptr)
{
v = from_json_string(convert_to_json(std::string(attributes)));
v = from_json_string(convert_to_json(std::string(buffer.data())));
}
auto op = make_op(name, v);
return op;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
template <class T>
bool equal(const T& x, const T& y)
{
......@@ -185,6 +213,41 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; }
struct experimental_custom_op
{
std::string name;
experimental_custom_op() = default;
experimental_custom_op(std::string pname) : name(std::move(pname)) {}
};
template <class CustomOp>
struct custom_operation
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
CustomOp op;
std::string name() const { return op.xobject.name; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); }
};
template <class CustomOp>
void register_custom_op(const CustomOp& op)
{
register_op(custom_operation<CustomOp>{op});
}
migraphx::context get_context(const program& p) { return p.get_context(); }
} // namespace migraphx
template <class T, class U, class Target = std::remove_pointer_t<T>>
......@@ -209,12 +272,60 @@ void destroy(T* x)
{
delete x; // NOLINT
}
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
{
manage_generic_ptr() = default;
manage_generic_ptr(std::nullptr_t) {}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
{
other.data = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
}
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
}
~manage_generic_ptr()
{
if(data != nullptr)
deleter(data);
}
void* data = nullptr;
C copier = nullptr;
D deleter = nullptr;
};
extern "C" struct migraphx_shape;
struct migraphx_shape
{
template <class... Ts>
migraphx_shape(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_shape(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::shape object;
......@@ -224,7 +335,8 @@ extern "C" struct migraphx_argument;
struct migraphx_argument
{
template <class... Ts>
migraphx_argument(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_argument(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::argument object;
......@@ -234,7 +346,8 @@ extern "C" struct migraphx_target;
struct migraphx_target
{
template <class... Ts>
migraphx_target(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_target(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::target object;
......@@ -244,7 +357,8 @@ extern "C" struct migraphx_program_parameter_shapes;
struct migraphx_program_parameter_shapes
{
template <class... Ts>
migraphx_program_parameter_shapes(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_program_parameter_shapes(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::unordered_map<std::string, migraphx::shape> object;
......@@ -254,7 +368,8 @@ extern "C" struct migraphx_program_parameters;
struct migraphx_program_parameters
{
template <class... Ts>
migraphx_program_parameters(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_program_parameters(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::unordered_map<std::string, migraphx::argument> object;
......@@ -264,7 +379,8 @@ extern "C" struct migraphx_arguments;
struct migraphx_arguments
{
template <class... Ts>
migraphx_arguments(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_arguments(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::argument> object;
......@@ -274,17 +390,52 @@ extern "C" struct migraphx_shapes;
struct migraphx_shapes
{
template <class... Ts>
migraphx_shapes(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_shapes(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::shape> object;
};
extern "C" struct migraphx_instruction;
struct migraphx_instruction
{
template <class... Ts>
migraphx_instruction(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::instruction_ref object;
};
extern "C" struct migraphx_instructions;
struct migraphx_instructions
{
template <class... Ts>
migraphx_instructions(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::instruction_ref> object;
};
extern "C" struct migraphx_modules;
struct migraphx_modules
{
template <class... Ts>
migraphx_modules(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::module*> object;
};
extern "C" struct migraphx_module;
struct migraphx_module
{
template <class... Ts>
migraphx_module(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_module(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::module object;
......@@ -294,7 +445,8 @@ extern "C" struct migraphx_program;
struct migraphx_program
{
template <class... Ts>
migraphx_program(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_program(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::program object;
......@@ -304,7 +456,8 @@ extern "C" struct migraphx_operation;
struct migraphx_operation
{
template <class... Ts>
migraphx_operation(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_operation(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::operation object;
......@@ -314,17 +467,41 @@ extern "C" struct migraphx_onnx_options;
struct migraphx_onnx_options
{
template <class... Ts>
migraphx_onnx_options(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_onnx_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::onnx_options object;
};
extern "C" struct migraphx_file_options;
struct migraphx_file_options
{
template <class... Ts>
migraphx_file_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::file_options object;
};
extern "C" struct migraphx_compile_options;
struct migraphx_compile_options
{
template <class... Ts>
migraphx_compile_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::compile_options object;
};
extern "C" struct migraphx_tf_options;
struct migraphx_tf_options
{
template <class... Ts>
migraphx_tf_options(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_tf_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::tf_options object;
......@@ -334,7 +511,8 @@ extern "C" struct migraphx_quantize_op_names;
struct migraphx_quantize_op_names
{
template <class... Ts>
migraphx_quantize_op_names(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_quantize_op_names(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<std::string> object;
......@@ -344,15 +522,63 @@ extern "C" struct migraphx_quantize_int8_options;
struct migraphx_quantize_int8_options
{
template <class... Ts>
migraphx_quantize_int8_options(Ts&&... xs) : object(std::forward<Ts>(xs)...)
migraphx_quantize_int8_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::quantize_int8_options object;
};
extern "C" struct migraphx_context;
struct migraphx_context
{
template <class... Ts>
migraphx_context(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::context object;
};
extern "C" struct migraphx_experimental_custom_op;
struct migraphx_experimental_custom_op
{
template <class... Ts>
migraphx_experimental_custom_op(void* p,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
{
}
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
object_ptr = nullptr;
migraphx::experimental_custom_op xobject;
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing.");
auto api_error_result =
compute_shape_f(&out, object_ptr.data, object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute_shape.");
return (&out)->object;
}
};
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{
return migraphx::try_([&] { destroy((shape)); });
auto api_error_result = migraphx::try_([&] { destroy((shape)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_assign_to(migraphx_shape_t output,
const_migraphx_shape_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
......@@ -360,13 +586,14 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
size_t* lengths,
size_t lengths_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)),
(std::vector<size_t>(lengths, lengths + lengths_size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
......@@ -376,7 +603,7 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
size_t* strides,
size_t strides_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
if(strides == nullptr and strides_size != 0)
......@@ -386,21 +613,23 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
(std::vector<size_t>(lengths, lengths + lengths_size)),
(std::vector<size_t>(strides, strides + strides_size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type))));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
......@@ -409,12 +638,13 @@ migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data();
*out_size = api_result.size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
......@@ -423,127 +653,163 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data();
*out_size = api_result.size();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = migraphx::to_shape_type((shape->object).type());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).bytes();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((shape->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{
return migraphx::try_([&] { destroy((argument)); });
auto api_error_result = migraphx::try_([&] { destroy((argument)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*argument = object_cast<migraphx_argument_t>(
allocate<migraphx::argument>((shape->object), (buffer)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((argument->object).get_shape()));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = (argument->object).data();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((argument->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_argument_t>(migraphx::generate_argument((s->object), (seed)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target)
{
return migraphx::try_([&] { destroy((target)); });
auto api_error_result = migraphx::try_([&] { destroy((target)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_target_assign_to(migraphx_target_t output,
const_migraphx_target_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*target = object_cast<migraphx_target_t>(
allocate<migraphx::target>(migraphx::get_target((name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] { destroy((program_parameter_shapes)); });
auto api_error_result = migraphx::try_([&] { destroy((program_parameter_shapes)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output,
const_migraphx_program_parameter_shapes_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameter_shapes_size(size_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer");
*out = (program_parameter_shapes->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
......@@ -551,19 +817,20 @@ migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes,
const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer");
*out =
object_cast<const_migraphx_shape_t>(&((program_parameter_shapes->object).at((name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_parameter_shapes_names(
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(program_parameter_shapes == nullptr)
......@@ -572,21 +839,32 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_names(
auto&& api_result = migraphx::get_names((program_parameter_shapes->object));
std::copy(api_result.begin(), api_result.end(), out);
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters)
{
return migraphx::try_([&] { destroy((program_parameters)); });
auto api_error_result = migraphx::try_([&] { destroy((program_parameters)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_assign_to(migraphx_program_parameters_t output,
const_migraphx_program_parameters_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*program_parameters = object_cast<migraphx_program_parameters_t>(
allocate<std::unordered_map<std::string, migraphx::argument>>());
});
return api_error_result;
}
extern "C" migraphx_status
......@@ -594,7 +872,7 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
const char* name,
const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameters == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameters: Null pointer");
......@@ -602,179 +880,416 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
(program_parameters->object)[(name)] = (argument->object);
});
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments)
{
return migraphx::try_([&] { destroy((arguments)); });
auto api_error_result = migraphx::try_([&] { destroy((arguments)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = (arguments->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = object_cast<const_migraphx_argument_t>(&((arguments->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes)
{
return migraphx::try_([&] { destroy((shapes)); });
auto api_error_result = migraphx::try_([&] { destroy((shapes)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output,
const_migraphx_shapes_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = (shapes->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((shapes->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction)
{
auto api_error_result = migraphx::try_([&] { destroy((instruction)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output,
const_migraphx_instruction_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions)
{
auto api_error_result = migraphx::try_([&] { destroy((instructions)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output,
const_migraphx_instructions_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr,
size_t size)
{
auto api_error_result = migraphx::try_([&] {
*instructions =
object_cast<migraphx_instructions_t>(allocate<std::vector<migraphx::instruction_ref>>(
migraphx::to_obj_vector<const_migraphx_instruction_t>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_modules_destroy(migraphx_modules_t modules)
{
auto api_error_result = migraphx::try_([&] { destroy((modules)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size)
{
auto api_error_result = migraphx::try_([&] {
*modules = object_cast<migraphx_modules_t>(allocate<std::vector<migraphx::module*>>(
migraphx::to_objptr_vector<migraphx::module*>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_create(migraphx_module_t* module, char* name)
{
auto api_error_result = migraphx::try_([&] {
if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
*module = object_cast<migraphx_module_t>(allocate<migraphx::module>((std::string(name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
migraphx::print_module((module->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter op: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_instruction((op->object), (args->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args,
migraphx_modules_t module_refs)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter op: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
if(module_refs == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module_refs: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_instruction((op->object), (args->object), (module_refs->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t shape,
const char* buffer)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_literal((shape->object), (buffer)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module,
const char* name,
const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_parameter((name), (shape->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_instructions_t args)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
*out = allocate<migraphx_instruction_t>((module->object).add_return((args->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{
return migraphx::try_([&] { destroy((program)); });
auto api_error_result = migraphx::try_([&] { destroy((program)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_create(migraphx_program_t* program)
{
auto api_error_result = migraphx::try_(
[&] { *program = object_cast<migraphx_program_t>(allocate<migraphx::program>()); });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).get_main_module());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_create_module(migraphx_module_t* out, migraphx_program_t program, const char* name)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).create_module((name)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options* options)
migraphx_compile_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(target == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter target: Null pointer");
(program->object)
.compile((target->object),
(options == nullptr ? migraphx::compile_options{}
: migraphx::to_compile_options(*options)));
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
(program->object).compile((target->object), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out =
allocate<migraphx_program_parameter_shapes_t>((program->object).get_parameter_shapes());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = allocate<migraphx_shapes_t>(migraphx::get_output_shapes((program->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
migraphx::print_program((program->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
(program->object).sort();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program,
migraphx_program_parameters_t params)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(params == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer");
*out = allocate<migraphx_arguments_t>(migraphx::run((program->object), (params->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((program->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_experimental_get_context(migraphx_context_t* out, const_migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = allocate<migraphx_context_t>(migraphx::get_context((program->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation)
{
return migraphx::try_([&] { destroy((operation)); });
auto api_error_result = migraphx::try_([&] { destroy((operation)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_operation_create(migraphx_operation_t* operation, const char* name, const char* attributes)
extern "C" migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes,
...)
{
return migraphx::try_([&] {
va_list vlist;
va_start(vlist, attributes);
auto api_error_result = migraphx::try_([&] {
*operation = object_cast<migraphx_operation_t>(
allocate<migraphx::operation>(migraphx::create_op((name), (attributes))));
allocate<migraphx::operation>(migraphx::create_op((name), (attributes), (vlist))));
});
va_end(vlist);
return api_error_result;
}
extern "C" migraphx_status
migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(operation == nullptr)
......@@ -783,47 +1298,58 @@ migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operati
auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out);
*it = '\0';
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options* options)
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options)
{
return migraphx::try_([&] {
*out = allocate<migraphx_program_t>(migraphx::load(
(name),
(options == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*options))));
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::load((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options* options)
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(p == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer");
migraphx::save(
(p->object),
(name),
(options == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*options)));
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::save((p->object), (name), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options)
{
return migraphx::try_([&] { destroy((onnx_options)); });
auto api_error_result = migraphx::try_([&] { destroy((onnx_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output,
const_migraphx_onnx_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*onnx_options = object_cast<migraphx_onnx_options_t>(allocate<migraphx::onnx_options>());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr and dims_size != 0)
......@@ -831,26 +1357,122 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape(
(onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_dim_value((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_loop_iterations((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options)
{
auto api_error_result = migraphx::try_([&] { destroy((file_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output,
const_migraphx_file_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options)
{
auto api_error_result = migraphx::try_([&] {
*file_options = object_cast<migraphx_file_options_t>(allocate<migraphx::file_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format)
{
auto api_error_result = migraphx::try_([&] {
if(file_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter file_options: Null pointer");
migraphx::set_file_format((file_options->object), (format));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options)
{
auto api_error_result = migraphx::try_([&] { destroy((compile_options)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_assign_to(migraphx_compile_options_t output,
const_migraphx_compile_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options)
{
auto api_error_result = migraphx::try_([&] {
*compile_options =
object_cast<migraphx_compile_options_t>(allocate<migraphx::compile_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value)
{
auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_offload_copy((compile_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value)
{
auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_fast_math((compile_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_onnx((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
......@@ -858,40 +1480,51 @@ extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
size_t size,
migraphx_onnx_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(
migraphx::parse_onnx_buffer((data), (size), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options)
{
return migraphx::try_([&] { destroy((tf_options)); });
auto api_error_result = migraphx::try_([&] { destroy((tf_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*tf_options = object_cast<migraphx_tf_options_t>(allocate<migraphx::tf_options>());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
bool is_nhwc)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_nhwc((tf_options->object), (is_nhwc));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(dims == nullptr and dims_size != 0)
......@@ -899,23 +1532,25 @@ extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape(
(tf_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_default_dim_value((tf_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options,
const char** names,
size_t names_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(names == nullptr and names_size != 0)
......@@ -923,96 +1558,122 @@ extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_opti
migraphx::set_output_names((tf_options->object),
(std::vector<const char*>(names, names + names_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_tf((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names)
{
return migraphx::try_([&] { destroy((quantize_op_names)); });
auto api_error_result = migraphx::try_([&] { destroy((quantize_op_names)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output,
const_migraphx_quantize_op_names_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*quantize_op_names =
object_cast<migraphx_quantize_op_names_t>(allocate<std::vector<std::string>>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_op_names == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_op_names: Null pointer");
(quantize_op_names->object).push_back((name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog,
migraphx_quantize_op_names_t name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
migraphx::quantize_fp16_with_op_names((prog->object), (name->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
migraphx::quantize_fp16((prog->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options)
{
return migraphx::try_([&] { destroy((quantize_int8_options)); });
auto api_error_result = migraphx::try_([&] { destroy((quantize_int8_options)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output,
const_migraphx_quantize_int8_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*quantize_int8_options = object_cast<migraphx_quantize_int8_options_t>(
allocate<migraphx::quantize_int8_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options,
const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
migraphx::add_op_name((quantize_int8_options->object), (name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
......@@ -1020,13 +1681,14 @@ extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer");
migraphx::add_calibration_data((quantize_int8_options->object), (data->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_int8_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(target == nullptr)
......@@ -1035,4 +1697,73 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::quantize_int8_wrap((prog->object), (target->object), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t context)
{
auto api_error_result = migraphx::try_([&] {
if(context == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter context: Null pointer");
(context->object).finish();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context)
{
auto api_error_result = migraphx::try_([&] {
if(context == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter context: Null pointer");
*out = (context->object).get_queue().unsafe_get();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op)
{
auto api_error_result = migraphx::try_([&] { destroy((experimental_custom_op)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output,
const_migraphx_experimental_custom_op_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* name)
{
auto api_error_result = migraphx::try_([&] {
*experimental_custom_op =
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input)
{
auto api_error_result = migraphx::try_([&] { (obj)->compute_shape_f = (input); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op)
{
auto api_error_result = migraphx::try_([&] {
if(experimental_custom_op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter experimental_custom_op: Null pointer");
migraphx::register_custom_op((*experimental_custom_op));
});
return api_error_result;
}
......@@ -25,7 +25,8 @@ extern "C" {
#endif
// return code, more to be added later
typedef enum {
typedef enum
{
migraphx_status_success = 0,
migraphx_status_bad_param = 1,
migraphx_status_unknown_target = 3,
......@@ -35,32 +36,13 @@ typedef enum {
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs
typedef enum {
typedef enum
{
migraphx_shape_tuple_type,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
} migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
/// Options to be passed when compiling
typedef struct
{
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// offloaded memory and to copy the final result from the offloaded
/// memory back to main memory.
bool offload_copy;
/// Optimize math functions to use faster approximate versions. There may
/// be slight accuracy degredation when enabled.
bool fast_math;
} migraphx_compile_options;
/// Options for saving and loading files
typedef struct
{
/// Format to be used for file. It can either be json or msgpack
const char* format;
} migraphx_file_options;
typedef struct migraphx_shape* migraphx_shape_t;
typedef const struct migraphx_shape* const_migraphx_shape_t;
......@@ -82,6 +64,15 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t;
typedef struct migraphx_shapes* migraphx_shapes_t;
typedef const struct migraphx_shapes* const_migraphx_shapes_t;
typedef struct migraphx_instruction* migraphx_instruction_t;
typedef const struct migraphx_instruction* const_migraphx_instruction_t;
typedef struct migraphx_instructions* migraphx_instructions_t;
typedef const struct migraphx_instructions* const_migraphx_instructions_t;
typedef struct migraphx_modules* migraphx_modules_t;
typedef const struct migraphx_modules* const_migraphx_modules_t;
typedef struct migraphx_module* migraphx_module_t;
typedef const struct migraphx_module* const_migraphx_module_t;
......@@ -94,6 +85,12 @@ typedef const struct migraphx_operation* const_migraphx_operation_t;
typedef struct migraphx_onnx_options* migraphx_onnx_options_t;
typedef const struct migraphx_onnx_options* const_migraphx_onnx_options_t;
typedef struct migraphx_file_options* migraphx_file_options_t;
typedef const struct migraphx_file_options* const_migraphx_file_options_t;
typedef struct migraphx_compile_options* migraphx_compile_options_t;
typedef const struct migraphx_compile_options* const_migraphx_compile_options_t;
typedef struct migraphx_tf_options* migraphx_tf_options_t;
typedef const struct migraphx_tf_options* const_migraphx_tf_options_t;
......@@ -103,8 +100,24 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name
typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t;
typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t;
typedef struct migraphx_context* migraphx_context_t;
typedef const struct migraphx_context* const_migraphx_context_t;
typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t;
typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t;
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj,
migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input);
migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
size_t* lengths,
......@@ -135,6 +148,9 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input);
migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer);
......@@ -151,11 +167,17 @@ migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, s
migraphx_status migraphx_target_destroy(migraphx_target_t target);
migraphx_status migraphx_target_assign_to(migraphx_target_t output, const_migraphx_target_t input);
migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name);
migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output,
const_migraphx_program_parameter_shapes_t input);
migraphx_status migraphx_program_parameter_shapes_size(
size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes);
......@@ -170,6 +192,9 @@ migraphx_status migraphx_program_parameter_shapes_names(
migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters);
migraphx_status migraphx_program_parameters_assign_to(migraphx_program_parameters_t output,
const_migraphx_program_parameters_t input);
migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters);
......@@ -179,6 +204,9 @@ migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t pr
migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments);
migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input);
migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments);
migraphx_status
......@@ -186,21 +214,81 @@ migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t argu
migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes);
migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output, const_migraphx_shapes_t input);
migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx);
migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction);
migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output,
const_migraphx_instruction_t input);
migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions);
migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output,
const_migraphx_instructions_t input);
migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr,
size_t size);
migraphx_status migraphx_modules_destroy(migraphx_modules_t modules);
migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input);
migraphx_status
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size);
migraphx_status migraphx_module_create(migraphx_module_t* module, char* name);
migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args);
migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args,
migraphx_modules_t module_refs);
migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t shape,
const char* buffer);
migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module,
const char* name,
const_migraphx_shape_t shape);
migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_instructions_t args);
migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input);
migraphx_status migraphx_program_create(migraphx_program_t* program);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program);
migraphx_status migraphx_program_create_module(migraphx_module_t* out,
migraphx_program_t program,
const char* name);
migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options* options);
migraphx_compile_options_t options);
migraphx_status migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out,
migraphx_program_t program);
......@@ -219,22 +307,32 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x);
migraphx_status migraphx_program_experimental_get_context(migraphx_context_t* out,
const_migraphx_program_t program);
migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes);
const char* attributes,
...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options* options);
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options);
migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options* options);
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options);
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output,
const_migraphx_onnx_options_t input);
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape(
......@@ -243,6 +341,33 @@ migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options,
size_t value);
migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value);
migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options);
migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output,
const_migraphx_file_options_t input);
migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options);
migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options,
const char* format);
migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options);
migraphx_status migraphx_compile_options_assign_to(migraphx_compile_options_t output,
const_migraphx_compile_options_t input);
migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options);
migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value);
migraphx_status migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options,
bool value);
migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options);
......@@ -253,6 +378,9 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options);
migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input);
migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options);
migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc);
......@@ -274,6 +402,9 @@ migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options
migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
migraphx_status migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output,
const_migraphx_quantize_op_names_t input);
migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names,
......@@ -287,6 +418,10 @@ migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);
migraphx_status
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output,
const_migraphx_quantize_int8_options_t input);
migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options);
......@@ -301,6 +436,30 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_int8_options_t options);
migraphx_status migraphx_context_finish(const_migraphx_context_t context);
migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context);
migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
migraphx_status
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output,
const_migraphx_experimental_custom_op_t input);
migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* name);
migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
#ifdef __cplusplus
}
#endif
......
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <memory>
#include <exception>
......@@ -13,14 +15,31 @@ namespace migraphx {
inline namespace api { // NOLINT
#endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs)
{
T* result = nullptr;
// cppcheck-suppress redundantInitialization
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto e = f(&result, std::forward<Ts>(xs)...);
auto e = f(&result, std::forward<Ts>(xs)...);
if(e != migraphx_status_success)
throw std::runtime_error("Failed to call function");
return result;
......@@ -29,9 +48,6 @@ T* make(F f, Ts&&... xs)
template <class F, class... Ts>
void call(F f, Ts&&... xs)
{
// cppcheck-suppress redundantInitialization
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto e = f(std::forward<Ts>(xs)...);
if(e != migraphx_status_success)
throw std::runtime_error("Failed to call function");
......@@ -87,34 +103,22 @@ struct iota_iterator
return it;
}
// TODO: operator->
reference operator*() const { return (*f)(index); }
};
reference operator*() const { return f(index); }
template <class F, class Iterator>
inline iota_iterator<F, Iterator> operator+(iota_iterator<F, Iterator> x,
iota_iterator<F, Iterator> y)
{
return iota_iterator<F, Iterator>(x.index + y.index, x.f);
}
friend iota_iterator operator+(iota_iterator x, iota_iterator y)
{
return iota_iterator(x.index + y.index, x.f);
}
template <class F, class Iterator>
inline iota_iterator<F, Iterator> operator-(iota_iterator<F, Iterator> x,
iota_iterator<F, Iterator> y)
{
return iota_iterator<F, Iterator>(x.index - y.index, x.f);
}
friend iota_iterator operator-(iota_iterator x, iota_iterator y)
{
return iota_iterator(x.index - y.index, x.f);
}
template <class F, class Iterator>
inline bool operator==(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
friend bool operator==(iota_iterator x, iota_iterator y) { return x.index == y.index; }
template <class F, class Iterator>
inline bool operator!=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index != y.index;
}
friend bool operator!=(iota_iterator x, iota_iterator y) { return x.index != y.index; }
};
template <class Derived>
struct array_base
......@@ -124,8 +128,20 @@ struct array_base
template <class T>
using value_type_t = decltype(std::declval<T>()[0]);
struct iterator_read
{
const Derived* self;
template <class D = Derived>
value_type_t<D> operator()(size_t pidx) const
{
return (*self)[pidx];
}
};
template <class T>
using iterator_t = iota_iterator<typename T::iterator_read>;
using iterator_t = iota_iterator<iterator_read>;
bool empty() const { return derived().size() == 0; }
template <class D = Derived>
value_type_t<D> front() const
......@@ -142,16 +158,45 @@ struct array_base
template <class D = Derived>
iterator_t<D> begin() const
{
return {0, {derived().get_handle_ptr()}};
return {0, {&derived()}};
}
template <class D = Derived>
iterator_t<D> end() const
{
return {derived().size(), {derived().get_handle_ptr()}};
return {derived().size(), {&derived()}};
}
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wnon-template-friend"
#endif
template <class T>
struct holder
{
// Friend injection
friend auto migraphx_adl_handle_lookup(holder<T>);
// Function left unimplemented since its only used in non-evaluated
// context
T get() const;
};
template <class C, class T>
struct handle_lookup
{
friend auto migraphx_adl_handle_lookup(holder<T>) { return holder<C>{}; }
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
template <class T>
using as_handle = decltype(
migraphx_adl_handle_lookup(holder<std::remove_cv_t<std::remove_pointer_t<T>>>{}).get());
struct own
{
};
......@@ -159,9 +204,25 @@ struct borrow
{
};
template <class T, class D, D Deleter>
struct handle_base
template <class T>
struct share
{
share(std::shared_ptr<T> p) : ptr(std::move(p)) {}
template <class U>
std::shared_ptr<U> alias(U* p) const
{
return std::shared_ptr<U>{ptr, p};
}
private:
std::shared_ptr<T> ptr;
};
template <class Derived, class T, class D, D Deleter, class A, A Assigner>
struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
{
using handle_type = T;
handle_base() : m_handle(nullptr) {}
template <class F, class... Ts>
void make_handle(F f, Ts&&... xs)
......@@ -190,17 +251,178 @@ struct handle_base
m_handle = std::shared_ptr<U>{ptr, [](U*) {}};
}
template <class U, class V>
void set_handle(U* ptr, share<V> b)
{
m_handle = std::shared_ptr<T>{ptr, [b](U*) {}};
}
share<T> share_handle() const { return {m_handle}; }
template <class U>
void assign_to_handle(U* x)
{
Assigner(x, this->get_handle_ptr());
}
protected:
std::shared_ptr<T> m_handle;
};
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template <class Base>
struct interface_base : Base
{
interface_base() : Base() {}
protected:
template <class F>
static migraphx_status try_(F f) // NOLINT
{
try
{
f();
return migraphx_status_success;
}
catch(...)
{
return migraphx_status_unknown_error;
}
}
template <class F, class T, class... Ts>
void make_interface(F f, T& obj, Ts&&... xs)
{
auto copy = [](void** out, void* input) {
return try_([&] {
T** y = reinterpret_cast<T**>(out);
T* x = reinterpret_cast<T*>(input);
assert(x != nullptr and y != nullptr and *y == nullptr);
// cppcheck-suppress useSmartPointer
*y = new T(*x); // NOLINT
});
};
auto del = [](void* input) {
return try_([&] {
T* x = reinterpret_cast<T*>(input);
delete x; // NOLINT
});
};
this->make_handle(f, &obj, copy, del, std::forward<Ts>(xs)...);
}
template <class T, class Setter, class F>
void set_fp(Setter setter, F pf)
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status {
return try_([&] { call_cast_arg<T>(rank<1>{}, f, xs...); });
});
}
template <class T, class Setter, class F>
void set_auto_fp(Setter setter, F f)
{
return set_fp<T>(setter, [=](T& obj, auto out, auto... xs) {
auto_invoke(f, out, obj, auto_convert_param(rank<2>{}, xs)...);
});
}
struct no_out_arg
{
};
template <class T, class F, class X, class... Xs, class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs)
{
f(reinterpret_cast<T*>(obj), no_out_arg{}, xs...);
}
template <class T,
class F,
class R,
class X,
class... Xs,
class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs)
{
f(*reinterpret_cast<T*>(obj), result, xs...);
}
template <class F, class T, class... Ts>
void auto_invoke(F f, T* out, Ts&&... xs)
{
auto_assign(rank<2>{}, out, f(std::forward<Ts>(xs)...));
}
template <class F, class T, class... Ts>
void auto_invoke(F f, no_out_arg, Ts&&... xs)
{
f(std::forward<Ts>(xs)...);
}
template <class T, class = std::enable_if_t<std::is_fundamental<T>{} or std::is_enum<T>{}>>
T auto_convert_param(rank<0>, T x)
{
return x;
}
template <class T>
auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x})
{
return as_handle<T>{x};
}
template <class T>
auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}})
{
return as_handle<T>{x, borrow{}};
}
template <class T, class U>
void auto_assign(rank<0>, T* out, U x)
{
return *out = x;
}
template <class T, class U>
auto auto_assign(rank<1>, T* out, U x) -> decltype(x.assign_to_handle(out))
{
x.assign_to_handle(out);
}
};
// NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); })
template <class Base, class T>
using require_interface =
std::enable_if_t<std::is_base_of<Base, T>{} and not std::is_same<T, Base>{} and
std::is_copy_constructible<T>{} and std::is_final<T>{}>;
#ifdef DOXYGEN
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy>
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<name, \
const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy, \
decltype(&migraphx_##name##_assign_to), \
migraphx_##name##_assign_to>
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, )
......@@ -216,11 +438,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{
shape() {}
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
shape(migraphx_shape* p, own) { this->set_handle(p, own{}); }
shape(migraphx_shape* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(shape);
/// Construct a scalar shape
shape(migraphx_shape_datatype_t type)
......@@ -252,7 +473,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout;
size_t pout_size;
call(&migraphx_shape_lengths, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size);
return {pout, pout + pout_size};
}
std::vector<size_t> strides() const
......@@ -260,7 +481,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout;
size_t pout_size;
call(&migraphx_shape_strides, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size);
return {pout, pout + pout_size};
}
migraphx_shape_datatype_t type() const
......@@ -297,10 +518,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
argument() {}
argument(migraphx_argument* p, borrow) { this->set_handle(p, borrow{}); }
argument(migraphx_argument* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(argument);
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
argument(shape pshape, void* pbuffer)
......@@ -312,7 +532,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return shape(pout);
return {pout, this->share_handle()};
}
char* data() const
......@@ -325,9 +545,8 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
/// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0)
{
return argument(
make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed),
own{});
return {make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed),
own{}};
}
friend bool operator==(const argument& px, const argument& py)
......@@ -345,9 +564,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
target() {}
target(migraphx_target* p, own) { this->set_handle(p, own{}); }
target(migraphx_target* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(target);
/// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); }
......@@ -357,15 +574,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
program_parameter_shapes() {}
program_parameter_shapes(migraphx_program_parameter_shapes* p, own)
{
this->set_handle(p, own{});
}
program_parameter_shapes(migraphx_program_parameter_shapes* p, borrow)
{
this->set_handle(p, borrow{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes);
size_t size() const
{
......@@ -378,7 +587,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return shape(pout);
return {pout, this->share_handle()};
}
std::vector<const char*> names() const
......@@ -395,10 +604,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{
program_parameters(migraphx_program_parameters* p, own) { this->set_handle(p, own{}); }
program_parameters(migraphx_program_parameters* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters);
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
program_parameters() { this->make_handle(&migraphx_program_parameters_create); }
......@@ -423,9 +631,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
arguments(migraphx_arguments* p, own) { this->set_handle(p, own{}); }
arguments(migraphx_arguments* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(arguments);
size_t size() const
{
......@@ -438,27 +644,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return argument(pout);
return {pout, this->share_handle()};
}
struct iterator_read
{
migraphx_arguments* self;
argument operator()(size_t pidx) const
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx);
return argument(pout);
}
};
};
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
shapes(migraphx_shapes* p, own) { this->set_handle(p, own{}); }
shapes(migraphx_shapes* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(shapes);
size_t size() const
{
......@@ -471,49 +663,198 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return shape(pout);
return {pout, this->share_handle()};
}
};
struct iterator_read
struct operation : MIGRAPHX_HANDLE_BASE(operation)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(operation);
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{
migraphx_shapes* self;
shape operator()(size_t pidx) const
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, self, pidx);
return shape(pout);
}
};
this->make_handle(&migraphx_operation_create, name, attributes, xs...);
}
std::string name()
{
std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return {out_name.data()};
}
};
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(instruction);
};
struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions);
template <class... Ts>
instructions(Ts... xs)
{
std::array<const_migraphx_instruction_t, sizeof...(Ts)> a{xs.get_handle_ptr()...};
this->make_handle(&migraphx_instructions_create, a.data(), a.size());
}
};
struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(modules);
template <class... Ts>
modules(Ts... xs)
{
std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.get_handle_ptr()...};
this->make_handle(&migraphx_modules_create, a.data(), a.size());
}
};
struct module
{
migraphx_module_t mm;
module(const migraphx_module_t& m) : mm(m) {}
MIGRAPHX_DEPRECATED("Constructor without lifetime annotation is deprecated.")
module(migraphx_module* m) : mm(std::shared_ptr<migraphx_module*>(), m) {}
module(migraphx_module* m, borrow) : mm(std::shared_ptr<migraphx_module*>(), m) {}
template <class T>
module(migraphx_module* m, share<T> b) : mm(b.alias(m))
{
}
void print() const { call(&migraphx_module_print, mm.get()); }
instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args)
{
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction,
&op_ins,
mm.get(),
op.get_handle_ptr(),
args.get_handle_ptr());
return instruction(op_ins, own{});
}
instruction add_instruction(const migraphx::operation& op,
const migraphx::instructions& args,
const migraphx::modules& module_args)
{
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction_with_mod_args,
&op_ins,
mm.get(),
op.get_handle_ptr(),
args.get_handle_ptr(),
module_args.get_handle_ptr());
return instruction(op_ins, own{});
}
template <typename T>
instruction add_literal(const migraphx::shape& s, T* buffer)
{
migraphx_instruction_t literal_ins;
const auto* buffer_ptr = reinterpret_cast<const char*>(buffer);
call(&migraphx_module_add_literal, &literal_ins, mm.get(), s.get_handle_ptr(), buffer_ptr);
return instruction(literal_ins, own{});
}
instruction add_parameter(const std::string& name, shape s)
{
migraphx_instruction_t param_ins;
call(
&migraphx_module_add_parameter, &param_ins, mm.get(), name.c_str(), s.get_handle_ptr());
return instruction(param_ins, own{});
}
instruction add_return(const migraphx::instructions& args)
{
migraphx_instruction_t ret_ins;
call(&migraphx_module_add_return, &ret_ins, mm.get(), args.get_handle_ptr());
return instruction(ret_ins, own{});
}
migraphx_module_t get_handle_ptr() const { return mm.get(); }
private:
std::shared_ptr<migraphx_module> mm;
};
struct context
{
context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
template <class T>
context(migraphx_context* p, share<T> b) : ctx(b.alias(p))
{
}
void finish() const { call(&migraphx_context_finish, ctx.get()); }
template <class T>
T get_queue()
{
void* out;
call(&migraphx_context_get_queue, &out, ctx.get());
// TODO: check type here
return reinterpret_cast<T>(out);
}
private:
std::shared_ptr<migraphx_context> ctx;
};
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{
compile_options() { this->make_handle(&migraphx_compile_options_create); }
void print() const { call(&migraphx_module_print, mm); }
MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options);
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// offloaded memory and to copy the final result from the offloaded
/// memory back to main memory.
void set_offload_copy(bool value = true)
{
call(&migraphx_compile_options_set_offload_copy, this->get_handle_ptr(), value);
}
/// Optimize math functions to use faster approximate versions. There may
/// be slight accuracy degredation when enabled.
void set_fast_math(bool value = true)
{
call(&migraphx_compile_options_set_fast_math, this->get_handle_ptr(), value);
}
};
/// A program represents the all computation graphs to be compiled and executed
struct program : MIGRAPHX_HANDLE_BASE(program)
{
program() {}
program() { this->make_handle(&migraphx_program_create); }
program(migraphx_program* p, own) { this->set_handle(p, own{}); }
program(migraphx_program* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(program);
/// Compile the program for a specific target to be ran on
void compile(const target& ptarget, migraphx_compile_options poptions) const
void compile(const target& ptarget, const compile_options& poptions) const
{
call(
&migraphx_program_compile, this->get_handle_ptr(), ptarget.get_handle_ptr(), &poptions);
call(&migraphx_program_compile,
this->get_handle_ptr(),
ptarget.get_handle_ptr(),
poptions.get_handle_ptr());
}
/// Compile the program for a specific target to be ran on
void compile(const target& ptarget) const
{
call(&migraphx_program_compile, this->get_handle_ptr(), ptarget.get_handle_ptr(), nullptr);
call(&migraphx_program_compile,
this->get_handle_ptr(),
ptarget.get_handle_ptr(),
migraphx::compile_options{}.get_handle_ptr());
}
/// Return the shapes for the input parameters
......@@ -559,53 +900,64 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
migraphx_module_t p_modu;
call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr());
return module{p_modu};
return module{p_modu, this->share_handle()};
}
context experimental_get_context()
{
migraphx_context_t ctx;
call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr());
return context{ctx, this->share_handle()};
}
module create_module(const std::string& name)
{
migraphx_module_t p_modu;
call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data());
return module{p_modu, this->share_handle()};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
};
struct operation : MIGRAPHX_HANDLE_BASE(operation)
// options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); }
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
operation(const char* name, const char* attributes = nullptr)
{
this->make_handle(&migraphx_operation_create, name, attributes);
}
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options);
file_options() { this->make_handle(&migraphx_file_options_create); }
std::string name()
// set file format
void set_file_format(const char* format)
{
std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return std::string(out_name.data());
call(&migraphx_file_options_set_file_format, this->get_handle_ptr(), format);
}
};
/// Load a saved migraphx program from a file
inline program load(const char* filename, migraphx_file_options options)
inline program load(const char* filename, const file_options& options)
{
return program(make<migraphx_program>(&migraphx_load, filename, &options), own{});
return program(make<migraphx_program>(&migraphx_load, filename, options.get_handle_ptr()),
own{});
}
/// Load a saved migraphx program from a file
inline program load(const char* filename)
{
return program(make<migraphx_program>(&migraphx_load, filename, nullptr), own{});
return program(
make<migraphx_program>(&migraphx_load, filename, migraphx::file_options{}.get_handle_ptr()),
own{});
}
/// Save a program to a file
inline void save(const program& p, const char* filename, migraphx_file_options options)
inline void save(const program& p, const char* filename, const file_options& options)
{
call(&migraphx_save, p.get_handle_ptr(), filename, &options);
call(&migraphx_save, p.get_handle_ptr(), filename, options.get_handle_ptr());
}
/// Save a program to a file
inline void save(const program& p, const char* filename)
{
call(&migraphx_save, p.get_handle_ptr(), filename, nullptr);
call(&migraphx_save, p.get_handle_ptr(), filename, migraphx::file_options{}.get_handle_ptr());
}
/// Options for parsing onnx options
......@@ -613,7 +965,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
onnx_options(migraphx_onnx_options* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options);
/// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
......@@ -630,6 +982,12 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value);
}
/// Set default max iteration number for the loop operator
void set_default_loop_iterations(int64_t value)
{
call(&migraphx_onnx_options_set_default_loop_iterations, this->get_handle_ptr(), value);
}
};
/// Parse an onnx file into a migraphx program
......@@ -689,7 +1047,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
tf_options() { this->make_handle(&migraphx_tf_options_create); }
tf_options(migraphx_tf_options* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options);
/// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
......@@ -742,7 +1100,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
quantize_op_names(migraphx_quantize_op_names* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names);
void add(const std::string& name)
{
......@@ -767,12 +1125,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
quantize_int8_options(migraphx_quantize_int8_options* p, own) { this->set_handle(p, own{}); }
quantize_int8_options(migraphx_quantize_int8_options* p, borrow)
{
this->set_handle(p, borrow{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options);
/// Add an operator that should be quantized
void add_op_name(const std::string& name)
......@@ -799,6 +1152,32 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
options.get_handle_ptr());
}
struct experimental_custom_op_base
{
virtual std::string name() const = 0;
virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default;
};
struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)>
{
template <class T>
experimental_custom_op(T& obj)
{
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
}
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
};
template <class T, class = require_interface<experimental_custom_op_base, T>>
void register_experimental_custom_op(T& obj)
{
experimental_custom_op op{obj};
op.register_op();
}
#ifndef DOXYGEN
} // namespace api
#endif
......
......@@ -178,14 +178,58 @@ def shapes(h):
returns='const migraphx::shape&')
@api.handle('migraphx_instruction', 'migraphx::instruction_ref')
def instruction(h):
pass
@api.handle('migraphx_instructions', 'std::vector<migraphx::instruction_ref>')
def instructions(h):
h.constructor(
'create',
api.params(ptr='const_migraphx_instruction_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_instruction_t>')
@api.handle('migraphx_modules', 'std::vector<migraphx::module*>')
def modules(h):
h.constructor('create',
api.params(ptr='migraphx_module_t*', size='size_t'),
fname='migraphx::to_objptr_vector<migraphx::module*>')
@auto_handle(ref=True)
def module(h):
h.constructor('create', api.params(name='std::string'))
h.method('print', invoke='migraphx::print_module($@)', const=True)
h.method('add_instruction',
api.params(op='migraphx::operation',
args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
h.method('add_instruction_with_mod_args',
api.params(op='migraphx::operation',
args='std::vector<migraphx::instruction_ref>',
module_refs='std::vector<migraphx::module*>'),
fname='add_instruction',
returns='migraphx::instruction_ref')
h.method('add_literal',
api.params(shape='const migraphx::shape&', buffer='const char*'),
returns='migraphx::instruction_ref')
h.method('add_parameter',
api.params(name='const char*', shape='const migraphx::shape&'),
returns='migraphx::instruction_ref')
h.method('add_return',
api.params(args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
@auto_handle()
def program(h):
h.constructor('create')
h.method('get_main_module', returns='migraphx::module*')
h.method('create_module',
api.params(name='const char*'),
returns='migraphx::module*')
h.method(
'compile',
api.params(target='migraphx::target',
......@@ -207,12 +251,18 @@ def program(h):
invoke='migraphx::equal($@)',
returns='bool',
const=True)
h.method('experimental_get_context',
invoke='migraphx::get_context($@)',
const=True,
returns='migraphx::context')
@auto_handle()
def operation(h):
h.constructor('create',
api.params(name='const char*', attributes='const char*'),
api.params(name='const char*',
attributes='const char*',
vlist='...'),
fname='migraphx::create_op')
h.method('name', returns='std::string')
......@@ -243,6 +293,30 @@ def onnx_options(h):
api.params(value='size_t'),
invoke='migraphx::set_default_dim_value($@)',
)
h.method(
'set_default_loop_iterations',
api.params(value='int64_t'),
invoke='migraphx::set_default_loop_iterations($@)',
)
@auto_handle()
def file_options(h):
h.constructor('create')
h.method('set_file_format',
api.params(format='const char*'),
invoke='migraphx::set_file_format($@)')
@auto_handle()
def compile_options(h):
h.constructor('create')
h.method('set_offload_copy',
api.params(value='bool'),
invoke='migraphx::set_offload_copy($@)')
h.method('set_fast_math',
api.params(value='bool'),
invoke='migraphx::set_fast_math($@)')
api.add_function('migraphx_parse_onnx',
......@@ -327,3 +401,19 @@ api.add_function('migraphx_quantize_int8',
target='migraphx::target',
options='migraphx::quantize_int8_options'),
fname='migraphx::quantize_int8_wrap')
@auto_handle(ref=True)
def context(h):
h.method('finish', const=True)
h.method('get_queue', returns='void*', fname='get_queue().unsafe_get')
@api.interface('migraphx_experimental_custom_op',
'migraphx::experimental_custom_op')
def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*'))
h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape')
h.method('register', invoke='migraphx::register_custom_op($@)')
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/apply_alpha_beta.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta)
{
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
if(a->get_shape().type() != input_type)
{
a = m.insert_instruction(pos, make_op("convert", {{"target_type", input_type}}), a);
}
}
auto op_res = m.insert_instruction(pos, op, a, b);
if(args.size() == 3)
{
if(not float_equal(beta.at<float>(0), 0.0) && args[2]->get_shape().elements() > 0)
{
auto out_lens = op_res->get_shape().lens();
auto c = args[2];
auto c_lens = c->get_shape().lens();
input_type = c->get_shape().type();
if(out_lens != c_lens)
{
c = m.insert_instruction(
pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = m.add_literal(beta);
auto beta_c = insert_common_op(m, pos, migraphx::make_op("mul"), {c, beta_literal});
if(beta_c->get_shape().type() != input_type)
{
beta_c = m.insert_instruction(
pos, migraphx::make_op("convert", {{"target_type", input_type}}), beta_c);
}
return m.insert_instruction(pos, migraphx::make_op("add"), op_res, beta_c);
}
}
return op_res;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -8,7 +8,7 @@ inline namespace MIGRAPHX_INLINE_NS {
argument::argument(const shape& s) : m_shape(s)
{
auto buffer = make_shared_array<char>(s.bytes());
m_data = {[=]() mutable { return buffer.get(); }};
assign_buffer({[=]() mutable { return buffer.get(); }});
}
argument::argument(shape s, std::nullptr_t)
......@@ -18,14 +18,17 @@ argument::argument(shape s, std::nullptr_t)
argument::argument(const shape& s, const argument::data_t& d) : m_shape(s), m_data(d) {}
argument argument::load(const shape& s, char* buffer)
void argument::assign_buffer(std::function<char*()> d)
{
const shape& s = m_shape;
if(s.type() != shape::tuple_type)
return argument{s, buffer};
{
m_data = {std::move(d)};
return;
}
// Collect all shapes
std::unordered_map<std::size_t, shape> shapes;
{
// cppcheck-suppress variableScope
std::size_t i = 0;
fix([&](auto self, auto ss) {
if(ss.sub_shapes().empty())
......@@ -56,21 +59,23 @@ argument argument::load(const shape& s, char* buffer)
}
assert(offset == s.bytes());
// cppcheck-suppress variableScope
std::size_t i = 0;
return fix<argument>([&](auto self, auto ss) {
m_data = fix<data_t>([&](auto self, auto ss) {
data_t result;
if(ss.sub_shapes().empty())
{
argument r{shapes[i], buffer + offsets[i]};
auto n = offsets[i];
result = {[d, n]() mutable { return d() + n; }};
i++;
return r;
return result;
}
std::vector<argument> subs;
std::vector<data_t> subs;
std::transform(ss.sub_shapes().begin(),
ss.sub_shapes().end(),
std::back_inserter(subs),
[&](auto child) { return self(child); });
return argument{subs};
result.sub = subs;
return result;
})(s);
}
......@@ -99,7 +104,11 @@ bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
const shape& argument::get_shape() const { return this->m_shape; }
argument argument::reshape(const shape& s) const { return {s, this->m_data}; }
argument argument::reshape(const shape& s) const
{
assert(s.element_space() <= this->get_shape().element_space());
return {s, this->m_data};
}
argument::data_t argument::data_t::share() const
{
......@@ -148,5 +157,13 @@ std::vector<argument> argument::get_sub_objects() const
return result;
}
argument argument::element(std::size_t i) const
{
assert(this->get_shape().sub_shapes().empty());
auto idx = this->get_shape().index(i);
auto offset = this->get_shape().type_size() * idx;
return argument{shape{this->get_shape().type()}, this->data() + offset};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -8,17 +8,44 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void auto_contiguous::apply(module& p) const
void auto_contiguous::apply(module& m) const
{
for(auto ins : iterator_for(p))
std::string key = "require_std_shape";
for(auto ins : reverse_iterator_for(m))
{
auto&& attr = ins->get_operator().attributes();
if((attr.get(key, false)))
{
auto args = ins->inputs();
auto new_args = args;
std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) {
if(in->name() == "contiguous")
{
return in;
}
return m.insert_instruction(ins, make_op("contiguous"), in);
});
if(new_args != args)
{
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
}
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(ins->name() == "layout")
continue;
// for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last)
continue;
shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0)
{
auto c = p.insert_instruction(std::next(ins), make_op("contiguous"), ins);
p.replace_instruction(ins, c);
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c);
}
}
}
......
......@@ -96,7 +96,7 @@ instruction_ref insert_common_op(module& m,
if(input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", common.lens()}}), input);
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
}
if(input->get_shape().type() != common.type())
{
......
......@@ -28,13 +28,20 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
{
params += " " + src.path.filename().string();
if(out.empty())
out = src.path.stem().string() + ".o";
out = src.path.stem().string() + out_ext;
}
}
params += " -o " + out;
td.execute(compiler, params);
if(not launcher.empty())
{
td.execute(launcher, compiler + " " + params);
}
else
{
td.execute(compiler, params);
}
auto out_path = td.path / out;
if(not fs::exists(out_path))
......
#include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/stringutils.hpp>
......@@ -26,16 +27,18 @@ cpp_generator::function::set_body(const module& m, const cpp_generator::generate
{
names[ins] =
migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter;
continue;
}
if(ins->name() == "@return")
else if(ins->name() == "@return")
{
assert(ins->inputs().size() == 1);
return_ins = ins->inputs().front();
}
std::string n = "z" + std::to_string(names.size());
names[ins] = n;
ss << "auto " << n << " = " << g(ins, names) << ";\n";
else
{
std::string n = "z" + std::to_string(names.size());
names[ins] = n;
ss << "auto " << n << " = " << g(ins, names) << ";\n";
}
}
ss << "return " << names.at(return_ins) << ";\n";
body = ss.str();
......@@ -49,6 +52,7 @@ cpp_generator::function& cpp_generator::function::set_types(const module& m)
cpp_generator::function&
cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
......@@ -61,11 +65,31 @@ cpp_generator::function::set_types(const module& m, const std::function<std::str
return *this;
}
cpp_generator::function& cpp_generator::function::set_generic_types(const module& m)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, "T" + p.first};
});
std::transform(input_map.begin(),
input_map.end(),
std::back_inserter(this->tparams),
[&](auto&& p) { return "class T" + p.first; });
this->return_type = "auto";
return *this;
}
struct cpp_generator_impl
{
std::stringstream fs{};
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
std::function<std::string(shape)> fresult = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
};
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
......@@ -81,38 +105,56 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
}
std::string cpp_generator::generate_point_op(const operation& op,
const std::vector<std::string>& args)
{
auto v = op.to_value();
return interpolate_string(op.attributes()["point_op"].to<std::string>(),
[&](auto start, auto last) -> std::string {
auto key = trim({start, last});
if(key.empty())
MIGRAPHX_THROW("Empty parameter");
std::string fselector = "function:";
if(starts_with(key, fselector))
{
auto fname = key.substr(fselector.size());
if(impl->fmap == nullptr)
return fname;
else
return impl->fmap(fname);
}
else if(with_char(::isdigit)(key[0]))
{
auto i = std::stoul(key);
return args.at(i);
}
else if(v.contains(key))
{
return v[key].template to<std::string>();
}
else
{
return key;
}
});
std::string code;
if(contains(impl->point_op_map, op.name()))
{
code = impl->point_op_map.at(op.name());
}
else
{
auto attributes = op.attributes();
if(not attributes.contains("point_op"))
MIGRAPHX_THROW("op is missing point_op attribute: " + op.name());
code = attributes["point_op"].to<std::string>();
}
return interpolate_string(code, [&](auto start, auto last) -> std::string {
auto key = trim({start, last});
if(key.empty())
MIGRAPHX_THROW("Empty parameter");
std::string fselector = "function:";
if(starts_with(key, fselector))
{
auto fname = key.substr(fselector.size());
if(impl->fmap == nullptr)
return fname;
else
return impl->fmap(fname);
}
else if(with_char(::isdigit)(key[0]))
{
auto i = std::stoul(key);
return args.at(i);
}
else if(v.contains(key))
{
return v[key].template to<std::string>();
}
else
{
return key;
}
});
}
std::string cpp_generator::str() const { return impl->fs.str(); }
......@@ -120,7 +162,12 @@ std::string cpp_generator::str() const { return impl->fs.str(); }
cpp_generator::function cpp_generator::generate_module(const module& m)
{
function f;
f.set_name(m.name()).set_types(m).set_body(
auto name = transform_string(m.name(), [](char c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return '_';
});
f.set_name(name).set_types(m).set_body(
m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal")
return shape::cpp_type(ins->get_shape().type()) + "(" +
......@@ -130,8 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins->inputs().end(),
std::back_inserter(args),
[&](auto i) { return names.at(i); });
auto s = this->generate_point_op(ins->get_operator(), args);
return this->generate_point_op(ins->get_operator(), args);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
return s;
});
return f;
}
......@@ -139,6 +190,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
std::string cpp_generator::create_function(const cpp_generator::function& f)
{
impl->function_count++;
if(not f.tparams.empty())
impl->fs << "template<" << join_strings(f.tparams, ", ") << ">\n";
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '(';
......
......@@ -9,26 +9,6 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Range, class Iterator>
std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
{
auto start_forward = start;
auto start_backwards = start;
std::size_t n = 0;
while(start_forward != last and start_backwards != last)
{
n++;
if(start_forward != r.end())
start_forward++;
if(start_backwards != r.begin())
start_backwards--;
}
if(start_forward == last)
return n;
else
return -n;
}
void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); }
void dead_code_elimination::apply(module& m) const
......@@ -48,19 +28,24 @@ void dead_code_elimination::apply(module& m) const
if(i->get_shape().elements() == 0 and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity")
continue;
assert(bidistance(m, i, last) > 0);
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited;
fix([&](auto self, auto leaf) {
if(not m.has_instruction(leaf))
return;
if(leaf->outputs().empty())
{
// Dont visit inputs twice
if(not visited.insert(leaf).second)
return;
std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
leaf->inputs().end());
leaf->clear_arguments();
assert(bidistance(m, last, leaf) < 0);
assert(std::distance(m.begin(), leaf) < std::distance(m.begin(), last));
assert(leaf != ins);
m.move_instruction(leaf, m.end());
if(leaf->name() != "@param")
m.move_instruction(leaf, m.end());
for(auto arg : args)
self(arg);
}
......
#include <migraphx/decompose.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct alpha_beta
{
float alpha = 0.0;
float beta = 0.0;
};
alpha_beta get_alpha_beta(const operation& op)
{
auto v = op.to_value();
return {v.at("alpha").to<float>(), v.at("beta").to<float>()};
}
struct find_dot_add
{
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(3)); }
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1))
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
auto dot_ins = p.insert_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1))
{
auto beta = p.add_literal(literal{shape{c_ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
}
p.replace_instruction(ins, make_op("add"), dot_ins, c_ins);
}
};
struct find_dot_alpha
{
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(2)); }
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1))
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
p.replace_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
}
};
} // namespace
void decompose::apply(module& p) const { match::find_matches(p, find_dot_add{}, find_dot_alpha{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -6,6 +6,7 @@ add_executable(driver
resnet50.cpp
inceptionv3.cpp
alexnet.cpp
marker_roctx.cpp
)
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility
......
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp"
namespace migraphx {
......@@ -60,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu19;
auto mx19 = mm->add_instruction(relu19, mx18);
migraphx::op::pooling pooling20;
pooling20.mode = "max";
pooling20.mode = migraphx::op::pooling_mode::max;
pooling20.padding = {0, 0};
pooling20.stride = {2, 2};
pooling20.lengths = {3, 3};
......@@ -80,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu24;
auto mx24 = mm->add_instruction(relu24, mx23);
migraphx::op::pooling pooling25;
pooling25.mode = "max";
pooling25.mode = migraphx::op::pooling_mode::max;
pooling25.padding = {0, 0};
pooling25.stride = {2, 2};
pooling25.lengths = {3, 3};
......@@ -128,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu37;
auto mx37 = mm->add_instruction(relu37, mx36);
migraphx::op::pooling pooling38;
pooling38.mode = "max";
pooling38.mode = migraphx::op::pooling_mode::max;
pooling38.padding = {0, 0};
pooling38.stride = {2, 2};
pooling38.lengths = {3, 3};
......@@ -144,10 +145,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast42;
multibroadcast42.output_lens = {batch, 4096};
auto mx42 = mm->add_instruction(multibroadcast42, mx4);
migraphx::op::dot dot43;
dot43.alpha = 1;
dot43.beta = 1;
auto mx43 = mm->add_instruction(dot43, mx40, mx41, mx42);
float dot43_alpha = 1;
float dot43_beta = 1;
auto mx43 = migraphx::add_apply_alpha_beta(
*mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta);
migraphx::op::relu relu44;
auto mx44 = mm->add_instruction(relu44, mx43);
migraphx::op::identity identity45;
......@@ -158,10 +159,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast47;
multibroadcast47.output_lens = {batch, 4096};
auto mx47 = mm->add_instruction(multibroadcast47, mx2);
migraphx::op::dot dot48;
dot48.alpha = 1;
dot48.beta = 1;
auto mx48 = mm->add_instruction(dot48, mx45, mx46, mx47);
float dot48_alpha = 1;
float dot48_beta = 1;
auto mx48 = migraphx::add_apply_alpha_beta(
*mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta);
migraphx::op::relu relu49;
auto mx49 = mm->add_instruction(relu49, mx48);
migraphx::op::transpose transpose50;
......@@ -170,10 +171,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000};
auto mx51 = mm->add_instruction(multibroadcast51, mx0);
migraphx::op::dot dot52;
dot52.alpha = 1;
dot52.beta = 1;
mm->add_instruction(dot52, mx49, mx50, mx51);
float dot52_alpha = 1;
float dot52_beta = 1;
migraphx::add_apply_alpha_beta(
*mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta);
return p;
}
......
......@@ -17,6 +17,7 @@
#include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
namespace driver {
......@@ -106,10 +107,22 @@ struct argument_parser
return to_string_range(x);
}
template <class T>
auto as_string_value(rank<1>, const T& x) -> decltype(to_string(x))
{
return to_string(x);
}
template <class T>
std::string as_string_value(rank<0>, const T&)
{
throw std::runtime_error("Can't convert to string");
}
template <class T, MIGRAPHX_REQUIRES(not is_multi_value<T>{})>
std::string as_string_value(const T& x)
{
return to_string(x);
return as_string_value(rank<1>{}, x);
}
template <class T, class... Fs>
......@@ -122,10 +135,11 @@ struct argument_parser
return false;
}});
argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>();
arg.default_value = as_string_value(x);
argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>();
migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
if(not arg.default_value.empty() and arg.nargs > 0)
arg.default_value = as_string_value(x);
}
template <class... Fs>
......
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp"
namespace migraphx {
......@@ -994,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu492;
auto mx492 = mm->add_instruction(relu492, mx491);
migraphx::op::pooling pooling493;
pooling493.mode = "max";
pooling493.mode = migraphx::op::pooling_mode::max;
pooling493.padding = {0, 0};
pooling493.stride = {2, 2};
pooling493.lengths = {3, 3};
......@@ -1024,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu499;
auto mx499 = mm->add_instruction(relu499, mx498);
migraphx::op::pooling pooling500;
pooling500.mode = "max";
pooling500.mode = migraphx::op::pooling_mode::max;
pooling500.padding = {0, 0};
pooling500.stride = {2, 2};
pooling500.lengths = {3, 3};
......@@ -1102,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu518;
auto mx518 = mm->add_instruction(relu518, mx517);
migraphx::op::pooling pooling519;
pooling519.mode = "average";
pooling519.mode = migraphx::op::pooling_mode::average;
pooling519.padding = {1, 1};
pooling519.stride = {1, 1};
pooling519.lengths = {3, 3};
......@@ -1195,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu541;
auto mx541 = mm->add_instruction(relu541, mx540);
migraphx::op::pooling pooling542;
pooling542.mode = "average";
pooling542.mode = migraphx::op::pooling_mode::average;
pooling542.padding = {1, 1};
pooling542.stride = {1, 1};
pooling542.lengths = {3, 3};
......@@ -1288,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu564;
auto mx564 = mm->add_instruction(relu564, mx563);
migraphx::op::pooling pooling565;
pooling565.mode = "average";
pooling565.mode = migraphx::op::pooling_mode::average;
pooling565.padding = {1, 1};
pooling565.stride = {1, 1};
pooling565.lengths = {3, 3};
......@@ -1357,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu581;
auto mx581 = mm->add_instruction(relu581, mx580);
migraphx::op::pooling pooling582;
pooling582.mode = "max";
pooling582.mode = migraphx::op::pooling_mode::max;
pooling582.padding = {0, 0};
pooling582.stride = {2, 2};
pooling582.lengths = {3, 3};
......@@ -1474,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu610;
auto mx610 = mm->add_instruction(relu610, mx609);
migraphx::op::pooling pooling611;
pooling611.mode = "average";
pooling611.mode = migraphx::op::pooling_mode::average;
pooling611.padding = {1, 1};
pooling611.stride = {1, 1};
pooling611.lengths = {3, 3};
......@@ -1603,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu642;
auto mx642 = mm->add_instruction(relu642, mx641);
migraphx::op::pooling pooling643;
pooling643.mode = "average";
pooling643.mode = migraphx::op::pooling_mode::average;
pooling643.padding = {1, 1};
pooling643.stride = {1, 1};
pooling643.lengths = {3, 3};
......@@ -1732,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu674;
auto mx674 = mm->add_instruction(relu674, mx673);
migraphx::op::pooling pooling675;
pooling675.mode = "average";
pooling675.mode = migraphx::op::pooling_mode::average;
pooling675.padding = {1, 1};
pooling675.stride = {1, 1};
pooling675.lengths = {3, 3};
......@@ -1861,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu706;
auto mx706 = mm->add_instruction(relu706, mx705);
migraphx::op::pooling pooling707;
pooling707.mode = "average";
pooling707.mode = migraphx::op::pooling_mode::average;
pooling707.padding = {1, 1};
pooling707.stride = {1, 1};
pooling707.lengths = {3, 3};
......@@ -1954,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu729;
auto mx729 = mm->add_instruction(relu729, mx728);
migraphx::op::pooling pooling730;
pooling730.mode = "max";
pooling730.mode = migraphx::op::pooling_mode::max;
pooling730.padding = {0, 0};
pooling730.stride = {2, 2};
pooling730.lengths = {3, 3};
......@@ -2065,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat757.axis = 1;
auto mx757 = mm->add_instruction(concat757, mx753, mx756);
migraphx::op::pooling pooling758;
pooling758.mode = "average";
pooling758.mode = migraphx::op::pooling_mode::average;
pooling758.padding = {1, 1};
pooling758.stride = {1, 1};
pooling758.lengths = {3, 3};
......@@ -2188,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat788.axis = 1;
auto mx788 = mm->add_instruction(concat788, mx784, mx787);
migraphx::op::pooling pooling789;
pooling789.mode = "average";
pooling789.mode = migraphx::op::pooling_mode::average;
pooling789.padding = {1, 1};
pooling789.stride = {1, 1};
pooling789.lengths = {3, 3};
......@@ -2209,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat793.axis = 1;
auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792);
migraphx::op::pooling pooling794;
pooling794.mode = "average";
pooling794.mode = migraphx::op::pooling_mode::average;
pooling794.padding = {0, 0};
pooling794.stride = {8, 8};
pooling794.lengths = {8, 8};
......@@ -2225,10 +2226,10 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::multibroadcast multibroadcast798;
multibroadcast798.output_lens = {batch, 1000};
auto mx798 = mm->add_instruction(multibroadcast798, mx0);
migraphx::op::dot dot799;
dot799.alpha = 1;
dot799.beta = 1;
mm->add_instruction(dot799, mx796, mx797, mx798);
float dot799_alpha = 1;
float dot799_beta = 1;
migraphx::add_apply_alpha_beta(
*mm, {mx796, mx797, mx798}, migraphx::make_op("dot"), dot799_alpha, dot799_beta);
return p;
}
......
#include "verify.hpp"
#include "argument_parser.hpp"
#include "command.hpp"
#include "verify.hpp"
#include "precision.hpp"
#include "perf.hpp"
#include "models.hpp"
#include "marker_roctx.hpp"
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
......@@ -287,14 +289,12 @@ struct compiler_target
struct compiler
{
static const int q_fp16 = 1;
static const int q_int8 = 2;
loader l;
program_params parameters;
compiler_target ct;
bool offload_copy = false;
bool fast_math = true;
int quantize = 0;
bool offload_copy = false;
bool fast_math = true;
precision quantize = precision::fp32;
std::vector<std::string> fill0;
std::vector<std::string> fill1;
......@@ -311,8 +311,8 @@ struct compiler
{"--disable-fast-math"},
ap.help("Disable fast math optimization"),
ap.set_value(false));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(q_fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(q_int8));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8));
}
auto params(const program& p) { return parameters.generate(p, ct.get_target(), offload_copy); }
......@@ -324,11 +324,11 @@ struct compiler
if(p.is_compiled())
return p;
auto t = ct.get_target();
if(quantize == q_fp16)
if(quantize == precision::fp16)
{
quantize_fp16(p);
}
else if(quantize == q_int8)
else if(quantize == precision::int8)
{
quantize_int8(p, t, {params(p)});
}
......@@ -376,6 +376,7 @@ struct verify : command<verify>
bool reduce = false;
bool offload_copy = false;
bool fast_math = true;
precision quantize = precision::fp32;
void parse(argument_parser& ap)
{
l.parse(ap);
......@@ -395,6 +396,7 @@ struct verify : command<verify>
ap.help("Verify each instruction"),
ap.set_value(true));
ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
}
void run()
......@@ -411,15 +413,15 @@ struct verify : command<verify>
if(per_instruction)
{
verify_instructions(p, t, options, tolerance);
verify_instructions(p, t, options, quantize, tolerance);
}
else if(reduce)
{
verify_reduced_program(p, t, options, m, tolerance);
verify_reduced_program(p, t, options, quantize, m, tolerance);
}
else
{
verify_program(l.file, p, t, options, m, tolerance);
verify_program(l.file, p, t, options, quantize, m, tolerance);
}
}
};
......@@ -479,15 +481,34 @@ struct perf : command<perf>
std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p);
std::cout << "Running performance report ... " << std::endl;
p.perf_report(std::cout, n, m);
p.perf_report(std::cout, n, m, c.l.batch);
}
};
struct roctx : command<roctx>
{
compiler c;
void parse(argument_parser& ap) { c.parse(ap); }
void run()
{
std::cout << "Compiling ... " << std::endl;
auto p = c.compile();
std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p);
std::cout << "rocTX:\tLoading rocTX library..." << std::endl;
auto rtx = create_marker_roctx();
p.mark(m, std::move(rtx));
}
};
struct op : command<op>
{
bool show_ops = false;
std::string op_name{};
void parse(argument_parser& ap)
{
ap(op_name, {}, ap.metavar("<MIGraphX operator name>"));
ap(show_ops,
{"--list", "-l"},
ap.help("List all the operators of MIGraphX"),
......@@ -500,6 +521,32 @@ struct op : command<op>
for(const auto& name : get_operators())
std::cout << name << std::endl;
}
else
{
auto op = load_op(op_name);
std::cout << op_name << ": " << std::endl;
std::cout << to_pretty_json_string(op.to_value()) << std::endl;
}
}
};
struct onnx : command<onnx>
{
bool show_ops = false;
void parse(argument_parser& ap)
{
ap(show_ops,
{"--list", "-l"},
ap.help("List all onnx operators supported by MIGraphX"),
ap.set_value(true));
}
void run() const
{
if(show_ops)
{
for(const auto& name : get_onnx_operators())
std::cout << name << std::endl;
}
}
};
......
#include "marker_roctx.hpp"
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
class marker_roctx
{
std::function<void(const char*)> sym_roctx_mark;
std::function<uint64_t(const char*)> sym_roctx_range_start;
std::function<void(uint64_t)> sym_roctx_range_stop;
std::function<int(const char*)> sym_roctx_range_push;
std::function<int()> sym_roctx_range_pop;
uint64_t range_id = 0;
public:
marker_roctx()
{
dynamic_loader lib = migraphx::dynamic_loader{"libroctx64.so"};
sym_roctx_mark = lib.get_function<void(const char*)>("roctxMarkA");
sym_roctx_range_start = lib.get_function<uint64_t(const char*)>("roctxRangeStartA");
sym_roctx_range_stop = lib.get_function<void(uint64_t)>("roctxRangeStop");
sym_roctx_range_push = lib.get_function<int(const char*)>("roctxRangePushA");
sym_roctx_range_pop = lib.get_function<int()>("roctxRangePop");
sym_roctx_mark("rocTX marker created.");
}
void mark_start(instruction_ref ins_ref)
{
std::string text = "Marker start: " + ins_ref->name();
sym_roctx_range_push(text.c_str());
}
void mark_stop(instruction_ref) { sym_roctx_range_pop(); }
void mark_start(const program&) { range_id = sym_roctx_range_start("0"); }
void mark_stop(const program&) { sym_roctx_range_stop(range_id); }
};
marker create_marker_roctx() { return marker_roctx(); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment