Commit d49d4f66 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into mlir-c

parents 7d248d46 4d82d761
......@@ -80,10 +80,10 @@ jobs:
uses: pat-s/always-upload-cache@v2.1.3
with:
path: cppcheck-cache
key: cppcheck-cache-${{ steps.cache_timestamp.outputs.timestamp }}
key: cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ steps.cache_timestamp.outputs.timestamp }}
restore-keys: |
cppcheck-cache-${{ steps.cache_timestamp.outputs.timestamp }}
cppcheck-cache-
cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ steps.cache_timestamp.outputs.timestamp }}
cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-
- name: Build the Docker image
run: docker build . --file hip-clang.docker --tag migraphx
......
......@@ -190,6 +190,7 @@ rocm_enable_cppcheck(
shadowVariable
unsafeClassDivZero
definePrefix:*test/include/test.hpp
ctuOneDefinitionRuleViolation:*test/*
useSmartPointer:*src/api/api.cpp
useSmartPointer:*make_shared_array.hpp
constParameter:*src/targets/gpu/*.cpp
......
......@@ -154,7 +154,7 @@
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[((?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)*) (\w) ; \2 = static_cast < \1 > (\([^()]*(?-1)*[^()]*\)) ;]]></pattern>
<pattern><![CDATA[((?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)*) (\w+) ; \2 = static_cast < \1 > (\([^()]*(?-1)*[^()]*\)) ;]]></pattern>
<message>
<id>RedundantCast</id>
<severity>style</severity>
......@@ -163,7 +163,7 @@
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[auto (\w) ; \1 = static_cast < (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* > (\([^()]*(?-1)*[^()]*\)) ;]]></pattern>
<pattern><![CDATA[auto (\w+) ; \1 = static_cast < (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* > (\([^()]*(?-1)*[^()]*\)) ;]]></pattern>
<message>
<id>RedundantCast</id>
<severity>style</severity>
......
......@@ -26,6 +26,7 @@ add_library(migraphx
eliminate_pad.cpp
env.cpp
file_buffer.cpp
fuse_pointwise.cpp
generate.cpp
inline_module.cpp
insert_pad.cpp
......@@ -134,6 +135,7 @@ register_migraphx_ops(
nonzero
outline
pad
pointwise
pooling
pow
prefix_scan_sum
......@@ -199,6 +201,9 @@ target_link_libraries(migraphx PRIVATE -ldl)
target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
find_package(Threads)
target_link_libraries(migraphx PUBLIC Threads::Threads)
find_package(msgpack REQUIRED)
target_link_libraries(migraphx PRIVATE msgpackc-cxx)
# Make this available to the tests
......@@ -260,6 +265,7 @@ rocm_export_targets(
TARGETS migraphx::migraphx migraphx_all_targets
NAMESPACE migraphx::
DEPENDS
Threads
${PACKAGE_DEPENDS}
)
......
......@@ -13,6 +13,7 @@
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
......@@ -155,18 +156,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
}
operation create_op(const char* name, const char* attributes)
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
operation create_op(const char* name, const char* attributes, va_list vlist)
{
std::string sattributes = attributes == nullptr ? "" : attributes;
std::vector<char> buffer(sattributes.size() * 2);
std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist);
value v = value::object{};
if(attributes != nullptr)
{
v = from_json_string(convert_to_json(std::string(attributes)));
v = from_json_string(convert_to_json(std::string(buffer.data())));
}
auto op = make_op(name, v);
return op;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
template <class T>
bool equal(const T& x, const T& y)
{
......@@ -368,7 +381,8 @@ struct migraphx_quantize_int8_options
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{
return migraphx::try_([&] { destroy((shape)); });
auto api_error_result = migraphx::try_([&] { destroy((shape)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
......@@ -376,13 +390,14 @@ extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
size_t* lengths,
size_t lengths_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)),
(std::vector<size_t>(lengths, lengths + lengths_size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
......@@ -392,7 +407,7 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
size_t* strides,
size_t strides_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(lengths == nullptr and lengths_size != 0)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer");
if(strides == nullptr and strides_size != 0)
......@@ -402,21 +417,23 @@ extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t*
(std::vector<size_t>(lengths, lengths + lengths_size)),
(std::vector<size_t>(strides, strides + strides_size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type))));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
......@@ -425,12 +442,13 @@ migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data();
*out_size = api_result.size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr or out_size == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
......@@ -439,127 +457,141 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
*out = api_result.data();
*out_size = api_result.size();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = migraphx::to_shape_type((shape->object).type());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).bytes();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((shape->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{
return migraphx::try_([&] { destroy((argument)); });
auto api_error_result = migraphx::try_([&] { destroy((argument)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*argument = object_cast<migraphx_argument_t>(
allocate<migraphx::argument>((shape->object), (buffer)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((argument->object).get_shape()));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
*out = (argument->object).data();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(argument == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((argument->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_argument_t>(migraphx::generate_argument((s->object), (seed)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target)
{
return migraphx::try_([&] { destroy((target)); });
auto api_error_result = migraphx::try_([&] { destroy((target)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*target = object_cast<migraphx_target_t>(
allocate<migraphx::target>(migraphx::get_target((name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] { destroy((program_parameter_shapes)); });
auto api_error_result = migraphx::try_([&] { destroy((program_parameter_shapes)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameter_shapes_size(size_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer");
*out = (program_parameter_shapes->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
......@@ -567,19 +599,20 @@ migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes,
const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameter_shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameter_shapes: Null pointer");
*out =
object_cast<const_migraphx_shape_t>(&((program_parameter_shapes->object).at((name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_parameter_shapes_names(
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(program_parameter_shapes == nullptr)
......@@ -588,21 +621,24 @@ extern "C" migraphx_status migraphx_program_parameter_shapes_names(
auto&& api_result = migraphx::get_names((program_parameter_shapes->object));
std::copy(api_result.begin(), api_result.end(), out);
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters)
{
return migraphx::try_([&] { destroy((program_parameters)); });
auto api_error_result = migraphx::try_([&] { destroy((program_parameters)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*program_parameters = object_cast<migraphx_program_parameters_t>(
allocate<std::unordered_map<std::string, migraphx::argument>>());
});
return api_error_result;
}
extern "C" migraphx_status
......@@ -610,7 +646,7 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
const char* name,
const_migraphx_argument_t argument)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program_parameters == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter program_parameters: Null pointer");
......@@ -618,85 +654,95 @@ migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer");
(program_parameters->object)[(name)] = (argument->object);
});
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments)
{
return migraphx::try_([&] { destroy((arguments)); });
auto api_error_result = migraphx::try_([&] { destroy((arguments)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = (arguments->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(arguments == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer");
*out = object_cast<const_migraphx_argument_t>(&((arguments->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes)
{
return migraphx::try_([&] { destroy((shapes)); });
auto api_error_result = migraphx::try_([&] { destroy((shapes)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = (shapes->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(shapes == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer");
*out = object_cast<const_migraphx_shape_t>(&((shapes->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
migraphx::print_module((module->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{
return migraphx::try_([&] { destroy((program)); });
auto api_error_result = migraphx::try_([&] { destroy((program)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).get_main_module());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(target == nullptr)
......@@ -705,91 +751,105 @@ extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
(program->object).compile((target->object), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out =
allocate<migraphx_program_parameter_shapes_t>((program->object).get_parameter_shapes());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = allocate<migraphx_shapes_t>(migraphx::get_output_shapes((program->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
migraphx::print_program((program->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
(program->object).sort();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program,
migraphx_program_parameters_t params)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(params == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer");
*out = allocate<migraphx_arguments_t>(migraphx::run((program->object), (params->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((program->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation)
{
return migraphx::try_([&] { destroy((operation)); });
auto api_error_result = migraphx::try_([&] { destroy((operation)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_operation_create(migraphx_operation_t* operation, const char* name, const char* attributes)
extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes,
...)
{
return migraphx::try_([&] {
va_list vlist;
va_start(vlist, attributes);
auto api_error_result = migraphx::try_([&] {
*operation = object_cast<migraphx_operation_t>(
allocate<migraphx::operation>(migraphx::create_op((name), (attributes))));
allocate<migraphx::operation>(migraphx::create_op((name), (attributes), (vlist))));
});
va_end(vlist);
return api_error_result;
}
extern "C" migraphx_status
migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(operation == nullptr)
......@@ -798,46 +858,51 @@ migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operati
auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out);
*it = '\0';
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::load((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(p == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer");
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::save((p->object), (name), (options->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options)
{
return migraphx::try_([&] { destroy((onnx_options)); });
auto api_error_result = migraphx::try_([&] { destroy((onnx_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*onnx_options = object_cast<migraphx_onnx_options_t>(allocate<migraphx::onnx_options>());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr and dims_size != 0)
......@@ -845,96 +910,107 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape(
(onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_dim_value((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_loop_iterations((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options)
{
return migraphx::try_([&] { destroy((file_options)); });
auto api_error_result = migraphx::try_([&] { destroy((file_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*file_options = object_cast<migraphx_file_options_t>(allocate<migraphx::file_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(file_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter file_options: Null pointer");
migraphx::set_file_format((file_options->object), (format));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options)
{
return migraphx::try_([&] { destroy((compile_options)); });
auto api_error_result = migraphx::try_([&] { destroy((compile_options)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*compile_options =
object_cast<migraphx_compile_options_t>(allocate<migraphx::compile_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_offload_copy((compile_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_fast_math((compile_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_onnx((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
......@@ -942,40 +1018,44 @@ extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
size_t size,
migraphx_onnx_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(
migraphx::parse_onnx_buffer((data), (size), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options)
{
return migraphx::try_([&] { destroy((tf_options)); });
auto api_error_result = migraphx::try_([&] { destroy((tf_options)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*tf_options = object_cast<migraphx_tf_options_t>(allocate<migraphx::tf_options>());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
bool is_nhwc)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_nhwc((tf_options->object), (is_nhwc));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(dims == nullptr and dims_size != 0)
......@@ -983,23 +1063,25 @@ extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape(
migraphx::set_input_parameter_shape(
(tf_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
migraphx::set_default_dim_value((tf_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options,
const char** names,
size_t names_size)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(tf_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer");
if(names == nullptr and names_size != 0)
......@@ -1007,96 +1089,106 @@ extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_opti
migraphx::set_output_names((tf_options->object),
(std::vector<const char*>(names, names + names_size)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_tf((name), (options->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names)
{
return migraphx::try_([&] { destroy((quantize_op_names)); });
auto api_error_result = migraphx::try_([&] { destroy((quantize_op_names)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*quantize_op_names =
object_cast<migraphx_quantize_op_names_t>(allocate<std::vector<std::string>>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_op_names == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_op_names: Null pointer");
(quantize_op_names->object).push_back((name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog,
migraphx_quantize_op_names_t name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
migraphx::quantize_fp16_with_op_names((prog->object), (name->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
migraphx::quantize_fp16((prog->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options)
{
return migraphx::try_([&] { destroy((quantize_int8_options)); });
auto api_error_result = migraphx::try_([&] { destroy((quantize_int8_options)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
*quantize_int8_options = object_cast<migraphx_quantize_int8_options_t>(
allocate<migraphx::quantize_int8_options>());
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options,
const char* name)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
migraphx::add_op_name((quantize_int8_options->object), (name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
......@@ -1104,13 +1196,14 @@ extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer");
migraphx::add_calibration_data((quantize_int8_options->object), (data->object));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_int8_options_t options)
{
return migraphx::try_([&] {
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(target == nullptr)
......@@ -1119,4 +1212,5 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::quantize_int8_wrap((prog->object), (target->object), (options->object));
});
return api_error_result;
}
......@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes);
const char* attributes,
...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
......
......@@ -599,9 +599,10 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
operation(const char* name, const char* attributes = nullptr)
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{
this->make_handle(&migraphx_operation_create, name, attributes);
this->make_handle(&migraphx_operation_create, name, attributes, xs...);
}
std::string name()
......
......@@ -212,7 +212,9 @@ def program(h):
@auto_handle()
def operation(h):
h.constructor('create',
api.params(name='const char*', attributes='const char*'),
api.params(name='const char*',
attributes='const char*',
vlist='...'),
fname='migraphx::create_op')
h.method('name', returns='std::string')
......
......@@ -155,5 +155,13 @@ std::vector<argument> argument::get_sub_objects() const
return result;
}
argument argument::element(std::size_t i) const
{
assert(this->get_shape().sub_shapes().empty());
auto idx = this->get_shape().index(i);
auto offset = this->get_shape().type_size() * idx;
return argument{shape{this->get_shape().type()}, this->data() + offset};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -6,6 +6,7 @@ add_executable(driver
resnet50.cpp
inceptionv3.cpp
alexnet.cpp
marker_roctx.cpp
)
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility
......
......@@ -3,6 +3,7 @@
#include "verify.hpp"
#include "perf.hpp"
#include "models.hpp"
#include "marker_roctx.hpp"
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
......@@ -483,6 +484,23 @@ struct perf : command<perf>
}
};
struct roctx : command<roctx>
{
compiler c;
void parse(argument_parser& ap) { c.parse(ap); }
void run()
{
std::cout << "Compiling ... " << std::endl;
auto p = c.compile();
std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p);
std::cout << "rocTX:\tLoading rocTX library..." << std::endl;
auto rtx = create_marker_roctx();
p.mark(m, std::move(rtx));
}
};
struct op : command<op>
{
bool show_ops = false;
......
#include "marker_roctx.hpp"
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
class marker_roctx
{
std::function<void(const char*)> sym_roctx_mark;
std::function<uint64_t(const char*)> sym_roctx_range_start;
std::function<void(uint64_t)> sym_roctx_range_stop;
std::function<int(const char*)> sym_roctx_range_push;
std::function<int()> sym_roctx_range_pop;
uint64_t range_id;
public:
marker_roctx()
{
dynamic_loader lib = migraphx::dynamic_loader{"libroctx64.so"};
sym_roctx_mark = lib.get_function<void(const char*)>("roctxMarkA");
sym_roctx_range_start = lib.get_function<uint64_t(const char*)>("roctxRangeStartA");
sym_roctx_range_stop = lib.get_function<void(uint64_t)>("roctxRangeStop");
sym_roctx_range_push = lib.get_function<int(const char*)>("roctxRangePushA");
sym_roctx_range_pop = lib.get_function<int()>("roctxRangePop");
sym_roctx_mark("rocTX marker created.");
}
void mark_start(instruction_ref ins_ref)
{
std::string text = "Marker start: " + ins_ref->name();
sym_roctx_range_push(text.c_str());
}
void mark_stop(instruction_ref) { sym_roctx_range_pop(); }
void mark_start(const program&) { range_id = sym_roctx_range_start("0"); }
void mark_stop(const program&) { sym_roctx_range_stop(range_id); }
};
marker create_marker_roctx() { return marker_roctx(); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP
#define MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP
#include <migraphx/marker.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
marker create_marker_roctx();
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
......@@ -45,6 +45,7 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const
{
dlerror();
void* symbol = dlsym(impl->handle.get(), name.c_str());
if(symbol == nullptr)
MIGRAPHX_THROW("Symbol not found: " + name);
......
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <iterator>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static literal get_scalar(instruction_ref ins)
{
const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar()))
return {};
if(not ins->can_eval())
return {};
auto e = ins->eval();
literal r{};
e.visit_at([&](auto x) { r = literal{x}; });
return r;
}
static void create_pointwise_modules(module_pass_manager& mpm)
{
std::size_t n = 0;
for(auto ins : iterator_for(mpm.get_module()))
{
if(not ins->get_operator().attributes().get("pointwise", false))
continue;
auto* pm = mpm.create_module("pointwise" + std::to_string(n++));
pm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
std::vector<instruction_ref> pointwise_inputs;
for(auto input : ins->inputs())
{
if(contains(param_map, input))
continue;
auto scalar = get_scalar(input);
if(scalar.empty())
{
pointwise_inputs.push_back(input);
param_map[input] = pm->add_parameter("x" + std::to_string(param_map.size()),
shape{input->get_shape().type()});
}
else
{
param_map[input] = pm->add_literal(scalar);
}
}
std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return param_map[input]; });
auto r = pm->add_instruction(ins->get_operator(), inputs);
pm->add_return({r});
mpm.get_module().replace_instruction(ins, make_op("pointwise"), pointwise_inputs, {pm});
}
}
static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
instruction_ref output)
{
module_ref pm = ins->module_inputs().at(0);
module_ref xm = output->module_inputs().at(0);
auto last = std::prev(pm->end());
assert(last->name() == "@return");
assert(last->inputs().size() == 1);
std::vector<instruction_ref> inputs = ins->inputs();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::unordered_map<instruction_ref, instruction_ref> input_map;
// Copy inputs to input_map
for(auto i : range(inputs.size()))
{
auto input = inputs[i];
auto param = pm->get_parameter("x" + std::to_string(i));
input_map[input] = param;
}
// Add the new parameter and additional inputs
for(auto i : range(output->inputs().size()))
{
auto input = output->inputs()[i];
auto param = xm->get_parameter("x" + std::to_string(i));
if(input == ins)
{
map_ins[param] = last->inputs().front();
input_map[input] = map_ins[param];
}
// Avoid duplicate paramter inputs
else if(contains(input_map, input))
{
map_ins[param] = input_map[input];
}
else
{
map_ins[param] =
pm->add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()});
inputs.push_back(input);
input_map[input] = map_ins[param];
}
}
pm->replace_return(pm->insert_module_instructions(last, xm, map_ins));
return inputs;
}
static bool find_pointwise_modules(module& m)
{
bool changed = false;
for(auto ins : iterator_for(m))
{
if(ins->name() != "pointwise")
continue;
if(ins->outputs().empty())
continue;
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and i->outputs().size() == 1;
});
if(it == ins->inputs().end())
continue;
auto new_inputs = append_pointwise_module(*it, ins);
m.replace_instruction(*it, (*it)->get_operator(), new_inputs, (*it)->module_inputs());
m.replace_instruction(ins, *it);
m.move_instruction(*it, ins);
changed = true;
}
return changed;
}
void fuse_pointwise::apply(module_pass_manager& mpm) const
{
create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{});
for(int i = 0; i < 8; i++)
{
if(not find_pointwise_modules(mpm.get_module()))
break;
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -67,6 +67,9 @@ struct argument : raw_data<argument>
std::vector<argument> get_sub_objects() const;
/// Return the ith element
argument element(std::size_t i) const;
private:
void assign_buffer(std::function<char*()> d);
struct data_t
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
struct fuse_pointwise
{
std::string name() const { return "fuse_pointwise"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
......@@ -152,30 +152,4 @@ struct instruction
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace std {
template <>
struct hash<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = std::size_t;
result_type operator()(const argument_type& x) const noexcept
{
return std::hash<migraphx::instruction*>{}(&*x);
}
};
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return &*x == &*y;
}
};
} // namespace std
#endif
......@@ -11,7 +11,35 @@ inline namespace MIGRAPHX_INLINE_NS {
struct instruction;
using instruction_ref = std::list<instruction>::iterator;
migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace std {
template <>
struct hash<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = std::size_t;
result_type operator()(const migraphx::instruction_ref& x) const noexcept
{
return std::hash<migraphx::instruction*>{}(migraphx::as_address(x));
}
};
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return migraphx::as_address(x) == migraphx::as_address(y);
}
};
} // namespace std
#endif
#ifndef MIGRAPHX_GUARD_MARKER_HPP
#define MIGRAPHX_GUARD_MARKER_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN
/// Marker is an interface to general marking functions, such as rocTX markers.
#else
/*
* Type-erased interface for:
*
* struct marker
* {
* void mark_start(instruction_ref ins_ref) ;
* void mark_start(const program& prog) ;
* void mark_stop(instruction_ref ins) ;
* void mark_stop(const program& prog) ;
* };
*
*/
struct marker
{
// Constructors
marker() = default;
template <typename PrivateDetailTypeErasedT>
marker(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
marker& operator=(PrivateDetailTypeErasedT value)
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
marker rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
void mark_start(instruction_ref ins_ref)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_start(ins_ref);
}
void mark_start(const program& prog)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_start(prog);
}
void mark_stop(instruction_ref ins)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_stop(ins);
}
void mark_stop(const program& prog)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_stop(prog);
}
friend bool is_shared(const marker& private_detail_x, const marker& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void mark_start(instruction_ref ins_ref) = 0;
virtual void mark_start(const program& prog) = 0;
virtual void mark_stop(instruction_ref ins) = 0;
virtual void mark_stop(const program& prog) = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
void mark_start(instruction_ref ins_ref) override
{
private_detail_te_value.mark_start(ins_ref);
}
void mark_start(const program& prog) override { private_detail_te_value.mark_start(prog); }
void mark_stop(instruction_ref ins) override { private_detail_te_value.mark_stop(ins); }
void mark_stop(const program& prog) override { private_detail_te_value.mark_stop(prog); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const marker* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(marker* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(marker& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const marker& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment