Unverified Commit 23cb7917 authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Merge branch 'develop' into blas_tuning

parents b5fcc0bc ea32ca70
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/execution_environment.hpp> #include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
...@@ -43,7 +44,7 @@ namespace migraphx { ...@@ -43,7 +44,7 @@ namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b) extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b)
{ {
disable_exception_catch = b; disable_exception_catch = b;
} }
...@@ -145,6 +146,11 @@ void set_default_dim_value(onnx_options& options, size_t value) ...@@ -145,6 +146,11 @@ void set_default_dim_value(onnx_options& options, size_t value)
options.default_dim_value = value; options.default_dim_value = value;
} }
void set_default_dyn_dim_value(onnx_options& options, const shape::dynamic_dimension& dd)
{
options.default_dyn_dim_value = dd;
}
void set_default_loop_iterations(onnx_options& options, int64_t value) void set_default_loop_iterations(onnx_options& options, int64_t value)
{ {
options.max_loop_iterations = value; options.max_loop_iterations = value;
...@@ -161,6 +167,13 @@ void set_input_parameter_shape(onnx_options& options, ...@@ -161,6 +167,13 @@ void set_input_parameter_shape(onnx_options& options,
options.map_input_dims[std::string(name)] = std::move(dims); options.map_input_dims[std::string(name)] = std::move(dims);
} }
void set_dyn_input_parameter_shape(onnx_options& options,
const char* name,
std::vector<shape::dynamic_dimension> dyn_dims)
{
options.map_dyn_input_dims[std::string(name)] = std::move(dyn_dims);
}
void set_input_parameter_shape(tf_options& options, const char* name, std::vector<std::size_t> dims) void set_input_parameter_shape(tf_options& options, const char* name, std::vector<std::size_t> dims)
{ {
options.map_input_dims[std::string(name)] = std::move(dims); options.map_input_dims[std::string(name)] = std::move(dims);
...@@ -187,6 +200,12 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& ...@@ -187,6 +200,12 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>&
return result; return result;
} }
template <class T>
std::set<T> make_set(const T* x, std::size_t n)
{
return {x, x + n};
}
void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names) void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
{ {
if(names.empty()) if(names.empty())
...@@ -346,7 +365,10 @@ const Target* object_cast(const U* x) ...@@ -346,7 +365,10 @@ const Target* object_cast(const U* x)
template <class T, class... Ts, class Target = std::remove_pointer_t<T>> template <class T, class... Ts, class Target = std::remove_pointer_t<T>>
Target* allocate(Ts&&... xs) Target* allocate(Ts&&... xs)
{ {
return new Target(std::forward<Ts>(xs)...); // NOLINT if constexpr(std::is_aggregate<Target>{})
return new Target{std::forward<Ts>(xs)...}; // NOLINT
else
return new Target(std::forward<Ts>(xs)...); // NOLINT
} }
template <class T> template <class T>
...@@ -409,6 +431,39 @@ struct manage_generic_ptr ...@@ -409,6 +431,39 @@ struct manage_generic_ptr
D deleter = nullptr; D deleter = nullptr;
}; };
extern "C" struct migraphx_optimals;
struct migraphx_optimals
{
template <class... Ts>
migraphx_optimals(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::set<size_t> object;
};
extern "C" struct migraphx_dynamic_dimension;
struct migraphx_dynamic_dimension
{
template <class... Ts>
migraphx_dynamic_dimension(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::shape::dynamic_dimension object;
};
extern "C" struct migraphx_dynamic_dimensions;
struct migraphx_dynamic_dimensions
{
template <class... Ts>
migraphx_dynamic_dimensions(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::shape::dynamic_dimension> object;
};
extern "C" struct migraphx_shape; extern "C" struct migraphx_shape;
struct migraphx_shape struct migraphx_shape
{ {
...@@ -736,6 +791,152 @@ struct migraphx_experimental_custom_op ...@@ -736,6 +791,152 @@ struct migraphx_experimental_custom_op
} }
}; };
extern "C" migraphx_status migraphx_optimals_destroy(migraphx_optimals_t optimals)
{
auto api_error_result = migraphx::try_([&] { destroy((optimals)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_optimals_assign_to(migraphx_optimals_t output,
const_migraphx_optimals_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_optimals_create(migraphx_optimals_t* optimals, const size_t* ptr, size_t size)
{
auto api_error_result = migraphx::try_([&] {
*optimals = object_cast<migraphx_optimals_t>(
allocate<std::set<size_t>>(migraphx::make_set<size_t>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_destroy(migraphx_dynamic_dimension_t dynamic_dimension)
{
auto api_error_result = migraphx::try_([&] { destroy((dynamic_dimension)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_assign_to(migraphx_dynamic_dimension_t output,
const_migraphx_dynamic_dimension_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_dynamic_dimension_create_min_max(
migraphx_dynamic_dimension_t* dynamic_dimension, size_t min, size_t max)
{
auto api_error_result = migraphx::try_([&] {
*dynamic_dimension = object_cast<migraphx_dynamic_dimension_t>(
allocate<migraphx::shape::dynamic_dimension>((min), (max)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals(migraphx_dynamic_dimension_t* dynamic_dimension,
size_t min,
size_t max,
migraphx_optimals_t optimals)
{
auto api_error_result = migraphx::try_([&] {
if(optimals == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter optimals: Null pointer");
*dynamic_dimension = object_cast<migraphx_dynamic_dimension_t>(
allocate<migraphx::shape::dynamic_dimension>((min), (max), (optimals->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_is_fixed(bool* out, const_migraphx_dynamic_dimension_t dynamic_dimension)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimension == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimension: Null pointer");
*out = (dynamic_dimension->object).is_fixed();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_equal(bool* out,
const_migraphx_dynamic_dimension_t dynamic_dimension,
const_migraphx_dynamic_dimension_t x)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimension == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimension: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((dynamic_dimension->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_destroy(migraphx_dynamic_dimensions_t dynamic_dimensions)
{
auto api_error_result = migraphx::try_([&] { destroy((dynamic_dimensions)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output,
const_migraphx_dynamic_dimensions_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions,
const const_migraphx_dynamic_dimension_t* ptr,
size_t size)
{
auto api_error_result = migraphx::try_([&] {
*dynamic_dimensions = object_cast<migraphx_dynamic_dimensions_t>(
allocate<std::vector<migraphx::shape::dynamic_dimension>>(
migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_size(size_t* out, migraphx_dynamic_dimensions_t dynamic_dimensions)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimensions == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimensions: Null pointer");
*out = (dynamic_dimensions->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_get(const_migraphx_dynamic_dimension_t* out,
migraphx_dynamic_dimensions_t dynamic_dimensions,
size_t idx)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimensions == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimensions: Null pointer");
*out = object_cast<const_migraphx_dynamic_dimension_t>(
&((dynamic_dimensions->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{ {
auto api_error_result = migraphx::try_([&] { destroy((shape)); }); auto api_error_result = migraphx::try_([&] { destroy((shape)); });
...@@ -794,6 +995,19 @@ extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, ...@@ -794,6 +995,19 @@ extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_create_dynamic(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
migraphx_dynamic_dimensions_t dims)
{
auto api_error_result = migraphx::try_([&] {
if(dims == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)), (dims->object)));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape) migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{ {
...@@ -824,6 +1038,17 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap ...@@ -824,6 +1038,17 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_dyn_dims(migraphx_dynamic_dimensions_t* out,
const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = allocate<migraphx_dynamic_dimensions_t>((shape->object).dyn_dims());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape) const_migraphx_shape_t shape)
{ {
...@@ -857,6 +1082,16 @@ extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shap ...@@ -857,6 +1082,16 @@ extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shap
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_ndim(size_t* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).ndim();
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x) migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x)
{ {
...@@ -880,6 +1115,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha ...@@ -880,6 +1115,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_dynamic(bool* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).dynamic();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i) extern "C" migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -915,6 +1160,17 @@ migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t s ...@@ -915,6 +1160,17 @@ migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t s
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_create_empty(migraphx_argument_t* argument,
const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*argument = object_cast<migraphx_argument_t>(allocate<migraphx::argument>((shape->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out, extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument) const_migraphx_argument_t argument)
{ {
...@@ -1176,7 +1432,7 @@ extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions ...@@ -1176,7 +1432,7 @@ extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions
} }
extern "C" migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions, extern "C" migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr, const const_migraphx_instruction_t* ptr,
size_t size) size_t size)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -1590,6 +1846,19 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape( ...@@ -1590,6 +1846,19 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_set_dyn_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, migraphx_dynamic_dimensions_t dims)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
migraphx::set_dyn_input_parameter_shape((onnx_options->object), (name), (dims->object));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value) migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value)
{ {
...@@ -1601,6 +1870,20 @@ migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options ...@@ -1601,6 +1870,20 @@ migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_onnx_options_set_default_dyn_dim_value(migraphx_onnx_options_t onnx_options,
const_migraphx_dynamic_dimension_t dd)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dd == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dd: Null pointer");
migraphx::set_default_dyn_dim_value((onnx_options->object), (dd->object));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value) int64_t value)
......
...@@ -26,6 +26,9 @@ ...@@ -26,6 +26,9 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
#include <migraphx/api/export.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
...@@ -66,6 +69,15 @@ typedef enum ...@@ -66,6 +69,15 @@ typedef enum
} migraphx_shape_datatype_t; } migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
typedef struct migraphx_optimals* migraphx_optimals_t;
typedef const struct migraphx_optimals* const_migraphx_optimals_t;
typedef struct migraphx_dynamic_dimension* migraphx_dynamic_dimension_t;
typedef const struct migraphx_dynamic_dimension* const_migraphx_dynamic_dimension_t;
typedef struct migraphx_dynamic_dimensions* migraphx_dynamic_dimensions_t;
typedef const struct migraphx_dynamic_dimensions* const_migraphx_dynamic_dimensions_t;
typedef struct migraphx_shape* migraphx_shape_t; typedef struct migraphx_shape* migraphx_shape_t;
typedef const struct migraphx_shape* const_migraphx_shape_t; typedef const struct migraphx_shape* const_migraphx_shape_t;
...@@ -157,360 +169,460 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void ...@@ -157,360 +169,460 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input); typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_destroy(migraphx_optimals_t optimals);
MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_assign_to(migraphx_optimals_t output,
const_migraphx_optimals_t input);
MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_create(migraphx_optimals_t* optimals,
const size_t* ptr,
size_t size);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimension_destroy(migraphx_dynamic_dimension_t dynamic_dimension);
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_assign_to(
migraphx_dynamic_dimension_t output, const_migraphx_dynamic_dimension_t input);
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_create_min_max(
migraphx_dynamic_dimension_t* dynamic_dimension, size_t min, size_t max);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals(migraphx_dynamic_dimension_t* dynamic_dimension,
size_t min,
size_t max,
migraphx_optimals_t optimals);
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_is_fixed(
bool* out, const_migraphx_dynamic_dimension_t dynamic_dimension);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimension_equal(bool* out,
const_migraphx_dynamic_dimension_t dynamic_dimension,
const_migraphx_dynamic_dimension_t x);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_destroy(migraphx_dynamic_dimensions_t dynamic_dimensions);
MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimensions_assign_to(
migraphx_dynamic_dimensions_t output, const_migraphx_dynamic_dimensions_t input);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions,
const const_migraphx_dynamic_dimension_t* ptr,
size_t size);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_size(size_t* out, migraphx_dynamic_dimensions_t dynamic_dimensions);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input); MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_get(const_migraphx_dynamic_dimension_t* out,
migraphx_dynamic_dimensions_t dynamic_dimensions,
size_t idx);
migraphx_status migraphx_shape_create(migraphx_shape_t* shape, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size);
migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_assign_to(migraphx_shape_t output,
migraphx_shape_datatype_t type, const_migraphx_shape_t input);
size_t* lengths,
size_t lengths_size,
size_t* strides,
size_t strides_size);
migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type); migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape); migraphx_shape_datatype_t type,
size_t* lengths,
size_t lengths_size,
size_t* strides,
size_t strides_size);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape); migraphx_shape_datatype_t type);
migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_dynamic(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
migraphx_dynamic_dimensions_t dims);
migraphx_status migraphx_shape_elements(size_t* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_lengths(const size_t** out,
size_t* out_size,
const_migraphx_shape_t shape);
migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_strides(const size_t** out,
size_t* out_size,
const_migraphx_shape_t shape);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_dyn_dims(migraphx_dynamic_dimensions_t* out,
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x); const_migraphx_shape_t shape);
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape);
migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_elements(size_t* out,
const_migraphx_shape_t shape);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_ndim(size_t* out, const_migraphx_shape_t shape);
const_migraphx_argument_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_equal(bool* out,
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer); const_migraphx_shape_t shape,
const_migraphx_shape_t x);
migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
const_migraphx_argument_t argument);
migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_dynamic(bool* out, const_migraphx_shape_t shape);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_index(size_t* out,
migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x); const_migraphx_shape_t shape,
size_t i);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed);
migraphx_status migraphx_target_destroy(migraphx_target_t target); MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input);
migraphx_status migraphx_target_assign_to(migraphx_target_t output, const_migraphx_target_t input); MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_create(migraphx_argument_t* argument,
const_migraphx_shape_t shape,
void* buffer);
migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name); MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_create_empty(migraphx_argument_t* argument,
const_migraphx_shape_t shape);
migraphx_status migraphx_program_parameter_shapes_destroy( MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument);
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_buffer(char** out,
const_migraphx_argument_t argument);
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_equal(bool* out,
const_migraphx_argument_t argument,
const_migraphx_argument_t x);
MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_generate(migraphx_argument_t* out,
const_migraphx_shape_t s,
size_t seed);
MIGRAPHX_C_EXPORT migraphx_status migraphx_target_destroy(migraphx_target_t target);
MIGRAPHX_C_EXPORT migraphx_status migraphx_target_assign_to(migraphx_target_t output,
const_migraphx_target_t input);
MIGRAPHX_C_EXPORT migraphx_status migraphx_target_create(migraphx_target_t* target,
const char* name);
MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes); migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_assign_to(
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output, migraphx_program_parameter_shapes_t output, const_migraphx_program_parameter_shapes_t input);
const_migraphx_program_parameter_shapes_t input);
migraphx_status migraphx_program_parameter_shapes_size( MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_size(
size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes); size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out, migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes, migraphx_program_parameter_shapes_t program_parameter_shapes,
const char* name); const char* name);
migraphx_status migraphx_program_parameter_shapes_names( MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_names(
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes); const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters); migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters);
migraphx_status migraphx_program_parameters_assign_to(migraphx_program_parameters_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameters_assign_to(
const_migraphx_program_parameters_t input); migraphx_program_parameters_t output, const_migraphx_program_parameters_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters); migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters);
migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters, MIGRAPHX_C_EXPORT migraphx_status
const char* name, migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters,
const_migraphx_argument_t argument); const char* name,
const_migraphx_argument_t argument);
migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments); MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments);
migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input); const_migraphx_arguments_t input);
migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments); MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_size(size_t* out,
migraphx_arguments_t arguments);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_get(const_migraphx_argument_t* out,
migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx); migraphx_arguments_t arguments,
size_t idx);
migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes); MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes);
migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output, const_migraphx_shapes_t input); MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output,
const_migraphx_shapes_t input);
migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes); MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_get(const_migraphx_shape_t* out,
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx); migraphx_shapes_t shapes,
size_t idx);
migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction); MIGRAPHX_C_EXPORT migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction);
migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output, MIGRAPHX_C_EXPORT migraphx_status
const_migraphx_instruction_t input); migraphx_instruction_assign_to(migraphx_instruction_t output, const_migraphx_instruction_t input);
migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions); MIGRAPHX_C_EXPORT migraphx_status
migraphx_instructions_destroy(migraphx_instructions_t instructions);
migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_instructions_assign_to(
const_migraphx_instructions_t input); migraphx_instructions_t output, const_migraphx_instructions_t input);
migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions, MIGRAPHX_C_EXPORT migraphx_status migraphx_instructions_create(
const_migraphx_instruction_t* ptr, migraphx_instructions_t* instructions, const const_migraphx_instruction_t* ptr, size_t size);
size_t size);
migraphx_status migraphx_modules_destroy(migraphx_modules_t modules); MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_destroy(migraphx_modules_t modules);
migraphx_status migraphx_modules_assign_to(migraphx_modules_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input); const_migraphx_modules_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_create(migraphx_modules_t* modules,
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size); migraphx_module_t* ptr,
size_t size);
migraphx_status migraphx_module_create(migraphx_module_t* module, char* name); MIGRAPHX_C_EXPORT migraphx_status migraphx_module_create(migraphx_module_t* module, char* name);
migraphx_status migraphx_module_print(const_migraphx_module_t module); MIGRAPHX_C_EXPORT migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
migraphx_operation_t op, migraphx_operation_t op,
migraphx_instructions_t args); migraphx_instructions_t args);
migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status
migraphx_module_t module, migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_operation_t op, migraphx_module_t module,
migraphx_instructions_t args, migraphx_operation_t op,
migraphx_modules_t module_refs); migraphx_instructions_t args,
migraphx_modules_t module_refs);
migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
const_migraphx_shape_t shape, const_migraphx_shape_t shape,
const char* buffer); const char* buffer);
migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
const char* name, const char* name,
const_migraphx_shape_t shape); const_migraphx_shape_t shape);
migraphx_status migraphx_module_add_return(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
migraphx_instructions_t args); migraphx_instructions_t args);
migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
const_migraphx_shape_t s); const_migraphx_shape_t s);
migraphx_status migraphx_program_destroy(migraphx_program_t program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input); const_migraphx_program_t input);
migraphx_status migraphx_program_create(migraphx_program_t* program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_create(migraphx_program_t* program);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program); migraphx_program_t program);
migraphx_status migraphx_program_create_module(migraphx_module_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_create_module(migraphx_module_t* out,
migraphx_program_t program, migraphx_program_t program,
const char* name); const char* name);
migraphx_status migraphx_program_compile(migraphx_program_t program, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target, migraphx_target_t target,
migraphx_compile_options_t options); migraphx_compile_options_t options);
migraphx_status migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_parameter_shapes(
migraphx_program_t program); migraphx_program_parameter_shapes_t* out, migraphx_program_t program);
migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_program_t program); migraphx_program_t program);
migraphx_status migraphx_program_print(const_migraphx_program_t program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_print(const_migraphx_program_t program);
migraphx_status migraphx_program_sort(migraphx_program_t program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_sort(migraphx_program_t program);
migraphx_status migraphx_program_run(migraphx_arguments_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params); migraphx_program_parameters_t params);
migraphx_status migraphx_program_run_async(migraphx_arguments_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_run_async(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params, migraphx_program_parameters_t params,
void* s, void* s,
const char* name); const char* name);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_program_equal(bool* out,
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x); const_migraphx_program_t program,
const_migraphx_program_t x);
migraphx_status migraphx_program_experimental_get_context(migraphx_context_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_experimental_get_context(
const_migraphx_program_t program); migraphx_context_t* out, const_migraphx_program_t program);
migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_assign_to(migraphx_operation_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input); const_migraphx_operation_t input);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation, MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes, const char* attributes,
...); ...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation); MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_name(char* out,
size_t out_size,
migraphx_operation_t operation);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_load(migraphx_program_t* out,
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options); const char* name,
migraphx_file_options_t options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_save(migraphx_program_t p,
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options); const char* name,
migraphx_file_options_t options);
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_assign_to(
const_migraphx_onnx_options_t input); migraphx_onnx_options_t output, const_migraphx_onnx_options_t input);
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape( MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size); migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size);
migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_dyn_input_parameter_shape(
size_t value); migraphx_onnx_options_t onnx_options, const char* name, migraphx_dynamic_dimensions_t dims);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value);
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_dyn_dim_value(
migraphx_onnx_options_t onnx_options, const_migraphx_dynamic_dimension_t dd);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_loop_iterations(
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_t onnx_options, int64_t value);
int64_t value);
migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_file_options_destroy(migraphx_file_options_t file_options);
migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_file_options_assign_to(
const_migraphx_file_options_t input); migraphx_file_options_t output, const_migraphx_file_options_t input);
migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_file_options_create(migraphx_file_options_t* file_options);
migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options, MIGRAPHX_C_EXPORT migraphx_status
const char* format); migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format);
migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options);
migraphx_status migraphx_compile_options_assign_to(migraphx_compile_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_compile_options_assign_to(
const_migraphx_compile_options_t input); migraphx_compile_options_t output, const_migraphx_compile_options_t input);
migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value); migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value);
migraphx_status migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, MIGRAPHX_C_EXPORT migraphx_status
bool value); migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_compile_options_set_exhaustive_tune_flag(
migraphx_compile_options_set_exhaustive_tune_flag(migraphx_compile_options_t compile_options, migraphx_compile_options_t compile_options, bool value);
bool value);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_onnx(migraphx_program_t* out,
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options); const char* name,
migraphx_onnx_options_t options);
migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
const void* data, const void* data,
size_t size, size_t size,
migraphx_onnx_options_t options); migraphx_onnx_options_t options);
migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options); MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options);
migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input); const_migraphx_tf_options_t input);
migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options); MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options);
migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc); MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
bool is_nhwc);
migraphx_status migraphx_tf_options_set_input_parameter_shape(migraphx_tf_options_t tf_options, MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_input_parameter_shape(
const char* name, migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size);
size_t* dims,
size_t dims_size);
migraphx_status migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, MIGRAPHX_C_EXPORT migraphx_status
size_t value); migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value);
migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options, MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_output_names(
const char** names, migraphx_tf_options_t tf_options, const char** names, size_t names_size);
size_t names_size);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_tf(migraphx_program_t* out,
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options); const char* name,
migraphx_tf_options_t options);
migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names); MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
migraphx_status migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_op_names_assign_to(
const_migraphx_quantize_op_names_t input); migraphx_quantize_op_names_t output, const_migraphx_quantize_op_names_t input);
migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names); MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, MIGRAPHX_C_EXPORT migraphx_status
const char* name); migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name);
migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_op_names_t name); migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, migraphx_quantize_op_names_t name);
migraphx_status migraphx_quantize_fp16(migraphx_program_t prog); MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options); migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_assign_to(
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output, migraphx_quantize_int8_options_t output, const_migraphx_quantize_int8_options_t input);
const_migraphx_quantize_int8_options_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options); migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_add_op_name(
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options, migraphx_quantize_int8_options_t quantize_int8_options, const char* name);
const char* name);
migraphx_status migraphx_quantize_int8_options_add_calibration_data( MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data); migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data);
migraphx_status migraphx_quantize_int8(migraphx_program_t prog, MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target, migraphx_target_t target,
migraphx_quantize_int8_options_t options); migraphx_quantize_int8_options_t options);
migraphx_status migraphx_context_finish(const_migraphx_context_t context); MIGRAPHX_C_EXPORT migraphx_status migraphx_context_finish(const_migraphx_context_t context);
migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context); MIGRAPHX_C_EXPORT migraphx_status migraphx_context_get_queue(void** out,
migraphx_context_t context);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op); migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_assign_to(
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output, migraphx_experimental_custom_op_t output, const_migraphx_experimental_custom_op_t input);
const_migraphx_experimental_custom_op_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op, migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj, void* obj,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
...@@ -518,21 +630,20 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -518,21 +630,20 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
const char* obj_typename, const char* obj_typename,
const char* name); const char* name);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_compute(
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute input);
migraphx_experimental_custom_op_compute input);
migraphx_status migraphx_experimental_custom_op_set_compute_shape( MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
migraphx_status migraphx_experimental_custom_op_set_output_alias( MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_output_alias(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_output_alias input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_output_alias input);
migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target( MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_runs_on_offload_target input); migraphx_experimental_custom_op_runs_on_offload_target input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op); migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -571,10 +571,90 @@ using require_interface = ...@@ -571,10 +571,90 @@ using require_interface =
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_CONST_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, const) #define MIGRAPHX_CONST_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, const)
/**
* Container to hold optimal dynamic dimension values.
*/
struct optimals : MIGRAPHX_HANDLE_BASE(optimals)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(optimals)
optimals(std::initializer_list<size_t> init_list)
{
this->make_handle(&migraphx_optimals_create, init_list.begin(), init_list.size());
}
};
/**
* @brief Dynamic dimension object.
* @details minimum, maximum, and optimal dimensions
*/
struct dynamic_dimension : MIGRAPHX_CONST_HANDLE_BASE(dynamic_dimension)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(dynamic_dimension)
dynamic_dimension(size_t min, size_t max)
{
this->make_handle(&migraphx_dynamic_dimension_create_min_max, min, max);
}
dynamic_dimension(size_t min, size_t max, const optimals& opts)
{
this->make_handle(
&migraphx_dynamic_dimension_create_min_max_optimals, min, max, opts.get_handle_ptr());
}
bool is_fixed() const
{
bool result = false;
call(&migraphx_dynamic_dimension_is_fixed, &result, this->get_handle_ptr());
return result;
}
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y)
{
bool pout;
call(&migraphx_dynamic_dimension_equal, &pout, x.get_handle_ptr(), y.get_handle_ptr());
return pout;
}
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y)
{
return not(x == y);
}
};
/**
* Container to hold dynamic_dimension objects.
*/
struct dynamic_dimensions : MIGRAPHX_HANDLE_BASE(dynamic_dimensions)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(dynamic_dimensions)
template <class... Ts>
dynamic_dimensions(Ts... xs)
{
std::array<const_migraphx_dynamic_dimension_t, sizeof...(Ts)> a{xs.get_handle_ptr()...};
this->make_handle(&migraphx_dynamic_dimensions_create, a.data(), a.size());
}
size_t size() const
{
size_t pout;
call(&migraphx_dynamic_dimensions_size, &pout, this->get_handle_ptr());
return pout;
}
dynamic_dimension operator[](size_t pidx) const
{
const_migraphx_dynamic_dimension_t pout;
call(&migraphx_dynamic_dimensions_get, &pout, this->get_handle_ptr(), pidx);
return {pout, this->share_handle()};
}
};
/** /**
* @brief Describe shape of tensor * @brief Describe shape of tensor
* @details A shape consists of a data type, lengths of multi-dimension tensor, and strides * @details A shape consists of a data type, lengths of multi-dimension tensor, and strides
*
*/ */
struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{ {
...@@ -598,6 +678,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -598,6 +678,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size()); this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size());
} }
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape(migraphx_shape_datatype_t t, std::initializer_list<std::size_t> d)
: shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}
shape(migraphx_shape_datatype_t type, shape(migraphx_shape_datatype_t type,
std::vector<size_t> plengths, std::vector<size_t> plengths,
std::vector<size_t> pstrides) std::vector<size_t> pstrides)
...@@ -610,6 +697,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -610,6 +697,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
pstrides.size()); pstrides.size());
} }
shape(migraphx_shape_datatype_t type, const dynamic_dimensions& dyn_dims)
{
this->make_handle(&migraphx_shape_create_dynamic, type, dyn_dims.get_handle_ptr());
}
std::vector<size_t> lengths() const std::vector<size_t> lengths() const
{ {
const size_t* pout; const size_t* pout;
...@@ -626,6 +718,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -626,6 +718,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return {pout, pout + pout_size}; return {pout, pout + pout_size};
} }
/// Get the dynamic dimensions of the shape
dynamic_dimensions dyn_dims() const
{
migraphx_dynamic_dimensions_t pout;
call(&migraphx_shape_dyn_dims, &pout, this->get_handle_ptr());
return {pout, own{}};
}
migraphx_shape_datatype_t type() const migraphx_shape_datatype_t type() const
{ {
migraphx_shape_datatype_t pout; migraphx_shape_datatype_t pout;
...@@ -654,6 +754,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -654,6 +754,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return result; return result;
} }
/// Is the shape dynamic
bool dynamic() const
{
bool result = false;
call(&migraphx_shape_dynamic, &result, this->get_handle_ptr());
return result;
}
// map element index to space index // map element index to space index
size_t index(size_t i) const size_t index(size_t i) const
{ {
...@@ -687,6 +795,11 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -687,6 +795,11 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
argument(shape pshape)
{
this->make_handle(&migraphx_argument_create_empty, pshape.get_handle_ptr());
}
argument(shape pshape, void* pbuffer) argument(shape pshape, void* pbuffer)
{ {
this->make_handle(&migraphx_argument_create, pshape.get_handle_ptr(), pbuffer); this->make_handle(&migraphx_argument_create, pshape.get_handle_ptr(), pbuffer);
...@@ -1182,12 +1295,27 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -1182,12 +1295,27 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
dim.size()); dim.size());
} }
void set_dyn_input_parameter_shape(const std::string& name, const dynamic_dimensions& dyn_dims)
{
call(&migraphx_onnx_options_set_dyn_input_parameter_shape,
this->get_handle_ptr(),
name.c_str(),
dyn_dims.get_handle_ptr());
}
/// When there is a dimension parameter, then use this default value /// When there is a dimension parameter, then use this default value
void set_default_dim_value(unsigned int value) void set_default_dim_value(unsigned int value)
{ {
call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value); call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value);
} }
void set_default_dyn_dim_value(const dynamic_dimension& dd)
{
call(&migraphx_onnx_options_set_default_dyn_dim_value,
this->get_handle_ptr(),
dd.get_handle_ptr());
}
/// Set default max iteration number for the loop operator /// Set default max iteration number for the loop operator
void set_default_loop_iterations(int64_t value) void set_default_loop_iterations(int64_t value)
{ {
...@@ -1359,13 +1487,17 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op ...@@ -1359,13 +1487,17 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
struct experimental_custom_op_base struct experimental_custom_op_base
{ {
experimental_custom_op_base() = default;
experimental_custom_op_base(const experimental_custom_op_base&) = default;
experimental_custom_op_base& operator=(const experimental_custom_op_base&) = default;
virtual ~experimental_custom_op_base() = default;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual argument compute(context ctx, shape output, arguments inputs) const = 0; virtual argument compute(context ctx, shape output, arguments inputs) const = 0;
virtual shape compute_shape(shapes inputs) const = 0; virtual shape compute_shape(shapes inputs) const = 0;
virtual std::vector<size_t> output_alias(shapes) const { return {}; } virtual std::vector<size_t> output_alias(shapes) const { return {}; }
// TODO: Return target string instead of bool // TODO: Return target string instead of bool
virtual bool runs_on_offload_target() const = 0; virtual bool runs_on_offload_target() const = 0;
virtual ~experimental_custom_op_base() = default;
}; };
struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)> struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)>
......
...@@ -45,56 +45,49 @@ def shape_type_wrap(p): ...@@ -45,56 +45,49 @@ def shape_type_wrap(p):
p.read = 'migraphx::to_shape_type(${name})' p.read = 'migraphx::to_shape_type(${name})'
@api.cwrap('migraphx::compile_options') def auto_handle(*args, **kwargs):
def compile_options_type_wrap(p): def with_handle(f):
if p.returns: return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
p.add_param('migraphx_compile_options *') *args, **kwargs)(f)
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_compile_options(${result})']
else:
p.add_param('migraphx_compile_options *')
p.read = '${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})'
@api.cwrap('migraphx::file_options')
def file_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_file_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_file_options(${result})']
else:
p.add_param('migraphx_file_options *')
p.read = '${name} == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*${name})'
return with_handle
@api.cwrap('migraphx::onnx_options')
def onnx_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_onnx_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_onnx_options(${result})']
else:
p.add_param('migraphx_onnx_options *')
p.read = '${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})'
@api.handle('migraphx_optimals', 'std::set<size_t>')
def optimals(h):
h.constructor('create',
api.params(ptr='const size_t*', size='size_t'),
fname='migraphx::make_set<size_t>')
@api.cwrap('migraphx::tf_options')
def tf_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_tf_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_tf_options(${result})']
else:
p.add_param('migraphx_tf_options *')
p.read = '${name} == nullptr ? migraphx::tf_options{} : migraphx::to_tf_options(*${name})'
@api.handle('migraphx_dynamic_dimension', 'migraphx::shape::dynamic_dimension')
def dynamic_dimension(h):
h.constructor('create_min_max', api.params(min='size_t', max='size_t'))
h.constructor(
'create_min_max_optimals',
api.params(min='size_t', max='size_t', optimals='std::set<size_t>'))
h.method('is_fixed', returns='bool', const=True)
h.method('equal',
api.params(x='const migraphx::shape::dynamic_dimension&'),
invoke='migraphx::equal($@)',
returns='bool',
const=True)
def auto_handle(*args, **kwargs):
def with_handle(f):
return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
*args, **kwargs)(f)
return with_handle @api.handle('migraphx_dynamic_dimensions',
'std::vector<migraphx::shape::dynamic_dimension>')
def dynamic_dimensions(h):
h.constructor(
'create',
api.params(ptr='const const_migraphx_dynamic_dimension_t*',
size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>')
h.method('size', returns='size_t')
h.method('get',
api.params(idx='size_t'),
fname='at',
cpp_name='operator[]',
returns='const migraphx::shape::dynamic_dimension&')
@auto_handle() @auto_handle()
...@@ -109,20 +102,29 @@ def shape(h): ...@@ -109,20 +102,29 @@ def shape(h):
lengths='std::vector<size_t>', lengths='std::vector<size_t>',
strides='std::vector<size_t>')) strides='std::vector<size_t>'))
h.constructor('create_scalar', api.params(type='migraphx::shape::type_t')) h.constructor('create_scalar', api.params(type='migraphx::shape::type_t'))
h.constructor(
'create_dynamic',
api.params(type='migraphx::shape::type_t',
dims='std::vector<migraphx::shape::dynamic_dimension>'))
h.method('lengths', h.method('lengths',
fname='lens', fname='lens',
returns='const std::vector<size_t>&', returns='const std::vector<size_t>&',
const=True) const=True)
h.method('strides', returns='const std::vector<size_t>&', const=True) h.method('strides', returns='const std::vector<size_t>&', const=True)
h.method('dyn_dims',
returns='std::vector<migraphx::shape::dynamic_dimension>',
const=True)
h.method('type', returns='migraphx::shape::type_t', const=True) h.method('type', returns='migraphx::shape::type_t', const=True)
h.method('elements', returns='size_t', const=True) h.method('elements', returns='size_t', const=True)
h.method('bytes', returns='size_t', const=True) h.method('bytes', returns='size_t', const=True)
h.method('ndim', returns='size_t', const=True)
h.method('equal', h.method('equal',
api.params(x='const migraphx::shape&'), api.params(x='const migraphx::shape&'),
invoke='migraphx::equal($@)', invoke='migraphx::equal($@)',
returns='bool', returns='bool',
const=True) const=True)
h.method('standard', returns='bool', const=True) h.method('standard', returns='bool', const=True)
h.method('dynamic', returns='bool', const=True)
h.method('index', api.params(i='size_t'), returns='size_t', const=True) h.method('index', api.params(i='size_t'), returns='size_t', const=True)
...@@ -130,6 +132,7 @@ def shape(h): ...@@ -130,6 +132,7 @@ def shape(h):
def argument(h): def argument(h):
h.constructor('create', h.constructor('create',
api.params(shape='const migraphx::shape&', buffer='void*')) api.params(shape='const migraphx::shape&', buffer='void*'))
h.constructor('create_empty', api.params(shape='const migraphx::shape&'))
h.method('shape', h.method('shape',
fname='get_shape', fname='get_shape',
cpp_name='get_shape', cpp_name='get_shape',
...@@ -213,7 +216,7 @@ def instruction(h): ...@@ -213,7 +216,7 @@ def instruction(h):
def instructions(h): def instructions(h):
h.constructor( h.constructor(
'create', 'create',
api.params(ptr='const_migraphx_instruction_t*', size='size_t'), api.params(ptr='const const_migraphx_instruction_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_instruction_t>') fname='migraphx::to_obj_vector<const_migraphx_instruction_t>')
...@@ -325,11 +328,22 @@ def onnx_options(h): ...@@ -325,11 +328,22 @@ def onnx_options(h):
api.params(name='const char*', dims='std::vector<size_t>'), api.params(name='const char*', dims='std::vector<size_t>'),
invoke='migraphx::set_input_parameter_shape($@)', invoke='migraphx::set_input_parameter_shape($@)',
) )
h.method(
'set_dyn_input_parameter_shape',
api.params(name='const char*',
dims='std::vector<migraphx::shape::dynamic_dimension>'),
invoke='migraphx::set_dyn_input_parameter_shape($@)',
)
h.method( h.method(
'set_default_dim_value', 'set_default_dim_value',
api.params(value='size_t'), api.params(value='size_t'),
invoke='migraphx::set_default_dim_value($@)', invoke='migraphx::set_default_dim_value($@)',
) )
h.method(
'set_default_dyn_dim_value',
api.params(dd='const migraphx::shape::dynamic_dimension&'),
invoke='migraphx::set_default_dyn_dim_value($@)',
)
h.method( h.method(
'set_default_loop_iterations', 'set_default_loop_iterations',
api.params(value='int64_t'), api.params(value='int64_t'),
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -31,20 +31,6 @@ ...@@ -31,20 +31,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1) std::vector<std::size_t> s1)
{ {
...@@ -77,32 +63,38 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -77,32 +63,38 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
} }
auto offset = s1.ndim() - s0.ndim(); auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims()); std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
std::transform( std::transform(s0.dyn_dims().cbegin(),
s0.dyn_dims().cbegin(), s0.dyn_dims().cend(),
s0.dyn_dims().cend(), s1.dyn_dims().cbegin() + offset,
s1.dyn_dims().cbegin() + offset, out_dims.begin() + offset,
out_dims.begin() + offset, [&](auto a, auto b) {
[&](auto a, auto b) { if(a == b or b == 1)
if(a == b) {
{ return a;
return a; }
} else if(a == 1)
else if(a == 1 or b == 1) {
{ return b;
// setting optimals to empty, may need to be changed }
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max)}; else
} {
else MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
{ migraphx::to_string_range(s0.dyn_dims()) + "} and {" +
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" + migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!");
migraphx::to_string_range(s0.dyn_dims()) + "} and {" + }
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!"); });
}
});
return out_dims; return out_dims;
} }
// Compute the common (broadcasted) dimensions of a list of fixed shapes std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<shape>& shapes)
{
auto ret_shape = shapes.at(0);
std::for_each(shapes.cbegin() + 1, shapes.cend(), [&](auto s) {
ret_shape = shape{ret_shape.type(), compute_broadcasted_dyn_dims(ret_shape, s)};
});
return ret_shape.dyn_dims();
}
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes) std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{ {
assert(not shapes.empty()); assert(not shapes.empty());
...@@ -154,34 +146,30 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> ...@@ -154,34 +146,30 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
if(std::any_of( if(std::any_of(
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); })) inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
{ {
// currently only handles the binary case auto input_shapes = to_shapes(inputs);
if(inputs.size() != 2) auto c_type = compute_common_types(input_shapes);
{ auto c_dyn_dims = compute_common_dyn_dims(input_shapes);
MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
"inputs, only handle two inputs if any are dynamic shape");
}
auto c_type = compute_common_types(to_shapes(inputs));
auto c_dyn_dims =
compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape());
// following should work for a static or dynamic shape auto s0 = inputs[0]->get_shape();
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims) if(not s0.dynamic() or s0.dyn_dims() != c_dyn_dims)
{ {
inputs[0] = m.insert_instruction( inputs[0] = m.insert_instruction(
ins, ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[0],
inputs[1]);
}
if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims)
{
inputs[1] = m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[1],
inputs[0]);
} }
std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
auto s = input->get_shape();
if(not s.dynamic() or s.dyn_dims() != c_dyn_dims)
{
return m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
input,
inputs[0]);
}
return input;
});
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type) if(input->get_shape().type() != c_type)
{ {
......
...@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module ...@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return *this; return *this;
} }
cpp_generator::function& cpp_generator::function::unused_param(const std::string& pname)
{
body.insert(0, "(void)" + pname + ";\n");
return *this;
}
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname) cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname)
{ {
params.push_back({pname, "T" + pname}); params.push_back({pname, "T" + pname});
...@@ -174,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op, ...@@ -174,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
else if(with_char(::isdigit)(key[0])) else if(with_char(::isdigit)(key[0]))
{ {
auto i = std::stoul(key); auto i = std::stoul(key);
if(i >= args.size())
MIGRAPHX_THROW("Invalid argument index: " + key);
return args.at(i); return args.at(i);
} }
else if(v.contains(key)) else if(v.contains(key))
...@@ -201,8 +208,24 @@ cpp_generator::function cpp_generator::generate_module(const module& m, ...@@ -201,8 +208,24 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
f.set_name(name).set_types(m).set_body( f.set_name(name).set_types(m).set_body(
m, [&](instruction_ref ins, const auto& names) -> std::string { m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal") if(ins->name() == "@literal")
return shape::cpp_type(ins->get_shape().type()) + "(" + {
ins->get_literal().to_string() + ")"; std::string string_literal;
ins->get_literal().visit([&](auto v) {
assert(v.size() == 1);
auto x = v.front();
if(std::isinf(x))
{
string_literal = "__builtin_huge_val()";
if(x < 0)
string_literal = "-__builtin_huge_val()";
}
else if(std::isnan(x))
string_literal = "__builtin_nan()";
else
string_literal = ins->get_literal().to_string();
});
return shape::cpp_type(ins->get_shape().type()) + "(" + string_literal + ")";
}
auto s = g(ins, names); auto s = g(ins, names);
if(impl->fresult) if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')'; return impl->fresult(ins->get_shape()) + '(' + s + ')';
......
...@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const ...@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined, // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate] // identity, allocate or tuple_type]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and if((not i->get_shape().dynamic() and
(i->get_shape().elements() == 0 and
i->get_shape().type() != migraphx::shape::tuple_type)) and
not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not i->is_undefined()) not i->is_undefined())
continue; continue;
......
...@@ -32,18 +32,20 @@ add_executable(driver ...@@ -32,18 +32,20 @@ add_executable(driver
marker_roctx.cpp marker_roctx.cpp
) )
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver) set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility if(NOT WIN32)
add_custom_command( # Copy driver for backwards compatibility (Linux only)
TARGET driver add_custom_command(
TARGET driver
POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:driver> $<TARGET_FILE:driver>
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
) )
set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver) set_directory_properties(PROPERTIES ADDITIONAL_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
endif()
rocm_clang_tidy_check(driver) rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf) target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py)
rocm_install_targets( rocm_install_targets(
TARGETS driver TARGETS driver
......
...@@ -338,11 +338,22 @@ struct argument_parser ...@@ -338,11 +338,22 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC auto file_exist() MIGRAPHX_DRIVER_STATIC auto file_exist()
{ {
return validate([](auto&, auto&, auto& params) { return validate([](auto&, auto&, const auto& params) {
if(params.empty()) if(params.empty())
throw std::runtime_error("No argument passed."); throw std::runtime_error("No argument passed.");
if(not fs::exists(params.back())) if(not fs::exists(params.back()))
throw std::runtime_error("Path does not exists: " + params.back()); throw std::runtime_error("Path does not exist: " + params.back());
});
}
MIGRAPHX_DRIVER_STATIC auto matches(const std::unordered_set<std::string>& names)
{
return validate([=](auto&, auto&, const auto& params) {
auto invalid_param = std::find_if(
params.begin(), params.end(), [&](const auto& p) { return names.count(p) == 0; });
if(invalid_param != params.end())
throw std::runtime_error("Invalid argument: " + *invalid_param +
". Valid arguments are {" + to_string_range(names) + "}");
}); });
} }
...@@ -570,8 +581,7 @@ struct argument_parser ...@@ -570,8 +581,7 @@ struct argument_parser
continue; continue;
if(flag[0] != '-') if(flag[0] != '-')
continue; continue;
auto d = std::ptrdiff_t d = levenshtein_distance(flag, input);
levenshtein_distance(flag.begin(), flag.end(), input.begin(), input.end());
if(d < result.distance) if(d < result.distance)
result = result_t{&arg, flag, input, d}; result = result_t{&arg, flag, input, d};
} }
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/py.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
...@@ -81,6 +82,7 @@ struct loader ...@@ -81,6 +82,7 @@ struct loader
{"--model"}, {"--model"},
ap.help("Load model"), ap.help("Load model"),
ap.type("resnet50|inceptionv3|alexnet"), ap.type("resnet50|inceptionv3|alexnet"),
ap.matches({"resnet50", "inceptionv3", "alexnet"}),
ap.group("input")); ap.group("input"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx")); ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
...@@ -241,6 +243,20 @@ struct loader ...@@ -241,6 +243,20 @@ struct loader
return options; return options;
} }
static std::string get_file_type(const std::string& file)
{
if(ends_with(file, ".onnx"))
return "onnx";
else if(ends_with(file, ".pb"))
return "tf";
else if(ends_with(file, ".json"))
return "json";
else if(ends_with(file, ".py"))
return "py";
else
return "migraphx";
}
program load() program load()
{ {
program p; program p;
...@@ -248,14 +264,7 @@ struct loader ...@@ -248,14 +264,7 @@ struct loader
{ {
if(file_type.empty()) if(file_type.empty())
{ {
if(ends_with(file, ".onnx")) file_type = get_file_type(file);
file_type = "onnx";
else if(ends_with(file, ".pb"))
file_type = "tf";
else if(ends_with(file, ".json"))
file_type = "json";
else
file_type = "migraphx";
} }
std::cout << "Reading: " << file << std::endl; std::cout << "Reading: " << file << std::endl;
if(file_type == "onnx") if(file_type == "onnx")
...@@ -272,6 +281,10 @@ struct loader ...@@ -272,6 +281,10 @@ struct loader
options.format = "json"; options.format = "json";
p = migraphx::load(file, options); p = migraphx::load(file, options);
} }
else if(file_type == "py")
{
p = migraphx::load_py(file);
}
else if(file_type == "migraphx") else if(file_type == "migraphx")
{ {
p = migraphx::load(file); p = migraphx::load(file);
...@@ -415,7 +428,8 @@ struct compiler ...@@ -415,7 +428,8 @@ struct compiler
program_params parameters; program_params parameters;
compiler_target ct; compiler_target ct;
compile_options co; compile_options co;
precision quantize = precision::fp32; bool to_fp16 = false;
bool to_int8 = false;
std::vector<std::string> fill0; std::vector<std::string> fill0;
std::vector<std::string> fill1; std::vector<std::string> fill1;
...@@ -436,13 +450,8 @@ struct compiler ...@@ -436,13 +450,8 @@ struct compiler
{"--exhaustive-tune"}, {"--exhaustive-tune"},
ap.help("Exhastively search for best tuning parameters for kernels"), ap.help("Exhastively search for best tuning parameters for kernels"),
ap.set_value(true)); ap.set_value(true));
ap(co.split_single_dyn_dim, ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
{"--split-single-dyn-dim"}, ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap.help("If there is a single non-fixed dynamic dimension in the model, then split to "
"static submodules"),
ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8));
} }
auto params(const program& p) auto params(const program& p)
...@@ -450,20 +459,46 @@ struct compiler ...@@ -450,20 +459,46 @@ struct compiler
return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch); return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch);
} }
auto host_params(const program& p)
{
return parameters.generate(p, ct.get_target(), true, l.batch);
}
program compile() program compile()
{ {
auto p = l.load(); auto p = l.load();
// Dont compile if its already been compiled // Dont compile if its already been compiled
if(p.is_compiled()) if(p.is_compiled())
{
if(ct.target_name == "gpu")
{
if(is_offload_copy_set(p) and not co.offload_copy)
{
std::cout << "MIGraphX program was likely compiled with offload_copy set, Try "
"passing "
"`--enable-offload-copy` if program run fails.\n";
}
else if(co.offload_copy)
{
std::cout << "MIGraphX program was likely compiled without "
"offload_copy set, Try "
"removing "
"`--enable-offload-copy` flag if passed to driver, if program run "
"fails.\n";
}
}
return p; return p;
}
auto t = ct.get_target(); auto t = ct.get_target();
if(quantize == precision::fp16) if(to_fp16)
{ {
quantize_fp16(p); quantize_fp16(p);
} }
else if(quantize == precision::int8) if(to_int8)
{ {
quantize_int8(p, t, {params(p)}); quantize_int8(p, t, {host_params(p)});
} }
p.compile(t, co); p.compile(t, co);
l.save(p); l.save(p);
...@@ -522,17 +557,23 @@ struct verify : command<verify> ...@@ -522,17 +557,23 @@ struct verify : command<verify>
auto t = c.ct.get_target(); auto t = c.ct.get_target();
auto m = c.parameters.generate(p, t, true, c.l.batch); auto m = c.parameters.generate(p, t, true, c.l.batch);
auto quantize = precision::fp32;
if(c.to_fp16)
quantize = precision::fp16;
if(c.to_int8)
quantize = precision::int8;
if(per_instruction) if(per_instruction)
{ {
verify_instructions(p, t, c.co, c.quantize, tolerance); verify_instructions(p, t, c.co, quantize, tolerance);
} }
else if(reduce) else if(reduce)
{ {
verify_reduced_program(p, t, c.co, c.quantize, m, tolerance); verify_reduced_program(p, t, c.co, quantize, m, tolerance);
} }
else else
{ {
verify_program(c.l.file, p, t, c.co, c.quantize, m, tolerance); verify_program(c.l.file, p, t, c.co, quantize, m, tolerance);
} }
} }
}; };
...@@ -662,6 +703,26 @@ struct onnx : command<onnx> ...@@ -662,6 +703,26 @@ struct onnx : command<onnx>
} }
}; };
struct tf : command<tf>
{
bool show_ops = false;
void parse(argument_parser& ap)
{
ap(show_ops,
{"--list", "-l"},
ap.help("List all tf operators supported by MIGraphX"),
ap.set_value(true));
}
void run() const
{
if(show_ops)
{
for(const auto& name : get_tf_operators())
std::cout << name << std::endl;
}
}
};
struct main_command struct main_command
{ {
static std::string get_command_help(const std::string& title = colorize(color::fg_yellow, static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
...@@ -709,7 +770,7 @@ struct main_command ...@@ -709,7 +770,7 @@ struct main_command
{ {
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
<< "' is not a valid command." << std::endl; << "' is not a valid command." << std::endl;
std::cout << get_command_help("Available commands:") << std::endl; std::cout << get_command_help("Available commands:");
} }
else else
{ {
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#include "perf.hpp" #include "perf.hpp"
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
...@@ -97,6 +99,38 @@ target get_target(bool gpu) ...@@ -97,6 +99,38 @@ target get_target(bool gpu)
return make_target("cpu"); return make_target("cpu");
} }
bool is_offload_copy_set(const program& p)
{
assert(p.is_compiled());
const module* mm = p.get_main_module();
std::vector<std::string> param_names = mm->get_parameter_names();
std::unordered_set<instruction_ref> param_ins;
std::transform(param_names.begin(),
param_names.end(),
std::inserter(param_ins, param_ins.begin()),
[&](const auto& i) { return mm->get_parameter(i); });
for(const auto& i : *mm)
{
if(i.name() == "hip::copy_to_gpu")
{
auto copy_arg = instruction::get_output_alias(i.inputs().front(), true);
param_ins.erase(copy_arg);
}
else if(i.name() == "@return")
{
auto return_args = i.inputs();
for(const auto& j : return_args)
{
auto alias_ins = instruction::get_output_alias(j, true);
if((alias_ins->name() == "@param" && param_ins.erase(alias_ins) == 0) or
(alias_ins->name() != "hip::copy_from_gpu"))
return false;
}
}
}
return param_ins.empty();
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
} // namespace migraphx } // namespace migraphx
...@@ -39,6 +39,15 @@ parameter_map create_param_map(const program& p, const target& t, bool offload = ...@@ -39,6 +39,15 @@ parameter_map create_param_map(const program& p, const target& t, bool offload =
parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu); parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu);
parameter_map create_param_map(const program& p, bool gpu = true); parameter_map create_param_map(const program& p, bool gpu = true);
target get_target(bool gpu); target get_target(bool gpu);
/**
* @brief Checks if MIGraphX program compiled for "GPU" has offload_copy set of not. This is
intended to print a HINT for the users and would not always correctly classify compiled program as
with or without offload_copy in all cases.
* @param p Compiled MIGraphX program for GPU backend
* @return true if program is classified as compiled with "offload_copy" set
*/
bool is_offload_copy_set(const program& p);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -48,7 +48,7 @@ struct dynamic_loader_impl ...@@ -48,7 +48,7 @@ struct dynamic_loader_impl
#pragma GCC diagnostic ignored "-Wignored-attributes" #pragma GCC diagnostic ignored "-Wignored-attributes"
#endif #endif
dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr) dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr)
: handle(dlopen(p.string().c_str(), RTLD_LAZY), : handle(dlopen(p.string().c_str(), RTLD_GLOBAL | RTLD_NOW),
manage_deleter<decltype(&dlclose), &dlclose>{}), manage_deleter<decltype(&dlclose), &dlclose>{}),
temp(std::move(t)) temp(std::move(t))
{ {
...@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address) ...@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
return p; return p;
} }
optional<dynamic_loader> dynamic_loader::try_load(const fs::path& p)
{
try
{
return dynamic_loader{p};
}
catch(const std::exception&)
{
return nullopt;
}
}
dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p)) dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p))
{ {
} }
......
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <iterator> #include <iterator>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,7 +41,7 @@ static literal get_scalar(instruction_ref ins) ...@@ -39,7 +41,7 @@ static literal get_scalar(instruction_ref ins)
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front()); return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape(); const auto& s = ins->get_shape();
if(s.elements() != 1 && not(s.scalar())) if(s.elements() != 1 and not(s.scalar()))
return {}; return {};
if(not ins->can_eval()) if(not ins->can_eval())
return {}; return {};
...@@ -67,7 +69,6 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -67,7 +69,6 @@ static void create_pointwise_modules(module_pass_manager& mpm)
continue; continue;
if(ins->get_operator().name() == "layout") if(ins->get_operator().name() == "layout")
continue; continue;
assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++)); auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass(); pm->set_bypass();
...@@ -193,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const ...@@ -193,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
{ {
create_pointwise_modules(mpm); create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
if(enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}))
{
return;
}
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
if(not find_pointwise_modules(mpm.get_module())) if(not find_pointwise_modules(mpm.get_module()))
......
...@@ -52,7 +52,7 @@ struct fused_reduce ...@@ -52,7 +52,7 @@ struct fused_reduce
{ {
if(mods.size() != 1) if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
auto* sm = mods.front(); const auto* sm = mods.front();
if(sm->get_output_shapes().size() != 1) if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported"); MIGRAPHX_THROW("Only one output supported");
auto names = sm->get_parameter_names(); auto names = sm->get_parameter_names();
...@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm, ...@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
} }
static std::vector<instruction_ref> static std::vector<instruction_ref>
find_inputs(module_ref sm, find_inputs(const_module_ref sm,
const module& parent, const module& parent,
const std::unordered_map<instruction_ref, instruction_ref>& map_ins) const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{ {
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
argument fill_argument(shape s, unsigned long value) argument fill_argument(shape s, double value)
{ {
argument result; argument result;
if(s.type() == shape::tuple_type) if(s.type() == shape::tuple_type)
......
...@@ -32,7 +32,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,7 +32,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
struct adjust_allocation struct MIGRAPHX_EXPORT adjust_allocation
{ {
allocation_model model; allocation_model model;
std::string name() const { return "adjust_allocation"; } std::string name() const { return "adjust_allocation"; }
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <string>
#include <vector>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat ...@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
return std::ptrdiff_t{1} + std::min({x1, x2, x3}); return std::ptrdiff_t{1} + std::min({x1, x2, x3});
} }
inline size_t levenshtein_distance(const std::string& s1, const std::string& s2)
{
const size_t l1 = s1.length();
const size_t l2 = s2.length();
if(l1 < l2)
levenshtein_distance(s2, s1);
std::vector<size_t> d(l2 + 1);
std::iota(d.begin(), d.end(), 0);
for(size_t i = 1; i <= l1; i++)
{
size_t prev_cost = d[0];
d[0] = i;
for(size_t j = 1; j <= l2; j++)
{
if(s1[i - 1] == s2[j - 1])
{
d[j] = prev_cost;
}
else
{
size_t cost_insert_or_delete = std::min(d[j - 1], d[j]);
size_t cost_substitute = prev_cost;
prev_cost = d[j];
d[j] = std::min(cost_substitute, cost_insert_or_delete) + 1;
}
}
}
return d[l2];
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -60,7 +60,7 @@ struct allocation_model ...@@ -60,7 +60,7 @@ struct allocation_model
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
struct allocation_model struct MIGRAPHX_EXPORT allocation_model
{ {
// //
std::string name() const; std::string name() const;
...@@ -96,7 +96,7 @@ struct allocation_model ...@@ -96,7 +96,7 @@ struct allocation_model
{ {
using std::swap; using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>(); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique()) if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{ {
*derived = std::forward<PrivateDetailTypeErasedT>(value); *derived = std::forward<PrivateDetailTypeErasedT>(value);
} }
...@@ -267,7 +267,7 @@ struct allocation_model ...@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique()) if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -39,7 +39,8 @@ struct stream_race ...@@ -39,7 +39,8 @@ struct stream_race
instruction_ref before; instruction_ref before;
}; };
std::vector<stream_race> analyze_streams(const module& m, const stream_model& strmm); MIGRAPHX_EXPORT std::vector<stream_race> analyze_streams(const module& m,
const stream_model& strmm);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment