Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e2eb6036
"src/vscode:/vscode.git/clone" did not exist on "25bac567de2f7809b4cac173a0e3c21696e07bc2"
Commit
e2eb6036
authored
Apr 13, 2022
by
Paul
Browse files
Merge
parents
298c93d5
1e0bbd78
Changes
267
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1187 additions
and
81 deletions
+1187
-81
examples/nlp/python_bert_squad/run_onnx_squad.py
examples/nlp/python_bert_squad/run_onnx_squad.py
+1
-0
examples/vision/python_yolov4/yolov4_inference.ipynb
examples/vision/python_yolov4/yolov4_inference.ipynb
+7
-7
hip-clang.docker
hip-clang.docker
+1
-1
src/CMakeLists.txt
src/CMakeLists.txt
+7
-2
src/api/api.cpp
src/api/api.cpp
+540
-16
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+146
-2
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+329
-23
src/api/migraphx.py
src/api/migraphx.py
+60
-0
src/argument.cpp
src/argument.cpp
+5
-1
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+27
-0
src/compile_src.cpp
src/compile_src.cpp
+8
-1
src/cpp_generator.cpp
src/cpp_generator.cpp
+9
-1
src/driver/alexnet.cpp
src/driver/alexnet.cpp
+3
-3
src/driver/inceptionv3.cpp
src/driver/inceptionv3.cpp
+14
-14
src/driver/main.cpp
src/driver/main.cpp
+8
-0
src/driver/perf.cpp
src/driver/perf.cpp
+1
-1
src/driver/resnet50.cpp
src/driver/resnet50.cpp
+2
-2
src/eliminate_common_subexpression.cpp
src/eliminate_common_subexpression.cpp
+6
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+5
-4
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+8
-2
No files found.
examples/nlp/python_bert_squad/run_onnx_squad.py
View file @
e2eb6036
# Modifications Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved
# Copyright 2018 The Google AI Language Team Authors.
# Copyright 2018 The Google AI Language Team Authors.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
examples/vision/python_yolov4/yolov4_inference.ipynb
View file @
e2eb6036
...
@@ -50,10 +50,10 @@
...
@@ -50,10 +50,10 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"if not os.path.exists(\"yolov4_fp16.m
sgpack
\"):\n",
"if not os.path.exists(\"yolov4_fp16.m
xr
\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.m
sgpack
\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.m
xr
\n",
"if not os.path.exists(\"yolov4.m
sgpack
\"):\n",
"if not os.path.exists(\"yolov4.m
xr
\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.m
sgpack
"
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.m
xr
"
]
]
},
},
{
{
...
@@ -115,8 +115,8 @@
...
@@ -115,8 +115,8 @@
"outputs": [],
"outputs": [],
"source": [
"source": [
"# Load serialized model (either single- or half-precision)\n",
"# Load serialized model (either single- or half-precision)\n",
"model = migraphx.load(\"yolov4.m
sgpack
\", format=\"msgpack\")\n",
"model = migraphx.load(\"yolov4.m
xr
\", format=\"msgpack\")\n",
"#model = migraphx.load(\"yolov4_fp16.m
sgpack
\", format=\"msgpack\")\n",
"#model = migraphx.load(\"yolov4_fp16.m
xr
\", format=\"msgpack\")\n",
"\n",
"\n",
"# Get the name of the input parameter and convert image data to an MIGraphX argument\n",
"# Get the name of the input parameter and convert image data to an MIGraphX argument\n",
"input_name = next(iter(model.get_parameter_shapes()))\n",
"input_name = next(iter(model.get_parameter_shapes()))\n",
...
@@ -192,4 +192,4 @@
...
@@ -192,4 +192,4 @@
},
},
"nbformat": 4,
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 2
}
}
\ No newline at end of file
hip-clang.docker
View file @
e2eb6036
...
@@ -12,7 +12,7 @@ RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.5
...
@@ -12,7 +12,7 @@ RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.5
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
apt-utils
\
apt-utils
\
build-essential
\
build-essential
\
clang-format-
5.
0
\
clang-format-
1
0
\
cmake
\
cmake
\
curl
\
curl
\
doxygen
\
doxygen
\
...
...
src/CMakeLists.txt
View file @
e2eb6036
...
@@ -38,6 +38,7 @@ add_library(migraphx
...
@@ -38,6 +38,7 @@ add_library(migraphx
msgpack.cpp
msgpack.cpp
normalize_attributes.cpp
normalize_attributes.cpp
normalize_ops.cpp
normalize_ops.cpp
op_enums.cpp
operation.cpp
operation.cpp
opt/memory_coloring.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
opt/memory_coloring_impl.cpp
...
@@ -114,6 +115,7 @@ register_migraphx_ops(
...
@@ -114,6 +115,7 @@ register_migraphx_ops(
identity
identity
if_op
if_op
im2col
im2col
isnan
leaky_relu
leaky_relu
less
less
load
load
...
@@ -161,6 +163,9 @@ register_migraphx_ops(
...
@@ -161,6 +163,9 @@ register_migraphx_ops(
rsqrt
rsqrt
scalar
scalar
scatter
scatter
scatternd_none
scatternd_add
scatternd_mul
sigmoid
sigmoid
sign
sign
sinh
sinh
...
@@ -211,7 +216,6 @@ target_link_libraries(migraphx PRIVATE msgpackc-cxx)
...
@@ -211,7 +216,6 @@ target_link_libraries(migraphx PRIVATE msgpackc-cxx)
target_link_libraries
(
migraphx INTERFACE $<BUILD_INTERFACE:msgpackc-cxx>
)
target_link_libraries
(
migraphx INTERFACE $<BUILD_INTERFACE:msgpackc-cxx>
)
add_library
(
migraphx_all_targets INTERFACE
)
add_library
(
migraphx_all_targets INTERFACE
)
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_ref
)
set
(
PACKAGE_DEPENDS
)
set
(
PACKAGE_DEPENDS
)
...
@@ -222,6 +226,7 @@ add_subdirectory(tf)
...
@@ -222,6 +226,7 @@ add_subdirectory(tf)
add_subdirectory
(
py
)
add_subdirectory
(
py
)
add_subdirectory
(
targets/ref
)
add_subdirectory
(
targets/ref
)
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_ref
)
if
(
MIGRAPHX_ENABLE_CPU
)
if
(
MIGRAPHX_ENABLE_CPU
)
add_subdirectory
(
targets/cpu
)
add_subdirectory
(
targets/cpu
)
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_cpu
)
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_cpu
)
...
@@ -239,7 +244,7 @@ if(HAVE_HALF_EXPR)
...
@@ -239,7 +244,7 @@ if(HAVE_HALF_EXPR)
endif
()
endif
()
rocm_export_targets
(
rocm_export_targets
(
TARGETS migraphx::migraphx
migraphx_all_targets
TARGETS migraphx::migraphx
_c
NAMESPACE migraphx::
NAMESPACE migraphx::
DEPENDS
DEPENDS
Threads
Threads
...
...
src/api/api.cpp
View file @
e2eb6036
This diff is collapsed.
Click to expand it.
src/api/include/migraphx/migraphx.h
View file @
e2eb6036
...
@@ -25,7 +25,8 @@ extern "C" {
...
@@ -25,7 +25,8 @@ extern "C" {
#endif
#endif
// return code, more to be added later
// return code, more to be added later
typedef
enum
{
typedef
enum
{
migraphx_status_success
=
0
,
migraphx_status_success
=
0
,
migraphx_status_bad_param
=
1
,
migraphx_status_bad_param
=
1
,
migraphx_status_unknown_target
=
3
,
migraphx_status_unknown_target
=
3
,
...
@@ -35,7 +36,8 @@ typedef enum {
...
@@ -35,7 +36,8 @@ typedef enum {
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs
/// An enum to represent the different data type inputs
typedef
enum
{
typedef
enum
{
migraphx_shape_tuple_type
,
migraphx_shape_tuple_type
,
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
)
}
migraphx_shape_datatype_t
;
}
migraphx_shape_datatype_t
;
...
@@ -62,6 +64,15 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t;
...
@@ -62,6 +64,15 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t;
typedef
struct
migraphx_shapes
*
migraphx_shapes_t
;
typedef
struct
migraphx_shapes
*
migraphx_shapes_t
;
typedef
const
struct
migraphx_shapes
*
const_migraphx_shapes_t
;
typedef
const
struct
migraphx_shapes
*
const_migraphx_shapes_t
;
typedef
struct
migraphx_instruction
*
migraphx_instruction_t
;
typedef
const
struct
migraphx_instruction
*
const_migraphx_instruction_t
;
typedef
struct
migraphx_instructions
*
migraphx_instructions_t
;
typedef
const
struct
migraphx_instructions
*
const_migraphx_instructions_t
;
typedef
struct
migraphx_modules
*
migraphx_modules_t
;
typedef
const
struct
migraphx_modules
*
const_migraphx_modules_t
;
typedef
struct
migraphx_module
*
migraphx_module_t
;
typedef
struct
migraphx_module
*
migraphx_module_t
;
typedef
const
struct
migraphx_module
*
const_migraphx_module_t
;
typedef
const
struct
migraphx_module
*
const_migraphx_module_t
;
...
@@ -89,8 +100,24 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name
...
@@ -89,8 +100,24 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name
typedef
struct
migraphx_quantize_int8_options
*
migraphx_quantize_int8_options_t
;
typedef
struct
migraphx_quantize_int8_options
*
migraphx_quantize_int8_options_t
;
typedef
const
struct
migraphx_quantize_int8_options
*
const_migraphx_quantize_int8_options_t
;
typedef
const
struct
migraphx_quantize_int8_options
*
const_migraphx_quantize_int8_options_t
;
typedef
struct
migraphx_context
*
migraphx_context_t
;
typedef
const
struct
migraphx_context
*
const_migraphx_context_t
;
typedef
struct
migraphx_experimental_custom_op
*
migraphx_experimental_custom_op_t
;
typedef
const
struct
migraphx_experimental_custom_op
*
const_migraphx_experimental_custom_op_t
;
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_compute_shape
)(
migraphx_shape_t
out
,
void
*
obj
,
migraphx_shapes_t
inputs
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_copy
)(
void
**
out
,
void
*
input
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_delete
)(
void
*
input
);
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_assign_to
(
migraphx_shape_t
output
,
const_migraphx_shape_t
input
);
migraphx_status
migraphx_shape_create
(
migraphx_shape_t
*
shape
,
migraphx_status
migraphx_shape_create
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
migraphx_shape_datatype_t
type
,
size_t
*
lengths
,
size_t
*
lengths
,
...
@@ -121,6 +148,9 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
...
@@ -121,6 +148,9 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
migraphx_status
migraphx_argument_assign_to
(
migraphx_argument_t
output
,
const_migraphx_argument_t
input
);
migraphx_status
migraphx_status
migraphx_argument_create
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
,
void
*
buffer
);
migraphx_argument_create
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
,
void
*
buffer
);
...
@@ -137,11 +167,17 @@ migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, s
...
@@ -137,11 +167,17 @@ migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, s
migraphx_status
migraphx_target_destroy
(
migraphx_target_t
target
);
migraphx_status
migraphx_target_destroy
(
migraphx_target_t
target
);
migraphx_status
migraphx_target_assign_to
(
migraphx_target_t
output
,
const_migraphx_target_t
input
);
migraphx_status
migraphx_target_create
(
migraphx_target_t
*
target
,
const
char
*
name
);
migraphx_status
migraphx_target_create
(
migraphx_target_t
*
target
,
const
char
*
name
);
migraphx_status
migraphx_program_parameter_shapes_destroy
(
migraphx_status
migraphx_program_parameter_shapes_destroy
(
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
migraphx_status
migraphx_program_parameter_shapes_assign_to
(
migraphx_program_parameter_shapes_t
output
,
const_migraphx_program_parameter_shapes_t
input
);
migraphx_status
migraphx_program_parameter_shapes_size
(
migraphx_status
migraphx_program_parameter_shapes_size
(
size_t
*
out
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
size_t
*
out
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
...
@@ -156,6 +192,9 @@ migraphx_status migraphx_program_parameter_shapes_names(
...
@@ -156,6 +192,9 @@ migraphx_status migraphx_program_parameter_shapes_names(
migraphx_status
migraphx_status
migraphx_program_parameters_destroy
(
migraphx_program_parameters_t
program_parameters
);
migraphx_program_parameters_destroy
(
migraphx_program_parameters_t
program_parameters
);
migraphx_status
migraphx_program_parameters_assign_to
(
migraphx_program_parameters_t
output
,
const_migraphx_program_parameters_t
input
);
migraphx_status
migraphx_status
migraphx_program_parameters_create
(
migraphx_program_parameters_t
*
program_parameters
);
migraphx_program_parameters_create
(
migraphx_program_parameters_t
*
program_parameters
);
...
@@ -165,6 +204,9 @@ migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t pr
...
@@ -165,6 +204,9 @@ migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t pr
migraphx_status
migraphx_arguments_destroy
(
migraphx_arguments_t
arguments
);
migraphx_status
migraphx_arguments_destroy
(
migraphx_arguments_t
arguments
);
migraphx_status
migraphx_arguments_assign_to
(
migraphx_arguments_t
output
,
const_migraphx_arguments_t
input
);
migraphx_status
migraphx_arguments_size
(
size_t
*
out
,
migraphx_arguments_t
arguments
);
migraphx_status
migraphx_arguments_size
(
size_t
*
out
,
migraphx_arguments_t
arguments
);
migraphx_status
migraphx_status
...
@@ -172,18 +214,73 @@ migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t argu
...
@@ -172,18 +214,73 @@ migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t argu
migraphx_status
migraphx_shapes_destroy
(
migraphx_shapes_t
shapes
);
migraphx_status
migraphx_shapes_destroy
(
migraphx_shapes_t
shapes
);
migraphx_status
migraphx_shapes_assign_to
(
migraphx_shapes_t
output
,
const_migraphx_shapes_t
input
);
migraphx_status
migraphx_shapes_size
(
size_t
*
out
,
migraphx_shapes_t
shapes
);
migraphx_status
migraphx_shapes_size
(
size_t
*
out
,
migraphx_shapes_t
shapes
);
migraphx_status
migraphx_status
migraphx_shapes_get
(
const_migraphx_shape_t
*
out
,
migraphx_shapes_t
shapes
,
size_t
idx
);
migraphx_shapes_get
(
const_migraphx_shape_t
*
out
,
migraphx_shapes_t
shapes
,
size_t
idx
);
migraphx_status
migraphx_instruction_destroy
(
migraphx_instruction_t
instruction
);
migraphx_status
migraphx_instruction_assign_to
(
migraphx_instruction_t
output
,
const_migraphx_instruction_t
input
);
migraphx_status
migraphx_instructions_destroy
(
migraphx_instructions_t
instructions
);
migraphx_status
migraphx_instructions_assign_to
(
migraphx_instructions_t
output
,
const_migraphx_instructions_t
input
);
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
const_migraphx_instruction_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_modules_destroy
(
migraphx_modules_t
modules
);
migraphx_status
migraphx_modules_assign_to
(
migraphx_modules_t
output
,
const_migraphx_modules_t
input
);
migraphx_status
migraphx_modules_create
(
migraphx_modules_t
*
modules
,
migraphx_module_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_module_create
(
migraphx_module_t
*
module
,
char
*
name
);
migraphx_status
migraphx_module_print
(
const_migraphx_module_t
module
);
migraphx_status
migraphx_module_print
(
const_migraphx_module_t
module
);
migraphx_status
migraphx_module_add_instruction
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
);
migraphx_status
migraphx_module_add_instruction_with_mod_args
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
,
migraphx_modules_t
module_refs
);
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const
char
*
name
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_module_add_return
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_instructions_t
args
);
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
);
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
);
migraphx_status
migraphx_program_assign_to
(
migraphx_program_t
output
,
const_migraphx_program_t
input
);
migraphx_status
migraphx_program_create
(
migraphx_program_t
*
program
);
migraphx_status
migraphx_program_get_main_module
(
migraphx_module_t
*
out
,
migraphx_status
migraphx_program_get_main_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
);
migraphx_program_t
program
);
migraphx_status
migraphx_program_create_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
,
const
char
*
name
);
migraphx_status
migraphx_program_compile
(
migraphx_program_t
program
,
migraphx_status
migraphx_program_compile
(
migraphx_program_t
program
,
migraphx_target_t
target
,
migraphx_target_t
target
,
migraphx_compile_options_t
options
);
migraphx_compile_options_t
options
);
...
@@ -205,8 +302,14 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
...
@@ -205,8 +302,14 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_status
migraphx_status
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
);
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
);
migraphx_status
migraphx_program_experimental_get_context
(
migraphx_context_t
*
out
,
const_migraphx_program_t
program
);
migraphx_status
migraphx_operation_destroy
(
migraphx_operation_t
operation
);
migraphx_status
migraphx_operation_destroy
(
migraphx_operation_t
operation
);
migraphx_status
migraphx_operation_assign_to
(
migraphx_operation_t
output
,
const_migraphx_operation_t
input
);
migraphx_status
migraphx_operation_create
(
migraphx_operation_t
*
operation
,
migraphx_status
migraphx_operation_create
(
migraphx_operation_t
*
operation
,
const
char
*
name
,
const
char
*
name
,
const
char
*
attributes
,
const
char
*
attributes
,
...
@@ -222,6 +325,9 @@ migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t op
...
@@ -222,6 +325,9 @@ migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t op
migraphx_status
migraphx_onnx_options_destroy
(
migraphx_onnx_options_t
onnx_options
);
migraphx_status
migraphx_onnx_options_destroy
(
migraphx_onnx_options_t
onnx_options
);
migraphx_status
migraphx_onnx_options_assign_to
(
migraphx_onnx_options_t
output
,
const_migraphx_onnx_options_t
input
);
migraphx_status
migraphx_onnx_options_create
(
migraphx_onnx_options_t
*
onnx_options
);
migraphx_status
migraphx_onnx_options_create
(
migraphx_onnx_options_t
*
onnx_options
);
migraphx_status
migraphx_onnx_options_set_input_parameter_shape
(
migraphx_status
migraphx_onnx_options_set_input_parameter_shape
(
...
@@ -236,6 +342,9 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o
...
@@ -236,6 +342,9 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o
migraphx_status
migraphx_file_options_destroy
(
migraphx_file_options_t
file_options
);
migraphx_status
migraphx_file_options_destroy
(
migraphx_file_options_t
file_options
);
migraphx_status
migraphx_file_options_assign_to
(
migraphx_file_options_t
output
,
const_migraphx_file_options_t
input
);
migraphx_status
migraphx_file_options_create
(
migraphx_file_options_t
*
file_options
);
migraphx_status
migraphx_file_options_create
(
migraphx_file_options_t
*
file_options
);
migraphx_status
migraphx_file_options_set_file_format
(
migraphx_file_options_t
file_options
,
migraphx_status
migraphx_file_options_set_file_format
(
migraphx_file_options_t
file_options
,
...
@@ -243,6 +352,9 @@ migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t fi
...
@@ -243,6 +352,9 @@ migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t fi
migraphx_status
migraphx_compile_options_destroy
(
migraphx_compile_options_t
compile_options
);
migraphx_status
migraphx_compile_options_destroy
(
migraphx_compile_options_t
compile_options
);
migraphx_status
migraphx_compile_options_assign_to
(
migraphx_compile_options_t
output
,
const_migraphx_compile_options_t
input
);
migraphx_status
migraphx_compile_options_create
(
migraphx_compile_options_t
*
compile_options
);
migraphx_status
migraphx_compile_options_create
(
migraphx_compile_options_t
*
compile_options
);
migraphx_status
migraphx_status
...
@@ -261,6 +373,9 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
...
@@ -261,6 +373,9 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
migraphx_status
migraphx_tf_options_destroy
(
migraphx_tf_options_t
tf_options
);
migraphx_status
migraphx_tf_options_destroy
(
migraphx_tf_options_t
tf_options
);
migraphx_status
migraphx_tf_options_assign_to
(
migraphx_tf_options_t
output
,
const_migraphx_tf_options_t
input
);
migraphx_status
migraphx_tf_options_create
(
migraphx_tf_options_t
*
tf_options
);
migraphx_status
migraphx_tf_options_create
(
migraphx_tf_options_t
*
tf_options
);
migraphx_status
migraphx_tf_options_set_nhwc
(
migraphx_tf_options_t
tf_options
,
bool
is_nhwc
);
migraphx_status
migraphx_tf_options_set_nhwc
(
migraphx_tf_options_t
tf_options
,
bool
is_nhwc
);
...
@@ -282,6 +397,9 @@ migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options
...
@@ -282,6 +397,9 @@ migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options
migraphx_status
migraphx_quantize_op_names_destroy
(
migraphx_quantize_op_names_t
quantize_op_names
);
migraphx_status
migraphx_quantize_op_names_destroy
(
migraphx_quantize_op_names_t
quantize_op_names
);
migraphx_status
migraphx_quantize_op_names_assign_to
(
migraphx_quantize_op_names_t
output
,
const_migraphx_quantize_op_names_t
input
);
migraphx_status
migraphx_quantize_op_names_create
(
migraphx_quantize_op_names_t
*
quantize_op_names
);
migraphx_status
migraphx_quantize_op_names_create
(
migraphx_quantize_op_names_t
*
quantize_op_names
);
migraphx_status
migraphx_quantize_op_names_add
(
migraphx_quantize_op_names_t
quantize_op_names
,
migraphx_status
migraphx_quantize_op_names_add
(
migraphx_quantize_op_names_t
quantize_op_names
,
...
@@ -295,6 +413,10 @@ migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
...
@@ -295,6 +413,10 @@ migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status
migraphx_status
migraphx_quantize_int8_options_destroy
(
migraphx_quantize_int8_options_t
quantize_int8_options
);
migraphx_quantize_int8_options_destroy
(
migraphx_quantize_int8_options_t
quantize_int8_options
);
migraphx_status
migraphx_quantize_int8_options_assign_to
(
migraphx_quantize_int8_options_t
output
,
const_migraphx_quantize_int8_options_t
input
);
migraphx_status
migraphx_status
migraphx_quantize_int8_options_create
(
migraphx_quantize_int8_options_t
*
quantize_int8_options
);
migraphx_quantize_int8_options_create
(
migraphx_quantize_int8_options_t
*
quantize_int8_options
);
...
@@ -309,6 +431,28 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
...
@@ -309,6 +431,28 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t
target
,
migraphx_target_t
target
,
migraphx_quantize_int8_options_t
options
);
migraphx_quantize_int8_options_t
options
);
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
);
migraphx_status
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
migraphx_status
migraphx_experimental_custom_op_assign_to
(
migraphx_experimental_custom_op_t
output
,
const_migraphx_experimental_custom_op_t
input
);
migraphx_status
migraphx_experimental_custom_op_create
(
migraphx_experimental_custom_op_t
*
experimental_custom_op
,
void
*
obj
,
migraphx_experimental_custom_op_copy
c
,
migraphx_experimental_custom_op_delete
d
,
const
char
*
name
);
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
);
migraphx_status
migraphx_experimental_custom_op_register
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
#ifdef __cplusplus
#ifdef __cplusplus
}
}
#endif
#endif
...
...
src/api/include/migraphx/migraphx.hpp
View file @
e2eb6036
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.h>
#include <memory>
#include <memory>
#include <exception>
#include <exception>
...
@@ -13,6 +15,16 @@ namespace migraphx {
...
@@ -13,6 +15,16 @@ namespace migraphx {
inline
namespace
api
{
// NOLINT
inline
namespace
api
{
// NOLINT
#endif
#endif
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
{
};
template
<
>
struct
rank
<
0
>
{
};
template
<
class
T
,
class
F
,
class
...
Ts
>
template
<
class
T
,
class
F
,
class
...
Ts
>
T
*
make
(
F
f
,
Ts
&&
...
xs
)
T
*
make
(
F
f
,
Ts
&&
...
xs
)
{
{
...
@@ -152,6 +164,35 @@ struct array_base
...
@@ -152,6 +164,35 @@ struct array_base
}
}
};
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wnon-template-friend"
#endif
template
<
class
T
>
struct
holder
{
// Friend injection
friend
auto
migraphx_adl_handle_lookup
(
holder
<
T
>
);
// Function left unimplemented since its only used in non-evaluated
// context
T
get
()
const
;
};
template
<
class
C
,
class
T
>
struct
handle_lookup
{
friend
auto
migraphx_adl_handle_lookup
(
holder
<
T
>
)
{
return
holder
<
C
>
{};
}
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
template
<
class
T
>
using
as_handle
=
decltype
(
migraphx_adl_handle_lookup
(
holder
<
std
::
remove_cv_t
<
std
::
remove_pointer_t
<
T
>>>
{}).
get
());
struct
own
struct
own
{
{
};
};
...
@@ -159,8 +200,8 @@ struct borrow
...
@@ -159,8 +200,8 @@ struct borrow
{
{
};
};
template
<
class
T
,
class
D
,
D
Deleter
>
template
<
class
Derived
,
class
T
,
class
D
,
D
Deleter
,
class
A
,
A
Assigner
>
struct
handle_base
struct
handle_base
:
handle_lookup
<
Derived
,
std
::
remove_cv_t
<
T
>>
{
{
handle_base
()
:
m_handle
(
nullptr
)
{}
handle_base
()
:
m_handle
(
nullptr
)
{}
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
...
@@ -190,17 +231,158 @@ struct handle_base
...
@@ -190,17 +231,158 @@ struct handle_base
m_handle
=
std
::
shared_ptr
<
U
>
{
ptr
,
[](
U
*
)
{}};
m_handle
=
std
::
shared_ptr
<
U
>
{
ptr
,
[](
U
*
)
{}};
}
}
template
<
class
U
>
void
assign_to_handle
(
U
*
x
)
{
Assigner
(
x
,
this
->
get_handle_ptr
());
}
protected:
protected:
std
::
shared_ptr
<
T
>
m_handle
;
std
::
shared_ptr
<
T
>
m_handle
;
};
};
template
<
class
Base
>
struct
interface_base
:
Base
{
interface_base
()
:
Base
()
{}
protected:
template
<
class
F
>
static
migraphx_status
try_
(
F
f
)
// NOLINT
{
try
{
f
();
return
migraphx_status_success
;
}
catch
(...)
{
return
migraphx_status_unknown_error
;
}
}
template
<
class
F
,
class
T
,
class
...
Ts
>
void
make_interface
(
F
f
,
T
&
obj
,
Ts
&&
...
xs
)
{
auto
copy
=
[](
void
**
out
,
void
*
input
)
{
return
try_
([
&
]
{
T
**
y
=
reinterpret_cast
<
T
**>
(
out
);
T
*
x
=
reinterpret_cast
<
T
*>
(
input
);
assert
(
x
!=
nullptr
and
y
!=
nullptr
and
*
y
==
nullptr
);
*
y
=
new
T
(
*
x
);
// NOLINT
});
};
auto
del
=
[](
void
*
input
)
{
return
try_
([
&
]
{
T
*
x
=
reinterpret_cast
<
T
*>
(
input
);
delete
x
;
// NOLINT
});
};
this
->
make_handle
(
f
,
&
obj
,
copy
,
del
,
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
T
,
class
Setter
,
class
F
>
void
set_fp
(
Setter
setter
,
F
pf
)
{
static
F
f
=
pf
;
(
void
)
f
;
// avoid warning on gcc
call
(
setter
,
this
->
get_handle_ptr
(),
[](
auto
...
xs
)
->
migraphx_status
{
return
try_
([
&
]
{
call_cast_arg
<
T
>
(
rank
<
1
>
{},
f
,
xs
...);
});
});
}
template
<
class
T
,
class
Setter
,
class
F
>
void
set_auto_fp
(
Setter
setter
,
F
f
)
{
return
set_fp
<
T
>
(
setter
,
[
=
](
T
&
obj
,
auto
out
,
auto
...
xs
)
{
auto_invoke
(
f
,
out
,
obj
,
auto_convert_param
(
rank
<
2
>
{},
xs
)...);
});
}
struct
no_out_arg
{
};
template
<
class
T
,
class
F
,
class
X
,
class
...
Xs
,
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
static
void
call_cast_arg
(
rank
<
0
>
,
F
f
,
X
*
obj
,
Xs
...
xs
)
{
f
(
reinterpret_cast
<
T
*>
(
obj
),
no_out_arg
{},
xs
...);
}
template
<
class
T
,
class
F
,
class
R
,
class
X
,
class
...
Xs
,
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
static
void
call_cast_arg
(
rank
<
1
>
,
F
f
,
R
result
,
X
*
obj
,
Xs
...
xs
)
{
f
(
*
reinterpret_cast
<
T
*>
(
obj
),
result
,
xs
...);
}
template
<
class
F
,
class
T
,
class
...
Ts
>
void
auto_invoke
(
F
f
,
T
*
out
,
Ts
&&
...
xs
)
{
auto_assign
(
rank
<
2
>
{},
out
,
f
(
std
::
forward
<
Ts
>
(
xs
)...));
}
template
<
class
F
,
class
T
,
class
...
Ts
>
void
auto_invoke
(
F
f
,
no_out_arg
,
Ts
&&
...
xs
)
{
f
(
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
T
,
class
=
std
::
enable_if_t
<
std
::
is_fundamental
<
T
>{}
or
std
::
is_enum
<
T
>
{}
>>
T
auto_convert_param
(
rank
<
0
>
,
T
x
)
{
return
x
;
}
template
<
class
T
>
auto
auto_convert_param
(
rank
<
1
>
,
T
x
)
->
decltype
(
as_handle
<
T
>
{
x
})
{
return
as_handle
<
T
>
{
x
};
}
template
<
class
T
>
auto
auto_convert_param
(
rank
<
2
>
,
T
x
)
->
decltype
(
as_handle
<
T
>
{
x
,
borrow
{}})
{
return
as_handle
<
T
>
{
x
,
borrow
{}};
}
template
<
class
T
,
class
U
>
void
auto_assign
(
rank
<
0
>
,
T
*
out
,
U
x
)
{
return
*
out
=
x
;
}
template
<
class
T
,
class
U
>
auto
auto_assign
(
rank
<
1
>
,
T
*
out
,
U
x
)
->
decltype
(
x
.
assign_to_handle
(
out
))
{
x
.
assign_to_handle
(
out
);
}
};
// NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); })
template
<
class
Base
,
class
T
>
using
require_interface
=
std
::
enable_if_t
<
std
::
is_base_of
<
Base
,
T
>
{}
and
not
std
::
is_same
<
T
,
Base
>
{}
and
std
::
is_copy_constructible
<
T
>
{}
and
std
::
is_final
<
T
>
{}
>
;
#ifdef DOXYGEN
#ifdef DOXYGEN
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else
#else
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<const_ migraphx_##name, \
handle_base<name, \
decltype(&migraphx_##name##_destroy), \
const_ migraphx_##name, \
migraphx_##name##_destroy>
decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy, \
decltype(&migraphx_##name##_assign_to), \
migraphx_##name##_assign_to>
#endif
#endif
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, )
#define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, )
...
@@ -485,12 +667,116 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
...
@@ -485,12 +667,116 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
};
};
};
};
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
operation
(
migraphx_operation
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
operation
(
migraphx_operation
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
{
this
->
make_handle
(
&
migraphx_operation_create
,
name
,
attributes
,
xs
...);
}
std
::
string
name
()
{
std
::
array
<
char
,
1024
>
out_name
;
call
(
&
migraphx_operation_name
,
out_name
.
data
(),
1024
,
this
->
get_handle_ptr
());
return
{
out_name
.
data
()};
}
};
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
{
instruction
(
migraphx_instruction
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
};
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
{
instructions
(
migraphx_instructions
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
instructions
(
migraphx_instructions
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
instructions
(
Ts
...
xs
)
{
std
::
array
<
const_migraphx_instruction_t
,
sizeof
...(
Ts
)
>
a
{
xs
.
get_handle_ptr
()...};
this
->
make_handle
(
&
migraphx_instructions_create
,
a
.
data
(),
a
.
size
());
}
};
struct
module
;
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
{
modules
(
migraphx_modules
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
modules
(
migraphx_modules
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
modules
(
Ts
...
xs
)
{
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
mm
...};
this
->
make_handle
(
&
migraphx_modules_create
,
a
.
data
(),
a
.
size
());
}
};
struct
module
struct
module
{
{
migraphx_module_t
mm
;
migraphx_module_t
mm
;
module
(
const
migraphx_module_t
&
m
)
:
mm
(
m
)
{}
module
(
const
migraphx_module_t
&
m
)
:
mm
(
m
)
{}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
);
}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
);
}
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
)
{
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction
,
&
op_ins
,
mm
,
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
}
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
,
const
migraphx
::
modules
&
module_args
)
{
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction_with_mod_args
,
&
op_ins
,
mm
,
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
(),
module_args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
}
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
{
migraphx_instruction_t
param_ins
;
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
,
name
.
c_str
(),
s
.
get_handle_ptr
());
return
instruction
(
param_ins
,
own
{});
}
instruction
add_return
(
const
migraphx
::
instructions
&
args
)
{
migraphx_instruction_t
ret_ins
;
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
,
args
.
get_handle_ptr
());
return
instruction
(
ret_ins
,
own
{});
}
};
struct
context
{
migraphx_context_t
ctx
;
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
);
}
};
};
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
...
@@ -519,7 +805,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
...
@@ -519,7 +805,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
/// A program represents the all computation graphs to be compiled and executed
/// A program represents the all computation graphs to be compiled and executed
struct
program
:
MIGRAPHX_HANDLE_BASE
(
program
)
struct
program
:
MIGRAPHX_HANDLE_BASE
(
program
)
{
{
program
()
{}
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
(
migraphx_program
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
program
(
migraphx_program
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
...
@@ -589,27 +875,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -589,27 +875,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return
module
{
p_modu
};
return
module
{
p_modu
};
}
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
context
experimental_get_context
()
};
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
operation
(
migraphx_operation
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
operation
(
migraphx_operation
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
{
{
this
->
make_handle
(
&
migraphx_operation_create
,
name
,
attributes
,
xs
...);
migraphx_context_t
ctx
;
call
(
&
migraphx_program_experimental_get_context
,
&
ctx
,
this
->
get_handle_ptr
());
return
context
{
ctx
};
}
}
std
::
string
name
(
)
module
create_module
(
const
std
::
string
&
name
)
{
{
std
::
array
<
char
,
1024
>
out_name
;
migraphx_module_t
p_modu
;
call
(
&
migraphx_
operation_name
,
out_name
.
data
(),
1024
,
this
->
get_handle_ptr
());
call
(
&
migraphx_
program_create_module
,
&
p_modu
,
this
->
get_handle_ptr
()
,
name
.
data
()
);
return
{
out_name
.
data
()
};
return
module
{
p_modu
};
}
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
};
};
// options for migraphx file format options
// options for migraphx file format options
...
@@ -850,6 +1130,32 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
...
@@ -850,6 +1130,32 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
options
.
get_handle_ptr
());
options
.
get_handle_ptr
());
}
}
struct
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
shapes
inputs
)
const
=
0
;
virtual
~
experimental_custom_op_base
()
=
default
;
};
struct
experimental_custom_op
:
interface_base
<
MIGRAPHX_HANDLE_BASE
(
experimental_custom_op
)
>
{
template
<
class
T
>
experimental_custom_op
(
T
&
obj
)
{
this
->
make_interface
(
&
migraphx_experimental_custom_op_create
,
obj
,
obj
.
name
().
c_str
());
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute_shape
);
}
void
register_op
()
{
call
(
&
migraphx_experimental_custom_op_register
,
this
->
get_handle_ptr
());
}
};
template
<
class
T
,
class
=
require_interface
<
experimental_custom_op_base
,
T
>
>
void
register_experimental_custom_op
(
T
&
obj
)
{
experimental_custom_op
op
{
obj
};
op
.
register_op
();
}
#ifndef DOXYGEN
#ifndef DOXYGEN
}
// namespace api
}
// namespace api
#endif
#endif
...
...
src/api/migraphx.py
View file @
e2eb6036
...
@@ -178,14 +178,55 @@ def shapes(h):
...
@@ -178,14 +178,55 @@ def shapes(h):
returns
=
'const migraphx::shape&'
)
returns
=
'const migraphx::shape&'
)
@
api
.
handle
(
'migraphx_instruction'
,
'migraphx::instruction_ref'
)
def
instruction
(
h
):
pass
@
api
.
handle
(
'migraphx_instructions'
,
'std::vector<migraphx::instruction_ref>'
)
def
instructions
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'const_migraphx_instruction_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_instruction_t>'
)
@
api
.
handle
(
'migraphx_modules'
,
'std::vector<migraphx::module*>'
)
def
modules
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'migraphx_module_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_objptr_vector<migraphx::module*>'
)
@
auto_handle
(
ref
=
True
)
@
auto_handle
(
ref
=
True
)
def
module
(
h
):
def
module
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
name
=
'std::string'
))
h
.
method
(
'print'
,
invoke
=
'migraphx::print_module($@)'
,
const
=
True
)
h
.
method
(
'print'
,
invoke
=
'migraphx::print_module($@)'
,
const
=
True
)
h
.
method
(
'add_instruction'
,
api
.
params
(
op
=
'migraphx::operation'
,
args
=
'std::vector<migraphx::instruction_ref>'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_instruction_with_mod_args'
,
api
.
params
(
op
=
'migraphx::operation'
,
args
=
'std::vector<migraphx::instruction_ref>'
,
module_refs
=
'std::vector<migraphx::module*>'
),
fname
=
'add_instruction'
,
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_parameter'
,
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_return'
,
api
.
params
(
args
=
'std::vector<migraphx::instruction_ref>'
),
returns
=
'migraphx::instruction_ref'
)
@
auto_handle
()
@
auto_handle
()
def
program
(
h
):
def
program
(
h
):
h
.
constructor
(
'create'
)
h
.
method
(
'get_main_module'
,
returns
=
'migraphx::module*'
)
h
.
method
(
'get_main_module'
,
returns
=
'migraphx::module*'
)
h
.
method
(
'create_module'
,
api
.
params
(
name
=
'const char*'
),
returns
=
'migraphx::module*'
)
h
.
method
(
h
.
method
(
'compile'
,
'compile'
,
api
.
params
(
target
=
'migraphx::target'
,
api
.
params
(
target
=
'migraphx::target'
,
...
@@ -207,6 +248,10 @@ def program(h):
...
@@ -207,6 +248,10 @@ def program(h):
invoke
=
'migraphx::equal($@)'
,
invoke
=
'migraphx::equal($@)'
,
returns
=
'bool'
,
returns
=
'bool'
,
const
=
True
)
const
=
True
)
h
.
method
(
'experimental_get_context'
,
invoke
=
'migraphx::get_context($@)'
,
const
=
True
,
returns
=
'migraphx::context'
)
@
auto_handle
()
@
auto_handle
()
...
@@ -353,3 +398,18 @@ api.add_function('migraphx_quantize_int8',
...
@@ -353,3 +398,18 @@ api.add_function('migraphx_quantize_int8',
target
=
'migraphx::target'
,
target
=
'migraphx::target'
,
options
=
'migraphx::quantize_int8_options'
),
options
=
'migraphx::quantize_int8_options'
),
fname
=
'migraphx::quantize_int8_wrap'
)
fname
=
'migraphx::quantize_int8_wrap'
)
@
auto_handle
(
ref
=
True
)
def
context
(
h
):
h
.
method
(
'finish'
,
const
=
True
)
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
'migraphx::experimental_custom_op'
)
def
experimental_custom_op
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
name
=
'const char*'
))
h
.
virtual
(
'compute_shape'
,
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
returns
=
'migraphx::shape'
)
h
.
method
(
'register'
,
invoke
=
'migraphx::register_custom_op($@)'
)
src/argument.cpp
View file @
e2eb6036
...
@@ -106,7 +106,11 @@ bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
...
@@ -106,7 +106,11 @@ bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
const
shape
&
argument
::
get_shape
()
const
{
return
this
->
m_shape
;
}
const
shape
&
argument
::
get_shape
()
const
{
return
this
->
m_shape
;
}
argument
argument
::
reshape
(
const
shape
&
s
)
const
{
return
{
s
,
this
->
m_data
};
}
argument
argument
::
reshape
(
const
shape
&
s
)
const
{
assert
(
s
.
element_space
()
<=
this
->
get_shape
().
element_space
());
return
{
s
,
this
->
m_data
};
}
argument
::
data_t
argument
::
data_t
::
share
()
const
argument
::
data_t
argument
::
data_t
::
share
()
const
{
{
...
...
src/auto_contiguous.cpp
View file @
e2eb6036
...
@@ -10,8 +10,35 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,8 +10,35 @@ inline namespace MIGRAPHX_INLINE_NS {
void
auto_contiguous
::
apply
(
module
&
p
)
const
void
auto_contiguous
::
apply
(
module
&
p
)
const
{
{
std
::
string
key
=
"require_std_shape"
;
for
(
auto
ins
:
reverse_iterator_for
(
p
))
{
auto
&&
attr
=
ins
->
get_operator
().
attributes
();
if
((
attr
.
get
(
key
,
false
)))
{
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
new_args
.
begin
(),
[
&
](
auto
in
)
{
if
(
in
->
name
()
==
"contiguous"
)
{
return
in
;
}
return
p
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
});
if
(
new_args
!=
args
)
{
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
}
}
}
auto
last
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
// for last instruction that is NOT a return
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
continue
;
shape
s
=
ins
->
get_shape
();
shape
s
=
ins
->
get_shape
();
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
{
{
...
...
src/compile_src.cpp
View file @
e2eb6036
...
@@ -34,7 +34,14 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
...
@@ -34,7 +34,14 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
params
+=
" -o "
+
out
;
params
+=
" -o "
+
out
;
td
.
execute
(
compiler
,
params
);
if
(
not
launcher
.
empty
())
{
td
.
execute
(
launcher
,
compiler
+
" "
+
params
);
}
else
{
td
.
execute
(
compiler
,
params
);
}
auto
out_path
=
td
.
path
/
out
;
auto
out_path
=
td
.
path
/
out
;
if
(
not
fs
::
exists
(
out_path
))
if
(
not
fs
::
exists
(
out_path
))
...
...
src/cpp_generator.cpp
View file @
e2eb6036
...
@@ -88,6 +88,7 @@ struct cpp_generator_impl
...
@@ -88,6 +88,7 @@ struct cpp_generator_impl
std
::
stringstream
fs
{};
std
::
stringstream
fs
{};
std
::
size_t
function_count
=
0
;
std
::
size_t
function_count
=
0
;
std
::
function
<
std
::
string
(
std
::
string
)
>
fmap
=
nullptr
;
std
::
function
<
std
::
string
(
std
::
string
)
>
fmap
=
nullptr
;
std
::
function
<
std
::
string
(
shape
)
>
fresult
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
point_op_map
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
string
>
point_op_map
=
{};
};
};
cpp_generator
::
cpp_generator
()
:
impl
(
std
::
make_unique
<
cpp_generator_impl
>
())
{}
cpp_generator
::
cpp_generator
()
:
impl
(
std
::
make_unique
<
cpp_generator_impl
>
())
{}
...
@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
...
@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
void
cpp_generator
::
fmap
(
const
std
::
function
<
std
::
string
(
std
::
string
)
>&
f
)
{
impl
->
fmap
=
f
;
}
void
cpp_generator
::
fmap
(
const
std
::
function
<
std
::
string
(
std
::
string
)
>&
f
)
{
impl
->
fmap
=
f
;
}
void
cpp_generator
::
fresult
(
const
std
::
function
<
std
::
string
(
shape
)
>&
f
)
{
impl
->
fresult
=
f
;
}
void
cpp_generator
::
add_point_op
(
const
std
::
string
&
op_name
,
const
std
::
string
&
code
)
void
cpp_generator
::
add_point_op
(
const
std
::
string
&
op_name
,
const
std
::
string
&
code
)
{
{
impl
->
point_op_map
[
op_name
]
=
code
;
impl
->
point_op_map
[
op_name
]
=
code
;
...
@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
...
@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins
->
inputs
().
end
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
args
),
std
::
back_inserter
(
args
),
[
&
](
auto
i
)
{
return
names
.
at
(
i
);
});
[
&
](
auto
i
)
{
return
names
.
at
(
i
);
});
return
this
->
generate_point_op
(
ins
->
get_operator
(),
args
);
auto
s
=
this
->
generate_point_op
(
ins
->
get_operator
(),
args
);
if
(
impl
->
fresult
)
return
impl
->
fresult
(
ins
->
get_shape
())
+
'('
+
s
+
')'
;
else
return
s
;
});
});
return
f
;
return
f
;
}
}
...
...
src/driver/alexnet.cpp
View file @
e2eb6036
...
@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu19
;
migraphx
::
op
::
relu
relu19
;
auto
mx19
=
mm
->
add_instruction
(
relu19
,
mx18
);
auto
mx19
=
mm
->
add_instruction
(
relu19
,
mx18
);
migraphx
::
op
::
pooling
pooling20
;
migraphx
::
op
::
pooling
pooling20
;
pooling20
.
mode
=
"max"
;
pooling20
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling20
.
padding
=
{
0
,
0
};
pooling20
.
padding
=
{
0
,
0
};
pooling20
.
stride
=
{
2
,
2
};
pooling20
.
stride
=
{
2
,
2
};
pooling20
.
lengths
=
{
3
,
3
};
pooling20
.
lengths
=
{
3
,
3
};
...
@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu24
;
migraphx
::
op
::
relu
relu24
;
auto
mx24
=
mm
->
add_instruction
(
relu24
,
mx23
);
auto
mx24
=
mm
->
add_instruction
(
relu24
,
mx23
);
migraphx
::
op
::
pooling
pooling25
;
migraphx
::
op
::
pooling
pooling25
;
pooling25
.
mode
=
"max"
;
pooling25
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling25
.
padding
=
{
0
,
0
};
pooling25
.
padding
=
{
0
,
0
};
pooling25
.
stride
=
{
2
,
2
};
pooling25
.
stride
=
{
2
,
2
};
pooling25
.
lengths
=
{
3
,
3
};
pooling25
.
lengths
=
{
3
,
3
};
...
@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu37
;
migraphx
::
op
::
relu
relu37
;
auto
mx37
=
mm
->
add_instruction
(
relu37
,
mx36
);
auto
mx37
=
mm
->
add_instruction
(
relu37
,
mx36
);
migraphx
::
op
::
pooling
pooling38
;
migraphx
::
op
::
pooling
pooling38
;
pooling38
.
mode
=
"max"
;
pooling38
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling38
.
padding
=
{
0
,
0
};
pooling38
.
padding
=
{
0
,
0
};
pooling38
.
stride
=
{
2
,
2
};
pooling38
.
stride
=
{
2
,
2
};
pooling38
.
lengths
=
{
3
,
3
};
pooling38
.
lengths
=
{
3
,
3
};
...
...
src/driver/inceptionv3.cpp
View file @
e2eb6036
...
@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu492
;
migraphx
::
op
::
relu
relu492
;
auto
mx492
=
mm
->
add_instruction
(
relu492
,
mx491
);
auto
mx492
=
mm
->
add_instruction
(
relu492
,
mx491
);
migraphx
::
op
::
pooling
pooling493
;
migraphx
::
op
::
pooling
pooling493
;
pooling493
.
mode
=
"max"
;
pooling493
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling493
.
padding
=
{
0
,
0
};
pooling493
.
padding
=
{
0
,
0
};
pooling493
.
stride
=
{
2
,
2
};
pooling493
.
stride
=
{
2
,
2
};
pooling493
.
lengths
=
{
3
,
3
};
pooling493
.
lengths
=
{
3
,
3
};
...
@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu499
;
migraphx
::
op
::
relu
relu499
;
auto
mx499
=
mm
->
add_instruction
(
relu499
,
mx498
);
auto
mx499
=
mm
->
add_instruction
(
relu499
,
mx498
);
migraphx
::
op
::
pooling
pooling500
;
migraphx
::
op
::
pooling
pooling500
;
pooling500
.
mode
=
"max"
;
pooling500
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling500
.
padding
=
{
0
,
0
};
pooling500
.
padding
=
{
0
,
0
};
pooling500
.
stride
=
{
2
,
2
};
pooling500
.
stride
=
{
2
,
2
};
pooling500
.
lengths
=
{
3
,
3
};
pooling500
.
lengths
=
{
3
,
3
};
...
@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu518
;
migraphx
::
op
::
relu
relu518
;
auto
mx518
=
mm
->
add_instruction
(
relu518
,
mx517
);
auto
mx518
=
mm
->
add_instruction
(
relu518
,
mx517
);
migraphx
::
op
::
pooling
pooling519
;
migraphx
::
op
::
pooling
pooling519
;
pooling519
.
mode
=
"
average
"
;
pooling519
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling519
.
padding
=
{
1
,
1
};
pooling519
.
padding
=
{
1
,
1
};
pooling519
.
stride
=
{
1
,
1
};
pooling519
.
stride
=
{
1
,
1
};
pooling519
.
lengths
=
{
3
,
3
};
pooling519
.
lengths
=
{
3
,
3
};
...
@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu541
;
migraphx
::
op
::
relu
relu541
;
auto
mx541
=
mm
->
add_instruction
(
relu541
,
mx540
);
auto
mx541
=
mm
->
add_instruction
(
relu541
,
mx540
);
migraphx
::
op
::
pooling
pooling542
;
migraphx
::
op
::
pooling
pooling542
;
pooling542
.
mode
=
"
average
"
;
pooling542
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling542
.
padding
=
{
1
,
1
};
pooling542
.
padding
=
{
1
,
1
};
pooling542
.
stride
=
{
1
,
1
};
pooling542
.
stride
=
{
1
,
1
};
pooling542
.
lengths
=
{
3
,
3
};
pooling542
.
lengths
=
{
3
,
3
};
...
@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu564
;
migraphx
::
op
::
relu
relu564
;
auto
mx564
=
mm
->
add_instruction
(
relu564
,
mx563
);
auto
mx564
=
mm
->
add_instruction
(
relu564
,
mx563
);
migraphx
::
op
::
pooling
pooling565
;
migraphx
::
op
::
pooling
pooling565
;
pooling565
.
mode
=
"
average
"
;
pooling565
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling565
.
padding
=
{
1
,
1
};
pooling565
.
padding
=
{
1
,
1
};
pooling565
.
stride
=
{
1
,
1
};
pooling565
.
stride
=
{
1
,
1
};
pooling565
.
lengths
=
{
3
,
3
};
pooling565
.
lengths
=
{
3
,
3
};
...
@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu581
;
migraphx
::
op
::
relu
relu581
;
auto
mx581
=
mm
->
add_instruction
(
relu581
,
mx580
);
auto
mx581
=
mm
->
add_instruction
(
relu581
,
mx580
);
migraphx
::
op
::
pooling
pooling582
;
migraphx
::
op
::
pooling
pooling582
;
pooling582
.
mode
=
"max"
;
pooling582
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling582
.
padding
=
{
0
,
0
};
pooling582
.
padding
=
{
0
,
0
};
pooling582
.
stride
=
{
2
,
2
};
pooling582
.
stride
=
{
2
,
2
};
pooling582
.
lengths
=
{
3
,
3
};
pooling582
.
lengths
=
{
3
,
3
};
...
@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu610
;
migraphx
::
op
::
relu
relu610
;
auto
mx610
=
mm
->
add_instruction
(
relu610
,
mx609
);
auto
mx610
=
mm
->
add_instruction
(
relu610
,
mx609
);
migraphx
::
op
::
pooling
pooling611
;
migraphx
::
op
::
pooling
pooling611
;
pooling611
.
mode
=
"
average
"
;
pooling611
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling611
.
padding
=
{
1
,
1
};
pooling611
.
padding
=
{
1
,
1
};
pooling611
.
stride
=
{
1
,
1
};
pooling611
.
stride
=
{
1
,
1
};
pooling611
.
lengths
=
{
3
,
3
};
pooling611
.
lengths
=
{
3
,
3
};
...
@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu642
;
migraphx
::
op
::
relu
relu642
;
auto
mx642
=
mm
->
add_instruction
(
relu642
,
mx641
);
auto
mx642
=
mm
->
add_instruction
(
relu642
,
mx641
);
migraphx
::
op
::
pooling
pooling643
;
migraphx
::
op
::
pooling
pooling643
;
pooling643
.
mode
=
"
average
"
;
pooling643
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling643
.
padding
=
{
1
,
1
};
pooling643
.
padding
=
{
1
,
1
};
pooling643
.
stride
=
{
1
,
1
};
pooling643
.
stride
=
{
1
,
1
};
pooling643
.
lengths
=
{
3
,
3
};
pooling643
.
lengths
=
{
3
,
3
};
...
@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu674
;
migraphx
::
op
::
relu
relu674
;
auto
mx674
=
mm
->
add_instruction
(
relu674
,
mx673
);
auto
mx674
=
mm
->
add_instruction
(
relu674
,
mx673
);
migraphx
::
op
::
pooling
pooling675
;
migraphx
::
op
::
pooling
pooling675
;
pooling675
.
mode
=
"
average
"
;
pooling675
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling675
.
padding
=
{
1
,
1
};
pooling675
.
padding
=
{
1
,
1
};
pooling675
.
stride
=
{
1
,
1
};
pooling675
.
stride
=
{
1
,
1
};
pooling675
.
lengths
=
{
3
,
3
};
pooling675
.
lengths
=
{
3
,
3
};
...
@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu706
;
migraphx
::
op
::
relu
relu706
;
auto
mx706
=
mm
->
add_instruction
(
relu706
,
mx705
);
auto
mx706
=
mm
->
add_instruction
(
relu706
,
mx705
);
migraphx
::
op
::
pooling
pooling707
;
migraphx
::
op
::
pooling
pooling707
;
pooling707
.
mode
=
"
average
"
;
pooling707
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling707
.
padding
=
{
1
,
1
};
pooling707
.
padding
=
{
1
,
1
};
pooling707
.
stride
=
{
1
,
1
};
pooling707
.
stride
=
{
1
,
1
};
pooling707
.
lengths
=
{
3
,
3
};
pooling707
.
lengths
=
{
3
,
3
};
...
@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu729
;
migraphx
::
op
::
relu
relu729
;
auto
mx729
=
mm
->
add_instruction
(
relu729
,
mx728
);
auto
mx729
=
mm
->
add_instruction
(
relu729
,
mx728
);
migraphx
::
op
::
pooling
pooling730
;
migraphx
::
op
::
pooling
pooling730
;
pooling730
.
mode
=
"max"
;
pooling730
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling730
.
padding
=
{
0
,
0
};
pooling730
.
padding
=
{
0
,
0
};
pooling730
.
stride
=
{
2
,
2
};
pooling730
.
stride
=
{
2
,
2
};
pooling730
.
lengths
=
{
3
,
3
};
pooling730
.
lengths
=
{
3
,
3
};
...
@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat757
.
axis
=
1
;
concat757
.
axis
=
1
;
auto
mx757
=
mm
->
add_instruction
(
concat757
,
mx753
,
mx756
);
auto
mx757
=
mm
->
add_instruction
(
concat757
,
mx753
,
mx756
);
migraphx
::
op
::
pooling
pooling758
;
migraphx
::
op
::
pooling
pooling758
;
pooling758
.
mode
=
"
average
"
;
pooling758
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling758
.
padding
=
{
1
,
1
};
pooling758
.
padding
=
{
1
,
1
};
pooling758
.
stride
=
{
1
,
1
};
pooling758
.
stride
=
{
1
,
1
};
pooling758
.
lengths
=
{
3
,
3
};
pooling758
.
lengths
=
{
3
,
3
};
...
@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat788
.
axis
=
1
;
concat788
.
axis
=
1
;
auto
mx788
=
mm
->
add_instruction
(
concat788
,
mx784
,
mx787
);
auto
mx788
=
mm
->
add_instruction
(
concat788
,
mx784
,
mx787
);
migraphx
::
op
::
pooling
pooling789
;
migraphx
::
op
::
pooling
pooling789
;
pooling789
.
mode
=
"
average
"
;
pooling789
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling789
.
padding
=
{
1
,
1
};
pooling789
.
padding
=
{
1
,
1
};
pooling789
.
stride
=
{
1
,
1
};
pooling789
.
stride
=
{
1
,
1
};
pooling789
.
lengths
=
{
3
,
3
};
pooling789
.
lengths
=
{
3
,
3
};
...
@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
...
@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat793
.
axis
=
1
;
concat793
.
axis
=
1
;
auto
mx793
=
mm
->
add_instruction
(
concat793
,
mx765
,
mx775
,
mx788
,
mx792
);
auto
mx793
=
mm
->
add_instruction
(
concat793
,
mx765
,
mx775
,
mx788
,
mx792
);
migraphx
::
op
::
pooling
pooling794
;
migraphx
::
op
::
pooling
pooling794
;
pooling794
.
mode
=
"
average
"
;
pooling794
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling794
.
padding
=
{
0
,
0
};
pooling794
.
padding
=
{
0
,
0
};
pooling794
.
stride
=
{
8
,
8
};
pooling794
.
stride
=
{
8
,
8
};
pooling794
.
lengths
=
{
8
,
8
};
pooling794
.
lengths
=
{
8
,
8
};
...
...
src/driver/main.cpp
View file @
e2eb6036
...
@@ -505,8 +505,10 @@ struct roctx : command<roctx>
...
@@ -505,8 +505,10 @@ struct roctx : command<roctx>
struct
op
:
command
<
op
>
struct
op
:
command
<
op
>
{
{
bool
show_ops
=
false
;
bool
show_ops
=
false
;
std
::
string
op_name
{};
void
parse
(
argument_parser
&
ap
)
void
parse
(
argument_parser
&
ap
)
{
{
ap
(
op_name
,
{},
ap
.
metavar
(
"<MIGraphX operator name>"
));
ap
(
show_ops
,
ap
(
show_ops
,
{
"--list"
,
"-l"
},
{
"--list"
,
"-l"
},
ap
.
help
(
"List all the operators of MIGraphX"
),
ap
.
help
(
"List all the operators of MIGraphX"
),
...
@@ -519,6 +521,12 @@ struct op : command<op>
...
@@ -519,6 +521,12 @@ struct op : command<op>
for
(
const
auto
&
name
:
get_operators
())
for
(
const
auto
&
name
:
get_operators
())
std
::
cout
<<
name
<<
std
::
endl
;
std
::
cout
<<
name
<<
std
::
endl
;
}
}
else
{
auto
op
=
load_op
(
op_name
);
std
::
cout
<<
op_name
<<
": "
<<
std
::
endl
;
std
::
cout
<<
to_pretty_json_string
(
op
.
to_value
())
<<
std
::
endl
;
}
}
}
};
};
...
...
src/driver/perf.cpp
View file @
e2eb6036
...
@@ -87,6 +87,6 @@ target get_target(bool gpu)
...
@@ -87,6 +87,6 @@ target get_target(bool gpu)
void
compile_program
(
program
&
p
,
bool
gpu
)
{
p
.
compile
(
get_target
(
gpu
));
}
void
compile_program
(
program
&
p
,
bool
gpu
)
{
p
.
compile
(
get_target
(
gpu
));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace
MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace driver
}
// namespace migraphx
}
// namespace migraphx
src/driver/resnet50.cpp
View file @
e2eb6036
...
@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
...
@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu269
;
migraphx
::
op
::
relu
relu269
;
auto
mx269
=
mm
->
add_instruction
(
relu269
,
mx268
);
auto
mx269
=
mm
->
add_instruction
(
relu269
,
mx268
);
migraphx
::
op
::
pooling
pooling270
;
migraphx
::
op
::
pooling
pooling270
;
pooling270
.
mode
=
"max"
;
pooling270
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling270
.
padding
=
{
1
,
1
};
pooling270
.
padding
=
{
1
,
1
};
pooling270
.
stride
=
{
2
,
2
};
pooling270
.
stride
=
{
2
,
2
};
pooling270
.
lengths
=
{
3
,
3
};
pooling270
.
lengths
=
{
3
,
3
};
...
@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
...
@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu438
;
migraphx
::
op
::
relu
relu438
;
auto
mx438
=
mm
->
add_instruction
(
relu438
,
mx437
);
auto
mx438
=
mm
->
add_instruction
(
relu438
,
mx437
);
migraphx
::
op
::
pooling
pooling439
;
migraphx
::
op
::
pooling
pooling439
;
pooling439
.
mode
=
"
average
"
;
pooling439
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling439
.
padding
=
{
0
,
0
};
pooling439
.
padding
=
{
0
,
0
};
pooling439
.
stride
=
{
1
,
1
};
pooling439
.
stride
=
{
1
,
1
};
pooling439
.
lengths
=
{
7
,
7
};
pooling439
.
lengths
=
{
7
,
7
};
...
...
src/eliminate_common_subexpression.cpp
View file @
e2eb6036
...
@@ -32,7 +32,12 @@ void cse_range(module& p, Range&& r)
...
@@ -32,7 +32,12 @@ void cse_range(module& p, Range&& r)
continue
;
continue
;
p
.
replace_instruction
(
ins
,
eq
);
p
.
replace_instruction
(
ins
,
eq
);
processed_ins
.
emplace
(
ins
);
processed_ins
.
emplace
(
ins
);
auto
outputs
=
eq
->
outputs
();
std
::
vector
<
instruction_ref
>
outputs
;
std
::
copy_if
(
eq
->
outputs
().
begin
(),
eq
->
outputs
().
end
(),
std
::
back_inserter
(
outputs
),
[
&
](
auto
x
)
{
return
p
.
has_instruction
(
x
);
});
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
});
});
...
...
src/eliminate_contiguous.cpp
View file @
e2eb6036
...
@@ -78,15 +78,16 @@ void eliminate_contiguous::apply(module& p) const
...
@@ -78,15 +78,16 @@ void eliminate_contiguous::apply(module& p) const
continue
;
continue
;
// Make a copy so we can modify it while we iterate
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
auto
mod_args
=
ins
->
module_inputs
();
for
(
auto
arg
:
ins
->
inputs
())
for
(
auto
arg
:
ins
->
inputs
())
{
{
if
(
arg
->
name
()
==
op_name
)
if
(
arg
->
name
()
==
op_name
)
{
{
auto
new_args
=
args
;
auto
prev
=
arg
->
inputs
().
front
();
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
ins
->
module_inputs
()
))
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
{
{
instruction
::
replace_argument
(
ins
,
arg
,
prev
);
instruction
::
replace_argument
(
ins
,
arg
,
prev
);
}
}
...
...
src/eliminate_data_type.cpp
View file @
e2eb6036
...
@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
void
eliminate_data_type
::
apply
(
module
&
m
)
const
void
eliminate_data_type
::
apply
(
module
&
m
)
const
{
{
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
"convert"
,
"convert"
,
"get_tuple_elem"
,
"if"
,
"loop"
,
"roialign"
};
"get_tuple_elem"
,
"if"
,
"loop"
,
"roialign"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_none"
};
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()[
0
]
==
'@'
)
if
(
ins
->
name
()[
0
]
==
'@'
)
...
...
Prev
1
2
3
4
5
6
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment