Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
#include <migraphx/adjust_allocation.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void adjust_allocation::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
// skip instruction with no input
if(ins->inputs().empty())
continue;
// Skip target-independent operators
if(ins->get_operator().is_context_free())
continue;
auto alias_ins = instruction::get_output_alias(ins, true);
if(alias_ins->name() != model.name() and alias_ins->name() != "@param")
continue;
// shape allocated is different from actual shape
// of the instruction, reallocate and replace the previous one
if(alias_ins->get_shape() == ins->get_shape())
continue;
auto alloc_ins = m.insert_instruction(ins, model.allocate(ins->get_shape()));
m.replace_instruction(alias_ins, alloc_ins);
// If the memory is an output parameter then copy the memory to the parameter
if(alias_ins->name() == "@param")
{
auto copy = m.insert_instruction(std::next(ins), make_op(model.copy()), ins, alias_ins);
auto tail = range(std::next(copy), m.end());
for(auto i : iterator_for(tail))
{
if(contains(i->inputs(), ins))
instruction::replace_argument(i, ins, copy);
}
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/analyze_streams.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/errors.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool happens_before(const std::vector<std::size_t>& e1, const std::vector<std::size_t>& e2)
{
return std::equal(e1.begin(), e1.end(), e2.begin(), e2.end(), std::less_equal<>{}) and
not std::equal(e1.begin(), e1.end(), e2.begin(), e2.end(), std::greater_equal<>{});
}
std::vector<stream_race> analyze_streams(const module& m, const stream_model& strmm)
{
using vector_clock = std::vector<std::size_t>;
std::vector<stream_race> races;
auto nstream = strmm.get_nstream();
std::vector<vector_clock> vclock(nstream, vector_clock(nstream));
std::unordered_map<instruction_ref, vector_clock> timestamp;
std::unordered_map<std::size_t, vector_clock> events;
for(auto ins : iterator_for(m))
{
if(not strmm.has_stream(ins))
continue;
std::size_t s = strmm.get_stream(ins);
assert(s < nstream);
assert(vclock.size() == nstream);
assert(vclock[s].size() == nstream);
if(strmm.is_record(ins))
{
vclock[s][s]++;
auto event = strmm.get_event_id(ins);
events[event] = vclock[s];
}
else if(strmm.is_wait(ins))
{
auto event = strmm.get_event_id(ins);
if(not contains(events, event))
MIGRAPHX_THROW("Event is waited on before being recorded: " +
std::to_string(event));
auto payload = events.at(event);
assert(vclock[s].size() == payload.size());
std::transform(vclock[s].begin(),
vclock[s].end(),
payload.begin(),
vclock[s].begin(),
[&](auto x, auto y) { return std::max(x, y); });
vclock[s][s]++;
}
else
{
vclock[s][s]++;
}
timestamp[ins] = vclock[s];
}
for(auto ins : iterator_for(m))
{
if(not strmm.has_stream(ins))
continue;
if(ins->inputs().empty())
continue;
std::size_t s = strmm.get_stream(ins);
// Find inputs from different streams
std::vector<instruction_ref> inputs;
fix([&](auto self, auto start) {
for(auto input : start->inputs())
{
if(not strmm.has_stream(input))
self(input);
else if(strmm.get_stream(input) != s)
inputs.push_back(input);
}
})(ins);
auto it = std::find_if(inputs.begin(), inputs.end(), [&](auto input) {
return not happens_before(timestamp.at(input), timestamp.at(ins));
});
if(it != inputs.end())
{
races.push_back({ins, *it});
}
}
return races;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
add_library(migraphx_c
api.cpp
)
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 3.0)
rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
rocm_install_targets(
TARGETS migraphx_c
INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include
)
#include <migraphx/migraphx.h>
#include <migraphx/rank.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT
{
try
{
f();
}
catch(const migraphx::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
if(ex.error > 0)
return migraphx_status(ex.error);
else
return migraphx_status_unknown_error;
}
catch(const std::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
}
return migraphx_status_success;
}
shape::type_t to_shape_type(migraphx_shape_datatype_t t)
{
switch(t)
{
case migraphx_shape_tuple_type: return shape::tuple_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case migraphx_shape_##x: return shape::x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT)
#undef MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT
}
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
}
migraphx_shape_datatype_t to_shape_type(shape::type_t t)
{
switch(t)
{
case shape::tuple_type: return migraphx_shape_tuple_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case shape::x: return migraphx_shape_##x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT)
#undef MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT
}
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
}
template <class T>
auto to_obj_vector(const T* x, std::size_t n)
{
std::vector<decltype((*x)->object)> result;
std::transform(x, x + n, std::back_inserter(result), [&](auto&& y) { return y->object; });
return result;
}
template <class T, class U>
auto to_objptr_vector(const U* x, std::size_t n)
{
std::vector<T> result;
std::transform(
x, x + n, std::back_inserter(result), [&](auto&& y) { return std::addressof(y->object); });
return result;
}
target get_target(const std::string& name) { return make_target(name); }
void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; }
void set_fast_math(compile_options& options, bool value) { options.fast_math = value; }
void set_file_format(file_options& options, const char* format) { options.format = format; }
void set_default_dim_value(onnx_options& options, size_t value)
{
options.default_dim_value = value;
}
void set_default_loop_iterations(onnx_options& options, int64_t value)
{
options.max_loop_iterations = value;
}
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
void set_input_parameter_shape(onnx_options& options,
const char* name,
std::vector<std::size_t> dims)
{
options.map_input_dims[std::string(name)] = std::move(dims);
}
void set_input_parameter_shape(tf_options& options, const char* name, std::vector<std::size_t> dims)
{
options.map_input_dims[std::string(name)] = std::move(dims);
}
void set_output_names(tf_options& options, std::vector<const char*> names)
{
options.output_node_names = std::vector<std::string>(names.begin(), names.end());
}
template <class Value>
std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m)
{
std::vector<const char*> result;
std::transform(
m.begin(), m.end(), std::back_inserter(result), [](auto&& p) { return p.first.c_str(); });
return result;
}
void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
{
if(names.empty())
{
names = {"all"};
}
migraphx::quantize_fp16(prog, names);
}
struct quantize_int8_options
{
std::vector<parameter_map> calibration = {};
std::vector<std::string> op_names = {};
};
void add_op_name(quantize_int8_options& options, const char* name)
{
options.op_names.push_back(name);
}
void add_calibration_data(quantize_int8_options& options, parameter_map& data)
{
options.calibration.push_back(data);
}
void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& options)
{
if(options.op_names.empty())
{
options.op_names = {"dot", "convolution"};
}
migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
}
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
operation create_op(const char* name, const char* attributes, va_list vlist)
{
std::string sattributes = attributes == nullptr ? "" : attributes;
std::vector<char> buffer(sattributes.size() * 2);
std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist);
value v = value::object{};
if(attributes != nullptr)
{
v = from_json_string(convert_to_json(std::string(buffer.data())));
}
auto op = make_op(name, v);
return op;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
template <class T>
bool equal(const T& x, const T& y)
{
return x == y;
}
std::vector<argument> run(program& p, const parameter_map& params) { return p.eval(params); }
std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; }
struct experimental_custom_op
{
std::string name;
experimental_custom_op() = default;
experimental_custom_op(std::string pname) : name(std::move(pname)) {}
};
template <class CustomOp>
struct custom_operation
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
CustomOp op;
std::string name() const { return op.xobject.name; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); }
};
template <class CustomOp>
void register_custom_op(const CustomOp& op)
{
register_op(custom_operation<CustomOp>{op});
}
migraphx::context get_context(const program& p) { return p.get_context(); }
} // namespace migraphx
template <class T, class U, class Target = std::remove_pointer_t<T>>
Target* object_cast(U* x)
{
return reinterpret_cast<Target*>(x);
}
template <class T, class U, class Target = std::remove_pointer_t<T>>
const Target* object_cast(const U* x)
{
return reinterpret_cast<const Target*>(x);
}
template <class T, class... Ts, class Target = std::remove_pointer_t<T>>
Target* allocate(Ts&&... xs)
{
return new Target(std::forward<Ts>(xs)...); // NOLINT
}
template <class T>
void destroy(T* x)
{
delete x; // NOLINT
}
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
{
manage_generic_ptr() = default;
manage_generic_ptr(std::nullptr_t) {}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
{
other.data = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
}
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
}
~manage_generic_ptr()
{
if(data != nullptr)
deleter(data);
}
void* data = nullptr;
C copier = nullptr;
D deleter = nullptr;
};
extern "C" struct migraphx_shape;
struct migraphx_shape
{
template <class... Ts>
migraphx_shape(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::shape object;
};
extern "C" struct migraphx_argument;
struct migraphx_argument
{
template <class... Ts>
migraphx_argument(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::argument object;
};
extern "C" struct migraphx_target;
struct migraphx_target
{
template <class... Ts>
migraphx_target(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::target object;
};
extern "C" struct migraphx_program_parameter_shapes;
struct migraphx_program_parameter_shapes
{
template <class... Ts>
migraphx_program_parameter_shapes(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::unordered_map<std::string, migraphx::shape> object;
};
extern "C" struct migraphx_program_parameters;
struct migraphx_program_parameters
{
template <class... Ts>
migraphx_program_parameters(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::unordered_map<std::string, migraphx::argument> object;
};
extern "C" struct migraphx_arguments;
struct migraphx_arguments
{
template <class... Ts>
migraphx_arguments(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::argument> object;
};
extern "C" struct migraphx_shapes;
struct migraphx_shapes
{
template <class... Ts>
migraphx_shapes(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::shape> object;
};
extern "C" struct migraphx_instruction;
struct migraphx_instruction
{
template <class... Ts>
migraphx_instruction(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::instruction_ref object;
};
extern "C" struct migraphx_instructions;
struct migraphx_instructions
{
template <class... Ts>
migraphx_instructions(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::instruction_ref> object;
};
extern "C" struct migraphx_modules;
struct migraphx_modules
{
template <class... Ts>
migraphx_modules(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::module*> object;
};
extern "C" struct migraphx_module;
struct migraphx_module
{
template <class... Ts>
migraphx_module(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::module object;
};
extern "C" struct migraphx_program;
struct migraphx_program
{
template <class... Ts>
migraphx_program(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::program object;
};
extern "C" struct migraphx_operation;
struct migraphx_operation
{
template <class... Ts>
migraphx_operation(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::operation object;
};
extern "C" struct migraphx_onnx_options;
struct migraphx_onnx_options
{
template <class... Ts>
migraphx_onnx_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::onnx_options object;
};
extern "C" struct migraphx_file_options;
struct migraphx_file_options
{
template <class... Ts>
migraphx_file_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::file_options object;
};
extern "C" struct migraphx_compile_options;
struct migraphx_compile_options
{
template <class... Ts>
migraphx_compile_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::compile_options object;
};
extern "C" struct migraphx_tf_options;
struct migraphx_tf_options
{
template <class... Ts>
migraphx_tf_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::tf_options object;
};
extern "C" struct migraphx_quantize_op_names;
struct migraphx_quantize_op_names
{
template <class... Ts>
migraphx_quantize_op_names(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<std::string> object;
};
extern "C" struct migraphx_quantize_int8_options;
struct migraphx_quantize_int8_options
{
template <class... Ts>
migraphx_quantize_int8_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::quantize_int8_options object;
};
extern "C" struct migraphx_context;
struct migraphx_context
{
template <class... Ts>
migraphx_context(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::context object;
};
extern "C" struct migraphx_experimental_custom_op;
struct migraphx_experimental_custom_op
{
template <class... Ts>
migraphx_experimental_custom_op(void* p,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
{
}
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
object_ptr = nullptr;
migraphx::experimental_custom_op xobject;
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing.");
auto api_error_result =
compute_shape_f(&out, object_ptr.data, object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute_shape.");
return (&out)->object;
}
};
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] { destroy((shape)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_assign_to(migraphx_shape_t output,
const_migraphx_shape_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size)
{
auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)),
(std::vector<size_t>(lengths, lengths + lengths_size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size,
size_t* strides,
size_t strides_size)
{
auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
if(strides == nullptr and strides_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter strides: Null pointer");
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)),
(std::vector<size_t>(lengths, lengths + lengths_size)),
(std::vector<size_t>(strides, strides + strides_size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type)
{
auto api_error_result = migraphx::try_([&] {
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type))));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
auto&& api_result = (shape->object).lens();
*out = api_result.data();
*out_size = api_result.size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
auto&& api_result = (shape->object).strides();
*out = api_result.data();
*out_size = api_result.size();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = migraphx::to_shape_type((shape->object).type());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).bytes();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((shape->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{
auto api_error_result = migraphx::try_([&] { destroy((argument)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*argument = object_cast<migraphx_argument_t>(
allocate<migraphx::argument>((shape->object), (buffer)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument)
{
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((argument->object).get_shape()));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument)
{
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = (argument->object).data();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x)
{
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((argument->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed)
{
auto api_error_result = migraphx::try_([&] {
if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_argument_t>(migraphx::generate_argument((s->object), (seed)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target)
{
auto api_error_result = migraphx::try_([&] { destroy((target)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_target_assign_to(migraphx_target_t output,
const_migraphx_target_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name)
{
auto api_error_result = migraphx::try_([&] {
*target = object_cast<migraphx_target_t>(
allocate<migraphx::target>(migraphx::get_target((name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes)
{
auto api_error_result = migraphx::try_([&] { destroy((program_parameter_shapes)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output,
const_migraphx_program_parameter_shapes_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameter_shapes_size(size_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes)
{
auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer");
*out = (program_parameter_shapes->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes,
const char* name)
{
auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer");
*out =
object_cast<const_migraphx_shape_t>(&((program_parameter_shapes->object).at((name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_parameter_shapes_names(
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes)
{
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer");
auto&& api_result = migraphx::get_names((program_parameter_shapes->object));
std::copy(api_result.begin(), api_result.end(), out);
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters)
{
auto api_error_result = migraphx::try_([&] { destroy((program_parameters)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_assign_to(migraphx_program_parameters_t output,
const_migraphx_program_parameters_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters)
{
auto api_error_result = migraphx::try_([&] {
*program_parameters = object_cast<migraphx_program_parameters_t>(
allocate<std::unordered_map<std::string, migraphx::argument>>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters,
const char* name,
const_migraphx_argument_t argument)
{
auto api_error_result = migraphx::try_([&] {
if(program_parameters == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameters: Null pointer");
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
(program_parameters->object)[(name)] = (argument->object);
});
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments)
{
auto api_error_result = migraphx::try_([&] { destroy((arguments)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments)
{
auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = (arguments->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx)
{
auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = object_cast<const_migraphx_argument_t>(&((arguments->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes)
{
auto api_error_result = migraphx::try_([&] { destroy((shapes)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output,
const_migraphx_shapes_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes)
{
auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = (shapes->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx)
{
auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((shapes->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction)
{
auto api_error_result = migraphx::try_([&] { destroy((instruction)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output,
const_migraphx_instruction_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions)
{
auto api_error_result = migraphx::try_([&] { destroy((instructions)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output,
const_migraphx_instructions_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr,
size_t size)
{
auto api_error_result = migraphx::try_([&] {
*instructions =
object_cast<migraphx_instructions_t>(allocate<std::vector<migraphx::instruction_ref>>(
migraphx::to_obj_vector<const_migraphx_instruction_t>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_modules_destroy(migraphx_modules_t modules)
{
auto api_error_result = migraphx::try_([&] { destroy((modules)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size)
{
auto api_error_result = migraphx::try_([&] {
*modules = object_cast<migraphx_modules_t>(allocate<std::vector<migraphx::module*>>(
migraphx::to_objptr_vector<migraphx::module*>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_create(migraphx_module_t* module, char* name)
{
auto api_error_result = migraphx::try_([&] {
if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
*module = object_cast<migraphx_module_t>(allocate<migraphx::module>((std::string(name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
migraphx::print_module((module->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter op: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_instruction((op->object), (args->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args,
migraphx_modules_t module_refs)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter op: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
if(module_refs == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module_refs: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_instruction((op->object), (args->object), (module_refs->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t shape,
const char* buffer)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_literal((shape->object), (buffer)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module,
const char* name,
const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_parameter((name), (shape->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_instructions_t args)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
*out = allocate<migraphx_instruction_t>((module->object).add_return((args->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] { destroy((program)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_create(migraphx_program_t* program)
{
auto api_error_result = migraphx::try_(
[&] { *program = object_cast<migraphx_program_t>(allocate<migraphx::program>()); });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).get_main_module());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_create_module(migraphx_module_t* out, migraphx_program_t program, const char* name)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).create_module((name)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options_t options)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(target == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter target: Null pointer");
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
(program->object).compile((target->object), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out,
migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out =
allocate<migraphx_program_parameter_shapes_t>((program->object).get_parameter_shapes());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = allocate<migraphx_shapes_t>(migraphx::get_output_shapes((program->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
migraphx::print_program((program->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
(program->object).sort();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program,
migraphx_program_parameters_t params)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(params == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer");
*out = allocate<migraphx_arguments_t>(migraphx::run((program->object), (params->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((program->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_experimental_get_context(migraphx_context_t* out, const_migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = allocate<migraphx_context_t>(migraphx::get_context((program->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation)
{
auto api_error_result = migraphx::try_([&] { destroy((operation)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes,
...)
{
va_list vlist;
va_start(vlist, attributes);
auto api_error_result = migraphx::try_([&] {
*operation = object_cast<migraphx_operation_t>(
allocate<migraphx::operation>(migraphx::create_op((name), (attributes), (vlist))));
});
va_end(vlist);
return api_error_result;
}
extern "C" migraphx_status
migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation)
{
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(operation == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter operation: Null pointer");
auto&& api_result = (operation->object).name();
auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out);
*it = '\0';
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options)
{
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::load((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options)
{
auto api_error_result = migraphx::try_([&] {
if(p == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer");
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::save((p->object), (name), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options)
{
auto api_error_result = migraphx::try_([&] { destroy((onnx_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output,
const_migraphx_onnx_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options)
{
auto api_error_result = migraphx::try_([&] {
*onnx_options = object_cast<migraphx_onnx_options_t>(allocate<migraphx::onnx_options>());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr and dims_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
migraphx::set_input_parameter_shape(
(onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_dim_value((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_loop_iterations((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options)
{
auto api_error_result = migraphx::try_([&] { destroy((file_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output,
const_migraphx_file_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options)
{
auto api_error_result = migraphx::try_([&] {
*file_options = object_cast<migraphx_file_options_t>(allocate<migraphx::file_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format)
{
auto api_error_result = migraphx::try_([&] {
if(file_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter file_options: Null pointer");
migraphx::set_file_format((file_options->object), (format));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options)
{
auto api_error_result = migraphx::try_([&] { destroy((compile_options)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_assign_to(migraphx_compile_options_t output,
const_migraphx_compile_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options)
{
auto api_error_result = migraphx::try_([&] {
*compile_options =
object_cast<migraphx_compile_options_t>(allocate<migraphx::compile_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value)
{
auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_offload_copy((compile_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value)
{
auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_fast_math((compile_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options)
{
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_onnx((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
const void* data,
size_t size,
migraphx_onnx_options_t options)
{
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(
migraphx::parse_onnx_buffer((data), (size), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options)
{
auto api_error_result = migraphx::try_([&] { destroy((tf_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options)
{
auto api_error_result = migraphx::try_([&] {
*tf_options = object_cast<migraphx_tf_options_t>(allocate<migraphx::tf_options>());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
bool is_nhwc)
{
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_nhwc((tf_options->object), (is_nhwc));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size)
{
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(dims == nullptr and dims_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
migraphx::set_input_parameter_shape(
(tf_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value)
{
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_default_dim_value((tf_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options,
const char** names,
size_t names_size)
{
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(names == nullptr and names_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter names: Null pointer");
migraphx::set_output_names((tf_options->object),
(std::vector<const char*>(names, names + names_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options)
{
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_tf((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names)
{
auto api_error_result = migraphx::try_([&] { destroy((quantize_op_names)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output,
const_migraphx_quantize_op_names_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names)
{
auto api_error_result = migraphx::try_([&] {
*quantize_op_names =
object_cast<migraphx_quantize_op_names_t>(allocate<std::vector<std::string>>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name)
{
auto api_error_result = migraphx::try_([&] {
if(quantize_op_names == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_op_names: Null pointer");
(quantize_op_names->object).push_back((name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog,
migraphx_quantize_op_names_t name)
{
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
migraphx::quantize_fp16_with_op_names((prog->object), (name->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog)
{
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
migraphx::quantize_fp16((prog->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options)
{
auto api_error_result = migraphx::try_([&] { destroy((quantize_int8_options)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output,
const_migraphx_quantize_int8_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options)
{
auto api_error_result = migraphx::try_([&] {
*quantize_int8_options = object_cast<migraphx_quantize_int8_options_t>(
allocate<migraphx::quantize_int8_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options,
const char* name)
{
auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
migraphx::add_op_name((quantize_int8_options->object), (name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data)
{
auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
if(data == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer");
migraphx::add_calibration_data((quantize_int8_options->object), (data->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_int8_options_t options)
{
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(target == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter target: Null pointer");
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::quantize_int8_wrap((prog->object), (target->object), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t context)
{
auto api_error_result = migraphx::try_([&] {
if(context == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter context: Null pointer");
(context->object).finish();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context)
{
auto api_error_result = migraphx::try_([&] {
if(context == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter context: Null pointer");
*out = (context->object).get_queue().unsafe_get();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op)
{
auto api_error_result = migraphx::try_([&] { destroy((experimental_custom_op)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output,
const_migraphx_experimental_custom_op_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* name)
{
auto api_error_result = migraphx::try_([&] {
*experimental_custom_op =
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input)
{
auto api_error_result = migraphx::try_([&] { (obj)->compute_shape_f = (input); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op)
{
auto api_error_result = migraphx::try_([&] {
if(experimental_custom_op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter experimental_custom_op: Null pointer");
migraphx::register_custom_op((*experimental_custom_op));
});
return api_error_result;
}
#ifndef MIGRAPHX_GUARD_C_API_MIGRAPHX_H
#define MIGRAPHX_GUARD_C_API_MIGRAPHX_H
#include <stdlib.h>
// Add new types here
// clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(bool_type, bool) \
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
m(int16_type, int16_t) \
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
// clang-format on
#ifdef __cplusplus
extern "C" {
#endif
// return code, more to be added later
typedef enum
{
migraphx_status_success = 0,
migraphx_status_bad_param = 1,
migraphx_status_unknown_target = 3,
migraphx_status_unknown_error = 4,
} migraphx_status;
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs
typedef enum
{
migraphx_shape_tuple_type,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
} migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
typedef struct migraphx_shape* migraphx_shape_t;
typedef const struct migraphx_shape* const_migraphx_shape_t;
typedef struct migraphx_argument* migraphx_argument_t;
typedef const struct migraphx_argument* const_migraphx_argument_t;
typedef struct migraphx_target* migraphx_target_t;
typedef const struct migraphx_target* const_migraphx_target_t;
typedef struct migraphx_program_parameter_shapes* migraphx_program_parameter_shapes_t;
typedef const struct migraphx_program_parameter_shapes* const_migraphx_program_parameter_shapes_t;
typedef struct migraphx_program_parameters* migraphx_program_parameters_t;
typedef const struct migraphx_program_parameters* const_migraphx_program_parameters_t;
typedef struct migraphx_arguments* migraphx_arguments_t;
typedef const struct migraphx_arguments* const_migraphx_arguments_t;
typedef struct migraphx_shapes* migraphx_shapes_t;
typedef const struct migraphx_shapes* const_migraphx_shapes_t;
typedef struct migraphx_instruction* migraphx_instruction_t;
typedef const struct migraphx_instruction* const_migraphx_instruction_t;
typedef struct migraphx_instructions* migraphx_instructions_t;
typedef const struct migraphx_instructions* const_migraphx_instructions_t;
typedef struct migraphx_modules* migraphx_modules_t;
typedef const struct migraphx_modules* const_migraphx_modules_t;
typedef struct migraphx_module* migraphx_module_t;
typedef const struct migraphx_module* const_migraphx_module_t;
typedef struct migraphx_program* migraphx_program_t;
typedef const struct migraphx_program* const_migraphx_program_t;
typedef struct migraphx_operation* migraphx_operation_t;
typedef const struct migraphx_operation* const_migraphx_operation_t;
typedef struct migraphx_onnx_options* migraphx_onnx_options_t;
typedef const struct migraphx_onnx_options* const_migraphx_onnx_options_t;
typedef struct migraphx_file_options* migraphx_file_options_t;
typedef const struct migraphx_file_options* const_migraphx_file_options_t;
typedef struct migraphx_compile_options* migraphx_compile_options_t;
typedef const struct migraphx_compile_options* const_migraphx_compile_options_t;
typedef struct migraphx_tf_options* migraphx_tf_options_t;
typedef const struct migraphx_tf_options* const_migraphx_tf_options_t;
typedef struct migraphx_quantize_op_names* migraphx_quantize_op_names_t;
typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_names_t;
typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t;
typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t;
typedef struct migraphx_context* migraphx_context_t;
typedef const struct migraphx_context* const_migraphx_context_t;
typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t;
typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t;
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj,
migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input);
migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size);
migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size,
size_t* strides,
size_t strides_size);
migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type);
migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape);
migraphx_status
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input);
migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer);
migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument);
migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument);
migraphx_status
migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x);
migraphx_status
migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed);
migraphx_status migraphx_target_destroy(migraphx_target_t target);
migraphx_status migraphx_target_assign_to(migraphx_target_t output, const_migraphx_target_t input);
migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name);
migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output,
const_migraphx_program_parameter_shapes_t input);
migraphx_status migraphx_program_parameter_shapes_size(
size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status
migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes,
const char* name);
migraphx_status migraphx_program_parameter_shapes_names(
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters);
migraphx_status migraphx_program_parameters_assign_to(migraphx_program_parameters_t output,
const_migraphx_program_parameters_t input);
migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters);
migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters,
const char* name,
const_migraphx_argument_t argument);
migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments);
migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input);
migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments);
migraphx_status
migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx);
migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes);
migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output, const_migraphx_shapes_t input);
migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx);
migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction);
migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output,
const_migraphx_instruction_t input);
migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions);
migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output,
const_migraphx_instructions_t input);
migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr,
size_t size);
migraphx_status migraphx_modules_destroy(migraphx_modules_t modules);
migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input);
migraphx_status
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size);
migraphx_status migraphx_module_create(migraphx_module_t* module, char* name);
migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args);
migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args,
migraphx_modules_t module_refs);
migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t shape,
const char* buffer);
migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module,
const char* name,
const_migraphx_shape_t shape);
migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_instructions_t args);
migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input);
migraphx_status migraphx_program_create(migraphx_program_t* program);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program);
migraphx_status migraphx_program_create_module(migraphx_module_t* out,
migraphx_program_t program,
const char* name);
migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options_t options);
migraphx_status migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out,
migraphx_program_t program);
migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_program_t program);
migraphx_status migraphx_program_print(const_migraphx_program_t program);
migraphx_status migraphx_program_sort(migraphx_program_t program);
migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program,
migraphx_program_parameters_t params);
migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x);
migraphx_status migraphx_program_experimental_get_context(migraphx_context_t* out,
const_migraphx_program_t program);
migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes,
...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options);
migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options);
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output,
const_migraphx_onnx_options_t input);
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size);
migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options,
size_t value);
migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value);
migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options);
migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output,
const_migraphx_file_options_t input);
migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options);
migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options,
const char* format);
migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options);
migraphx_status migraphx_compile_options_assign_to(migraphx_compile_options_t output,
const_migraphx_compile_options_t input);
migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options);
migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value);
migraphx_status migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options,
bool value);
migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options);
migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
const void* data,
size_t size,
migraphx_onnx_options_t options);
migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options);
migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input);
migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options);
migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc);
migraphx_status migraphx_tf_options_set_input_parameter_shape(migraphx_tf_options_t tf_options,
const char* name,
size_t* dims,
size_t dims_size);
migraphx_status migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options,
size_t value);
migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options,
const char** names,
size_t names_size);
migraphx_status
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options);
migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
migraphx_status migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output,
const_migraphx_quantize_op_names_t input);
migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names,
const char* name);
migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog,
migraphx_quantize_op_names_t name);
migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);
migraphx_status
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output,
const_migraphx_quantize_int8_options_t input);
migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options);
migraphx_status
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options,
const char* name);
migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data);
migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_int8_options_t options);
migraphx_status migraphx_context_finish(const_migraphx_context_t context);
migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context);
migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
migraphx_status
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output,
const_migraphx_experimental_custom_op_t input);
migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* name);
migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
#ifdef __cplusplus
}
#endif
#endif
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <memory>
#include <exception>
#include <vector>
#include <cassert>
#include <iostream>
namespace migraphx {
#ifndef DOXYGEN
inline namespace api { // NOLINT
#endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs)
{
T* result = nullptr;
auto e = f(&result, std::forward<Ts>(xs)...);
if(e != migraphx_status_success)
throw std::runtime_error("Failed to call function");
return result;
}
template <class F, class... Ts>
void call(F f, Ts&&... xs)
{
auto e = f(std::forward<Ts>(xs)...);
if(e != migraphx_status_success)
throw std::runtime_error("Failed to call function");
}
template <class F, class Iterator = std::size_t>
struct iota_iterator
{
Iterator index;
F f;
using difference_type = std::ptrdiff_t;
using reference = decltype(f(std::declval<Iterator>()));
using value_type = typename std::remove_reference<reference>::type;
using pointer = typename std::add_pointer<value_type>::type;
using iterator_category = std::input_iterator_tag;
iota_iterator& operator+=(int n)
{
index += n;
return *this;
}
iota_iterator& operator-=(int n)
{
index += n;
return *this;
}
iota_iterator& operator++()
{
index++;
return *this;
}
iota_iterator& operator--()
{
index--;
return *this;
}
iota_iterator operator++(int) // NOLINT
{
iota_iterator it = *this;
index++;
return it;
}
iota_iterator operator--(int) // NOLINT
{
iota_iterator it = *this;
index--;
return it;
}
// TODO: operator->
reference operator*() const { return f(index); }
friend iota_iterator operator+(iota_iterator x, iota_iterator y)
{
return iota_iterator(x.index + y.index, x.f);
}
friend iota_iterator operator-(iota_iterator x, iota_iterator y)
{
return iota_iterator(x.index - y.index, x.f);
}
friend bool operator==(iota_iterator x, iota_iterator y) { return x.index == y.index; }
friend bool operator!=(iota_iterator x, iota_iterator y) { return x.index != y.index; }
};
template <class Derived>
struct array_base
{
const Derived& derived() const { return static_cast<const Derived&>(*this); }
template <class T>
using value_type_t = decltype(std::declval<T>()[0]);
struct iterator_read
{
const Derived* self;
template <class D = Derived>
value_type_t<D> operator()(size_t pidx) const
{
return (*self)[pidx];
}
};
template <class T>
using iterator_t = iota_iterator<iterator_read>;
bool empty() const { return derived().size() == 0; }
template <class D = Derived>
value_type_t<D> front() const
{
return derived()[0];
}
template <class D = Derived>
value_type_t<D> back() const
{
return derived()[derived().size() - 1];
}
template <class D = Derived>
iterator_t<D> begin() const
{
return {0, {&derived()}};
}
template <class D = Derived>
iterator_t<D> end() const
{
return {derived().size(), {&derived()}};
}
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wnon-template-friend"
#endif
template <class T>
struct holder
{
// Friend injection
friend auto migraphx_adl_handle_lookup(holder<T>);
// Function left unimplemented since its only used in non-evaluated
// context
T get() const;
};
template <class C, class T>
struct handle_lookup
{
friend auto migraphx_adl_handle_lookup(holder<T>) { return holder<C>{}; }
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
template <class T>
using as_handle = decltype(
migraphx_adl_handle_lookup(holder<std::remove_cv_t<std::remove_pointer_t<T>>>{}).get());
struct own
{
};
struct borrow
{
};
template <class T>
struct share
{
share(std::shared_ptr<T> p) : ptr(std::move(p)) {}
template <class U>
std::shared_ptr<U> alias(U* p) const
{
return std::shared_ptr<U>{ptr, p};
}
private:
std::shared_ptr<T> ptr;
};
template <class Derived, class T, class D, D Deleter, class A, A Assigner>
struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
{
using handle_type = T;
handle_base() : m_handle(nullptr) {}
template <class F, class... Ts>
void make_handle(F f, Ts&&... xs)
{
using type = typename std::remove_cv<T>::type;
set_handle(make<type>(f, std::forward<Ts>(xs)...), own{});
}
const std::shared_ptr<T>& get_handle() const { return m_handle; }
T* get_handle_ptr() const
{
assert(m_handle != nullptr);
return get_handle().get();
}
template <class U>
void set_handle(U* ptr, own)
{
m_handle = std::shared_ptr<U>{ptr, Deleter};
}
template <class U>
void set_handle(U* ptr, borrow)
{
m_handle = std::shared_ptr<U>{ptr, [](U*) {}};
}
template <class U, class V>
void set_handle(U* ptr, share<V> b)
{
m_handle = std::shared_ptr<T>{ptr, [b](U*) {}};
}
share<T> share_handle() const { return {m_handle}; }
template <class U>
void assign_to_handle(U* x)
{
Assigner(x, this->get_handle_ptr());
}
protected:
std::shared_ptr<T> m_handle;
};
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template <class Base>
struct interface_base : Base
{
interface_base() : Base() {}
protected:
template <class F>
static migraphx_status try_(F f) // NOLINT
{
try
{
f();
return migraphx_status_success;
}
catch(...)
{
return migraphx_status_unknown_error;
}
}
template <class F, class T, class... Ts>
void make_interface(F f, T& obj, Ts&&... xs)
{
auto copy = [](void** out, void* input) {
return try_([&] {
T** y = reinterpret_cast<T**>(out);
T* x = reinterpret_cast<T*>(input);
assert(x != nullptr and y != nullptr and *y == nullptr);
// cppcheck-suppress useSmartPointer
*y = new T(*x); // NOLINT
});
};
auto del = [](void* input) {
return try_([&] {
T* x = reinterpret_cast<T*>(input);
delete x; // NOLINT
});
};
this->make_handle(f, &obj, copy, del, std::forward<Ts>(xs)...);
}
template <class T, class Setter, class F>
void set_fp(Setter setter, F pf)
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status {
return try_([&] { call_cast_arg<T>(rank<1>{}, f, xs...); });
});
}
template <class T, class Setter, class F>
void set_auto_fp(Setter setter, F f)
{
return set_fp<T>(setter, [=](T& obj, auto out, auto... xs) {
auto_invoke(f, out, obj, auto_convert_param(rank<2>{}, xs)...);
});
}
struct no_out_arg
{
};
template <class T, class F, class X, class... Xs, class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs)
{
f(reinterpret_cast<T*>(obj), no_out_arg{}, xs...);
}
template <class T,
class F,
class R,
class X,
class... Xs,
class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs)
{
f(*reinterpret_cast<T*>(obj), result, xs...);
}
template <class F, class T, class... Ts>
void auto_invoke(F f, T* out, Ts&&... xs)
{
auto_assign(rank<2>{}, out, f(std::forward<Ts>(xs)...));
}
template <class F, class T, class... Ts>
void auto_invoke(F f, no_out_arg, Ts&&... xs)
{
f(std::forward<Ts>(xs)...);
}
template <class T, class = std::enable_if_t<std::is_fundamental<T>{} or std::is_enum<T>{}>>
T auto_convert_param(rank<0>, T x)
{
return x;
}
template <class T>
auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x})
{
return as_handle<T>{x};
}
template <class T>
auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}})
{
return as_handle<T>{x, borrow{}};
}
template <class T, class U>
void auto_assign(rank<0>, T* out, U x)
{
return *out = x;
}
template <class T, class U>
auto auto_assign(rank<1>, T* out, U x) -> decltype(x.assign_to_handle(out))
{
x.assign_to_handle(out);
}
};
// NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); })
template <class Base, class T>
using require_interface =
std::enable_if_t<std::is_base_of<Base, T>{} and not std::is_same<T, Base>{} and
std::is_copy_constructible<T>{} and std::is_final<T>{}>;
#ifdef DOXYGEN
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<name, \
const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy, \
decltype(&migraphx_##name##_assign_to), \
migraphx_##name##_assign_to>
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, )
// NOLINTNEXTLINE
#define MIGRAPHX_CONST_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, const)
/**
* @brief Describe shape of tensor
* @details A shape consists of a data type, lengths of multi-dimension tensor, and strides
*
*/
struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{
shape() {}
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(shape);
/// Construct a scalar shape
shape(migraphx_shape_datatype_t type)
{
this->make_handle(&migraphx_shape_create_scalar, type);
}
/// Construct a shape with its type and lengths. The strides are
/// automatically computed assumming a packed layout.
shape(migraphx_shape_datatype_t type, std::vector<size_t> plengths)
{
this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size());
}
shape(migraphx_shape_datatype_t type,
std::vector<size_t> plengths,
std::vector<size_t> pstrides)
{
this->make_handle(&migraphx_shape_create_with_strides,
type,
plengths.data(),
plengths.size(),
pstrides.data(),
pstrides.size());
}
std::vector<size_t> lengths() const
{
const size_t* pout;
size_t pout_size;
call(&migraphx_shape_lengths, &pout, &pout_size, this->get_handle_ptr());
return {pout, pout + pout_size};
}
std::vector<size_t> strides() const
{
const size_t* pout;
size_t pout_size;
call(&migraphx_shape_strides, &pout, &pout_size, this->get_handle_ptr());
return {pout, pout + pout_size};
}
migraphx_shape_datatype_t type() const
{
migraphx_shape_datatype_t pout;
call(&migraphx_shape_type, &pout, this->get_handle_ptr());
return pout;
}
size_t bytes() const
{
size_t pout;
call(&migraphx_shape_bytes, &pout, this->get_handle_ptr());
return pout;
}
friend bool operator==(const shape& px, const shape& py)
{
bool pout;
call(&migraphx_shape_equal, &pout, px.get_handle_ptr(), py.get_handle_ptr());
return pout;
}
friend bool operator!=(const shape& px, const shape& py) { return !(px == py); }
};
/**
* @brief Arguments to be passed to an migraphx arguments
*
* An `argument` represents a raw buffer of data with a shape.
*
*/
struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
argument() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(argument);
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
argument(shape pshape, void* pbuffer)
{
this->make_handle(&migraphx_argument_create, pshape.get_handle_ptr(), pbuffer);
}
shape get_shape() const
{
const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return {pout, this->share_handle()};
}
char* data() const
{
char* pout;
call(&migraphx_argument_buffer, &pout, this->get_handle_ptr());
return pout;
}
/// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0)
{
return {make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed),
own{}};
}
friend bool operator==(const argument& px, const argument& py)
{
bool pout;
call(&migraphx_argument_equal, &pout, px.get_handle_ptr(), py.get_handle_ptr());
return pout;
}
friend bool operator!=(const argument& px, const argument& py) { return !(px == py); }
};
/// A target for compilation
struct target : MIGRAPHX_HANDLE_BASE(target)
{
target() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(target);
/// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); }
};
struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
program_parameter_shapes() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes);
size_t size() const
{
size_t pout;
call(&migraphx_program_parameter_shapes_size, &pout, this->get_handle_ptr());
return pout;
}
shape operator[](const char* pname) const
{
const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return {pout, this->share_handle()};
}
std::vector<const char*> names() const
{
std::vector<const char*> result(this->size());
if(!result.empty())
{
call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr());
}
return result;
}
};
/// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters);
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
program_parameters() { this->make_handle(&migraphx_program_parameters_create); }
/// Construct the parameters from initializer_list
program_parameters(std::initializer_list<std::pair<std::string, argument>> l)
{
this->make_handle(&migraphx_program_parameters_create);
for(auto&& p : l)
this->add(p.first.c_str(), p.second);
}
/// Add a new parameter
void add(const char* pname, const argument& pargument) const
{
call(&migraphx_program_parameters_add,
this->get_handle_ptr(),
pname,
pargument.get_handle_ptr());
}
};
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
MIGRAPHX_HANDLE_CONSTRUCTOR(arguments);
size_t size() const
{
size_t pout;
call(&migraphx_arguments_size, &pout, this->get_handle_ptr());
return pout;
}
argument operator[](size_t pidx) const
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return {pout, this->share_handle()};
}
};
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
MIGRAPHX_HANDLE_CONSTRUCTOR(shapes);
size_t size() const
{
size_t pout;
call(&migraphx_shapes_size, &pout, this->get_handle_ptr());
return pout;
}
shape operator[](size_t pidx) const
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return {pout, this->share_handle()};
}
};
struct operation : MIGRAPHX_HANDLE_BASE(operation)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(operation);
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{
this->make_handle(&migraphx_operation_create, name, attributes, xs...);
}
std::string name()
{
std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return {out_name.data()};
}
};
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(instruction);
};
struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions);
template <class... Ts>
instructions(Ts... xs)
{
std::array<const_migraphx_instruction_t, sizeof...(Ts)> a{xs.get_handle_ptr()...};
this->make_handle(&migraphx_instructions_create, a.data(), a.size());
}
};
struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(modules);
template <class... Ts>
modules(Ts... xs)
{
std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.get_handle_ptr()...};
this->make_handle(&migraphx_modules_create, a.data(), a.size());
}
};
struct module
{
MIGRAPHX_DEPRECATED("Constructor without lifetime annotation is deprecated.")
module(migraphx_module* m) : mm(std::shared_ptr<migraphx_module*>(), m) {}
module(migraphx_module* m, borrow) : mm(std::shared_ptr<migraphx_module*>(), m) {}
template <class T>
module(migraphx_module* m, share<T> b) : mm(b.alias(m))
{
}
void print() const { call(&migraphx_module_print, mm.get()); }
instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args)
{
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction,
&op_ins,
mm.get(),
op.get_handle_ptr(),
args.get_handle_ptr());
return instruction(op_ins, own{});
}
instruction add_instruction(const migraphx::operation& op,
const migraphx::instructions& args,
const migraphx::modules& module_args)
{
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction_with_mod_args,
&op_ins,
mm.get(),
op.get_handle_ptr(),
args.get_handle_ptr(),
module_args.get_handle_ptr());
return instruction(op_ins, own{});
}
template <typename T>
instruction add_literal(const migraphx::shape& s, T* buffer)
{
migraphx_instruction_t literal_ins;
const auto* buffer_ptr = reinterpret_cast<const char*>(buffer);
call(&migraphx_module_add_literal, &literal_ins, mm.get(), s.get_handle_ptr(), buffer_ptr);
return instruction(literal_ins, own{});
}
instruction add_parameter(const std::string& name, shape s)
{
migraphx_instruction_t param_ins;
call(
&migraphx_module_add_parameter, &param_ins, mm.get(), name.c_str(), s.get_handle_ptr());
return instruction(param_ins, own{});
}
instruction add_return(const migraphx::instructions& args)
{
migraphx_instruction_t ret_ins;
call(&migraphx_module_add_return, &ret_ins, mm.get(), args.get_handle_ptr());
return instruction(ret_ins, own{});
}
migraphx_module_t get_handle_ptr() const { return mm.get(); }
private:
std::shared_ptr<migraphx_module> mm;
};
struct context
{
context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
template <class T>
context(migraphx_context* p, share<T> b) : ctx(b.alias(p))
{
}
void finish() const { call(&migraphx_context_finish, ctx.get()); }
template <class T>
T get_queue()
{
void* out;
call(&migraphx_context_get_queue, &out, ctx.get());
// TODO: check type here
return reinterpret_cast<T>(out);
}
private:
std::shared_ptr<migraphx_context> ctx;
};
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{
compile_options() { this->make_handle(&migraphx_compile_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options);
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// offloaded memory and to copy the final result from the offloaded
/// memory back to main memory.
void set_offload_copy(bool value = true)
{
call(&migraphx_compile_options_set_offload_copy, this->get_handle_ptr(), value);
}
/// Optimize math functions to use faster approximate versions. There may
/// be slight accuracy degredation when enabled.
void set_fast_math(bool value = true)
{
call(&migraphx_compile_options_set_fast_math, this->get_handle_ptr(), value);
}
};
/// A program represents the all computation graphs to be compiled and executed
struct program : MIGRAPHX_HANDLE_BASE(program)
{
program() { this->make_handle(&migraphx_program_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(program);
/// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const
{
call(&migraphx_program_compile,
this->get_handle_ptr(),
ptarget.get_handle_ptr(),
poptions.get_handle_ptr());
}
/// Compile the program for a specific target to be ran on
void compile(const target& ptarget) const
{
call(&migraphx_program_compile,
this->get_handle_ptr(),
ptarget.get_handle_ptr(),
migraphx::compile_options{}.get_handle_ptr());
}
/// Return the shapes for the input parameters
program_parameter_shapes get_parameter_shapes() const
{
migraphx_program_parameter_shapes_t pout;
call(&migraphx_program_get_parameter_shapes, &pout, this->get_handle_ptr());
return program_parameter_shapes(pout, own{});
}
/// Get the shapes of all the outputs returned by this program
shapes get_output_shapes() const
{
migraphx_shapes_t pout;
call(&migraphx_program_get_output_shapes, &pout, this->get_handle_ptr());
return shapes(pout, own{});
}
/// Run the program using the inputs passed in
arguments eval(const program_parameters& pparams) const
{
migraphx_arguments_t pout;
call(&migraphx_program_run, &pout, this->get_handle_ptr(), pparams.get_handle_ptr());
return arguments(pout, own{});
}
void print() const { call(&migraphx_program_print, this->get_handle_ptr()); }
program sort()
{
call(&migraphx_program_sort, this->get_handle_ptr());
return *this;
}
friend bool operator==(const program& px, const program& py)
{
bool pout;
call(&migraphx_program_equal, &pout, px.get_handle_ptr(), py.get_handle_ptr());
return pout;
}
module get_main_module()
{
migraphx_module_t p_modu;
call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr());
return module{p_modu, this->share_handle()};
}
context experimental_get_context()
{
migraphx_context_t ctx;
call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr());
return context{ctx, this->share_handle()};
}
module create_module(const std::string& name)
{
migraphx_module_t p_modu;
call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data());
return module{p_modu, this->share_handle()};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
};
// options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options);
file_options() { this->make_handle(&migraphx_file_options_create); }
// set file format
void set_file_format(const char* format)
{
call(&migraphx_file_options_set_file_format, this->get_handle_ptr(), format);
}
};
/// Load a saved migraphx program from a file
inline program load(const char* filename, const file_options& options)
{
return program(make<migraphx_program>(&migraphx_load, filename, options.get_handle_ptr()),
own{});
}
/// Load a saved migraphx program from a file
inline program load(const char* filename)
{
return program(
make<migraphx_program>(&migraphx_load, filename, migraphx::file_options{}.get_handle_ptr()),
own{});
}
/// Save a program to a file
inline void save(const program& p, const char* filename, const file_options& options)
{
call(&migraphx_save, p.get_handle_ptr(), filename, options.get_handle_ptr());
}
/// Save a program to a file
inline void save(const program& p, const char* filename)
{
call(&migraphx_save, p.get_handle_ptr(), filename, migraphx::file_options{}.get_handle_ptr());
}
/// Options for parsing onnx options
struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options);
/// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
{
call(&migraphx_onnx_options_set_input_parameter_shape,
this->get_handle_ptr(),
name.c_str(),
dim.data(),
dim.size());
}
/// When there is a dimension parameter, then use this default value
void set_default_dim_value(unsigned int value)
{
call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value);
}
/// Set default max iteration number for the loop operator
void set_default_loop_iterations(int64_t value)
{
call(&migraphx_onnx_options_set_default_loop_iterations, this->get_handle_ptr(), value);
}
};
/// Parse an onnx file into a migraphx program
inline program parse_onnx(const char* filename, const migraphx::onnx_options& options)
{
return program(make<migraphx_program>(&migraphx_parse_onnx, filename, options.get_handle_ptr()),
own{});
}
/// Parse an onnx file into a migraphx program
inline program parse_onnx(const char* filename)
{
migraphx::onnx_options options;
return program(make<migraphx_program>(&migraphx_parse_onnx, filename, options.get_handle_ptr()),
own{});
}
/// Parse a buffer of memory as an onnx file
inline program
parse_onnx_buffer(const void* data, size_t size, const migraphx::onnx_options& options)
{
return program(
make<migraphx_program>(&migraphx_parse_onnx_buffer, data, size, options.get_handle_ptr()),
own{});
}
/// Parse a buffer of memory as an onnx file
inline program parse_onnx_buffer(const void* data, size_t size)
{
migraphx::onnx_options options;
return program(
make<migraphx_program>(&migraphx_parse_onnx_buffer, data, size, options.get_handle_ptr()),
own{});
}
/// Parse a buffer of memory as an onnx file
inline program parse_onnx_buffer(const std::string& buffer, const migraphx::onnx_options& options)
{
return program(
make<migraphx_program>(
&migraphx_parse_onnx_buffer, buffer.data(), buffer.size(), options.get_handle_ptr()),
own{});
}
/// Parse a buffer of memory as an onnx file
inline program parse_onnx_buffer(const std::string& buffer)
{
migraphx::onnx_options options;
return program(
make<migraphx_program>(
&migraphx_parse_onnx_buffer, buffer.data(), buffer.size(), options.get_handle_ptr()),
own{});
}
/// Options for parsing tf options
struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
tf_options() { this->make_handle(&migraphx_tf_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options);
/// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
{
call(&migraphx_tf_options_set_input_parameter_shape,
this->get_handle_ptr(),
name.c_str(),
dim.data(),
dim.size());
}
/// Change data layout to NHWC (default is NCHW)
void set_nhwc(bool is_nhwc = true)
{
call(&migraphx_tf_options_set_nhwc, this->get_handle_ptr(), is_nhwc);
}
/// When there is a dimension parameter, then use this default value
void set_default_dim_value(unsigned int value)
{
call(&migraphx_tf_options_set_default_dim_value, this->get_handle_ptr(), value);
}
/// Set output node names to return specific outputs from graph
void set_output_names(std::vector<const char*> names)
{
call(&migraphx_tf_options_set_output_names,
this->get_handle_ptr(),
names.data(),
names.size());
}
};
/// Parse a tf file into a migraphx program
inline program parse_tf(const char* filename, const migraphx::tf_options& options)
{
return program(make<migraphx_program>(&migraphx_parse_tf, filename, options.get_handle_ptr()),
own{});
}
/// Parse a tf file into a migraphx program
inline program parse_tf(const char* filename)
{
migraphx::tf_options options;
return program(make<migraphx_program>(&migraphx_parse_tf, filename, options.get_handle_ptr()),
own{});
}
struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names);
void add(const std::string& name)
{
call(&migraphx_quantize_op_names_add, this->get_handle_ptr(), name.c_str());
}
};
/// Quantize program to use fp16
inline void quantize_fp16(const program& prog, const quantize_op_names& names)
{
call(&migraphx_quantize_fp16_with_op_names, prog.get_handle_ptr(), names.get_handle_ptr());
}
/// Quantize program to use fp16
inline void quantize_fp16(const program& prog)
{
call(&migraphx_quantize_fp16, prog.get_handle_ptr());
}
/// Options to be passed when quantizing for int8
struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options);
/// Add an operator that should be quantized
void add_op_name(const std::string& name)
{
call(&migraphx_quantize_int8_options_add_op_name, this->get_handle_ptr(), name.c_str());
}
/// Add calibrartion data to be used for quantizing
void add_calibration_data(const program_parameters& pp)
{
call(&migraphx_quantize_int8_options_add_calibration_data,
this->get_handle_ptr(),
pp.get_handle_ptr());
}
};
/// Quantize program to use int8
inline void
quantize_int8(const program& prog, const target& ptarget, const quantize_int8_options& options)
{
call(&migraphx_quantize_int8,
prog.get_handle_ptr(),
ptarget.get_handle_ptr(),
options.get_handle_ptr());
}
struct experimental_custom_op_base
{
virtual std::string name() const = 0;
virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default;
};
struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)>
{
template <class T>
experimental_custom_op(T& obj)
{
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
}
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
};
template <class T, class = require_interface<experimental_custom_op_base, T>>
void register_experimental_custom_op(T& obj)
{
experimental_custom_op op{obj};
op.register_op();
}
#ifndef DOXYGEN
} // namespace api
#endif
} // namespace migraphx
#endif
import api
def bad_param_error(msg):
return 'MIGRAPHX_THROW(migraphx_status_bad_param, "{}")'.format(msg)
api.error_type = 'migraphx_status'
api.success_type = 'migraphx_status_success'
api.try_wrap = 'migraphx::try_'
api.bad_param_error = bad_param_error
@api.cwrap('migraphx::shape::type_t')
def shape_type_wrap(p):
if p.returns:
p.add_param('migraphx_shape_datatype_t *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_shape_type(${result})']
else:
p.add_param('migraphx_shape_datatype_t')
p.read = 'migraphx::to_shape_type(${name})'
@api.cwrap('migraphx::compile_options')
def compile_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_compile_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_compile_options(${result})']
else:
p.add_param('migraphx_compile_options *')
p.read = '${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})'
@api.cwrap('migraphx::file_options')
def file_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_file_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_file_options(${result})']
else:
p.add_param('migraphx_file_options *')
p.read = '${name} == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*${name})'
@api.cwrap('migraphx::onnx_options')
def onnx_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_onnx_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_onnx_options(${result})']
else:
p.add_param('migraphx_onnx_options *')
p.read = '${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})'
@api.cwrap('migraphx::tf_options')
def tf_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_tf_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_tf_options(${result})']
else:
p.add_param('migraphx_tf_options *')
p.read = '${name} == nullptr ? migraphx::tf_options{} : migraphx::to_tf_options(*${name})'
def auto_handle(*args, **kwargs):
def with_handle(f):
return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
*args, **kwargs)(f)
return with_handle
@auto_handle()
def shape(h):
h.constructor(
'create',
api.params(type='migraphx::shape::type_t',
lengths='std::vector<size_t>'))
h.constructor(
'create_with_strides',
api.params(type='migraphx::shape::type_t',
lengths='std::vector<size_t>',
strides='std::vector<size_t>'))
h.constructor('create_scalar', api.params(type='migraphx::shape::type_t'))
h.method('lengths',
fname='lens',
returns='const std::vector<size_t>&',
const=True)
h.method('strides', returns='const std::vector<size_t>&', const=True)
h.method('type', returns='migraphx::shape::type_t', const=True)
h.method('bytes', returns='size_t', const=True)
h.method('equal',
api.params(x='const migraphx::shape&'),
invoke='migraphx::equal($@)',
returns='bool',
const=True)
@auto_handle()
def argument(h):
h.constructor('create',
api.params(shape='const migraphx::shape&', buffer='void*'))
h.method('shape',
fname='get_shape',
cpp_name='get_shape',
returns='const migraphx::shape&',
const=True)
h.method('buffer',
fname='data',
cpp_name='data',
returns='char*',
const=True)
h.method('equal',
api.params(x='const migraphx::argument&'),
invoke='migraphx::equal($@)',
returns='bool',
const=True)
api.add_function('migraphx_argument_generate',
api.params(s='const migraphx::shape&', seed='size_t'),
fname='migraphx::generate_argument',
returns='migraphx::argument')
@auto_handle()
def target(h):
h.constructor('create',
api.params(name='const char*'),
fname='migraphx::get_target')
@api.handle('migraphx_program_parameter_shapes',
'std::unordered_map<std::string, migraphx::shape>')
def program_parameter_shapes(h):
h.method('size', returns='size_t')
h.method('get',
api.params(name='const char*'),
fname='at',
cpp_name='operator[]',
returns='const migraphx::shape&')
h.method('names',
invoke='migraphx::get_names(${program_parameter_shapes})',
returns='std::vector<const char*>')
@api.handle('migraphx_program_parameters',
'std::unordered_map<std::string, migraphx::argument>')
def program_parameters(h):
h.constructor('create')
h.method('add',
api.params(name='const char*',
argument='const migraphx::argument&'),
invoke='${program_parameters}[${name}] = ${argument}')
@api.handle('migraphx_arguments', 'std::vector<migraphx::argument>')
def arguments(h):
h.method('size', returns='size_t')
h.method('get',
api.params(idx='size_t'),
fname='at',
cpp_name='operator[]',
returns='const migraphx::argument&')
@api.handle('migraphx_shapes', 'std::vector<migraphx::shape>')
def shapes(h):
h.method('size', returns='size_t')
h.method('get',
api.params(idx='size_t'),
fname='at',
cpp_name='operator[]',
returns='const migraphx::shape&')
@api.handle('migraphx_instruction', 'migraphx::instruction_ref')
def instruction(h):
pass
@api.handle('migraphx_instructions', 'std::vector<migraphx::instruction_ref>')
def instructions(h):
h.constructor(
'create',
api.params(ptr='const_migraphx_instruction_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_instruction_t>')
@api.handle('migraphx_modules', 'std::vector<migraphx::module*>')
def modules(h):
h.constructor('create',
api.params(ptr='migraphx_module_t*', size='size_t'),
fname='migraphx::to_objptr_vector<migraphx::module*>')
@auto_handle(ref=True)
def module(h):
h.constructor('create', api.params(name='std::string'))
h.method('print', invoke='migraphx::print_module($@)', const=True)
h.method('add_instruction',
api.params(op='migraphx::operation',
args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
h.method('add_instruction_with_mod_args',
api.params(op='migraphx::operation',
args='std::vector<migraphx::instruction_ref>',
module_refs='std::vector<migraphx::module*>'),
fname='add_instruction',
returns='migraphx::instruction_ref')
h.method('add_literal',
api.params(shape='const migraphx::shape&', buffer='const char*'),
returns='migraphx::instruction_ref')
h.method('add_parameter',
api.params(name='const char*', shape='const migraphx::shape&'),
returns='migraphx::instruction_ref')
h.method('add_return',
api.params(args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
@auto_handle()
def program(h):
h.constructor('create')
h.method('get_main_module', returns='migraphx::module*')
h.method('create_module',
api.params(name='const char*'),
returns='migraphx::module*')
h.method(
'compile',
api.params(target='migraphx::target',
options='migraphx::compile_options'))
h.method('get_parameter_shapes',
returns='std::unordered_map<std::string, migraphx::shape>')
h.method('get_output_shapes',
invoke='migraphx::get_output_shapes($@)',
returns='std::vector<migraphx::shape>')
h.method('print', invoke='migraphx::print_program($@)', const=True)
h.method('sort')
h.method('run',
api.params(
params='std::unordered_map<std::string, migraphx::argument>'),
invoke='migraphx::run($@)',
returns='std::vector<migraphx::argument>')
h.method('equal',
api.params(x='const migraphx::program&'),
invoke='migraphx::equal($@)',
returns='bool',
const=True)
h.method('experimental_get_context',
invoke='migraphx::get_context($@)',
const=True,
returns='migraphx::context')
@auto_handle()
def operation(h):
h.constructor('create',
api.params(name='const char*',
attributes='const char*',
vlist='...'),
fname='migraphx::create_op')
h.method('name', returns='std::string')
api.add_function('migraphx_load',
api.params(name='const char*',
options='migraphx::file_options'),
fname='migraphx::load',
returns='migraphx::program')
api.add_function('migraphx_save',
api.params(p='migraphx::program&',
name='const char*',
options='migraphx::file_options'),
fname='migraphx::save')
@auto_handle()
def onnx_options(h):
h.constructor('create')
h.method(
'set_input_parameter_shape',
api.params(name='const char*', dims='std::vector<size_t>'),
invoke='migraphx::set_input_parameter_shape($@)',
)
h.method(
'set_default_dim_value',
api.params(value='size_t'),
invoke='migraphx::set_default_dim_value($@)',
)
h.method(
'set_default_loop_iterations',
api.params(value='int64_t'),
invoke='migraphx::set_default_loop_iterations($@)',
)
@auto_handle()
def file_options(h):
h.constructor('create')
h.method('set_file_format',
api.params(format='const char*'),
invoke='migraphx::set_file_format($@)')
@auto_handle()
def compile_options(h):
h.constructor('create')
h.method('set_offload_copy',
api.params(value='bool'),
invoke='migraphx::set_offload_copy($@)')
h.method('set_fast_math',
api.params(value='bool'),
invoke='migraphx::set_fast_math($@)')
api.add_function('migraphx_parse_onnx',
api.params(name='const char*',
options='migraphx::onnx_options'),
fname='migraphx::parse_onnx',
returns='migraphx::program')
api.add_function('migraphx_parse_onnx_buffer',
api.params(data='const void*',
size='size_t',
options='migraphx::onnx_options'),
fname='migraphx::parse_onnx_buffer',
returns='migraphx::program')
@auto_handle()
def tf_options(h):
h.constructor('create')
h.method(
'set_nhwc',
api.params(is_nhwc='bool'),
invoke='migraphx::set_nhwc($@)',
)
h.method(
'set_input_parameter_shape',
api.params(name='const char*', dims='std::vector<size_t>'),
invoke='migraphx::set_input_parameter_shape($@)',
)
h.method(
'set_default_dim_value',
api.params(value='size_t'),
invoke='migraphx::set_default_dim_value($@)',
)
h.method(
'set_output_names',
api.params(names='std::vector<const char*>'),
invoke='migraphx::set_output_names($@)',
)
api.add_function('migraphx_parse_tf',
api.params(name='const char*',
options='migraphx::tf_options'),
fname='migraphx::parse_tf',
returns='migraphx::program')
@api.handle('migraphx_quantize_op_names', 'std::vector<std::string>')
def quantize_op_names(h):
h.constructor('create')
h.method('add', api.params(name='const char*'), fname='push_back')
api.add_function('migraphx_quantize_fp16_with_op_names',
api.params(prog='migraphx::program&',
name='std::vector<std::string>&'),
fname='migraphx::quantize_fp16_with_op_names')
api.add_function('migraphx_quantize_fp16',
api.params(prog='migraphx::program&'),
fname='migraphx::quantize_fp16')
@auto_handle()
def quantize_int8_options(h):
h.constructor('create')
h.method(
'add_op_name',
api.params(name='const char*'),
invoke='migraphx::add_op_name($@)',
)
h.method(
'add_calibration_data',
api.params(data='std::unordered_map<std::string, migraphx::argument>'),
invoke='migraphx::add_calibration_data($@)',
)
api.add_function('migraphx_quantize_int8',
api.params(prog='migraphx::program&',
target='migraphx::target',
options='migraphx::quantize_int8_options'),
fname='migraphx::quantize_int8_wrap')
@auto_handle(ref=True)
def context(h):
h.method('finish', const=True)
h.method('get_queue', returns='void*', fname='get_queue().unsafe_get')
@api.interface('migraphx_experimental_custom_op',
'migraphx::experimental_custom_op')
def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*'))
h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape')
h.method('register', invoke='migraphx::register_custom_op($@)')
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/apply_alpha_beta.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta)
{
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
if(a->get_shape().type() != input_type)
{
a = m.insert_instruction(pos, make_op("convert", {{"target_type", input_type}}), a);
}
}
auto op_res = m.insert_instruction(pos, op, a, b);
if(args.size() == 3)
{
if(not float_equal(beta.at<float>(0), 0.0) && args[2]->get_shape().elements() > 0)
{
auto out_lens = op_res->get_shape().lens();
auto c = args[2];
auto c_lens = c->get_shape().lens();
input_type = c->get_shape().type();
if(out_lens != c_lens)
{
c = m.insert_instruction(
pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = m.add_literal(beta);
auto beta_c = insert_common_op(m, pos, migraphx::make_op("mul"), {c, beta_literal});
if(beta_c->get_shape().type() != input_type)
{
beta_c = m.insert_instruction(
pos, migraphx::make_op("convert", {{"target_type", input_type}}), beta_c);
}
return m.insert_instruction(pos, migraphx::make_op("add"), op_res, beta_c);
}
}
return op_res;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <unordered_map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
argument::argument(const shape& s) : m_shape(s)
{
auto buffer = make_shared_array<char>(s.bytes());
assign_buffer({[=]() mutable { return buffer.get(); }});
}
argument::argument(shape s, std::nullptr_t)
: m_shape(std::move(s)), m_data({[] { return nullptr; }})
{
}
argument::argument(const shape& s, const argument::data_t& d) : m_shape(s), m_data(d) {}
void argument::assign_buffer(std::function<char*()> d)
{
const shape& s = m_shape;
if(s.type() != shape::tuple_type)
{
m_data = {std::move(d)};
return;
}
// Collect all shapes
std::unordered_map<std::size_t, shape> shapes;
{
std::size_t i = 0;
fix([&](auto self, auto ss) {
if(ss.sub_shapes().empty())
{
shapes[i] = ss;
i++;
}
else
{
for(auto&& child : ss.sub_shapes())
self(child);
}
})(s);
}
// Sort by type size
std::vector<std::size_t> order(shapes.size());
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(), by(std::greater<>{}, [&](auto i) {
return shapes[i].type_size();
}));
// Compute offsets
std::unordered_map<std::size_t, std::size_t> offsets;
std::size_t offset = 0;
for(auto i : order)
{
offsets[i] = offset;
offset += shapes[i].bytes();
}
assert(offset == s.bytes());
std::size_t i = 0;
m_data = fix<data_t>([&](auto self, auto ss) {
data_t result;
if(ss.sub_shapes().empty())
{
auto n = offsets[i];
result = {[d, n]() mutable { return d() + n; }};
i++;
return result;
}
std::vector<data_t> subs;
std::transform(ss.sub_shapes().begin(),
ss.sub_shapes().end(),
std::back_inserter(subs),
[&](auto child) { return self(child); });
result.sub = subs;
return result;
})(s);
}
std::vector<shape> to_shapes(const std::vector<argument>& args)
{
std::vector<shape> shapes;
std::transform(args.begin(), args.end(), std::back_inserter(shapes), [](auto&& arg) {
return arg.get_shape();
});
return shapes;
}
argument::argument(const std::vector<argument>& args)
: m_shape(to_shapes(args)), m_data(data_t::from_args(args))
{
}
char* argument::data() const
{
assert(m_shape.type() != shape::tuple_type);
assert(not this->empty());
return m_data.get();
}
bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
const shape& argument::get_shape() const { return this->m_shape; }
argument argument::reshape(const shape& s) const
{
assert(s.element_space() <= this->get_shape().element_space());
return {s, this->m_data};
}
argument::data_t argument::data_t::share() const
{
data_t result;
if(this->get)
{
auto self = std::make_shared<data_t>(*this);
result.get = [self]() mutable { return self->get(); };
}
std::transform(sub.begin(), sub.end(), std::back_inserter(result.sub), [](const auto& d) {
return d.share();
});
return result;
}
argument::data_t argument::data_t::from_args(const std::vector<argument>& args)
{
data_t result;
std::transform(args.begin(), args.end(), std::back_inserter(result.sub), [](auto&& arg) {
return arg.m_data;
});
return result;
}
argument argument::copy() const
{
argument result{this->get_shape()};
auto* src = this->data();
std::copy(src, src + this->get_shape().bytes(), result.data());
return result;
}
argument argument::share() const { return {m_shape, m_data.share()}; }
std::vector<argument> argument::get_sub_objects() const
{
std::vector<argument> result;
assert(m_shape.sub_shapes().size() == m_data.sub.size());
std::transform(m_shape.sub_shapes().begin(),
m_shape.sub_shapes().end(),
m_data.sub.begin(),
std::back_inserter(result),
[](auto&& s, auto&& d) {
return argument{s, d};
});
return result;
}
argument argument::element(std::size_t i) const
{
assert(this->get_shape().sub_shapes().empty());
auto idx = this->get_shape().index(i);
auto offset = this->get_shape().type_size() * idx;
return argument{shape{this->get_shape().type()}, this->data() + offset};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void auto_contiguous::apply(program& p) const
void auto_contiguous::apply(module& m) const
{
for(auto ins : iterator_for(p))
std::string key = "require_std_shape";
for(auto ins : reverse_iterator_for(m))
{
auto&& attr = ins->get_operator().attributes();
if((attr.get(key, false)))
{
auto args = ins->inputs();
auto new_args = args;
std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) {
if(in->name() == "contiguous")
{
return in;
}
return m.insert_instruction(ins, make_op("contiguous"), in);
});
if(new_args != args)
{
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
}
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
// for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last)
continue;
shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0)
{
auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins);
p.replace_instruction(ins, c);
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c);
}
}
}
......
#include <migraphx/common.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
if(s0 == s1)
return s0;
if(s0.size() > s1.size())
s0.swap(s1);
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{
assert(not shapes.empty());
return transform_accumulate(shapes.begin() + 1,
shapes.end(),
shapes.front().lens(),
&compute_broadcasted_lens,
[](auto s) { return s.lens(); });
}
shape::type_t compute_common_type(shape::type_t t1, shape::type_t t2)
{
if(t1 == t2)
return t1;
shape::type_t result;
shape::visit(t1, [&](auto x) {
shape::visit(t2, [&](auto y) {
// Workaround broken warning on gcc 5
(void)x;
(void)y;
using type = std::common_type_t<decltype(x()), decltype(y())>;
result = shape::get_type<type>{};
});
});
return result;
}
shape::type_t compute_common_types(const std::vector<shape>& shapes)
{
assert(not shapes.empty());
return transform_accumulate(
shapes.begin() + 1, shapes.end(), shapes.front().type(), &compute_common_type, [&](auto s) {
return s.type();
});
}
shape common_shape(const std::vector<shape>& shapes)
{
if(shapes.empty())
return {};
return {compute_common_types(shapes), compute_common_lens(shapes)};
}
instruction_ref insert_common_op(module& m,
instruction_ref ins,
const operation& op,
std::vector<instruction_ref> inputs)
{
auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
}
if(input->get_shape().type() != common.type())
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input);
}
return input;
});
return m.insert_instruction(ins, op, inputs);
}
instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs)
{
return insert_common_op(m, m.end(), op, std::move(inputs));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/compile_src.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
{
assert(not srcs.empty());
tmp_dir td{"compile"};
auto params = flags;
params += " -I.";
auto out = output;
for(const auto& src : srcs)
{
fs::path full_path = td.path / src.path;
fs::path parent_path = full_path.parent_path();
fs::create_directories(parent_path);
write_buffer(full_path.string(), src.content.first, src.len());
if(src.path.extension().string() == ".cpp")
{
params += " " + src.path.filename().string();
if(out.empty())
out = src.path.stem().string() + out_ext;
}
}
params += " -o " + out;
if(not launcher.empty())
{
td.execute(launcher, compiler + " " + params);
}
else
{
td.execute(compiler, params);
}
auto out_path = td.path / out;
if(not fs::exists(out_path))
MIGRAPHX_THROW("Output file missing: " + out);
if(process)
out_path = process(out_path);
return read_buffer(out_path.string());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <algorithm>
#include <string>
#include <vector>
#include <functional>
#include <sstream>
#include <migraphx/errors.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
using token = std::pair<const char*, const char*>;
using lexer = std::function<const char*(const char* start, const char* end)>;
template <class P>
auto lex_while(P p)
{
return [=](const char* start, const char* end) {
return std::find_if(start, end, [&](char c) { return not p(c); });
};
}
template <class P>
auto lex_if(P p)
{
return [=](const char* start, const char*) {
if(p(*start))
return start + 1;
return start;
};
}
std::vector<token> tokenize(const char* start, const char* end, const std::vector<lexer>& lexers)
{
std::vector<token> result;
while(start != end)
{
bool error = true;
for(const auto& l : lexers)
{
const auto* next = l(start, end);
if(next != start)
{
result.emplace_back(start, next);
start = next;
error = false;
break;
}
}
if(error)
{
MIGRAPHX_THROW("TOKENIZE: no token found!");
}
}
return result;
}
std::vector<token> json_tokenize(const std::string& s)
{
std::vector<lexer> lexers;
// Quote
lexers.push_back([](const char* start, const char* end) {
if(*start != '\"')
return start;
++start;
while((start != end) and (*start != '\"'))
{
if(*start == '\\')
start++;
start++;
}
return ++start;
});
// Line comments
lexers.push_back([](const char* start, const char* end) {
if(*start == '#')
start++;
else if((start + 1) < end and start[0] == '/' and start[1] == '/')
start += 2;
else
return start;
return std::find_if(start, end, [&](char c) { return c == '\n'; });
});
// Whitespace
lexers.push_back(lex_while(&isspace));
// Punctation
lexers.push_back(lex_if(&ispunct));
// Identifier/number
lexers.push_back(lex_while([](char c) {
return (isalnum(c) != 0 or contains({'_', '.', '+'}, c));
}));
return tokenize(s.data(), s.data() + s.length(), lexers);
}
std::string convert_to_json(const std::string& str)
{
auto tokens = json_tokenize(str);
std::stringstream ss;
for(auto& token : tokens)
{
std::string s(token.first, token.second);
if(starts_with(s, "#") or starts_with(s, "//"))
continue;
if(std::isalpha(s.front()) != 0 and
not contains({"null", "nan", "true", "false", "inf"}, s))
{
ss << "\"" << s << "\"";
}
else
{
ss << s;
}
}
return ss.str();
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/iterator_for.hpp>
#include <map>
#include <sstream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
cpp_generator::function&
cpp_generator::function::set_body(const module& m, const cpp_generator::generate_module_callback& g)
{
std::unordered_map<migraphx::instruction_ref, std::string> names;
std::stringstream ss;
auto return_ins = std::prev(m.end());
for(auto ins : iterator_for(m))
{
ss << "// " << ins->get_operator() << " -> " << ins->get_shape() << "\n";
if(ins->name() == "@param")
{
names[ins] =
migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter;
}
else if(ins->name() == "@return")
{
assert(ins->inputs().size() == 1);
return_ins = ins->inputs().front();
}
else
{
std::string n = "z" + std::to_string(names.size());
names[ins] = n;
ss << "auto " << n << " = " << g(ins, names) << ";\n";
}
}
ss << "return " << names.at(return_ins) << ";\n";
body = ss.str();
return *this;
}
cpp_generator::function& cpp_generator::function::set_types(const module& m)
{
return cpp_generator::function::set_types(m, [](auto s) { return shape::cpp_type(s.type()); });
}
cpp_generator::function&
cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, parse(p.second)};
});
auto output_shapes = m.get_output_shapes();
assert(not output_shapes.empty());
this->return_type = parse(output_shapes.front());
return *this;
}
cpp_generator::function& cpp_generator::function::set_generic_types(const module& m)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, "T" + p.first};
});
std::transform(input_map.begin(),
input_map.end(),
std::back_inserter(this->tparams),
[&](auto&& p) { return "class T" + p.first; });
this->return_type = "auto";
return *this;
}
struct cpp_generator_impl
{
std::stringstream fs{};
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
std::function<std::string(shape)> fresult = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
};
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
cpp_generator::cpp_generator(cpp_generator&&) noexcept = default;
cpp_generator& cpp_generator::operator=(cpp_generator rhs)
{
std::swap(impl, rhs.impl);
return *this;
}
cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
}
std::string cpp_generator::generate_point_op(const operation& op,
const std::vector<std::string>& args)
{
auto v = op.to_value();
std::string code;
if(contains(impl->point_op_map, op.name()))
{
code = impl->point_op_map.at(op.name());
}
else
{
auto attributes = op.attributes();
if(not attributes.contains("point_op"))
MIGRAPHX_THROW("op is missing point_op attribute: " + op.name());
code = attributes["point_op"].to<std::string>();
}
return interpolate_string(code, [&](auto start, auto last) -> std::string {
auto key = trim({start, last});
if(key.empty())
MIGRAPHX_THROW("Empty parameter");
std::string fselector = "function:";
if(starts_with(key, fselector))
{
auto fname = key.substr(fselector.size());
if(impl->fmap == nullptr)
return fname;
else
return impl->fmap(fname);
}
else if(with_char(::isdigit)(key[0]))
{
auto i = std::stoul(key);
return args.at(i);
}
else if(v.contains(key))
{
return v[key].template to<std::string>();
}
else
{
return key;
}
});
}
std::string cpp_generator::str() const { return impl->fs.str(); }
cpp_generator::function cpp_generator::generate_module(const module& m)
{
function f;
auto name = transform_string(m.name(), [](char c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return '_';
});
f.set_name(name).set_types(m).set_body(
m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal")
return shape::cpp_type(ins->get_shape().type()) + "(" +
ins->get_literal().to_string() + ")";
std::vector<std::string> args;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(args),
[&](auto i) { return names.at(i); });
auto s = this->generate_point_op(ins->get_operator(), args);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
return s;
});
return f;
}
std::string cpp_generator::create_function(const cpp_generator::function& f)
{
impl->function_count++;
if(not f.tparams.empty())
impl->fs << "template<" << join_strings(f.tparams, ", ") << ">\n";
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '(';
for(auto&& p : f.params)
{
impl->fs << delim << p.type << " " << p.name;
delim = ',';
}
impl->fs << ") {\n" << f.body << "\n}\n";
return name;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -4,65 +4,56 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Range, class Iterator>
std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
{
auto start_forward = start;
auto start_backwards = start;
std::size_t n = 0;
while(start_forward != last and start_backwards != last)
{
n++;
if(start_forward != r.end())
start_forward++;
if(start_backwards != r.begin())
start_backwards--;
}
if(start_forward == last)
return n;
else
return -n;
}
void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); }
void dead_code_elimination::apply(program& p) const
void dead_code_elimination::apply(module& m) const
{
auto last = std::prev(p.end());
for(auto ins : iterator_for(p))
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
// Skip the first instruction, since we always process the previous
// instruction
if(ins == p.begin())
if(ins == m.begin())
continue;
const auto i = std::prev(ins);
// Skip the last instruction
if(i == last)
break;
// Skip instruction with empty shape as output unless its a builtin or undefined or identity
// Skip instruction with empty shape as output unless its a builtin, undefined, identity, or
// allocate
if(i->get_shape().elements() == 0 and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity")
not contains({"undefined", "identity", "allocate"}, i->name()))
continue;
assert(bidistance(p, i, last) > 0);
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited;
fix([&](auto self, auto leaf) {
assert(p.has_instruction(leaf));
if(not m.has_instruction(leaf))
return;
if(leaf->outputs().empty())
{
// Dont visit inputs twice
if(not visited.insert(leaf).second)
return;
std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
leaf->inputs().end());
leaf->clear_arguments();
assert(bidistance(p, last, leaf) < 0);
assert(std::distance(m.begin(), leaf) < std::distance(m.begin(), last));
assert(leaf != ins);
p.move_instruction(leaf, p.end());
if(leaf->name() != "@param")
m.move_instruction(leaf, m.end());
for(auto arg : args)
self(arg);
}
})(i);
}
p.remove_instructions(std::next(last), p.end());
m.remove_instructions(std::next(last), m.end());
}
} // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/dom_info.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/erase.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2)
{
if(ins1 == ins2)
return false;
auto iter = ins2idom.find(ins2);
while(iter != ins2idom.end())
{
if(ins1 == iter->second)
return true;
assert(iter != ins2idom.find(iter->second));
iter = ins2idom.find(iter->second);
}
return false;
}
struct module_visitor
{
module* mm;
module& get_nodes() const { return *mm; }
const std::vector<instruction_ref>& get_children(instruction_ref ins) { return ins->inputs(); }
};
template <class Visitor>
dominator_info compute_dominator_generic(Visitor v)
{
dominator_info info;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> instr2_doms;
for(instruction_ref ins : iterator_for(v.get_nodes()))
{
const std::vector<instruction_ref>& children = v.get_children(ins);
if(children.size() == 1)
{
info.ins2idom[ins] = children.front();
instr2_doms[ins].insert(children.front());
}
else if(children.size() > 1)
{
auto&& doms = instr2_doms[ins];
doms = instr2_doms[children.front()];
std::for_each(children.begin() + 1, children.end(), [&](instruction_ref child) {
auto&& child_doms = instr2_doms[child];
erase_if(doms, [&](auto x) { return not contains(child_doms, x); });
});
auto iter = std::find_if(doms.begin(), doms.end(), [&](auto dom1) {
return std::none_of(doms.begin(), doms.end(), [&](auto dom2) {
if(dom1 == dom2)
return false;
return info.strictly_dominate(dom1, dom2);
});
});
if(iter != doms.end())
info.ins2idom[ins] = *iter;
}
instr2_doms[ins].insert(ins);
}
return info;
}
dominator_info compute_dominator(module& m)
{
return compute_dominator_generic(module_visitor{&m});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
add_executable(driver main.cpp verify.cpp perf.cpp)
add_executable(driver
main.cpp
verify.cpp
perf.cpp
resnet50.cpp
inceptionv3.cpp
alexnet.cpp
marker_roctx.cpp
)
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility
add_custom_command(
TARGET driver
POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:driver>
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
)
set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_cpu migraphx_onnx migraphx_tf)
if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(driver migraphx_gpu)
target_compile_definitions(driver PRIVATE -DHAVE_GPU)
endif()
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf)
rocm_install_targets(
TARGETS driver
)
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp"
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto m0 =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto mx0 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0));
auto mx1 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1));
auto mx2 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2));
auto mx3 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3));
auto mx4 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4));
auto mx5 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5));
auto mx6 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6));
auto mx7 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7));
auto mx8 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8));
auto mx9 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9));
auto mx10 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10));
auto mx11 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
auto mx12 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12));
auto mx13 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13));
auto mx14 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14));
auto mx15 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15));
migraphx::op::convolution convolution16;
convolution16.padding = {2, 2};
convolution16.stride = {4, 4};
convolution16.dilation = {1, 1};
convolution16.group = 1;
auto mx16 = mm->add_instruction(convolution16, m0, mx15);
migraphx::op::broadcast broadcast17;
broadcast17.axis = 1;
broadcast17.broadcast_lens = {batch, 64, 55, 55};
auto mx17 = mm->add_instruction(broadcast17, mx14);
migraphx::op::add add18;
auto mx18 = mm->add_instruction(add18, mx16, mx17);
migraphx::op::relu relu19;
auto mx19 = mm->add_instruction(relu19, mx18);
migraphx::op::pooling pooling20;
pooling20.mode = migraphx::op::pooling_mode::max;
pooling20.padding = {0, 0};
pooling20.stride = {2, 2};
pooling20.lengths = {3, 3};
auto mx20 = mm->add_instruction(pooling20, mx19);
migraphx::op::convolution convolution21;
convolution21.padding = {2, 2};
convolution21.stride = {1, 1};
convolution21.dilation = {1, 1};
convolution21.group = 1;
auto mx21 = mm->add_instruction(convolution21, mx20, mx13);
migraphx::op::broadcast broadcast22;
broadcast22.axis = 1;
broadcast22.broadcast_lens = {batch, 192, 27, 27};
auto mx22 = mm->add_instruction(broadcast22, mx12);
migraphx::op::add add23;
auto mx23 = mm->add_instruction(add23, mx21, mx22);
migraphx::op::relu relu24;
auto mx24 = mm->add_instruction(relu24, mx23);
migraphx::op::pooling pooling25;
pooling25.mode = migraphx::op::pooling_mode::max;
pooling25.padding = {0, 0};
pooling25.stride = {2, 2};
pooling25.lengths = {3, 3};
auto mx25 = mm->add_instruction(pooling25, mx24);
migraphx::op::convolution convolution26;
convolution26.padding = {1, 1};
convolution26.stride = {1, 1};
convolution26.dilation = {1, 1};
convolution26.group = 1;
auto mx26 = mm->add_instruction(convolution26, mx25, mx11);
migraphx::op::broadcast broadcast27;
broadcast27.axis = 1;
broadcast27.broadcast_lens = {batch, 384, 13, 13};
auto mx27 = mm->add_instruction(broadcast27, mx10);
migraphx::op::add add28;
auto mx28 = mm->add_instruction(add28, mx26, mx27);
migraphx::op::relu relu29;
auto mx29 = mm->add_instruction(relu29, mx28);
migraphx::op::convolution convolution30;
convolution30.padding = {1, 1};
convolution30.stride = {1, 1};
convolution30.dilation = {1, 1};
convolution30.group = 1;
auto mx30 = mm->add_instruction(convolution30, mx29, mx9);
migraphx::op::broadcast broadcast31;
broadcast31.axis = 1;
broadcast31.broadcast_lens = {batch, 256, 13, 13};
auto mx31 = mm->add_instruction(broadcast31, mx8);
migraphx::op::add add32;
auto mx32 = mm->add_instruction(add32, mx30, mx31);
migraphx::op::relu relu33;
auto mx33 = mm->add_instruction(relu33, mx32);
migraphx::op::convolution convolution34;
convolution34.padding = {1, 1};
convolution34.stride = {1, 1};
convolution34.dilation = {1, 1};
convolution34.group = 1;
auto mx34 = mm->add_instruction(convolution34, mx33, mx7);
migraphx::op::broadcast broadcast35;
broadcast35.axis = 1;
broadcast35.broadcast_lens = {batch, 256, 13, 13};
auto mx35 = mm->add_instruction(broadcast35, mx6);
migraphx::op::add add36;
auto mx36 = mm->add_instruction(add36, mx34, mx35);
migraphx::op::relu relu37;
auto mx37 = mm->add_instruction(relu37, mx36);
migraphx::op::pooling pooling38;
pooling38.mode = migraphx::op::pooling_mode::max;
pooling38.padding = {0, 0};
pooling38.stride = {2, 2};
pooling38.lengths = {3, 3};
auto mx38 = mm->add_instruction(pooling38, mx37);
migraphx::op::flatten flatten39;
flatten39.axis = 1;
auto mx39 = mm->add_instruction(flatten39, mx38);
migraphx::op::identity identity40;
auto mx40 = mm->add_instruction(identity40, mx39);
migraphx::op::transpose transpose41;
transpose41.dims = {1, 0};
auto mx41 = mm->add_instruction(transpose41, mx5);
migraphx::op::multibroadcast multibroadcast42;
multibroadcast42.output_lens = {batch, 4096};
auto mx42 = mm->add_instruction(multibroadcast42, mx4);
float dot43_alpha = 1;
float dot43_beta = 1;
auto mx43 = migraphx::add_apply_alpha_beta(
*mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta);
migraphx::op::relu relu44;
auto mx44 = mm->add_instruction(relu44, mx43);
migraphx::op::identity identity45;
auto mx45 = mm->add_instruction(identity45, mx44);
migraphx::op::transpose transpose46;
transpose46.dims = {1, 0};
auto mx46 = mm->add_instruction(transpose46, mx3);
migraphx::op::multibroadcast multibroadcast47;
multibroadcast47.output_lens = {batch, 4096};
auto mx47 = mm->add_instruction(multibroadcast47, mx2);
float dot48_alpha = 1;
float dot48_beta = 1;
auto mx48 = migraphx::add_apply_alpha_beta(
*mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta);
migraphx::op::relu relu49;
auto mx49 = mm->add_instruction(relu49, mx48);
migraphx::op::transpose transpose50;
transpose50.dims = {1, 0};
auto mx50 = mm->add_instruction(transpose50, mx1);
migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000};
auto mx51 = mm->add_instruction(multibroadcast51, mx0);
float dot52_alpha = 1;
float dot52_beta = 1;
migraphx::add_apply_alpha_beta(
*mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta);
return p;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
......@@ -17,6 +17,7 @@
#include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
namespace driver {
......@@ -132,10 +133,22 @@ struct argument_parser
return to_string_range(x);
}
template <class T>
auto as_string_value(rank<1>, const T& x) -> decltype(to_string(x))
{
return to_string(x);
}
template <class T>
std::string as_string_value(rank<0>, const T&)
{
throw std::runtime_error("Can't convert to string");
}
template <class T, MIGRAPHX_REQUIRES(not is_multi_value<T>{})>
std::string as_string_value(const T& x)
{
return to_string(x);
return as_string_value(rank<1>{}, x);
}
template <class T, class... Fs>
......@@ -148,10 +161,11 @@ struct argument_parser
return false;
}});
argument& arg = arguments.back();
arg.type = type_name<T>::apply();
arg.default_value = as_string_value(x);
argument& arg = arguments.back();
arg.type = type_name<T>::apply();
migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
if(not arg.default_value.empty() and arg.nargs > 0)
arg.default_value = as_string_value(x);
}
template <class... Fs>
......@@ -247,6 +261,11 @@ struct argument_parser
return [=](auto&, auto& arg) { arg.metavar = metavar; };
}
MIGRAPHX_DRIVER_STATIC auto type(const std::string& type)
{
return [=](auto&, auto& arg) { arg.type = type; };
}
template <class T>
MIGRAPHX_DRIVER_STATIC auto set_value(T value)
{
......
......@@ -17,6 +17,7 @@ inline namespace MIGRAPHX_INLINE_NS {
inline auto& get_commands()
{
// NOLINTNEXTLINE
static std::unordered_map<std::string, std::function<void(std::vector<std::string> args)>> m;
return m;
}
......@@ -64,7 +65,7 @@ int auto_register_command()
template <class T>
struct command
{
static int static_register;
static const int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type =
......@@ -77,7 +78,7 @@ struct command
#endif
template <class T>
int command<T>::static_register = auto_register_command<T>(); // NOLINT
const int command<T>::static_register = auto_register_command<T>(); // NOLINT
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment