Unverified Commit 1530ec24 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into add_parity_check_ci

parents 5c98fcb0 c2e01b10
...@@ -26,6 +26,7 @@ add_library(migraphx_c ...@@ -26,6 +26,7 @@ add_library(migraphx_c
api.cpp api.cpp
) )
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c) set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
migraphx_generate_export_header(migraphx_c DIRECTORY migraphx/api)
# migraphx_c is stable API interface library. SO version of this should be # migraphx_c is stable API interface library. SO version of this should be
# bumped when binary compatibility is broken. # bumped when binary compatibility is broken.
......
...@@ -44,7 +44,7 @@ namespace migraphx { ...@@ -44,7 +44,7 @@ namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b) extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b)
{ {
disable_exception_catch = b; disable_exception_catch = b;
} }
...@@ -899,7 +899,7 @@ migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output, ...@@ -899,7 +899,7 @@ migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output,
extern "C" migraphx_status extern "C" migraphx_status
migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions, migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions,
const_migraphx_dynamic_dimension_t* ptr, const const_migraphx_dynamic_dimension_t* ptr,
size_t size) size_t size)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -1432,7 +1432,7 @@ extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions ...@@ -1432,7 +1432,7 @@ extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions
} }
extern "C" migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions, extern "C" migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr, const const_migraphx_instruction_t* ptr,
size_t size) size_t size)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
......
...@@ -26,6 +26,9 @@ ...@@ -26,6 +26,9 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
#include <migraphx/api/export.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
...@@ -166,430 +169,460 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void ...@@ -166,430 +169,460 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input); typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
migraphx_status migraphx_optimals_destroy(migraphx_optimals_t optimals); MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_destroy(migraphx_optimals_t optimals);
migraphx_status migraphx_optimals_assign_to(migraphx_optimals_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_assign_to(migraphx_optimals_t output,
const_migraphx_optimals_t input); const_migraphx_optimals_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_optimals_create(migraphx_optimals_t* optimals,
migraphx_optimals_create(migraphx_optimals_t* optimals, const size_t* ptr, size_t size); const size_t* ptr,
size_t size);
migraphx_status migraphx_dynamic_dimension_destroy(migraphx_dynamic_dimension_t dynamic_dimension); MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimension_destroy(migraphx_dynamic_dimension_t dynamic_dimension);
migraphx_status migraphx_dynamic_dimension_assign_to(migraphx_dynamic_dimension_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_assign_to(
const_migraphx_dynamic_dimension_t input); migraphx_dynamic_dimension_t output, const_migraphx_dynamic_dimension_t input);
migraphx_status migraphx_dynamic_dimension_create_min_max( MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_create_min_max(
migraphx_dynamic_dimension_t* dynamic_dimension, size_t min, size_t max); migraphx_dynamic_dimension_t* dynamic_dimension, size_t min, size_t max);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals(migraphx_dynamic_dimension_t* dynamic_dimension, migraphx_dynamic_dimension_create_min_max_optimals(migraphx_dynamic_dimension_t* dynamic_dimension,
size_t min, size_t min,
size_t max, size_t max,
migraphx_optimals_t optimals); migraphx_optimals_t optimals);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimension_is_fixed(
migraphx_dynamic_dimension_is_fixed(bool* out, bool* out, const_migraphx_dynamic_dimension_t dynamic_dimension);
const_migraphx_dynamic_dimension_t dynamic_dimension);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimension_equal(bool* out, migraphx_dynamic_dimension_equal(bool* out,
const_migraphx_dynamic_dimension_t dynamic_dimension, const_migraphx_dynamic_dimension_t dynamic_dimension,
const_migraphx_dynamic_dimension_t x); const_migraphx_dynamic_dimension_t x);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_destroy(migraphx_dynamic_dimensions_t dynamic_dimensions); migraphx_dynamic_dimensions_destroy(migraphx_dynamic_dimensions_t dynamic_dimensions);
migraphx_status migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_dynamic_dimensions_assign_to(
const_migraphx_dynamic_dimensions_t input); migraphx_dynamic_dimensions_t output, const_migraphx_dynamic_dimensions_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions, migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions,
const_migraphx_dynamic_dimension_t* ptr, const const_migraphx_dynamic_dimension_t* ptr,
size_t size); size_t size);
migraphx_status migraphx_dynamic_dimensions_size(size_t* out, MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_t dynamic_dimensions); migraphx_dynamic_dimensions_size(size_t* out, migraphx_dynamic_dimensions_t dynamic_dimensions);
migraphx_status migraphx_dynamic_dimensions_get(const_migraphx_dynamic_dimension_t* out, MIGRAPHX_C_EXPORT migraphx_status
migraphx_dynamic_dimensions_t dynamic_dimensions, migraphx_dynamic_dimensions_get(const_migraphx_dynamic_dimension_t* out,
size_t idx); migraphx_dynamic_dimensions_t dynamic_dimensions,
size_t idx);
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_assign_to(migraphx_shape_t output,
const_migraphx_shape_t input);
migraphx_status migraphx_shape_create(migraphx_shape_t* shape, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type, migraphx_shape_datatype_t type,
size_t* lengths, size_t* lengths,
size_t lengths_size); size_t lengths_size);
migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
migraphx_shape_datatype_t type, migraphx_shape_datatype_t type,
size_t* lengths, size_t* lengths,
size_t lengths_size, size_t lengths_size,
size_t* strides, size_t* strides,
size_t strides_size); size_t strides_size);
migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type); migraphx_shape_datatype_t type);
migraphx_status migraphx_shape_create_dynamic(migraphx_shape_t* shape, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_create_dynamic(migraphx_shape_t* shape,
migraphx_shape_datatype_t type, migraphx_shape_datatype_t type,
migraphx_dynamic_dimensions_t dims); migraphx_dynamic_dimensions_t dims);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_lengths(const size_t** out,
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape); size_t* out_size,
const_migraphx_shape_t shape);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_strides(const size_t** out,
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape); size_t* out_size,
const_migraphx_shape_t shape);
migraphx_status migraphx_shape_dyn_dims(migraphx_dynamic_dimensions_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_dyn_dims(migraphx_dynamic_dimensions_t* out,
const_migraphx_shape_t shape); const_migraphx_shape_t shape);
migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape);
migraphx_status migraphx_shape_elements(size_t* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_elements(size_t* out,
const_migraphx_shape_t shape);
migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_ndim(size_t* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_ndim(size_t* out, const_migraphx_shape_t shape);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_equal(bool* out,
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x); const_migraphx_shape_t shape,
const_migraphx_shape_t x);
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_dynamic(bool* out, const_migraphx_shape_t shape); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_dynamic(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i); MIGRAPHX_C_EXPORT migraphx_status migraphx_shape_index(size_t* out,
const_migraphx_shape_t shape,
size_t i);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input); const_migraphx_argument_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_create(migraphx_argument_t* argument,
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer); const_migraphx_shape_t shape,
void* buffer);
migraphx_status migraphx_argument_create_empty(migraphx_argument_t* argument, MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_create_empty(migraphx_argument_t* argument,
const_migraphx_shape_t shape); const_migraphx_shape_t shape);
migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument); const_migraphx_argument_t argument);
migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument); MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_buffer(char** out,
const_migraphx_argument_t argument);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_equal(bool* out,
migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x); const_migraphx_argument_t argument,
const_migraphx_argument_t x);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_argument_generate(migraphx_argument_t* out,
migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed); const_migraphx_shape_t s,
size_t seed);
migraphx_status migraphx_target_destroy(migraphx_target_t target); MIGRAPHX_C_EXPORT migraphx_status migraphx_target_destroy(migraphx_target_t target);
migraphx_status migraphx_target_assign_to(migraphx_target_t output, const_migraphx_target_t input); MIGRAPHX_C_EXPORT migraphx_status migraphx_target_assign_to(migraphx_target_t output,
const_migraphx_target_t input);
migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name); MIGRAPHX_C_EXPORT migraphx_status migraphx_target_create(migraphx_target_t* target,
const char* name);
migraphx_status migraphx_program_parameter_shapes_destroy( MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes); migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_assign_to(
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output, migraphx_program_parameter_shapes_t output, const_migraphx_program_parameter_shapes_t input);
const_migraphx_program_parameter_shapes_t input);
migraphx_status migraphx_program_parameter_shapes_size( MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_size(
size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes); size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out, migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out,
migraphx_program_parameter_shapes_t program_parameter_shapes, migraphx_program_parameter_shapes_t program_parameter_shapes,
const char* name); const char* name);
migraphx_status migraphx_program_parameter_shapes_names( MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameter_shapes_names(
const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes); const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters); migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters);
migraphx_status migraphx_program_parameters_assign_to(migraphx_program_parameters_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_parameters_assign_to(
const_migraphx_program_parameters_t input); migraphx_program_parameters_t output, const_migraphx_program_parameters_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters); migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters);
migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters, MIGRAPHX_C_EXPORT migraphx_status
const char* name, migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters,
const_migraphx_argument_t argument); const char* name,
const_migraphx_argument_t argument);
migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments); MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments);
migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input); const_migraphx_arguments_t input);
migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments); MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_size(size_t* out,
migraphx_arguments_t arguments);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_arguments_get(const_migraphx_argument_t* out,
migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx); migraphx_arguments_t arguments,
size_t idx);
migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes); MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes);
migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output, const_migraphx_shapes_t input); MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output,
const_migraphx_shapes_t input);
migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes); MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_shapes_get(const_migraphx_shape_t* out,
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx); migraphx_shapes_t shapes,
size_t idx);
migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction); MIGRAPHX_C_EXPORT migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction);
migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output, MIGRAPHX_C_EXPORT migraphx_status
const_migraphx_instruction_t input); migraphx_instruction_assign_to(migraphx_instruction_t output, const_migraphx_instruction_t input);
migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions); MIGRAPHX_C_EXPORT migraphx_status
migraphx_instructions_destroy(migraphx_instructions_t instructions);
migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_instructions_assign_to(
const_migraphx_instructions_t input); migraphx_instructions_t output, const_migraphx_instructions_t input);
migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions, MIGRAPHX_C_EXPORT migraphx_status migraphx_instructions_create(
const_migraphx_instruction_t* ptr, migraphx_instructions_t* instructions, const const_migraphx_instruction_t* ptr, size_t size);
size_t size);
migraphx_status migraphx_modules_destroy(migraphx_modules_t modules); MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_destroy(migraphx_modules_t modules);
migraphx_status migraphx_modules_assign_to(migraphx_modules_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input); const_migraphx_modules_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_modules_create(migraphx_modules_t* modules,
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size); migraphx_module_t* ptr,
size_t size);
migraphx_status migraphx_module_create(migraphx_module_t* module, char* name); MIGRAPHX_C_EXPORT migraphx_status migraphx_module_create(migraphx_module_t* module, char* name);
migraphx_status migraphx_module_print(const_migraphx_module_t module); MIGRAPHX_C_EXPORT migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
migraphx_operation_t op, migraphx_operation_t op,
migraphx_instructions_t args); migraphx_instructions_t args);
migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status
migraphx_module_t module, migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_operation_t op, migraphx_module_t module,
migraphx_instructions_t args, migraphx_operation_t op,
migraphx_modules_t module_refs); migraphx_instructions_t args,
migraphx_modules_t module_refs);
migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
const_migraphx_shape_t shape, const_migraphx_shape_t shape,
const char* buffer); const char* buffer);
migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
const char* name, const char* name,
const_migraphx_shape_t shape); const_migraphx_shape_t shape);
migraphx_status migraphx_module_add_return(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
migraphx_instructions_t args); migraphx_instructions_t args);
migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
const_migraphx_shape_t s); const_migraphx_shape_t s);
migraphx_status migraphx_program_destroy(migraphx_program_t program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input); const_migraphx_program_t input);
migraphx_status migraphx_program_create(migraphx_program_t* program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_create(migraphx_program_t* program);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program); migraphx_program_t program);
migraphx_status migraphx_program_create_module(migraphx_module_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_create_module(migraphx_module_t* out,
migraphx_program_t program, migraphx_program_t program,
const char* name); const char* name);
migraphx_status migraphx_program_compile(migraphx_program_t program, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target, migraphx_target_t target,
migraphx_compile_options_t options); migraphx_compile_options_t options);
migraphx_status migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_parameter_shapes(
migraphx_program_t program); migraphx_program_parameter_shapes_t* out, migraphx_program_t program);
migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_program_t program); migraphx_program_t program);
migraphx_status migraphx_program_print(const_migraphx_program_t program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_print(const_migraphx_program_t program);
migraphx_status migraphx_program_sort(migraphx_program_t program); MIGRAPHX_C_EXPORT migraphx_status migraphx_program_sort(migraphx_program_t program);
migraphx_status migraphx_program_run(migraphx_arguments_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params); migraphx_program_parameters_t params);
migraphx_status migraphx_program_run_async(migraphx_arguments_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_run_async(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params, migraphx_program_parameters_t params,
void* s, void* s,
const char* name); const char* name);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_program_equal(bool* out,
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x); const_migraphx_program_t program,
const_migraphx_program_t x);
migraphx_status migraphx_program_experimental_get_context(migraphx_context_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_program_experimental_get_context(
const_migraphx_program_t program); migraphx_context_t* out, const_migraphx_program_t program);
migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_assign_to(migraphx_operation_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input); const_migraphx_operation_t input);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation, MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes, const char* attributes,
...); ...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation); MIGRAPHX_C_EXPORT migraphx_status migraphx_operation_name(char* out,
size_t out_size,
migraphx_operation_t operation);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_load(migraphx_program_t* out,
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options); const char* name,
migraphx_file_options_t options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_save(migraphx_program_t p,
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options); const char* name,
migraphx_file_options_t options);
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_assign_to(
const_migraphx_onnx_options_t input); migraphx_onnx_options_t output, const_migraphx_onnx_options_t input);
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape( MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size); migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size);
migraphx_status migraphx_onnx_options_set_dyn_input_parameter_shape( MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_dyn_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, migraphx_dynamic_dimensions_t dims); migraphx_onnx_options_t onnx_options, const char* name, migraphx_dynamic_dimensions_t dims);
migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, MIGRAPHX_C_EXPORT migraphx_status
size_t value); migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_dyn_dim_value(
migraphx_onnx_options_set_default_dyn_dim_value(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_t onnx_options, const_migraphx_dynamic_dimension_t dd);
const_migraphx_dynamic_dimension_t dd);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_loop_iterations(
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_t onnx_options, int64_t value);
int64_t value);
migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_file_options_destroy(migraphx_file_options_t file_options);
migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_file_options_assign_to(
const_migraphx_file_options_t input); migraphx_file_options_t output, const_migraphx_file_options_t input);
migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_file_options_create(migraphx_file_options_t* file_options);
migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options, MIGRAPHX_C_EXPORT migraphx_status
const char* format); migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format);
migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options);
migraphx_status migraphx_compile_options_assign_to(migraphx_compile_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_compile_options_assign_to(
const_migraphx_compile_options_t input); migraphx_compile_options_t output, const_migraphx_compile_options_t input);
migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options); MIGRAPHX_C_EXPORT migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value); migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value);
migraphx_status migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, MIGRAPHX_C_EXPORT migraphx_status
bool value); migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_compile_options_set_exhaustive_tune_flag(
migraphx_compile_options_set_exhaustive_tune_flag(migraphx_compile_options_t compile_options, migraphx_compile_options_t compile_options, bool value);
bool value);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_onnx(migraphx_program_t* out,
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options); const char* name,
migraphx_onnx_options_t options);
migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
const void* data, const void* data,
size_t size, size_t size,
migraphx_onnx_options_t options); migraphx_onnx_options_t options);
migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options); MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options);
migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input); const_migraphx_tf_options_t input);
migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options); MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options);
migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc); MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options,
bool is_nhwc);
migraphx_status migraphx_tf_options_set_input_parameter_shape(migraphx_tf_options_t tf_options, MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_input_parameter_shape(
const char* name, migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size);
size_t* dims,
size_t dims_size);
migraphx_status migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, MIGRAPHX_C_EXPORT migraphx_status
size_t value); migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value);
migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options, MIGRAPHX_C_EXPORT migraphx_status migraphx_tf_options_set_output_names(
const char** names, migraphx_tf_options_t tf_options, const char** names, size_t names_size);
size_t names_size);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_parse_tf(migraphx_program_t* out,
migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options); const char* name,
migraphx_tf_options_t options);
migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names); MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
migraphx_status migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output, MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_op_names_assign_to(
const_migraphx_quantize_op_names_t input); migraphx_quantize_op_names_t output, const_migraphx_quantize_op_names_t input);
migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names); MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, MIGRAPHX_C_EXPORT migraphx_status
const char* name); migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name);
migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_op_names_t name); migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, migraphx_quantize_op_names_t name);
migraphx_status migraphx_quantize_fp16(migraphx_program_t prog); MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options); migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_assign_to(
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output, migraphx_quantize_int8_options_t output, const_migraphx_quantize_int8_options_t input);
const_migraphx_quantize_int8_options_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options); migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_add_op_name(
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options, migraphx_quantize_int8_options_t quantize_int8_options, const char* name);
const char* name);
migraphx_status migraphx_quantize_int8_options_add_calibration_data( MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data); migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data);
migraphx_status migraphx_quantize_int8(migraphx_program_t prog, MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target, migraphx_target_t target,
migraphx_quantize_int8_options_t options); migraphx_quantize_int8_options_t options);
migraphx_status migraphx_context_finish(const_migraphx_context_t context); MIGRAPHX_C_EXPORT migraphx_status migraphx_context_finish(const_migraphx_context_t context);
migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context); MIGRAPHX_C_EXPORT migraphx_status migraphx_context_get_queue(void** out,
migraphx_context_t context);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op); migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_assign_to(
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output, migraphx_experimental_custom_op_t output, const_migraphx_experimental_custom_op_t input);
const_migraphx_experimental_custom_op_t input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op, migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj, void* obj,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
...@@ -597,21 +630,20 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -597,21 +630,20 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
const char* obj_typename, const char* obj_typename,
const char* name); const char* name);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_compute(
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute input);
migraphx_experimental_custom_op_compute input);
migraphx_status migraphx_experimental_custom_op_set_compute_shape( MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
migraphx_status migraphx_experimental_custom_op_set_output_alias( MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_output_alias(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_output_alias input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_output_alias input);
migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target( MIGRAPHX_C_EXPORT migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_runs_on_offload_target input); migraphx_experimental_custom_op_runs_on_offload_target input);
migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op); migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -79,7 +79,8 @@ def dynamic_dimension(h): ...@@ -79,7 +79,8 @@ def dynamic_dimension(h):
def dynamic_dimensions(h): def dynamic_dimensions(h):
h.constructor( h.constructor(
'create', 'create',
api.params(ptr='const_migraphx_dynamic_dimension_t*', size='size_t'), api.params(ptr='const const_migraphx_dynamic_dimension_t*',
size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>') fname='migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>')
h.method('size', returns='size_t') h.method('size', returns='size_t')
h.method('get', h.method('get',
...@@ -215,7 +216,7 @@ def instruction(h): ...@@ -215,7 +216,7 @@ def instruction(h):
def instructions(h): def instructions(h):
h.constructor( h.constructor(
'create', 'create',
api.params(ptr='const_migraphx_instruction_t*', size='size_t'), api.params(ptr='const const_migraphx_instruction_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_instruction_t>') fname='migraphx::to_obj_vector<const_migraphx_instruction_t>')
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x > dim;
});
if(x < dim)
return start;
return it;
}
template <class Range>
static auto elements(const Range& r)
{
return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
}
struct common_dim_state
{
common_dim_state(const std::vector<std::size_t>& pdims,
std::vector<std::vector<std::size_t>>& paxes_map)
: dims(&pdims), axes_map(&paxes_map), it(dims->begin())
{
}
const std::vector<std::size_t>* dims = nullptr;
std::vector<std::vector<std::size_t>>* axes_map = nullptr;
std::vector<std::size_t>::const_iterator it{};
std::size_t rem = 1;
std::size_t get() const { return *it / rem; }
bool is_end() const { return it == dims->end(); }
void next(std::size_t i = 1) { it += i; }
auto dims_for(std::size_t d) const
{
auto dim_end = compute_end_dim(it, dims->end(), d);
return range(it, dim_end);
}
void add_axes(std::size_t naxes, std::size_t start) MIGRAPHX_TIDY_CONST
{
auto axes = compute_axes(naxes, start);
axes_map->push_back(std::move(axes));
}
void add_multi_axes(std::size_t naxes, std::size_t start) MIGRAPHX_TIDY_CONST
{
auto axes = compute_axes(naxes, start);
std::transform(axes.begin(),
axes.end(),
std::back_inserter(*axes_map),
[&](auto axis) -> std::vector<std::size_t> { return {axis}; });
}
std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const
{
if(rem != 1)
{
assert(start > 0);
naxes++;
start--;
}
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), start);
return axes;
}
};
static bool compute_common_dim(std::vector<std::size_t>& cd_dims,
common_dim_state& state1,
common_dim_state& state2)
{
assert(state1.get() <= state2.get());
auto d2 = state2.get();
auto dims = state1.dims_for(d2);
auto n = elements(dims);
auto naxes = distance(dims);
if(naxes == 0)
return false;
// If not divisible then we can't compute a common dim
if((d2 % n) != 0)
return false;
auto rem = d2 / n;
state1.add_multi_axes(naxes, cd_dims.size());
state2.add_axes(rem == 1 ? naxes : naxes + 1, cd_dims.size());
state1.rem = rem;
state2.rem = 1;
cd_dims.insert(cd_dims.end(), dims.begin(), dims.end());
if(state1.rem != 1)
cd_dims.push_back(state1.rem);
state1.next(distance(dims));
state2.next();
return true;
}
common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2)
{
assert(elements(dims1) > 0);
assert(elements(dims1) == elements(dims2));
common_dims cd;
common_dim_state state1{dims1, cd.axes_map1};
common_dim_state state2{dims2, cd.axes_map2};
while(not state1.is_end() and not state2.is_end())
{
auto d1 = state1.get();
auto d2 = state2.get();
if(d1 <= d2)
{
if(not compute_common_dim(cd.dims, state1, state2))
return {};
}
else // if(d1 > d2)
{
if(not compute_common_dim(cd.dims, state2, state1))
return {};
}
}
assert(elements(dims1) == elements(cd.dims));
return cd;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -32,18 +32,23 @@ add_executable(driver ...@@ -32,18 +32,23 @@ add_executable(driver
marker_roctx.cpp marker_roctx.cpp
) )
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver) set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility if(NOT WIN32)
add_custom_command( # Copy driver for backwards compatibility (Linux only)
TARGET driver add_custom_command(
TARGET driver
POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:driver> $<TARGET_FILE:driver>
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
) )
set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver) set_directory_properties(PROPERTIES ADDITIONAL_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
endif()
rocm_clang_tidy_check(driver) rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf) file(STRINGS "${CMAKE_SOURCE_DIR}/test/onnx/.onnxrt-commit" String_output)
target_compile_definitions(driver PUBLIC MIGRAPHX_ORT_SHA1="${String_output}")
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py)
rocm_install_targets( rocm_install_targets(
TARGETS driver TARGETS driver
......
...@@ -338,11 +338,22 @@ struct argument_parser ...@@ -338,11 +338,22 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC auto file_exist() MIGRAPHX_DRIVER_STATIC auto file_exist()
{ {
return validate([](auto&, auto&, auto& params) { return validate([](auto&, auto&, const auto& params) {
if(params.empty()) if(params.empty())
throw std::runtime_error("No argument passed."); throw std::runtime_error("No argument passed.");
if(not fs::exists(params.back())) if(not fs::exists(params.back()))
throw std::runtime_error("Path does not exists: " + params.back()); throw std::runtime_error("Path does not exist: " + params.back());
});
}
MIGRAPHX_DRIVER_STATIC auto matches(const std::unordered_set<std::string>& names)
{
return validate([=](auto&, auto&, const auto& params) {
auto invalid_param = std::find_if(
params.begin(), params.end(), [&](const auto& p) { return names.count(p) == 0; });
if(invalid_param != params.end())
throw std::runtime_error("Invalid argument: " + *invalid_param +
". Valid arguments are {" + to_string_range(names) + "}");
}); });
} }
...@@ -570,8 +581,7 @@ struct argument_parser ...@@ -570,8 +581,7 @@ struct argument_parser
continue; continue;
if(flag[0] != '-') if(flag[0] != '-')
continue; continue;
auto d = std::ptrdiff_t d = levenshtein_distance(flag, input);
levenshtein_distance(flag.begin(), flag.end(), input.begin(), input.end());
if(d < result.distance) if(d < result.distance)
result = result_t{&arg, flag, input, d}; result = result_t{&arg, flag, input, d};
} }
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/py.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
...@@ -81,6 +82,7 @@ struct loader ...@@ -81,6 +82,7 @@ struct loader
{"--model"}, {"--model"},
ap.help("Load model"), ap.help("Load model"),
ap.type("resnet50|inceptionv3|alexnet"), ap.type("resnet50|inceptionv3|alexnet"),
ap.matches({"resnet50", "inceptionv3", "alexnet"}),
ap.group("input")); ap.group("input"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx")); ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
...@@ -241,6 +243,20 @@ struct loader ...@@ -241,6 +243,20 @@ struct loader
return options; return options;
} }
static std::string get_file_type(const std::string& file)
{
if(ends_with(file, ".onnx"))
return "onnx";
else if(ends_with(file, ".pb"))
return "tf";
else if(ends_with(file, ".json"))
return "json";
else if(ends_with(file, ".py"))
return "py";
else
return "migraphx";
}
program load() program load()
{ {
program p; program p;
...@@ -248,14 +264,7 @@ struct loader ...@@ -248,14 +264,7 @@ struct loader
{ {
if(file_type.empty()) if(file_type.empty())
{ {
if(ends_with(file, ".onnx")) file_type = get_file_type(file);
file_type = "onnx";
else if(ends_with(file, ".pb"))
file_type = "tf";
else if(ends_with(file, ".json"))
file_type = "json";
else
file_type = "migraphx";
} }
std::cout << "Reading: " << file << std::endl; std::cout << "Reading: " << file << std::endl;
if(file_type == "onnx") if(file_type == "onnx")
...@@ -272,6 +281,10 @@ struct loader ...@@ -272,6 +281,10 @@ struct loader
options.format = "json"; options.format = "json";
p = migraphx::load(file, options); p = migraphx::load(file, options);
} }
else if(file_type == "py")
{
p = migraphx::load_py(file);
}
else if(file_type == "migraphx") else if(file_type == "migraphx")
{ {
p = migraphx::load(file); p = migraphx::load(file);
...@@ -462,13 +475,15 @@ struct compiler ...@@ -462,13 +475,15 @@ struct compiler
{ {
if(is_offload_copy_set(p) and not co.offload_copy) if(is_offload_copy_set(p) and not co.offload_copy)
{ {
std::cout << "MIGraphX program was likely compiled with offload_copy set, Try " std::cout
"passing " << "[WARNING]: MIGraphX program was likely compiled with offload_copy "
"`--enable-offload-copy` if program run fails.\n"; "set, Try "
"passing "
"`--enable-offload-copy` if program run fails.\n";
} }
else if(co.offload_copy) else if(co.offload_copy)
{ {
std::cout << "MIGraphX program was likely compiled without " std::cout << "[WARNING]: MIGraphX program was likely compiled without "
"offload_copy set, Try " "offload_copy set, Try "
"removing " "removing "
"`--enable-offload-copy` flag if passed to driver, if program run " "`--enable-offload-copy` flag if passed to driver, if program run "
...@@ -757,7 +772,7 @@ struct main_command ...@@ -757,7 +772,7 @@ struct main_command
{ {
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
<< "' is not a valid command." << std::endl; << "' is not a valid command." << std::endl;
std::cout << get_command_help("Available commands:") << std::endl; std::cout << get_command_help("Available commands:");
} }
else else
{ {
...@@ -789,6 +804,13 @@ int main(int argc, const char* argv[]) ...@@ -789,6 +804,13 @@ int main(int argc, const char* argv[])
auto&& m = get_commands(); auto&& m = get_commands();
auto cmd = args.front(); auto cmd = args.front();
if(cmd == "ort-sha")
{
std::cout << MIGRAPHX_ORT_SHA1 << std::endl;
return 0;
}
if(m.count(cmd) > 0) if(m.count(cmd) > 0)
{ {
m.at(cmd)(argv[0], {args.begin() + 1, args.end()}); m.at(cmd)(argv[0], {args.begin() + 1, args.end()});
......
...@@ -48,7 +48,7 @@ struct dynamic_loader_impl ...@@ -48,7 +48,7 @@ struct dynamic_loader_impl
#pragma GCC diagnostic ignored "-Wignored-attributes" #pragma GCC diagnostic ignored "-Wignored-attributes"
#endif #endif
dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr) dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr)
: handle(dlopen(p.string().c_str(), RTLD_LAZY), : handle(dlopen(p.string().c_str(), RTLD_GLOBAL | RTLD_NOW),
manage_deleter<decltype(&dlclose), &dlclose>{}), manage_deleter<decltype(&dlclose), &dlclose>{}),
temp(std::move(t)) temp(std::move(t))
{ {
...@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address) ...@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
return p; return p;
} }
optional<dynamic_loader> dynamic_loader::try_load(const fs::path& p)
{
try
{
return dynamic_loader{p};
}
catch(const std::exception&)
{
return nullopt;
}
}
dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p)) dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p))
{ {
} }
......
...@@ -35,6 +35,8 @@ ...@@ -35,6 +35,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS)
static bool try_compute_shape(instruction_ref ins, static bool try_compute_shape(instruction_ref ins,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
const std::vector<module_ref>& mods) const std::vector<module_ref>& mods)
...@@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins,
return (arg == ins) ? new_shape : arg->get_shape(); return (arg == ins) ? new_shape : arg->get_shape();
}); });
if(not try_compute_shape(output, input_shapes, mods)) if(not try_compute_shape(output, input_shapes, output->module_inputs()))
{ {
return false; return false;
} }
} }
} }
catch(const std::exception& e)
{
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
{
std::cout << "Exception: " << e.what() << std::endl;
}
return false;
}
catch(...) catch(...)
{ {
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
{
std::cout << "Unknown exception" << std::endl;
}
return false; return false;
} }
...@@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f) ...@@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
{ {
if(arg->name() != op_name) if(arg->name() != op_name)
continue; continue;
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
{
std::cout << "eliminate_contiguous: ";
m.debug_print(ins);
}
auto prev = arg->inputs().front(); auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, mod_args)) if(try_compute_shape(ins, new_args, mod_args))
......
...@@ -24,11 +24,14 @@ ...@@ -24,11 +24,14 @@
#include <migraphx/fuse_pointwise.hpp> #include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common_dims.hpp>
#include <iterator> #include <iterator>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
...@@ -41,7 +44,7 @@ static literal get_scalar(instruction_ref ins) ...@@ -41,7 +44,7 @@ static literal get_scalar(instruction_ref ins)
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front()); return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape(); const auto& s = ins->get_shape();
if(s.elements() != 1 && not(s.scalar())) if(s.elements() != 1 and not(s.scalar()))
return {}; return {};
if(not ins->can_eval()) if(not ins->can_eval())
return {}; return {};
...@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m) ...@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
} }
return changed; return changed;
} }
namespace {
struct find_pointwise_reshape_pointwise
{
auto matcher() const
{
auto reshape =
match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once());
auto skip_contiguous = [](auto... ms) {
return match::arg(0)(match::skip(match::name("contiguous")(match::used_once()))(ms...));
};
auto pointwise = match::name("pointwise")(match::used_once());
auto reshape_pointwise = reshape(skip_contiguous(pointwise.bind("x"))).bind("reshape");
return match::name("pointwise")(match::any_of[match::inputs()](reshape_pointwise));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"];
auto cd = common_dims::compute(ins->get_shape().lens(), x_ins->get_shape().lens());
if(cd.dims.empty())
return;
auto reshape_input = [&](const auto& ins_to_insert) {
return [&](auto input) {
auto c = m.insert_instruction(ins_to_insert, make_op("contiguous"), input);
return m.insert_instruction(
ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), c);
};
};
auto x_inputs = x_ins->inputs();
std::transform(x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins));
auto new_x_ins =
m.insert_instruction(x_ins, x_ins->get_operator(), x_inputs, x_ins->module_inputs());
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input == reshape_ins)
return new_x_ins;
return reshape_input(ins)(input);
});
auto pw = m.insert_instruction(ins, ins->get_operator(), inputs, ins->module_inputs());
m.replace_instruction(ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), pw);
}
};
} // namespace
void fuse_pointwise::apply(module_pass_manager& mpm) const void fuse_pointwise::apply(module_pass_manager& mpm) const
{ {
...@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const ...@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
} }
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
match::find_matches(mpm.get_module(), find_pointwise_reshape_pointwise{});
mpm.run_pass(simplify_reshapes{1});
if(not find_pointwise_modules(mpm.get_module())) if(not find_pointwise_modules(mpm.get_module()))
break; break;
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
......
...@@ -52,7 +52,7 @@ struct fused_reduce ...@@ -52,7 +52,7 @@ struct fused_reduce
{ {
if(mods.size() != 1) if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
auto* sm = mods.front(); const auto* sm = mods.front();
if(sm->get_output_shapes().size() != 1) if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported"); MIGRAPHX_THROW("Only one output supported");
auto names = sm->get_parameter_names(); auto names = sm->get_parameter_names();
...@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm, ...@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
} }
static std::vector<instruction_ref> static std::vector<instruction_ref>
find_inputs(module_ref sm, find_inputs(const_module_ref sm,
const module& parent, const module& parent,
const std::unordered_map<instruction_ref, instruction_ref>& map_ins) const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{ {
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <string>
#include <vector>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat ...@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
return std::ptrdiff_t{1} + std::min({x1, x2, x3}); return std::ptrdiff_t{1} + std::min({x1, x2, x3});
} }
inline size_t levenshtein_distance(const std::string& s1, const std::string& s2)
{
const size_t l1 = s1.length();
const size_t l2 = s2.length();
if(l1 < l2)
levenshtein_distance(s2, s1);
std::vector<size_t> d(l2 + 1);
std::iota(d.begin(), d.end(), 0);
for(size_t i = 1; i <= l1; i++)
{
size_t prev_cost = d[0];
d[0] = i;
for(size_t j = 1; j <= l2; j++)
{
if(s1[i - 1] == s2[j - 1])
{
d[j] = prev_cost;
}
else
{
size_t cost_insert_or_delete = std::min(d[j - 1], d[j]);
size_t cost_substitute = prev_cost;
prev_cost = d[j];
d[j] = std::min(cost_substitute, cost_insert_or_delete) + 1;
}
}
}
return d[l2];
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -96,7 +96,7 @@ struct allocation_model ...@@ -96,7 +96,7 @@ struct allocation_model
{ {
using std::swap; using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>(); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique()) if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{ {
*derived = std::forward<PrivateDetailTypeErasedT>(value); *derived = std::forward<PrivateDetailTypeErasedT>(value);
} }
...@@ -267,7 +267,7 @@ struct allocation_model ...@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique()) if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -90,7 +90,17 @@ struct param ...@@ -90,7 +90,17 @@ struct param
struct returns struct returns
{ {
std::string name() const { return "@return"; } std::string name() const { return "@return"; }
shape compute_shape(const std::vector<shape>&) const { return {}; }
shape compute_shape(const std::vector<shape>& arg) const
{
if(arg.empty())
return {};
else if(arg.size() == 1)
return arg[0];
else
return arg;
}
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPHX_THROW("builtin"); MIGRAPHX_THROW("builtin");
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -34,29 +34,51 @@ ...@@ -34,29 +34,51 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Check that deduced type is incrementable, dereferencable, and comparable
template <class, class = void>
struct is_iterator
{
};
template <class T>
struct is_iterator<T,
std::void_t<decltype(++std::declval<T&>()),
decltype(*std::declval<T&>()),
decltype(std::declval<T&>() == std::declval<T&>())>> : std::true_type
{
};
template <class Iterator>
struct check_shapes struct check_shapes
{ {
const shape* begin; static_assert(is_iterator<Iterator>{}, "CHECK_SHAPES: Deduced type must be an iterator");
const shape* end; Iterator begin;
Iterator end;
std::string name; std::string name;
bool dynamic_allowed; bool dynamic_allowed;
check_shapes(const shape* b, const shape* e, const std::string& n, const bool d = false) check_shapes(Iterator b, Iterator e, const std::string& n, const bool d = false)
: begin(b), end(e), name(n), dynamic_allowed(d) : begin(b), end(e), name(n), dynamic_allowed(d)
{ {
check_dynamic(); check_dynamic();
} }
template <class Op> template <class Op>
check_shapes(const shape* b, const shape* e, const Op& op, const bool d = false) check_shapes(Iterator b, Iterator e, const Op& op, const bool d = false)
: begin(b), end(e), name(op.name()), dynamic_allowed(d) : begin(b), end(e), name(op.name()), dynamic_allowed(d)
{ {
check_dynamic(); check_dynamic();
} }
template <class Op> template <class Op, MIGRAPHX_REQUIRES(not std::is_convertible<Op, std::string>{})>
check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false) check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
: begin(s.data()), end(s.data() + s.size()), name(op.name()), dynamic_allowed(d) : begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d)
{
check_dynamic();
}
check_shapes(const std::vector<shape>& s, const std::string& n, const bool d = false)
: begin(s.begin()), end(s.end()), name(n), dynamic_allowed(d)
{ {
check_dynamic(); check_dynamic();
} }
...@@ -81,8 +103,6 @@ struct check_shapes ...@@ -81,8 +103,6 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return 0; return 0;
assert(begin != nullptr);
assert(end != nullptr);
return end - begin; return end - begin;
} }
...@@ -131,11 +151,9 @@ struct check_shapes ...@@ -131,11 +151,9 @@ struct check_shapes
*/ */
const check_shapes& only_dims(std::size_t n) const const check_shapes& only_dims(std::size_t n) const
{ {
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() != n) if(begin->ndim() != n)
MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported"); MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
...@@ -148,11 +166,9 @@ struct check_shapes ...@@ -148,11 +166,9 @@ struct check_shapes
*/ */
const check_shapes& max_ndims(std::size_t n) const const check_shapes& max_ndims(std::size_t n) const
{ {
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() > n) if(begin->ndim() > n)
MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) +
" dimensions"); " dimensions");
} }
...@@ -166,11 +182,9 @@ struct check_shapes ...@@ -166,11 +182,9 @@ struct check_shapes
*/ */
const check_shapes& min_ndims(std::size_t n) const const check_shapes& min_ndims(std::size_t n) const
{ {
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() < n) if(begin->ndim() < n)
MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
" dimensions"); " dimensions");
} }
...@@ -220,6 +234,16 @@ struct check_shapes ...@@ -220,6 +234,16 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same layout.
*/
const check_shapes& same_layout() const
{
if(not this->same([](const shape& s) { return find_permutation(s); }))
MIGRAPHX_THROW(prefix() + "Layouts do not match");
return *this;
}
/*! /*!
* Check all shapes are standard. * Check all shapes are standard.
*/ */
...@@ -230,6 +254,16 @@ struct check_shapes ...@@ -230,6 +254,16 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are scalar.
*/
const check_shapes& scalar() const
{
if(not this->all_of([](const shape& s) { return s.scalar(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar");
return *this;
}
/*! /*!
* Check all shapes are standard or scalar. * Check all shapes are standard or scalar.
*/ */
...@@ -330,8 +364,6 @@ struct check_shapes ...@@ -330,8 +364,6 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return true; return true;
assert(begin != nullptr);
assert(end != nullptr);
auto&& key = f(*begin); auto&& key = f(*begin);
return this->all_of([&](const shape& s) { return f(s) == key; }); return this->all_of([&](const shape& s) { return f(s) == key; });
} }
...@@ -341,8 +373,6 @@ struct check_shapes ...@@ -341,8 +373,6 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return true; return true;
assert(begin != nullptr);
assert(end != nullptr);
return std::all_of(begin, end, p); return std::all_of(begin, end, p);
} }
...@@ -351,17 +381,13 @@ struct check_shapes ...@@ -351,17 +381,13 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return false; return false;
assert(begin != nullptr);
assert(end != nullptr);
return std::any_of(begin, end, p); return std::any_of(begin, end, p);
} }
const shape* get(long i) const Iterator get(long i) const
{ {
if(i >= size()) if(i >= size())
MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds"); MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds");
assert(begin != nullptr);
assert(end != nullptr);
if(i < 0) if(i < 0)
return end - i; return end - i;
return begin + i; return begin + i;
...@@ -394,6 +420,11 @@ struct check_shapes ...@@ -394,6 +420,11 @@ struct check_shapes
} }
}; };
// Deduction guide for std::vector constructor
template <class Op>
check_shapes(const std::vector<shape>&, const Op&, bool d = false)
-> check_shapes<std::vector<shape>::const_iterator>;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/// This will compute a higher dimensional space that will preserve the axes
/// for both sets of dimensions. Two axes_maps are provided for each of the
/// dims that will map the axis to the axes that are used by the result of
/// common_dims.
struct MIGRAPHX_EXPORT common_dims
{
static common_dims compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2);
std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
...@@ -88,7 +88,7 @@ struct concat_optimization ...@@ -88,7 +88,7 @@ struct concat_optimization
{ {
using std::swap; using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>(); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique()) if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{ {
*derived = std::forward<PrivateDetailTypeErasedT>(value); *derived = std::forward<PrivateDetailTypeErasedT>(value);
} }
...@@ -233,7 +233,7 @@ struct concat_optimization ...@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique()) if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -118,7 +118,7 @@ struct context ...@@ -118,7 +118,7 @@ struct context
{ {
using std::swap; using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>(); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique()) if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{ {
*derived = std::forward<PrivateDetailTypeErasedT>(value); *derived = std::forward<PrivateDetailTypeErasedT>(value);
} }
...@@ -373,7 +373,7 @@ struct context ...@@ -373,7 +373,7 @@ struct context
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique()) if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -62,7 +62,7 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri ...@@ -62,7 +62,7 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri
shape win_shape{output_shape.type(), win_size}; shape win_shape{output_shape.type(), win_size};
double acc = 0.0; double acc = 0.0;
shape_for_each(win_shape, [&](auto idx_win) { shape_for_each(win_shape, [&](const auto& idx_win) {
auto k = idx_win[0]; auto k = idx_win[0];
const auto in_ch = group_id * wei_c + k; const auto in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end()); std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment