Commit 712f6134 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch and resolve merge conflicts

parents 4a39a0f7 b20e3d4d
......@@ -26,6 +26,7 @@ add_library(migraphx
eliminate_pad.cpp
env.cpp
file_buffer.cpp
fuse_pointwise.cpp
generate.cpp
inline_module.cpp
insert_pad.cpp
......@@ -130,9 +131,11 @@ register_migraphx_ops(
multibroadcast
multinomial
neg
nonmaxsuppression
nonzero
outline
pad
pointwise
pooling
pow
prefix_scan_sum
......@@ -153,6 +156,7 @@ register_migraphx_ops(
rnn_last_cell_output
rnn_last_hs_output
rnn_var_sl_last_output
roialign
round
rsqrt
scalar
......@@ -198,6 +202,9 @@ target_link_libraries(migraphx PRIVATE -ldl)
target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
find_package(Threads)
target_link_libraries(migraphx PUBLIC Threads::Threads)
find_package(msgpack REQUIRED)
target_link_libraries(migraphx PRIVATE msgpackc-cxx)
# Make this available to the tests
......@@ -235,6 +242,7 @@ rocm_export_targets(
TARGETS migraphx::migraphx migraphx_all_targets
NAMESPACE migraphx::
DEPENDS
Threads
${PACKAGE_DEPENDS}
)
......
......@@ -3,7 +3,7 @@ add_library(migraphx_c
api.cpp
)
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 2.0)
rocm_set_soversion(migraphx_c 3.0)
rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
......
......@@ -13,6 +13,7 @@
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
......@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
}
operation create_op(const char* name, const char* attributes)
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
operation create_op(const char* name, const char* attributes, va_list vlist)
{
std::string sattributes = attributes == nullptr ? "" : attributes;
std::vector<char> buffer(sattributes.size() * 2);
std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist);
value v = value::object{};
if(attributes != nullptr)
{
v = from_json_string(convert_to_json(std::string(attributes)));
v = from_json_string(convert_to_json(std::string(buffer.data())));
}
auto op = make_op(name, v);
return op;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
template <class T>
bool equal(const T& x, const T& y)
{
......@@ -368,7 +381,8 @@ struct migraphx_quantize_int8_options
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{
return migraphx::try_([&] { destroy((shape)); });
auto api_error_result = migraphx::try_([&] { destroy((shape)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
......@@ -376,13 +390,14 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
size_t* lengths,
size_t lengths_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)),
(std::vector<size_t>(lengths, lengths + lengths_size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
......@@ -392,7 +407,7 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
size_t* strides,
size_t strides_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
if(strides == nullptr and strides_size != 0)
......@@ -402,21 +417,23 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
(std::vector<size_t>(lengths, lengths + lengths_size)),
(std::vector<size_t>(strides, strides + strides_size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type))));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
......@@ -425,12 +442,13 @@ migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data();
*out_size = api_result.size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
......@@ -439,127 +457,141 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data();
*out_size = api_result.size();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = migraphx::to_shape_type((shape->object).type());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).bytes();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((shape->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{
return migraphx::try_([&] { destroy((argument)); });
auto api_error_result = migraphx::try_([&] { destroy((argument)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*argument = object_cast<migraphx_argument_t>(
allocate<migraphx::argument>((shape->object), (buffer)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((argument->object).get_shape()));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = (argument->object).data();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((argument->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_argument_t>(migraphx::generate_argument((s->object), (seed)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target)
{
return migraphx::try_([&] { destroy((target)); });
auto api_error_result = migraphx::try_([&] { destroy((target)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*target = object_cast<migraphx_target_t>(
allocate<migraphx::target>(migraphx::get_target((name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] { destroy((program_parameter_shapes)); });
auto api_error_result = migraphx::try_([&] { destroy((program_parameter_shapes)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameter_shapes_size(size_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer");
*out = (program_parameter_shapes->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
......@@ -567,19 +599,20 @@ migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes,
const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer");
*out =
object_cast<const_migraphx_shape_t>(&((program_parameter_shapes->object).at((name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_parameter_shapes_names(
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(program_parameter_shapes == nullptr)
......@@ -588,21 +621,24 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_names(
auto&& api_result = migraphx::get_names((program_parameter_shapes->object));
std::copy(api_result.begin(), api_result.end(), out);
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters)
{
return migraphx::try_([&] { destroy((program_parameters)); });
auto api_error_result = migraphx::try_([&] { destroy((program_parameters)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*program_parameters = object_cast<migraphx_program_parameters_t>(
allocate<std::unordered_map<std::string, migraphx::argument>>());
});
return api_error_result;
}
extern "C" migraphx_status
......@@ -610,7 +646,7 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
const char* name,
const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameters == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameters: Null pointer");
......@@ -618,85 +654,95 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
(program_parameters->object)[(name)] = (argument->object);
});
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments)
{
return migraphx::try_([&] { destroy((arguments)); });
auto api_error_result = migraphx::try_([&] { destroy((arguments)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = (arguments->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = object_cast<const_migraphx_argument_t>(&((arguments->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes)
{
return migraphx::try_([&] { destroy((shapes)); });
auto api_error_result = migraphx::try_([&] { destroy((shapes)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = (shapes->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((shapes->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
migraphx::print_module((module->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{
return migraphx::try_([&] { destroy((program)); });
auto api_error_result = migraphx::try_([&] { destroy((program)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).get_main_module());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(target == nullptr)
......@@ -705,91 +751,105 @@ extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
(program->object).compile((target->object), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out =
allocate<migraphx_program_parameter_shapes_t>((program->object).get_parameter_shapes());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = allocate<migraphx_shapes_t>(migraphx::get_output_shapes((program->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
migraphx::print_program((program->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
(program->object).sort();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program,
migraphx_program_parameters_t params)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(params == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer");
*out = allocate<migraphx_arguments_t>(migraphx::run((program->object), (params->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((program->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation)
{
return migraphx::try_([&] { destroy((operation)); });
auto api_error_result = migraphx::try_([&] { destroy((operation)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_operation_create(migraphx_operation_t* operation, const char* name, const char* attributes)
extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes,
...)
{
return migraphx::try_([&] {
va_list vlist;
va_start(vlist, attributes);
auto api_error_result = migraphx::try_([&] {
*operation = object_cast<migraphx_operation_t>(
allocate<migraphx::operation>(migraphx::create_op((name), (attributes))));
allocate<migraphx::operation>(migraphx::create_op((name), (attributes), (vlist))));
});
va_end(vlist);
return api_error_result;
}
extern "C" migraphx_status
migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(operation == nullptr)
......@@ -798,46 +858,51 @@ migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operati
auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out);
*it = '\0';
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::load((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(p == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer");
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::save((p->object), (name), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options)
{
return migraphx::try_([&] { destroy((onnx_options)); });
auto api_error_result = migraphx::try_([&] { destroy((onnx_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*onnx_options = object_cast<migraphx_onnx_options_t>(allocate<migraphx::onnx_options>());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr and dims_size != 0)
......@@ -845,96 +910,107 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape(
(onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_dim_value((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_loop_iterations((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options)
{
return migraphx::try_([&] { destroy((file_options)); });
auto api_error_result = migraphx::try_([&] { destroy((file_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*file_options = object_cast<migraphx_file_options_t>(allocate<migraphx::file_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(file_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter file_options: Null pointer");
migraphx::set_file_format((file_options->object), (format));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options)
{
return migraphx::try_([&] { destroy((compile_options)); });
auto api_error_result = migraphx::try_([&] { destroy((compile_options)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*compile_options =
object_cast<migraphx_compile_options_t>(allocate<migraphx::compile_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_offload_copy((compile_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_fast_math((compile_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_onnx((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
......@@ -942,40 +1018,44 @@ extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
size_t size,
migraphx_onnx_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(
migraphx::parse_onnx_buffer((data), (size), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options)
{
return migraphx::try_([&] { destroy((tf_options)); });
auto api_error_result = migraphx::try_([&] { destroy((tf_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*tf_options = object_cast<migraphx_tf_options_t>(allocate<migraphx::tf_options>());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
bool is_nhwc)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_nhwc((tf_options->object), (is_nhwc));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(dims == nullptr and dims_size != 0)
......@@ -983,23 +1063,25 @@ extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape(
(tf_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_default_dim_value((tf_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options,
const char** names,
size_t names_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(names == nullptr and names_size != 0)
......@@ -1007,96 +1089,106 @@ extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_opti
migraphx::set_output_names((tf_options->object),
(std::vector<const char*>(names, names + names_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_tf((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names)
{
return migraphx::try_([&] { destroy((quantize_op_names)); });
auto api_error_result = migraphx::try_([&] { destroy((quantize_op_names)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*quantize_op_names =
object_cast<migraphx_quantize_op_names_t>(allocate<std::vector<std::string>>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_op_names == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_op_names: Null pointer");
(quantize_op_names->object).push_back((name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog,
migraphx_quantize_op_names_t name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
migraphx::quantize_fp16_with_op_names((prog->object), (name->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
migraphx::quantize_fp16((prog->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options)
{
return migraphx::try_([&] { destroy((quantize_int8_options)); });
auto api_error_result = migraphx::try_([&] { destroy((quantize_int8_options)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*quantize_int8_options = object_cast<migraphx_quantize_int8_options_t>(
allocate<migraphx::quantize_int8_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options,
const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
migraphx::add_op_name((quantize_int8_options->object), (name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
......@@ -1104,13 +1196,14 @@ extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer");
migraphx::add_calibration_data((quantize_int8_options->object), (data->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_int8_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(target == nullptr)
......@@ -1119,4 +1212,5 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::quantize_int8_wrap((prog->object), (target->object), (options->object));
});
return api_error_result;
}
......@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes);
const char* attributes,
...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
......
......@@ -252,7 +252,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout;
size_t pout_size;
call(&migraphx_shape_lengths, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size);
return {pout, pout + pout_size};
}
std::vector<size_t> strides() const
......@@ -260,7 +260,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout;
size_t pout_size;
call(&migraphx_shape_strides, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size);
return {pout, pout + pout_size};
}
migraphx_shape_datatype_t type() const
......@@ -312,7 +312,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return shape(pout);
return {pout};
}
char* data() const
......@@ -325,9 +325,8 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
/// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0)
{
return argument(
make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed),
own{});
return {make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed),
own{}};
}
friend bool operator==(const argument& px, const argument& py)
......@@ -378,7 +377,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return shape(pout);
return {pout};
}
std::vector<const char*> names() const
......@@ -438,7 +437,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return argument(pout);
return {pout};
}
struct iterator_read
......@@ -449,7 +448,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx);
return argument(pout);
return {pout};
}
};
};
......@@ -471,7 +470,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return shape(pout);
return {pout};
}
struct iterator_read
......@@ -481,7 +480,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, self, pidx);
return shape(pout);
return {pout};
}
};
};
......@@ -599,16 +598,17 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
operation(const char* name, const char* attributes = nullptr)
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{
this->make_handle(&migraphx_operation_create, name, attributes);
this->make_handle(&migraphx_operation_create, name, attributes, xs...);
}
std::string name()
{
std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return std::string(out_name.data());
return {out_name.data()};
}
};
......
......@@ -212,7 +212,9 @@ def program(h):
@auto_handle()
def operation(h):
h.constructor('create',
api.params(name='const char*', attributes='const char*'),
api.params(name='const char*',
attributes='const char*',
vlist='...'),
fname='migraphx::create_op')
h.method('name', returns='std::string')
......
......@@ -155,5 +155,13 @@ std::vector<argument> argument::get_sub_objects() const
return result;
}
argument argument::element(std::size_t i) const
{
assert(this->get_shape().sub_shapes().empty());
auto idx = this->get_shape().index(i);
auto offset = this->get_shape().type_size() * idx;
return argument{shape{this->get_shape().type()}, this->data() + offset};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/stringutils.hpp>
......@@ -26,16 +27,18 @@ cpp_generator::function::set_body(const module& m, const cpp_generator::generate
{
names[ins] =
migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter;
continue;
}
if(ins->name() == "@return")
else if(ins->name() == "@return")
{
assert(ins->inputs().size() == 1);
return_ins = ins->inputs().front();
}
std::string n = "z" + std::to_string(names.size());
names[ins] = n;
ss << "auto " << n << " = " << g(ins, names) << ";\n";
else
{
std::string n = "z" + std::to_string(names.size());
names[ins] = n;
ss << "auto " << n << " = " << g(ins, names) << ";\n";
}
}
ss << "return " << names.at(return_ins) << ";\n";
body = ss.str();
......@@ -49,6 +52,7 @@ cpp_generator::function& cpp_generator::function::set_types(const module& m)
cpp_generator::function&
cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
......@@ -61,11 +65,30 @@ cpp_generator::function::set_types(const module& m, const std::function<std::str
return *this;
}
cpp_generator::function& cpp_generator::function::set_generic_types(const module& m)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, "T" + p.first};
});
std::transform(input_map.begin(),
input_map.end(),
std::back_inserter(this->tparams),
[&](auto&& p) { return "class T" + p.first; });
this->return_type = "auto";
return *this;
}
struct cpp_generator_impl
{
std::stringstream fs{};
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
};
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
......@@ -81,38 +104,54 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
}
std::string cpp_generator::generate_point_op(const operation& op,
const std::vector<std::string>& args)
{
auto v = op.to_value();
return interpolate_string(op.attributes()["point_op"].to<std::string>(),
[&](auto start, auto last) -> std::string {
auto key = trim({start, last});
if(key.empty())
MIGRAPHX_THROW("Empty parameter");
std::string fselector = "function:";
if(starts_with(key, fselector))
{
auto fname = key.substr(fselector.size());
if(impl->fmap == nullptr)
return fname;
else
return impl->fmap(fname);
}
else if(with_char(::isdigit)(key[0]))
{
auto i = std::stoul(key);
return args.at(i);
}
else if(v.contains(key))
{
return v[key].template to<std::string>();
}
else
{
return key;
}
});
std::string code;
if(contains(impl->point_op_map, op.name()))
{
code = impl->point_op_map.at(op.name());
}
else
{
auto attributes = op.attributes();
if(not attributes.contains("point_op"))
MIGRAPHX_THROW("op is missing point_op attribute: " + op.name());
code = attributes["point_op"].to<std::string>();
}
return interpolate_string(code, [&](auto start, auto last) -> std::string {
auto key = trim({start, last});
if(key.empty())
MIGRAPHX_THROW("Empty parameter");
std::string fselector = "function:";
if(starts_with(key, fselector))
{
auto fname = key.substr(fselector.size());
if(impl->fmap == nullptr)
return fname;
else
return impl->fmap(fname);
}
else if(with_char(::isdigit)(key[0]))
{
auto i = std::stoul(key);
return args.at(i);
}
else if(v.contains(key))
{
return v[key].template to<std::string>();
}
else
{
return key;
}
});
}
std::string cpp_generator::str() const { return impl->fs.str(); }
......@@ -120,7 +159,12 @@ std::string cpp_generator::str() const { return impl->fs.str(); }
cpp_generator::function cpp_generator::generate_module(const module& m)
{
function f;
f.set_name(m.name()).set_types(m).set_body(
auto name = transform_string(m.name(), [](char c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return '_';
});
f.set_name(name).set_types(m).set_body(
m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal")
return shape::cpp_type(ins->get_shape().type()) + "(" +
......@@ -130,7 +174,6 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins->inputs().end(),
std::back_inserter(args),
[&](auto i) { return names.at(i); });
auto s = this->generate_point_op(ins->get_operator(), args);
return this->generate_point_op(ins->get_operator(), args);
});
return f;
......@@ -139,6 +182,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
std::string cpp_generator::create_function(const cpp_generator::function& f)
{
impl->function_count++;
if(not f.tparams.empty())
impl->fs << "template<" << join_strings(f.tparams, ", ") << ">\n";
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '(';
......
......@@ -6,6 +6,7 @@ add_executable(driver
resnet50.cpp
inceptionv3.cpp
alexnet.cpp
marker_roctx.cpp
)
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility
......
......@@ -17,6 +17,7 @@
#include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
namespace driver {
......@@ -106,10 +107,22 @@ struct argument_parser
return to_string_range(x);
}
template <class T>
auto as_string_value(rank<1>, const T& x) -> decltype(to_string(x))
{
return to_string(x);
}
template <class T>
std::string as_string_value(rank<0>, const T&)
{
throw std::runtime_error("Can't convert to string");
}
template <class T, MIGRAPHX_REQUIRES(not is_multi_value<T>{})>
std::string as_string_value(const T& x)
{
return to_string(x);
return as_string_value(rank<1>{}, x);
}
template <class T, class... Fs>
......@@ -122,10 +135,11 @@ struct argument_parser
return false;
}});
argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>();
arg.default_value = as_string_value(x);
argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>();
migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
if(not arg.default_value.empty() and arg.nargs > 0)
arg.default_value = as_string_value(x);
}
template <class... Fs>
......
#include "verify.hpp"
#include "argument_parser.hpp"
#include "command.hpp"
#include "verify.hpp"
#include "precision.hpp"
#include "perf.hpp"
#include "models.hpp"
#include "marker_roctx.hpp"
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
......@@ -287,14 +289,12 @@ struct compiler_target
struct compiler
{
static const int q_fp16 = 1;
static const int q_int8 = 2;
loader l;
program_params parameters;
compiler_target ct;
bool offload_copy = false;
bool fast_math = true;
int quantize = 0;
bool offload_copy = false;
bool fast_math = true;
precision quantize = precision::fp32;
std::vector<std::string> fill0;
std::vector<std::string> fill1;
......@@ -311,8 +311,8 @@ struct compiler
{"--disable-fast-math"},
ap.help("Disable fast math optimization"),
ap.set_value(false));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(q_fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(q_int8));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8));
}
auto params(const program& p) { return parameters.generate(p, ct.get_target(), offload_copy); }
......@@ -324,11 +324,11 @@ struct compiler
if(p.is_compiled())
return p;
auto t = ct.get_target();
if(quantize == q_fp16)
if(quantize == precision::fp16)
{
quantize_fp16(p);
}
else if(quantize == q_int8)
else if(quantize == precision::int8)
{
quantize_int8(p, t, {params(p)});
}
......@@ -376,6 +376,7 @@ struct verify : command<verify>
bool reduce = false;
bool offload_copy = false;
bool fast_math = true;
precision quantize = precision::fp32;
void parse(argument_parser& ap)
{
l.parse(ap);
......@@ -395,6 +396,7 @@ struct verify : command<verify>
ap.help("Verify each instruction"),
ap.set_value(true));
ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
}
void run()
......@@ -411,15 +413,15 @@ struct verify : command<verify>
if(per_instruction)
{
verify_instructions(p, t, options, tolerance);
verify_instructions(p, t, options, quantize, tolerance);
}
else if(reduce)
{
verify_reduced_program(p, t, options, m, tolerance);
verify_reduced_program(p, t, options, quantize, m, tolerance);
}
else
{
verify_program(l.file, p, t, options, m, tolerance);
verify_program(l.file, p, t, options, quantize, m, tolerance);
}
}
};
......@@ -479,7 +481,24 @@ struct perf : command<perf>
std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p);
std::cout << "Running performance report ... " << std::endl;
p.perf_report(std::cout, n, m);
p.perf_report(std::cout, n, m, c.l.batch);
}
};
struct roctx : command<roctx>
{
compiler c;
void parse(argument_parser& ap) { c.parse(ap); }
void run()
{
std::cout << "Compiling ... " << std::endl;
auto p = c.compile();
std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p);
std::cout << "rocTX:\tLoading rocTX library..." << std::endl;
auto rtx = create_marker_roctx();
p.mark(m, std::move(rtx));
}
};
......
#include "marker_roctx.hpp"
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
class marker_roctx
{
std::function<void(const char*)> sym_roctx_mark;
std::function<uint64_t(const char*)> sym_roctx_range_start;
std::function<void(uint64_t)> sym_roctx_range_stop;
std::function<int(const char*)> sym_roctx_range_push;
std::function<int()> sym_roctx_range_pop;
uint64_t range_id;
public:
marker_roctx()
{
dynamic_loader lib = migraphx::dynamic_loader{"libroctx64.so"};
sym_roctx_mark = lib.get_function<void(const char*)>("roctxMarkA");
sym_roctx_range_start = lib.get_function<uint64_t(const char*)>("roctxRangeStartA");
sym_roctx_range_stop = lib.get_function<void(uint64_t)>("roctxRangeStop");
sym_roctx_range_push = lib.get_function<int(const char*)>("roctxRangePushA");
sym_roctx_range_pop = lib.get_function<int()>("roctxRangePop");
sym_roctx_mark("rocTX marker created.");
}
void mark_start(instruction_ref ins_ref)
{
std::string text = "Marker start: " + ins_ref->name();
sym_roctx_range_push(text.c_str());
}
void mark_stop(instruction_ref) { sym_roctx_range_pop(); }
void mark_start(const program&) { range_id = sym_roctx_range_start("0"); }
void mark_stop(const program&) { sym_roctx_range_stop(range_id); }
};
marker create_marker_roctx() { return marker_roctx(); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP
#define MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP
#include <migraphx/marker.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
marker create_marker_roctx();
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_PRECISION_HPP
#define MIGRAPHX_GUARD_RTGLIB_PRECISION_HPP
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
enum class precision
{
fp32,
fp16,
int8
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
......@@ -6,6 +6,7 @@
#include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/quantization.hpp>
namespace migraphx {
namespace driver {
......@@ -19,9 +20,16 @@ std::vector<argument> run_ref(program p, const parameter_map& inputs)
return out;
}
std::vector<argument>
run_target(program p, const target& t, const compile_options& options, const parameter_map& inputs)
std::vector<argument> run_target(program p,
const target& t,
const compile_options& options,
precision quantize,
const parameter_map& inputs)
{
if(quantize == precision::fp16)
{
quantize_fp16(p);
}
p.compile(t, options);
parameter_map m;
......@@ -43,24 +51,24 @@ void verify_program(const std::string& name,
const program& p,
const target& t,
compile_options options,
precision quantize,
const parameter_map& inputs,
double tolerance)
{
auto x = run_ref(p, inputs);
auto y = run_target(p, t, options, inputs);
auto y = run_target(p, t, options, quantize, inputs);
std::size_t output_num = x.size();
for(std::size_t i = 0; i < output_num; ++i)
{
verify_args(name, x[i], y[i], tolerance);
}
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
}
void verify_instructions(const program& prog,
const target& t,
compile_options options,
precision quantize,
double tolerance)
{
const auto* mm_prog = prog.get_main_module();
......@@ -92,7 +100,8 @@ void verify_instructions(const program& prog,
{
std::cout << "Verify: " << ins.name() << std::endl;
std::cout << p << std::endl;
verify_program(ins.name(), p, t, options, create_param_map(p, false), tolerance);
verify_program(
ins.name(), p, t, options, quantize, create_param_map(p, false), tolerance);
}
catch(...)
{
......@@ -106,6 +115,7 @@ void verify_reduced(program p,
int n,
const target& t,
compile_options options,
precision quantize,
const parameter_map& inputs,
double tolerance)
{
......@@ -114,12 +124,13 @@ void verify_reduced(program p,
mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << std::endl;
std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, inputs, tolerance);
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance);
}
void verify_reduced_program(const program& p,
const target& t,
compile_options options,
precision quantize,
const parameter_map& inputs,
double tolerance)
{
......@@ -127,7 +138,7 @@ void verify_reduced_program(const program& p,
auto n = std::distance(mm->begin(), mm->end());
for(std::size_t i = 0; i < n; i++)
{
verify_reduced(p, i, t, options, inputs, tolerance);
verify_reduced(p, i, t, options, quantize, inputs, tolerance);
}
}
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP
#include "precision.hpp"
#include <migraphx/program.hpp>
namespace migraphx {
......@@ -11,15 +12,18 @@ void verify_program(const std::string& name,
const program& p,
const target& t,
compile_options options = compile_options{},
precision quantize = precision::fp32,
const parameter_map& inputs = {},
double tolerance = 100);
void verify_instructions(const program& prog,
const target& t,
compile_options options = compile_options{},
precision quantize = precision::fp32,
double tolerance = 80);
void verify_reduced_program(const program& p,
const target& t,
compile_options options = compile_options{},
precision quantize = precision::fp32,
const parameter_map& inputs = {},
double tolerance = 80);
......
......@@ -45,6 +45,7 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const
{
dlerror();
void* symbol = dlsym(impl->handle.get(), name.c_str());
if(symbol == nullptr)
MIGRAPHX_THROW("Symbol not found: " + name);
......
......@@ -11,11 +11,13 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
static bool try_compute_shape(instruction_ref ins,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mods)
{
try
{
shape new_shape = ins->get_operator().compute_shape(inputs);
shape new_shape = ins->get_operator().compute_shape(inputs, mods);
// If the output shape is a standard shape, no need to try its output
if(new_shape.standard())
{
......@@ -45,7 +47,7 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inp
return (arg == ins) ? new_shape : arg->get_shape();
});
if(!try_compute_shape(output, input_shapes))
if(!try_compute_shape(output, input_shapes, mods))
{
return false;
}
......@@ -59,10 +61,12 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inp
return true;
}
static bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args)
static bool try_compute_shape(instruction_ref ins,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods)
{
auto inputs = to_shapes(args);
return try_compute_shape(ins, inputs);
return try_compute_shape(ins, inputs, mods);
}
void eliminate_contiguous::apply(module& p) const
......@@ -82,7 +86,7 @@ void eliminate_contiguous::apply(module& p) const
auto new_args = args;
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args))
if(try_compute_shape(ins, new_args, ins->module_inputs()))
{
instruction::replace_argument(ins, arg, prev);
}
......
......@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void eliminate_data_type::apply(module& m) const
{
static const std::vector<std::string> skip_op_names = {
"convert", "get_tuple_elem", "if", "loop"};
"convert", "get_tuple_elem", "if", "loop", "roialign"};
for(auto ins : iterator_for(m))
{
if(ins->name()[0] == '@')
......
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <iterator>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static literal get_scalar(instruction_ref ins)
{
if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar()))
return {};
if(not ins->can_eval())
return {};
auto e = ins->eval();
literal r{};
e.visit_at([&](auto x) { r = literal{x}; });
return r;
}
static void create_pointwise_modules(module_pass_manager& mpm)
{
std::size_t n = 0;
for(auto ins : iterator_for(mpm.get_module()))
{
if(not ins->get_operator().attributes().get("pointwise", false))
continue;
assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
std::vector<instruction_ref> pointwise_inputs;
std::size_t i = 0;
for(auto input : ins->inputs())
{
if(contains(param_map, input))
continue;
auto scalar = get_scalar(input);
if(scalar.empty())
{
pointwise_inputs.push_back(input);
param_map[input] =
pm->add_parameter("x" + std::to_string(i), shape{input->get_shape().type()});
i++;
}
else
{
param_map[input] = pm->add_literal(scalar);
}
}
std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return param_map[input]; });
auto r = pm->add_instruction(ins->get_operator(), inputs);
pm->add_return({r});
mpm.get_module().replace_instruction(ins, make_op("pointwise"), pointwise_inputs, {pm});
}
}
static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
instruction_ref output)
{
assert(contains(output->inputs(), ins));
module_ref pm = ins->module_inputs().at(0);
module_ref xm = output->module_inputs().at(0);
auto last = std::prev(pm->end());
assert(last->name() == "@return");
assert(last->inputs().size() == 1);
assert(pm->get_parameter_names().size() == ins->inputs().size());
assert(xm->get_parameter_names().size() == output->inputs().size());
std::vector<instruction_ref> inputs = ins->inputs();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::unordered_map<instruction_ref, instruction_ref> input_map;
// Copy inputs to input_map
for(auto i : range(inputs.size()))
{
auto input = inputs[i];
auto param = pm->get_parameter("x" + std::to_string(i));
assert(param != pm->end());
input_map[input] = param;
}
// Add the new parameter and additional inputs
for(auto i : range(output->inputs().size()))
{
auto input = output->inputs()[i];
auto param = xm->get_parameter("x" + std::to_string(i));
assert(param != xm->end());
if(input == ins)
{
map_ins[param] = last->inputs().front();
input_map[input] = map_ins[param];
}
// Avoid duplicate paramter inputs
else if(contains(input_map, input))
{
map_ins[param] = input_map[input];
}
else
{
map_ins[param] =
pm->add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()});
inputs.push_back(input);
input_map[input] = map_ins[param];
}
}
pm->replace_return(pm->insert_module_instructions(last, xm, map_ins));
return inputs;
}
static bool find_pointwise_modules(module& m)
{
bool changed = false;
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(ins->name() != "pointwise")
continue;
if(ins->outputs().empty() and ins != last)
continue;
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and i->outputs().size() == 1;
});
if(it == ins->inputs().end())
continue;
auto input = *it;
auto new_inputs = append_pointwise_module(input, ins);
m.replace_instruction(input, input->get_operator(), new_inputs, input->module_inputs());
m.replace_instruction(ins, input);
m.move_instruction(input, ins);
changed = true;
}
return changed;
}
void fuse_pointwise::apply(module_pass_manager& mpm) const
{
create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{});
for(int i = 0; i < 8; i++)
{
if(not find_pointwise_modules(mpm.get_module()))
break;
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment