Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1530ec24
Unverified
Commit
1530ec24
authored
Sep 18, 2023
by
Ted Themistokleous
Committed by
GitHub
Sep 18, 2023
Browse files
Merge branch 'develop' into add_parity_check_ci
parents
5c98fcb0
c2e01b10
Changes
305
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
774 additions
and
335 deletions
+774
-335
src/api/CMakeLists.txt
src/api/CMakeLists.txt
+1
-0
src/api/api.cpp
src/api/api.cpp
+3
-3
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+298
-266
src/api/migraphx.py
src/api/migraphx.py
+3
-2
src/common_dims.cpp
src/common_dims.cpp
+156
-0
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+11
-6
src/driver/argument_parser.hpp
src/driver/argument_parser.hpp
+14
-4
src/driver/main.cpp
src/driver/main.cpp
+35
-13
src/dynamic_loader.cpp
src/dynamic_loader.cpp
+13
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+20
-1
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+54
-1
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+2
-2
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+38
-0
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+2
-2
src/include/migraphx/builtin.hpp
src/include/migraphx/builtin.hpp
+11
-1
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+58
-27
src/include/migraphx/common_dims.hpp
src/include/migraphx/common_dims.hpp
+49
-0
src/include/migraphx/concat_opt.hpp
src/include/migraphx/concat_opt.hpp
+2
-2
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+2
-2
src/include/migraphx/convolution.hpp
src/include/migraphx/convolution.hpp
+2
-2
No files found.
src/api/CMakeLists.txt
View file @
1530ec24
...
...
@@ -26,6 +26,7 @@ add_library(migraphx_c
api.cpp
)
set_target_properties
(
migraphx_c PROPERTIES EXPORT_NAME c
)
migraphx_generate_export_header
(
migraphx_c DIRECTORY migraphx/api
)
# migraphx_c is stable API interface library. SO version of this should be
# bumped when binary compatibility is broken.
...
...
src/api/api.cpp
View file @
1530ec24
...
...
@@ -44,7 +44,7 @@ namespace migraphx {
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
disable_exception_catch
=
b
;
}
...
...
@@ -899,7 +899,7 @@ migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output,
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_create
(
migraphx_dynamic_dimensions_t
*
dynamic_dimensions
,
const_migraphx_dynamic_dimension_t
*
ptr
,
const
const_migraphx_dynamic_dimension_t
*
ptr
,
size_t
size
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
...
...
@@ -1432,7 +1432,7 @@ extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions
}
extern
"C"
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
const_migraphx_instruction_t
*
ptr
,
const
const_migraphx_instruction_t
*
ptr
,
size_t
size
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
...
...
src/api/include/migraphx/migraphx.h
View file @
1530ec24
...
...
@@ -26,6 +26,9 @@
#include <stdlib.h>
#include <stdbool.h>
#include <migraphx/api/export.h>
// Add new types here
// clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
...
...
@@ -166,430 +169,460 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_delete
)(
void
*
input
);
migraphx_status
migraphx_optimals_destroy
(
migraphx_optimals_t
optimals
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_optimals_destroy
(
migraphx_optimals_t
optimals
);
migraphx_status
migraphx_optimals_assign_to
(
migraphx_optimals_t
output
,
const_migraphx_optimals_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_optimals_assign_to
(
migraphx_optimals_t
output
,
const_migraphx_optimals_t
input
);
migraphx_status
migraphx_optimals_create
(
migraphx_optimals_t
*
optimals
,
const
size_t
*
ptr
,
size_t
size
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_optimals_create
(
migraphx_optimals_t
*
optimals
,
const
size_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_dynamic_dimension_destroy
(
migraphx_dynamic_dimension_t
dynamic_dimension
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_destroy
(
migraphx_dynamic_dimension_t
dynamic_dimension
);
migraphx_status
migraphx_dynamic_dimension_assign_to
(
migraphx_dynamic_dimension_t
output
,
const_migraphx_dynamic_dimension_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_assign_to
(
migraphx_dynamic_dimension_t
output
,
const_migraphx_dynamic_dimension_t
input
);
migraphx_status
migraphx_dynamic_dimension_create_min_max
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_create_min_max
(
migraphx_dynamic_dimension_t
*
dynamic_dimension
,
size_t
min
,
size_t
max
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals
(
migraphx_dynamic_dimension_t
*
dynamic_dimension
,
size_t
min
,
size_t
max
,
migraphx_optimals_t
optimals
);
migraphx_status
migraphx_dynamic_dimension_is_fixed
(
bool
*
out
,
const_migraphx_dynamic_dimension_t
dynamic_dimension
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_is_fixed
(
bool
*
out
,
const_migraphx_dynamic_dimension_t
dynamic_dimension
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimension_equal
(
bool
*
out
,
const_migraphx_dynamic_dimension_t
dynamic_dimension
,
const_migraphx_dynamic_dimension_t
x
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimensions_destroy
(
migraphx_dynamic_dimensions_t
dynamic_dimensions
);
migraphx_status
migraphx_dynamic_dimensions_assign_to
(
migraphx_dynamic_dimensions_t
output
,
const_migraphx_dynamic_dimensions_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimensions_assign_to
(
migraphx_dynamic_dimensions_t
output
,
const_migraphx_dynamic_dimensions_t
input
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimensions_create
(
migraphx_dynamic_dimensions_t
*
dynamic_dimensions
,
const_migraphx_dynamic_dimension_t
*
ptr
,
const
const_migraphx_dynamic_dimension_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_dynamic_dimensions_size
(
size_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimensions_size
(
size_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
);
migraphx_status
migraphx_dynamic_dimensions_get
(
const_migraphx_dynamic_dimension_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
,
size_t
idx
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_dynamic_dimensions_get
(
const_migraphx_dynamic_dimension_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
,
size_t
idx
);
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_assign_to
(
migraphx_shape_t
output
,
const_migraphx_shape_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_assign_to
(
migraphx_shape_t
output
,
const_migraphx_shape_t
input
);
migraphx_status
migraphx_shape_create
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
size_t
*
lengths
,
size_t
lengths_size
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_create
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
size_t
*
lengths
,
size_t
lengths_size
);
migraphx_status
migraphx_shape_create_with_strides
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
size_t
*
lengths
,
size_t
lengths_size
,
size_t
*
strides
,
size_t
strides_size
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_create_with_strides
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
size_t
*
lengths
,
size_t
lengths_size
,
size_t
*
strides
,
size_t
strides_size
);
migraphx_status
migraphx_shape_create_scalar
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_create_scalar
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
);
migraphx_status
migraphx_shape_create_dynamic
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
migraphx_dynamic_dimensions_t
dims
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_create_dynamic
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
migraphx_dynamic_dimensions_t
dims
);
migraphx_status
migraphx_shape_lengths
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_lengths
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_strides
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_strides
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_dyn_dims
(
migraphx_dynamic_dimensions_t
*
out
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_dyn_dims
(
migraphx_dynamic_dimensions_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_type
(
migraphx_shape_datatype_t
*
out
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_type
(
migraphx_shape_datatype_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_elements
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_elements
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_bytes
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_bytes
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_ndim
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_ndim
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_equal
(
bool
*
out
,
const_migraphx_shape_t
shape
,
const_migraphx_shape_t
x
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_equal
(
bool
*
out
,
const_migraphx_shape_t
shape
,
const_migraphx_shape_t
x
);
migraphx_status
migraphx_shape_standard
(
bool
*
out
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_standard
(
bool
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_dynamic
(
bool
*
out
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_dynamic
(
bool
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
const_migraphx_shape_t
shape
,
size_t
i
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
const_migraphx_shape_t
shape
,
size_t
i
);
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
migraphx_status
migraphx_argument_assign_to
(
migraphx_argument_t
output
,
const_migraphx_argument_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_assign_to
(
migraphx_argument_t
output
,
const_migraphx_argument_t
input
);
migraphx_status
migraphx_argument_create
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
,
void
*
buffer
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_create
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
,
void
*
buffer
);
migraphx_status
migraphx_argument_create_empty
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_create_empty
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_argument_shape
(
const_migraphx_shape_t
*
out
,
const_migraphx_argument_t
argument
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_shape
(
const_migraphx_shape_t
*
out
,
const_migraphx_argument_t
argument
);
migraphx_status
migraphx_argument_buffer
(
char
**
out
,
const_migraphx_argument_t
argument
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_buffer
(
char
**
out
,
const_migraphx_argument_t
argument
);
migraphx_status
migraphx_argument_equal
(
bool
*
out
,
const_migraphx_argument_t
argument
,
const_migraphx_argument_t
x
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_equal
(
bool
*
out
,
const_migraphx_argument_t
argument
,
const_migraphx_argument_t
x
);
migraphx_status
migraphx_argument_generate
(
migraphx_argument_t
*
out
,
const_migraphx_shape_t
s
,
size_t
seed
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_argument_generate
(
migraphx_argument_t
*
out
,
const_migraphx_shape_t
s
,
size_t
seed
);
migraphx_status
migraphx_target_destroy
(
migraphx_target_t
target
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_target_destroy
(
migraphx_target_t
target
);
migraphx_status
migraphx_target_assign_to
(
migraphx_target_t
output
,
const_migraphx_target_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_target_assign_to
(
migraphx_target_t
output
,
const_migraphx_target_t
input
);
migraphx_status
migraphx_target_create
(
migraphx_target_t
*
target
,
const
char
*
name
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_target_create
(
migraphx_target_t
*
target
,
const
char
*
name
);
migraphx_status
migraphx_program_parameter_shapes_destroy
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameter_shapes_destroy
(
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
migraphx_status
migraphx_program_parameter_shapes_assign_to
(
migraphx_program_parameter_shapes_t
output
,
const_migraphx_program_parameter_shapes_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameter_shapes_assign_to
(
migraphx_program_parameter_shapes_t
output
,
const_migraphx_program_parameter_shapes_t
input
);
migraphx_status
migraphx_program_parameter_shapes_size
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameter_shapes_size
(
size_t
*
out
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameter_shapes_get
(
const_migraphx_shape_t
*
out
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
,
const
char
*
name
);
migraphx_status
migraphx_program_parameter_shapes_names
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameter_shapes_names
(
const
char
**
out
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameters_destroy
(
migraphx_program_parameters_t
program_parameters
);
migraphx_status
migraphx_program_parameters_assign_to
(
migraphx_program_parameters_t
output
,
const_migraphx_program_parameters_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameters_assign_to
(
migraphx_program_parameters_t
output
,
const_migraphx_program_parameters_t
input
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameters_create
(
migraphx_program_parameters_t
*
program_parameters
);
migraphx_status
migraphx_program_parameters_add
(
migraphx_program_parameters_t
program_parameters
,
const
char
*
name
,
const_migraphx_argument_t
argument
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_parameters_add
(
migraphx_program_parameters_t
program_parameters
,
const
char
*
name
,
const_migraphx_argument_t
argument
);
migraphx_status
migraphx_arguments_destroy
(
migraphx_arguments_t
arguments
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_arguments_destroy
(
migraphx_arguments_t
arguments
);
migraphx_status
migraphx_arguments_assign_to
(
migraphx_arguments_t
output
,
const_migraphx_arguments_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_arguments_assign_to
(
migraphx_arguments_t
output
,
const_migraphx_arguments_t
input
);
migraphx_status
migraphx_arguments_size
(
size_t
*
out
,
migraphx_arguments_t
arguments
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_arguments_size
(
size_t
*
out
,
migraphx_arguments_t
arguments
);
migraphx_status
migraphx_arguments_get
(
const_migraphx_argument_t
*
out
,
migraphx_arguments_t
arguments
,
size_t
idx
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_arguments_get
(
const_migraphx_argument_t
*
out
,
migraphx_arguments_t
arguments
,
size_t
idx
);
migraphx_status
migraphx_shapes_destroy
(
migraphx_shapes_t
shapes
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shapes_destroy
(
migraphx_shapes_t
shapes
);
migraphx_status
migraphx_shapes_assign_to
(
migraphx_shapes_t
output
,
const_migraphx_shapes_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shapes_assign_to
(
migraphx_shapes_t
output
,
const_migraphx_shapes_t
input
);
migraphx_status
migraphx_shapes_size
(
size_t
*
out
,
migraphx_shapes_t
shapes
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shapes_size
(
size_t
*
out
,
migraphx_shapes_t
shapes
);
migraphx_status
migraphx_shapes_get
(
const_migraphx_shape_t
*
out
,
migraphx_shapes_t
shapes
,
size_t
idx
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_shapes_get
(
const_migraphx_shape_t
*
out
,
migraphx_shapes_t
shapes
,
size_t
idx
);
migraphx_status
migraphx_instruction_destroy
(
migraphx_instruction_t
instruction
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_instruction_destroy
(
migraphx_instruction_t
instruction
);
migraphx_status
migraphx_instruction_assign_to
(
migraphx_instruction_t
output
,
const_migraphx_instruction_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_instruction_assign_to
(
migraphx_instruction_t
output
,
const_migraphx_instruction_t
input
);
migraphx_status
migraphx_instructions_destroy
(
migraphx_instructions_t
instructions
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_instructions_destroy
(
migraphx_instructions_t
instructions
);
migraphx_status
migraphx_instructions_assign_to
(
migraphx_instructions_t
output
,
const_migraphx_instructions_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_instructions_assign_to
(
migraphx_instructions_t
output
,
const_migraphx_instructions_t
input
);
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
const_migraphx_instruction_t
*
ptr
,
size_t
size
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
const
const_migraphx_instruction_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_modules_destroy
(
migraphx_modules_t
modules
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_modules_destroy
(
migraphx_modules_t
modules
);
migraphx_status
migraphx_modules_assign_to
(
migraphx_modules_t
output
,
const_migraphx_modules_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_modules_assign_to
(
migraphx_modules_t
output
,
const_migraphx_modules_t
input
);
migraphx_status
migraphx_modules_create
(
migraphx_modules_t
*
modules
,
migraphx_module_t
*
ptr
,
size_t
size
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_modules_create
(
migraphx_modules_t
*
modules
,
migraphx_module_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_module_create
(
migraphx_module_t
*
module
,
char
*
name
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_create
(
migraphx_module_t
*
module
,
char
*
name
);
migraphx_status
migraphx_module_print
(
const_migraphx_module_t
module
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_print
(
const_migraphx_module_t
module
);
migraphx_status
migraphx_module_add_instruction
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_add_instruction
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
);
migraphx_status
migraphx_module_add_instruction_with_mod_args
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
,
migraphx_modules_t
module_refs
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_add_instruction_with_mod_args
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
,
migraphx_modules_t
module_refs
);
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
shape
,
const
char
*
buffer
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
shape
,
const
char
*
buffer
);
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const
char
*
name
,
const_migraphx_shape_t
shape
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const
char
*
name
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_module_add_return
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_instructions_t
args
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_add_return
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_instructions_t
args
);
migraphx_status
migraphx_module_add_allocation
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
s
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_module_add_allocation
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
s
);
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
);
migraphx_status
migraphx_program_assign_to
(
migraphx_program_t
output
,
const_migraphx_program_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_assign_to
(
migraphx_program_t
output
,
const_migraphx_program_t
input
);
migraphx_status
migraphx_program_create
(
migraphx_program_t
*
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_create
(
migraphx_program_t
*
program
);
migraphx_status
migraphx_program_get_main_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_get_main_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
);
migraphx_status
migraphx_program_create_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
,
const
char
*
name
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_create_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
,
const
char
*
name
);
migraphx_status
migraphx_program_compile
(
migraphx_program_t
program
,
migraphx_target_t
target
,
migraphx_compile_options_t
options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_compile
(
migraphx_program_t
program
,
migraphx_target_t
target
,
migraphx_compile_options_t
options
);
migraphx_status
migraphx_program_get_parameter_shapes
(
migraphx_program_parameter_shapes_t
*
out
,
migraphx_program_t
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_get_parameter_shapes
(
migraphx_program_parameter_shapes_t
*
out
,
migraphx_program_t
program
);
migraphx_status
migraphx_program_get_output_shapes
(
migraphx_shapes_t
*
out
,
migraphx_program_t
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_get_output_shapes
(
migraphx_shapes_t
*
out
,
migraphx_program_t
program
);
migraphx_status
migraphx_program_print
(
const_migraphx_program_t
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_print
(
const_migraphx_program_t
program
);
migraphx_status
migraphx_program_sort
(
migraphx_program_t
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_sort
(
migraphx_program_t
program
);
migraphx_status
migraphx_program_run
(
migraphx_arguments_t
*
out
,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_run
(
migraphx_arguments_t
*
out
,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
);
migraphx_status
migraphx_program_run_async
(
migraphx_arguments_t
*
out
,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
,
void
*
s
,
const
char
*
name
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_run_async
(
migraphx_arguments_t
*
out
,
migraphx_program_t
program
,
migraphx_program_parameters_t
params
,
void
*
s
,
const
char
*
name
);
migraphx_status
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
);
migraphx_status
migraphx_program_experimental_get_context
(
migraphx_context_t
*
out
,
const_migraphx_program_t
program
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_program_experimental_get_context
(
migraphx_context_t
*
out
,
const_migraphx_program_t
program
);
migraphx_status
migraphx_operation_destroy
(
migraphx_operation_t
operation
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_operation_destroy
(
migraphx_operation_t
operation
);
migraphx_status
migraphx_operation_assign_to
(
migraphx_operation_t
output
,
const_migraphx_operation_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_operation_assign_to
(
migraphx_operation_t
output
,
const_migraphx_operation_t
input
);
migraphx_status
migraphx_operation_create
(
migraphx_operation_t
*
operation
,
const
char
*
name
,
const
char
*
attributes
,
...);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_operation_create
(
migraphx_operation_t
*
operation
,
const
char
*
name
,
const
char
*
attributes
,
...);
migraphx_status
migraphx_operation_name
(
char
*
out
,
size_t
out_size
,
migraphx_operation_t
operation
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_operation_name
(
char
*
out
,
size_t
out_size
,
migraphx_operation_t
operation
);
migraphx_status
migraphx_load
(
migraphx_program_t
*
out
,
const
char
*
name
,
migraphx_file_options_t
options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_load
(
migraphx_program_t
*
out
,
const
char
*
name
,
migraphx_file_options_t
options
);
migraphx_status
migraphx_save
(
migraphx_program_t
p
,
const
char
*
name
,
migraphx_file_options_t
options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_save
(
migraphx_program_t
p
,
const
char
*
name
,
migraphx_file_options_t
options
);
migraphx_status
migraphx_onnx_options_destroy
(
migraphx_onnx_options_t
onnx_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_destroy
(
migraphx_onnx_options_t
onnx_options
);
migraphx_status
migraphx_onnx_options_assign_to
(
migraphx_onnx_options_t
output
,
const_migraphx_onnx_options_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_assign_to
(
migraphx_onnx_options_t
output
,
const_migraphx_onnx_options_t
input
);
migraphx_status
migraphx_onnx_options_create
(
migraphx_onnx_options_t
*
onnx_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_create
(
migraphx_onnx_options_t
*
onnx_options
);
migraphx_status
migraphx_onnx_options_set_input_parameter_shape
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_set_input_parameter_shape
(
migraphx_onnx_options_t
onnx_options
,
const
char
*
name
,
size_t
*
dims
,
size_t
dims_size
);
migraphx_status
migraphx_onnx_options_set_dyn_input_parameter_shape
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_set_dyn_input_parameter_shape
(
migraphx_onnx_options_t
onnx_options
,
const
char
*
name
,
migraphx_dynamic_dimensions_t
dims
);
migraphx_status
migraphx_onnx_options_set_default_dim_value
(
migraphx_onnx_options_t
onnx_options
,
size_t
value
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_set_default_dim_value
(
migraphx_onnx_options_t
onnx_options
,
size_t
value
);
migraphx_status
migraphx_onnx_options_set_default_dyn_dim_value
(
migraphx_onnx_options_t
onnx_options
,
const_migraphx_dynamic_dimension_t
dd
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_set_default_dyn_dim_value
(
migraphx_onnx_options_t
onnx_options
,
const_migraphx_dynamic_dimension_t
dd
);
migraphx_status
migraphx_onnx_options_set_default_loop_iterations
(
migraphx_onnx_options_t
onnx_options
,
int64_t
value
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_onnx_options_set_default_loop_iterations
(
migraphx_onnx_options_t
onnx_options
,
int64_t
value
);
migraphx_status
migraphx_file_options_destroy
(
migraphx_file_options_t
file_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_file_options_destroy
(
migraphx_file_options_t
file_options
);
migraphx_status
migraphx_file_options_assign_to
(
migraphx_file_options_t
output
,
const_migraphx_file_options_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_file_options_assign_to
(
migraphx_file_options_t
output
,
const_migraphx_file_options_t
input
);
migraphx_status
migraphx_file_options_create
(
migraphx_file_options_t
*
file_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_file_options_create
(
migraphx_file_options_t
*
file_options
);
migraphx_status
migraphx_file_options_set_file_format
(
migraphx_file_options_t
file_options
,
const
char
*
format
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_file_options_set_file_format
(
migraphx_file_options_t
file_options
,
const
char
*
format
);
migraphx_status
migraphx_compile_options_destroy
(
migraphx_compile_options_t
compile_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_compile_options_destroy
(
migraphx_compile_options_t
compile_options
);
migraphx_status
migraphx_compile_options_assign_to
(
migraphx_compile_options_t
output
,
const_migraphx_compile_options_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_compile_options_assign_to
(
migraphx_compile_options_t
output
,
const_migraphx_compile_options_t
input
);
migraphx_status
migraphx_compile_options_create
(
migraphx_compile_options_t
*
compile_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_compile_options_create
(
migraphx_compile_options_t
*
compile_options
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_compile_options_set_offload_copy
(
migraphx_compile_options_t
compile_options
,
bool
value
);
migraphx_status
migraphx_compile_options_set_fast_math
(
migraphx_compile_options_t
compile_options
,
bool
value
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_compile_options_set_fast_math
(
migraphx_compile_options_t
compile_options
,
bool
value
);
migraphx_status
migraphx_compile_options_set_exhaustive_tune_flag
(
migraphx_compile_options_t
compile_options
,
bool
value
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_compile_options_set_exhaustive_tune_flag
(
migraphx_compile_options_t
compile_options
,
bool
value
);
migraphx_status
migraphx_parse_onnx
(
migraphx_program_t
*
out
,
const
char
*
name
,
migraphx_onnx_options_t
options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_parse_onnx
(
migraphx_program_t
*
out
,
const
char
*
name
,
migraphx_onnx_options_t
options
);
migraphx_status
migraphx_parse_onnx_buffer
(
migraphx_program_t
*
out
,
const
void
*
data
,
size_t
size
,
migraphx_onnx_options_t
options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_parse_onnx_buffer
(
migraphx_program_t
*
out
,
const
void
*
data
,
size_t
size
,
migraphx_onnx_options_t
options
);
migraphx_status
migraphx_tf_options_destroy
(
migraphx_tf_options_t
tf_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_destroy
(
migraphx_tf_options_t
tf_options
);
migraphx_status
migraphx_tf_options_assign_to
(
migraphx_tf_options_t
output
,
const_migraphx_tf_options_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_assign_to
(
migraphx_tf_options_t
output
,
const_migraphx_tf_options_t
input
);
migraphx_status
migraphx_tf_options_create
(
migraphx_tf_options_t
*
tf_options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_create
(
migraphx_tf_options_t
*
tf_options
);
migraphx_status
migraphx_tf_options_set_nhwc
(
migraphx_tf_options_t
tf_options
,
bool
is_nhwc
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_set_nhwc
(
migraphx_tf_options_t
tf_options
,
bool
is_nhwc
);
migraphx_status
migraphx_tf_options_set_input_parameter_shape
(
migraphx_tf_options_t
tf_options
,
const
char
*
name
,
size_t
*
dims
,
size_t
dims_size
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_set_input_parameter_shape
(
migraphx_tf_options_t
tf_options
,
const
char
*
name
,
size_t
*
dims
,
size_t
dims_size
);
migraphx_status
migraphx_tf_options_set_default_dim_value
(
migraphx_tf_options_t
tf_options
,
size_t
value
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_set_default_dim_value
(
migraphx_tf_options_t
tf_options
,
size_t
value
);
migraphx_status
migraphx_tf_options_set_output_names
(
migraphx_tf_options_t
tf_options
,
const
char
**
names
,
size_t
names_size
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_tf_options_set_output_names
(
migraphx_tf_options_t
tf_options
,
const
char
**
names
,
size_t
names_size
);
migraphx_status
migraphx_parse_tf
(
migraphx_program_t
*
out
,
const
char
*
name
,
migraphx_tf_options_t
options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_parse_tf
(
migraphx_program_t
*
out
,
const
char
*
name
,
migraphx_tf_options_t
options
);
migraphx_status
migraphx_quantize_op_names_destroy
(
migraphx_quantize_op_names_t
quantize_op_names
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_op_names_destroy
(
migraphx_quantize_op_names_t
quantize_op_names
);
migraphx_status
migraphx_quantize_op_names_assign_to
(
migraphx_quantize_op_names_t
output
,
const_migraphx_quantize_op_names_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_op_names_assign_to
(
migraphx_quantize_op_names_t
output
,
const_migraphx_quantize_op_names_t
input
);
migraphx_status
migraphx_quantize_op_names_create
(
migraphx_quantize_op_names_t
*
quantize_op_names
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_op_names_create
(
migraphx_quantize_op_names_t
*
quantize_op_names
);
migraphx_status
migraphx_quantize_op_names_add
(
migraphx_quantize_op_names_t
quantize_op_names
,
const
char
*
name
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_op_names_add
(
migraphx_quantize_op_names_t
quantize_op_names
,
const
char
*
name
);
migraphx_status
migraphx_quantize_fp16_with_op_names
(
migraphx_program_t
prog
,
migraphx_quantize_op_names_t
name
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_fp16_with_op_names
(
migraphx_program_t
prog
,
migraphx_quantize_op_names_t
name
);
migraphx_status
migraphx_quantize_fp16
(
migraphx_program_t
prog
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_fp16
(
migraphx_program_t
prog
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8_options_destroy
(
migraphx_quantize_int8_options_t
quantize_int8_options
);
migraphx_status
migraphx_quantize_int8_options_assign_to
(
migraphx_quantize_int8_options_t
output
,
const_migraphx_quantize_int8_options_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8_options_assign_to
(
migraphx_quantize_int8_options_t
output
,
const_migraphx_quantize_int8_options_t
input
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8_options_create
(
migraphx_quantize_int8_options_t
*
quantize_int8_options
);
migraphx_status
migraphx_quantize_int8_options_add_op_name
(
migraphx_quantize_int8_options_t
quantize_int8_options
,
const
char
*
name
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8_options_add_op_name
(
migraphx_quantize_int8_options_t
quantize_int8_options
,
const
char
*
name
);
migraphx_status
migraphx_quantize_int8_options_add_calibration_data
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8_options_add_calibration_data
(
migraphx_quantize_int8_options_t
quantize_int8_options
,
migraphx_program_parameters_t
data
);
migraphx_status
migraphx_quantize_int8
(
migraphx_program_t
prog
,
migraphx_target_t
target
,
migraphx_quantize_int8_options_t
options
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_quantize_int8
(
migraphx_program_t
prog
,
migraphx_target_t
target
,
migraphx_quantize_int8_options_t
options
);
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
);
migraphx_status
migraphx_context_get_queue
(
void
**
out
,
migraphx_context_t
context
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_context_get_queue
(
void
**
out
,
migraphx_context_t
context
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
migraphx_status
migraphx_experimental_custom_op_assign_to
(
migraphx_experimental_custom_op_t
output
,
const_migraphx_experimental_custom_op_t
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_assign_to
(
migraphx_experimental_custom_op_t
output
,
const_migraphx_experimental_custom_op_t
input
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_create
(
migraphx_experimental_custom_op_t
*
experimental_custom_op
,
void
*
obj
,
migraphx_experimental_custom_op_copy
c
,
...
...
@@ -597,21 +630,20 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
const
char
*
obj_typename
,
const
char
*
name
);
migraphx_status
migraphx_experimental_custom_op_set_compute
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute
input
);
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_set_compute
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute
input
);
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
);
migraphx_status
migraphx_experimental_custom_op_set_output_alias
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_set_output_alias
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_output_alias
input
);
migraphx_status
migraphx_experimental_custom_op_set_runs_on_offload_target
(
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_set_runs_on_offload_target
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_runs_on_offload_target
input
);
migraphx_status
MIGRAPHX_C_EXPORT
migraphx_status
migraphx_experimental_custom_op_register
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
#ifdef __cplusplus
...
...
src/api/migraphx.py
View file @
1530ec24
...
...
@@ -79,7 +79,8 @@ def dynamic_dimension(h):
def
dynamic_dimensions
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'const_migraphx_dynamic_dimension_t*'
,
size
=
'size_t'
),
api
.
params
(
ptr
=
'const const_migraphx_dynamic_dimension_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>'
)
h
.
method
(
'size'
,
returns
=
'size_t'
)
h
.
method
(
'get'
,
...
...
@@ -215,7 +216,7 @@ def instruction(h):
def
instructions
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'const_migraphx_instruction_t*'
,
size
=
'size_t'
),
api
.
params
(
ptr
=
'const
const
_migraphx_instruction_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_instruction_t>'
)
...
...
src/common_dims.cpp
0 → 100644
View file @
1530ec24
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
>
static
auto
compute_end_dim
(
Iterator
start
,
Iterator
last
,
std
::
size_t
dim
)
{
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>
dim
;
});
if
(
x
<
dim
)
return
start
;
return
it
;
}
template
<
class
Range
>
static
auto
elements
(
const
Range
&
r
)
{
return
std
::
accumulate
(
r
.
begin
(),
r
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
}
struct
common_dim_state
{
common_dim_state
(
const
std
::
vector
<
std
::
size_t
>&
pdims
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>&
paxes_map
)
:
dims
(
&
pdims
),
axes_map
(
&
paxes_map
),
it
(
dims
->
begin
())
{
}
const
std
::
vector
<
std
::
size_t
>*
dims
=
nullptr
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>*
axes_map
=
nullptr
;
std
::
vector
<
std
::
size_t
>::
const_iterator
it
{};
std
::
size_t
rem
=
1
;
std
::
size_t
get
()
const
{
return
*
it
/
rem
;
}
bool
is_end
()
const
{
return
it
==
dims
->
end
();
}
void
next
(
std
::
size_t
i
=
1
)
{
it
+=
i
;
}
auto
dims_for
(
std
::
size_t
d
)
const
{
auto
dim_end
=
compute_end_dim
(
it
,
dims
->
end
(),
d
);
return
range
(
it
,
dim_end
);
}
void
add_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
MIGRAPHX_TIDY_CONST
{
auto
axes
=
compute_axes
(
naxes
,
start
);
axes_map
->
push_back
(
std
::
move
(
axes
));
}
void
add_multi_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
MIGRAPHX_TIDY_CONST
{
auto
axes
=
compute_axes
(
naxes
,
start
);
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
*
axes_map
),
[
&
](
auto
axis
)
->
std
::
vector
<
std
::
size_t
>
{
return
{
axis
};
});
}
std
::
vector
<
std
::
size_t
>
compute_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
const
{
if
(
rem
!=
1
)
{
assert
(
start
>
0
);
naxes
++
;
start
--
;
}
std
::
vector
<
std
::
size_t
>
axes
(
naxes
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
start
);
return
axes
;
}
};
static
bool
compute_common_dim
(
std
::
vector
<
std
::
size_t
>&
cd_dims
,
common_dim_state
&
state1
,
common_dim_state
&
state2
)
{
assert
(
state1
.
get
()
<=
state2
.
get
());
auto
d2
=
state2
.
get
();
auto
dims
=
state1
.
dims_for
(
d2
);
auto
n
=
elements
(
dims
);
auto
naxes
=
distance
(
dims
);
if
(
naxes
==
0
)
return
false
;
// If not divisible then we can't compute a common dim
if
((
d2
%
n
)
!=
0
)
return
false
;
auto
rem
=
d2
/
n
;
state1
.
add_multi_axes
(
naxes
,
cd_dims
.
size
());
state2
.
add_axes
(
rem
==
1
?
naxes
:
naxes
+
1
,
cd_dims
.
size
());
state1
.
rem
=
rem
;
state2
.
rem
=
1
;
cd_dims
.
insert
(
cd_dims
.
end
(),
dims
.
begin
(),
dims
.
end
());
if
(
state1
.
rem
!=
1
)
cd_dims
.
push_back
(
state1
.
rem
);
state1
.
next
(
distance
(
dims
));
state2
.
next
();
return
true
;
}
common_dims
common_dims
::
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
)
{
assert
(
elements
(
dims1
)
>
0
);
assert
(
elements
(
dims1
)
==
elements
(
dims2
));
common_dims
cd
;
common_dim_state
state1
{
dims1
,
cd
.
axes_map1
};
common_dim_state
state2
{
dims2
,
cd
.
axes_map2
};
while
(
not
state1
.
is_end
()
and
not
state2
.
is_end
())
{
auto
d1
=
state1
.
get
();
auto
d2
=
state2
.
get
();
if
(
d1
<=
d2
)
{
if
(
not
compute_common_dim
(
cd
.
dims
,
state1
,
state2
))
return
{};
}
else
// if(d1 > d2)
{
if
(
not
compute_common_dim
(
cd
.
dims
,
state2
,
state1
))
return
{};
}
}
assert
(
elements
(
dims1
)
==
elements
(
cd
.
dims
));
return
cd
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/driver/CMakeLists.txt
View file @
1530ec24
...
...
@@ -32,18 +32,23 @@ add_executable(driver
marker_roctx.cpp
)
set_target_properties
(
driver PROPERTIES OUTPUT_NAME migraphx-driver
)
# Copy driver for backwards compatibility
add_custom_command
(
TARGET driver
if
(
NOT WIN32
)
# Copy driver for backwards compatibility (Linux only)
add_custom_command
(
TARGET driver
POST_BUILD COMMAND
${
CMAKE_COMMAND
}
-E copy
$<TARGET_FILE:driver>
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
BYPRODUCTS
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
)
set_directory_properties
(
PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
)
)
set_directory_properties
(
PROPERTIES ADDITIONAL_CLEAN_FILES
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
)
endif
()
rocm_clang_tidy_check
(
driver
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf
)
file
(
STRINGS
"
${
CMAKE_SOURCE_DIR
}
/test/onnx/.onnxrt-commit"
String_output
)
target_compile_definitions
(
driver PUBLIC MIGRAPHX_ORT_SHA1=
"
${
String_output
}
"
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py
)
rocm_install_targets
(
TARGETS driver
...
...
src/driver/argument_parser.hpp
View file @
1530ec24
...
...
@@ -338,11 +338,22 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC
auto
file_exist
()
{
return
validate
([](
auto
&
,
auto
&
,
auto
&
params
)
{
return
validate
([](
auto
&
,
auto
&
,
const
auto
&
params
)
{
if
(
params
.
empty
())
throw
std
::
runtime_error
(
"No argument passed."
);
if
(
not
fs
::
exists
(
params
.
back
()))
throw
std
::
runtime_error
(
"Path does not exists: "
+
params
.
back
());
throw
std
::
runtime_error
(
"Path does not exist: "
+
params
.
back
());
});
}
MIGRAPHX_DRIVER_STATIC
auto
matches
(
const
std
::
unordered_set
<
std
::
string
>&
names
)
{
return
validate
([
=
](
auto
&
,
auto
&
,
const
auto
&
params
)
{
auto
invalid_param
=
std
::
find_if
(
params
.
begin
(),
params
.
end
(),
[
&
](
const
auto
&
p
)
{
return
names
.
count
(
p
)
==
0
;
});
if
(
invalid_param
!=
params
.
end
())
throw
std
::
runtime_error
(
"Invalid argument: "
+
*
invalid_param
+
". Valid arguments are {"
+
to_string_range
(
names
)
+
"}"
);
});
}
...
...
@@ -570,8 +581,7 @@ struct argument_parser
continue
;
if
(
flag
[
0
]
!=
'-'
)
continue
;
auto
d
=
levenshtein_distance
(
flag
.
begin
(),
flag
.
end
(),
input
.
begin
(),
input
.
end
());
std
::
ptrdiff_t
d
=
levenshtein_distance
(
flag
,
input
);
if
(
d
<
result
.
distance
)
result
=
result_t
{
&
arg
,
flag
,
input
,
d
};
}
...
...
src/driver/main.cpp
View file @
1530ec24
...
...
@@ -32,6 +32,7 @@
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/py.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp>
...
...
@@ -81,6 +82,7 @@ struct loader
{
"--model"
},
ap
.
help
(
"Load model"
),
ap
.
type
(
"resnet50|inceptionv3|alexnet"
),
ap
.
matches
({
"resnet50"
,
"inceptionv3"
,
"alexnet"
}),
ap
.
group
(
"input"
));
ap
(
file_type
,
{
"--onnx"
},
ap
.
help
(
"Load as onnx"
),
ap
.
set_value
(
"onnx"
));
ap
(
file_type
,
{
"--tf"
},
ap
.
help
(
"Load as tensorflow"
),
ap
.
set_value
(
"tf"
));
...
...
@@ -241,6 +243,20 @@ struct loader
return
options
;
}
static
std
::
string
get_file_type
(
const
std
::
string
&
file
)
{
if
(
ends_with
(
file
,
".onnx"
))
return
"onnx"
;
else
if
(
ends_with
(
file
,
".pb"
))
return
"tf"
;
else
if
(
ends_with
(
file
,
".json"
))
return
"json"
;
else
if
(
ends_with
(
file
,
".py"
))
return
"py"
;
else
return
"migraphx"
;
}
program
load
()
{
program
p
;
...
...
@@ -248,14 +264,7 @@ struct loader
{
if
(
file_type
.
empty
())
{
if
(
ends_with
(
file
,
".onnx"
))
file_type
=
"onnx"
;
else
if
(
ends_with
(
file
,
".pb"
))
file_type
=
"tf"
;
else
if
(
ends_with
(
file
,
".json"
))
file_type
=
"json"
;
else
file_type
=
"migraphx"
;
file_type
=
get_file_type
(
file
);
}
std
::
cout
<<
"Reading: "
<<
file
<<
std
::
endl
;
if
(
file_type
==
"onnx"
)
...
...
@@ -272,6 +281,10 @@ struct loader
options
.
format
=
"json"
;
p
=
migraphx
::
load
(
file
,
options
);
}
else
if
(
file_type
==
"py"
)
{
p
=
migraphx
::
load_py
(
file
);
}
else
if
(
file_type
==
"migraphx"
)
{
p
=
migraphx
::
load
(
file
);
...
...
@@ -462,13 +475,15 @@ struct compiler
{
if
(
is_offload_copy_set
(
p
)
and
not
co
.
offload_copy
)
{
std
::
cout
<<
"MIGraphX program was likely compiled with offload_copy set, Try "
"passing "
"`--enable-offload-copy` if program run fails.
\n
"
;
std
::
cout
<<
"[WARNING]: MIGraphX program was likely compiled with offload_copy "
"set, Try "
"passing "
"`--enable-offload-copy` if program run fails.
\n
"
;
}
else
if
(
co
.
offload_copy
)
{
std
::
cout
<<
"MIGraphX program was likely compiled without "
std
::
cout
<<
"
[WARNING]:
MIGraphX program was likely compiled without "
"offload_copy set, Try "
"removing "
"`--enable-offload-copy` flag if passed to driver, if program run "
...
...
@@ -757,7 +772,7 @@ struct main_command
{
std
::
cout
<<
"'"
<<
color
::
fg_yellow
<<
wrong_commands
.
front
()
<<
color
::
reset
<<
"' is not a valid command."
<<
std
::
endl
;
std
::
cout
<<
get_command_help
(
"Available commands:"
)
<<
std
::
endl
;
std
::
cout
<<
get_command_help
(
"Available commands:"
);
}
else
{
...
...
@@ -789,6 +804,13 @@ int main(int argc, const char* argv[])
auto
&&
m
=
get_commands
();
auto
cmd
=
args
.
front
();
if
(
cmd
==
"ort-sha"
)
{
std
::
cout
<<
MIGRAPHX_ORT_SHA1
<<
std
::
endl
;
return
0
;
}
if
(
m
.
count
(
cmd
)
>
0
)
{
m
.
at
(
cmd
)(
argv
[
0
],
{
args
.
begin
()
+
1
,
args
.
end
()});
...
...
src/dynamic_loader.cpp
View file @
1530ec24
...
...
@@ -48,7 +48,7 @@ struct dynamic_loader_impl
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
dynamic_loader_impl
(
const
fs
::
path
&
p
,
std
::
shared_ptr
<
tmp_dir
>
t
=
nullptr
)
:
handle
(
dlopen
(
p
.
string
().
c_str
(),
RTLD_
LAZY
),
:
handle
(
dlopen
(
p
.
string
().
c_str
(),
RTLD_
GLOBAL
|
RTLD_NOW
),
manage_deleter
<
decltype
(
&
dlclose
),
&
dlclose
>
{}),
temp
(
std
::
move
(
t
))
{
...
...
@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
return
p
;
}
optional
<
dynamic_loader
>
dynamic_loader
::
try_load
(
const
fs
::
path
&
p
)
{
try
{
return
dynamic_loader
{
p
};
}
catch
(
const
std
::
exception
&
)
{
return
nullopt
;
}
}
dynamic_loader
::
dynamic_loader
(
const
fs
::
path
&
p
)
:
impl
(
std
::
make_shared
<
dynamic_loader_impl
>
(
p
))
{
}
...
...
src/eliminate_contiguous.cpp
View file @
1530ec24
...
...
@@ -35,6 +35,8 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
)
static
bool
try_compute_shape
(
instruction_ref
ins
,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
...
...
@@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins,
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
});
if
(
not
try_compute_shape
(
output
,
input_shapes
,
mods
))
if
(
not
try_compute_shape
(
output
,
input_shapes
,
output
->
module_inputs
()
))
{
return
false
;
}
}
}
catch
(
const
std
::
exception
&
e
)
{
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"Exception: "
<<
e
.
what
()
<<
std
::
endl
;
}
return
false
;
}
catch
(...)
{
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"Unknown exception"
<<
std
::
endl
;
}
return
false
;
}
...
...
@@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
{
if
(
arg
->
name
()
!=
op_name
)
continue
;
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"eliminate_contiguous: "
;
m
.
debug_print
(
ins
);
}
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
...
...
src/fuse_pointwise.cpp
View file @
1530ec24
...
...
@@ -24,11 +24,14 @@
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common_dims.hpp>
#include <iterator>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
...
...
@@ -41,7 +44,7 @@ static literal get_scalar(instruction_ref ins)
if
(
ins
->
name
()
==
"contiguous"
)
return
get_scalar
(
ins
->
inputs
().
front
());
const
auto
&
s
=
ins
->
get_shape
();
if
(
s
.
elements
()
!=
1
&&
not
(
s
.
scalar
()))
if
(
s
.
elements
()
!=
1
and
not
(
s
.
scalar
()))
return
{};
if
(
not
ins
->
can_eval
())
return
{};
...
...
@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
}
return
changed
;
}
namespace
{
struct
find_pointwise_reshape_pointwise
{
auto
matcher
()
const
{
auto
reshape
=
match
::
name
(
"reshape"
,
"squeeze"
,
"unsqueeze"
,
"flatten"
)(
match
::
used_once
());
auto
skip_contiguous
=
[](
auto
...
ms
)
{
return
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"contiguous"
)(
match
::
used_once
()))(
ms
...));
};
auto
pointwise
=
match
::
name
(
"pointwise"
)(
match
::
used_once
());
auto
reshape_pointwise
=
reshape
(
skip_contiguous
(
pointwise
.
bind
(
"x"
))).
bind
(
"reshape"
);
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
reshape_pointwise
));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
reshape_ins
=
r
.
instructions
[
"reshape"
];
auto
cd
=
common_dims
::
compute
(
ins
->
get_shape
().
lens
(),
x_ins
->
get_shape
().
lens
());
if
(
cd
.
dims
.
empty
())
return
;
auto
reshape_input
=
[
&
](
const
auto
&
ins_to_insert
)
{
return
[
&
](
auto
input
)
{
auto
c
=
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"contiguous"
),
input
);
return
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"reshape"
,
{{
"dims"
,
cd
.
dims
}}),
c
);
};
};
auto
x_inputs
=
x_ins
->
inputs
();
std
::
transform
(
x_inputs
.
begin
(),
x_inputs
.
end
(),
x_inputs
.
begin
(),
reshape_input
(
x_ins
));
auto
new_x_ins
=
m
.
insert_instruction
(
x_ins
,
x_ins
->
get_operator
(),
x_inputs
,
x_ins
->
module_inputs
());
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
==
reshape_ins
)
return
new_x_ins
;
return
reshape_input
(
ins
)(
input
);
});
auto
pw
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
ins
->
module_inputs
());
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
ins
->
get_shape
().
lens
()}}),
pw
);
}
};
}
// namespace
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
{
...
...
@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
}
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
match
::
find_matches
(
mpm
.
get_module
(),
find_pointwise_reshape_pointwise
{});
mpm
.
run_pass
(
simplify_reshapes
{
1
});
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
break
;
mpm
.
run_pass
(
dead_code_elimination
{});
...
...
src/fuse_reduce.cpp
View file @
1530ec24
...
...
@@ -52,7 +52,7 @@ struct fused_reduce
{
if
(
mods
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"should have one submodule."
);
auto
*
sm
=
mods
.
front
();
const
auto
*
sm
=
mods
.
front
();
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Only one output supported"
);
auto
names
=
sm
->
get_parameter_names
();
...
...
@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
}
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
find_inputs
(
const_
module_ref
sm
,
const
module
&
parent
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
...
...
src/include/migraphx/algorithm.hpp
View file @
1530ec24
...
...
@@ -26,6 +26,8 @@
#include <algorithm>
#include <numeric>
#include <string>
#include <vector>
#include <migraphx/config.hpp>
namespace
migraphx
{
...
...
@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
return
std
::
ptrdiff_t
{
1
}
+
std
::
min
({
x1
,
x2
,
x3
});
}
inline
size_t
levenshtein_distance
(
const
std
::
string
&
s1
,
const
std
::
string
&
s2
)
{
const
size_t
l1
=
s1
.
length
();
const
size_t
l2
=
s2
.
length
();
if
(
l1
<
l2
)
levenshtein_distance
(
s2
,
s1
);
std
::
vector
<
size_t
>
d
(
l2
+
1
);
std
::
iota
(
d
.
begin
(),
d
.
end
(),
0
);
for
(
size_t
i
=
1
;
i
<=
l1
;
i
++
)
{
size_t
prev_cost
=
d
[
0
];
d
[
0
]
=
i
;
for
(
size_t
j
=
1
;
j
<=
l2
;
j
++
)
{
if
(
s1
[
i
-
1
]
==
s2
[
j
-
1
])
{
d
[
j
]
=
prev_cost
;
}
else
{
size_t
cost_insert_or_delete
=
std
::
min
(
d
[
j
-
1
],
d
[
j
]);
size_t
cost_substitute
=
prev_cost
;
prev_cost
=
d
[
j
];
d
[
j
]
=
std
::
min
(
cost_substitute
,
cost_insert_or_delete
)
+
1
;
}
}
}
return
d
[
l2
];
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/allocation_model.hpp
View file @
1530ec24
...
...
@@ -96,7 +96,7 @@ struct allocation_model
{
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
derived
and
private_detail_te_handle_mem_var
.
u
se_count
()
==
1
)
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
...
...
@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
not
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
private_detail_te_handle_mem_var
.
u
se_count
()
>
1
)
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/builtin.hpp
View file @
1530ec24
...
...
@@ -90,7 +90,17 @@ struct param
struct
returns
{
std
::
string
name
()
const
{
return
"@return"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
)
const
{
return
{};
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
arg
)
const
{
if
(
arg
.
empty
())
return
{};
else
if
(
arg
.
size
()
==
1
)
return
arg
[
0
];
else
return
arg
;
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPHX_THROW
(
"builtin"
);
...
...
src/include/migraphx/check_shapes.hpp
View file @
1530ec24
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -34,29 +34,51 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Check that deduced type is incrementable, dereferencable, and comparable
template
<
class
,
class
=
void
>
struct
is_iterator
{
};
template
<
class
T
>
struct
is_iterator
<
T
,
std
::
void_t
<
decltype
(
++
std
::
declval
<
T
&>
()),
decltype
(
*
std
::
declval
<
T
&>
()),
decltype
(
std
::
declval
<
T
&>
()
==
std
::
declval
<
T
&>
())
>>
:
std
::
true_type
{
};
template
<
class
Iterator
>
struct
check_shapes
{
const
shape
*
begin
;
const
shape
*
end
;
static_assert
(
is_iterator
<
Iterator
>
{},
"CHECK_SHAPES: Deduced type must be an iterator"
);
Iterator
begin
;
Iterator
end
;
std
::
string
name
;
bool
dynamic_allowed
;
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
template
<
class
Op
>
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
template
<
class
Op
>
template
<
class
Op
,
MIGRAPHX_REQUIRES
(
not
std
::
is_convertible
<
Op
,
std
::
string
>{})
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
n
),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
...
...
@@ -81,8 +103,6 @@ struct check_shapes
{
if
(
begin
==
end
)
return
0
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
end
-
begin
;
}
...
...
@@ -131,11 +151,9 @@ struct check_shapes
*/
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
!=
n
)
if
(
begin
->
ndim
()
!=
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
}
return
*
this
;
...
...
@@ -148,11 +166,9 @@ struct check_shapes
*/
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
>
n
)
if
(
begin
->
ndim
()
>
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at most "
+
std
::
to_string
(
n
)
+
" dimensions"
);
}
...
...
@@ -166,11 +182,9 @@ struct check_shapes
*/
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
<
n
)
if
(
begin
->
ndim
()
<
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at least "
+
std
::
to_string
(
n
)
+
" dimensions"
);
}
...
...
@@ -220,6 +234,16 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes have the same layout.
*/
const
check_shapes
&
same_layout
()
const
{
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
find_permutation
(
s
);
}))
MIGRAPHX_THROW
(
prefix
()
+
"Layouts do not match"
);
return
*
this
;
}
/*!
* Check all shapes are standard.
*/
...
...
@@ -230,6 +254,16 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes are scalar.
*/
const
check_shapes
&
scalar
()
const
{
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar"
);
return
*
this
;
}
/*!
* Check all shapes are standard or scalar.
*/
...
...
@@ -330,8 +364,6 @@ struct check_shapes
{
if
(
begin
==
end
)
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
auto
&&
key
=
f
(
*
begin
);
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
}
...
...
@@ -341,8 +373,6 @@ struct check_shapes
{
if
(
begin
==
end
)
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
all_of
(
begin
,
end
,
p
);
}
...
...
@@ -351,17 +381,13 @@ struct check_shapes
{
if
(
begin
==
end
)
return
false
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
any_of
(
begin
,
end
,
p
);
}
const
shape
*
get
(
long
i
)
const
Iterator
get
(
long
i
)
const
{
if
(
i
>=
size
())
MIGRAPHX_THROW
(
prefix
()
+
"Accessing shape out of bounds"
);
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
i
<
0
)
return
end
-
i
;
return
begin
+
i
;
...
...
@@ -394,6 +420,11 @@ struct check_shapes
}
};
// Deduction guide for std::vector constructor
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
,
const
Op
&
,
bool
d
=
false
)
->
check_shapes
<
std
::
vector
<
shape
>::
const_iterator
>
;
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/common_dims.hpp
0 → 100644
View file @
1530ec24
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/// This will compute a higher dimensional space that will preserve the axes
/// for both sets of dimensions. Two axes_maps are provided for each of the
/// dims that will map the axis to the axes that are used by the result of
/// common_dims.
struct
MIGRAPHX_EXPORT
common_dims
{
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
);
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
src/include/migraphx/concat_opt.hpp
View file @
1530ec24
...
...
@@ -88,7 +88,7 @@ struct concat_optimization
{
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
derived
and
private_detail_te_handle_mem_var
.
u
se_count
()
==
1
)
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
...
...
@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
not
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
private_detail_te_handle_mem_var
.
u
se_count
()
>
1
)
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/context.hpp
View file @
1530ec24
...
...
@@ -118,7 +118,7 @@ struct context
{
using
std
::
swap
;
auto
*
derived
=
this
->
any_cast
<
PrivateDetailTypeErasedT
>
();
if
(
derived
and
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
derived
and
private_detail_te_handle_mem_var
.
u
se_count
()
==
1
)
{
*
derived
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
}
...
...
@@ -373,7 +373,7 @@ struct context
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
not
private_detail_te_handle_mem_var
.
u
nique
()
)
if
(
private_detail_te_handle_mem_var
.
u
se_count
()
>
1
)
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/convolution.hpp
View file @
1530ec24
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -62,7 +62,7 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri
shape
win_shape
{
output_shape
.
type
(),
win_size
};
double
acc
=
0.0
;
shape_for_each
(
win_shape
,
[
&
](
auto
idx_win
)
{
shape_for_each
(
win_shape
,
[
&
](
const
auto
&
idx_win
)
{
auto
k
=
idx_win
[
0
];
const
auto
in_ch
=
group_id
*
wei_c
+
k
;
std
::
vector
<
std
::
ptrdiff_t
>
idx
(
idx_o
.
begin
(),
idx_o
.
end
());
...
...
Prev
1
2
3
4
5
6
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment