Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e2eb6036
Commit
e2eb6036
authored
Apr 13, 2022
by
Paul
Browse files
Merge
parents
298c93d5
1e0bbd78
Changes
267
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1187 additions
and
81 deletions
+1187
-81
examples/nlp/python_bert_squad/run_onnx_squad.py
examples/nlp/python_bert_squad/run_onnx_squad.py
+1
-0
examples/vision/python_yolov4/yolov4_inference.ipynb
examples/vision/python_yolov4/yolov4_inference.ipynb
+7
-7
hip-clang.docker
hip-clang.docker
+1
-1
src/CMakeLists.txt
src/CMakeLists.txt
+7
-2
src/api/api.cpp
src/api/api.cpp
+540
-16
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+146
-2
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+329
-23
src/api/migraphx.py
src/api/migraphx.py
+60
-0
src/argument.cpp
src/argument.cpp
+5
-1
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+27
-0
src/compile_src.cpp
src/compile_src.cpp
+8
-1
src/cpp_generator.cpp
src/cpp_generator.cpp
+9
-1
src/driver/alexnet.cpp
src/driver/alexnet.cpp
+3
-3
src/driver/inceptionv3.cpp
src/driver/inceptionv3.cpp
+14
-14
src/driver/main.cpp
src/driver/main.cpp
+8
-0
src/driver/perf.cpp
src/driver/perf.cpp
+1
-1
src/driver/resnet50.cpp
src/driver/resnet50.cpp
+2
-2
src/eliminate_common_subexpression.cpp
src/eliminate_common_subexpression.cpp
+6
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+5
-4
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+8
-2
No files found.
examples/nlp/python_bert_squad/run_onnx_squad.py
View file @
e2eb6036
# Modifications Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
examples/vision/python_yolov4/yolov4_inference.ipynb
View file @
e2eb6036
...
...
@@ -50,10 +50,10 @@
"metadata": {},
"outputs": [],
"source": [
"if not os.path.exists(\"yolov4_fp16.m
sgpack
\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.m
sgpack
\n",
"if not os.path.exists(\"yolov4.m
sgpack
\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.m
sgpack
"
"if not os.path.exists(\"yolov4_fp16.m
xr
\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.m
xr
\n",
"if not os.path.exists(\"yolov4.m
xr
\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.m
xr
"
]
},
{
...
...
@@ -115,8 +115,8 @@
"outputs": [],
"source": [
"# Load serialized model (either single- or half-precision)\n",
"model = migraphx.load(\"yolov4.m
sgpack
\", format=\"msgpack\")\n",
"#model = migraphx.load(\"yolov4_fp16.m
sgpack
\", format=\"msgpack\")\n",
"model = migraphx.load(\"yolov4.m
xr
\", format=\"msgpack\")\n",
"#model = migraphx.load(\"yolov4_fp16.m
xr
\", format=\"msgpack\")\n",
"\n",
"# Get the name of the input parameter and convert image data to an MIGraphX argument\n",
"input_name = next(iter(model.get_parameter_shapes()))\n",
...
...
@@ -192,4 +192,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
}
hip-clang.docker
View file @
e2eb6036
...
...
@@ -12,7 +12,7 @@ RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.5
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--allow-unauthenticated
\
apt-utils
\
build-essential
\
clang-format-
5.
0
\
clang-format-
1
0
\
cmake
\
curl
\
doxygen
\
...
...
src/CMakeLists.txt
View file @
e2eb6036
...
...
@@ -38,6 +38,7 @@ add_library(migraphx
msgpack.cpp
normalize_attributes.cpp
normalize_ops.cpp
op_enums.cpp
operation.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
...
...
@@ -114,6 +115,7 @@ register_migraphx_ops(
identity
if_op
im2col
isnan
leaky_relu
less
load
...
...
@@ -161,6 +163,9 @@ register_migraphx_ops(
rsqrt
scalar
scatter
scatternd_none
scatternd_add
scatternd_mul
sigmoid
sign
sinh
...
...
@@ -211,7 +216,6 @@ target_link_libraries(migraphx PRIVATE msgpackc-cxx)
target_link_libraries
(
migraphx INTERFACE $<BUILD_INTERFACE:msgpackc-cxx>
)
add_library
(
migraphx_all_targets INTERFACE
)
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_ref
)
set
(
PACKAGE_DEPENDS
)
...
...
@@ -222,6 +226,7 @@ add_subdirectory(tf)
add_subdirectory
(
py
)
add_subdirectory
(
targets/ref
)
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_ref
)
if
(
MIGRAPHX_ENABLE_CPU
)
add_subdirectory
(
targets/cpu
)
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_cpu
)
...
...
@@ -239,7 +244,7 @@ if(HAVE_HALF_EXPR)
endif
()
rocm_export_targets
(
TARGETS migraphx::migraphx
migraphx_all_targets
TARGETS migraphx::migraphx
_c
NAMESPACE migraphx::
DEPENDS
Threads
...
...
src/api/api.cpp
View file @
e2eb6036
This diff is collapsed.
Click to expand it.
src/api/include/migraphx/migraphx.h
View file @
e2eb6036
...
...
@@ -25,7 +25,8 @@ extern "C" {
#endif
// return code, more to be added later
typedef
enum
{
typedef
enum
{
migraphx_status_success
=
0
,
migraphx_status_bad_param
=
1
,
migraphx_status_unknown_target
=
3
,
...
...
@@ -35,7 +36,8 @@ typedef enum {
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs
typedef
enum
{
typedef
enum
{
migraphx_shape_tuple_type
,
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
)
}
migraphx_shape_datatype_t
;
...
...
@@ -62,6 +64,15 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t;
typedef
struct
migraphx_shapes
*
migraphx_shapes_t
;
typedef
const
struct
migraphx_shapes
*
const_migraphx_shapes_t
;
typedef
struct
migraphx_instruction
*
migraphx_instruction_t
;
typedef
const
struct
migraphx_instruction
*
const_migraphx_instruction_t
;
typedef
struct
migraphx_instructions
*
migraphx_instructions_t
;
typedef
const
struct
migraphx_instructions
*
const_migraphx_instructions_t
;
typedef
struct
migraphx_modules
*
migraphx_modules_t
;
typedef
const
struct
migraphx_modules
*
const_migraphx_modules_t
;
typedef
struct
migraphx_module
*
migraphx_module_t
;
typedef
const
struct
migraphx_module
*
const_migraphx_module_t
;
...
...
@@ -89,8 +100,24 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name
typedef
struct
migraphx_quantize_int8_options
*
migraphx_quantize_int8_options_t
;
typedef
const
struct
migraphx_quantize_int8_options
*
const_migraphx_quantize_int8_options_t
;
typedef
struct
migraphx_context
*
migraphx_context_t
;
typedef
const
struct
migraphx_context
*
const_migraphx_context_t
;
typedef
struct
migraphx_experimental_custom_op
*
migraphx_experimental_custom_op_t
;
typedef
const
struct
migraphx_experimental_custom_op
*
const_migraphx_experimental_custom_op_t
;
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_compute_shape
)(
migraphx_shape_t
out
,
void
*
obj
,
migraphx_shapes_t
inputs
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_copy
)(
void
**
out
,
void
*
input
);
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_delete
)(
void
*
input
);
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_assign_to
(
migraphx_shape_t
output
,
const_migraphx_shape_t
input
);
migraphx_status
migraphx_shape_create
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
size_t
*
lengths
,
...
...
@@ -121,6 +148,9 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
migraphx_status
migraphx_argument_assign_to
(
migraphx_argument_t
output
,
const_migraphx_argument_t
input
);
migraphx_status
migraphx_argument_create
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
,
void
*
buffer
);
...
...
@@ -137,11 +167,17 @@ migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, s
migraphx_status
migraphx_target_destroy
(
migraphx_target_t
target
);
migraphx_status
migraphx_target_assign_to
(
migraphx_target_t
output
,
const_migraphx_target_t
input
);
migraphx_status
migraphx_target_create
(
migraphx_target_t
*
target
,
const
char
*
name
);
migraphx_status
migraphx_program_parameter_shapes_destroy
(
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
migraphx_status
migraphx_program_parameter_shapes_assign_to
(
migraphx_program_parameter_shapes_t
output
,
const_migraphx_program_parameter_shapes_t
input
);
migraphx_status
migraphx_program_parameter_shapes_size
(
size_t
*
out
,
migraphx_program_parameter_shapes_t
program_parameter_shapes
);
...
...
@@ -156,6 +192,9 @@ migraphx_status migraphx_program_parameter_shapes_names(
migraphx_status
migraphx_program_parameters_destroy
(
migraphx_program_parameters_t
program_parameters
);
migraphx_status
migraphx_program_parameters_assign_to
(
migraphx_program_parameters_t
output
,
const_migraphx_program_parameters_t
input
);
migraphx_status
migraphx_program_parameters_create
(
migraphx_program_parameters_t
*
program_parameters
);
...
...
@@ -165,6 +204,9 @@ migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t pr
migraphx_status
migraphx_arguments_destroy
(
migraphx_arguments_t
arguments
);
migraphx_status
migraphx_arguments_assign_to
(
migraphx_arguments_t
output
,
const_migraphx_arguments_t
input
);
migraphx_status
migraphx_arguments_size
(
size_t
*
out
,
migraphx_arguments_t
arguments
);
migraphx_status
...
...
@@ -172,18 +214,73 @@ migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t argu
migraphx_status
migraphx_shapes_destroy
(
migraphx_shapes_t
shapes
);
migraphx_status
migraphx_shapes_assign_to
(
migraphx_shapes_t
output
,
const_migraphx_shapes_t
input
);
migraphx_status
migraphx_shapes_size
(
size_t
*
out
,
migraphx_shapes_t
shapes
);
migraphx_status
migraphx_shapes_get
(
const_migraphx_shape_t
*
out
,
migraphx_shapes_t
shapes
,
size_t
idx
);
migraphx_status
migraphx_instruction_destroy
(
migraphx_instruction_t
instruction
);
migraphx_status
migraphx_instruction_assign_to
(
migraphx_instruction_t
output
,
const_migraphx_instruction_t
input
);
migraphx_status
migraphx_instructions_destroy
(
migraphx_instructions_t
instructions
);
migraphx_status
migraphx_instructions_assign_to
(
migraphx_instructions_t
output
,
const_migraphx_instructions_t
input
);
migraphx_status
migraphx_instructions_create
(
migraphx_instructions_t
*
instructions
,
const_migraphx_instruction_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_modules_destroy
(
migraphx_modules_t
modules
);
migraphx_status
migraphx_modules_assign_to
(
migraphx_modules_t
output
,
const_migraphx_modules_t
input
);
migraphx_status
migraphx_modules_create
(
migraphx_modules_t
*
modules
,
migraphx_module_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_module_create
(
migraphx_module_t
*
module
,
char
*
name
);
migraphx_status
migraphx_module_print
(
const_migraphx_module_t
module
);
migraphx_status
migraphx_module_add_instruction
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
);
migraphx_status
migraphx_module_add_instruction_with_mod_args
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_operation_t
op
,
migraphx_instructions_t
args
,
migraphx_modules_t
module_refs
);
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const
char
*
name
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_module_add_return
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_instructions_t
args
);
migraphx_status
migraphx_program_destroy
(
migraphx_program_t
program
);
migraphx_status
migraphx_program_assign_to
(
migraphx_program_t
output
,
const_migraphx_program_t
input
);
migraphx_status
migraphx_program_create
(
migraphx_program_t
*
program
);
migraphx_status
migraphx_program_get_main_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
);
migraphx_status
migraphx_program_create_module
(
migraphx_module_t
*
out
,
migraphx_program_t
program
,
const
char
*
name
);
migraphx_status
migraphx_program_compile
(
migraphx_program_t
program
,
migraphx_target_t
target
,
migraphx_compile_options_t
options
);
...
...
@@ -205,8 +302,14 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_status
migraphx_program_equal
(
bool
*
out
,
const_migraphx_program_t
program
,
const_migraphx_program_t
x
);
migraphx_status
migraphx_program_experimental_get_context
(
migraphx_context_t
*
out
,
const_migraphx_program_t
program
);
migraphx_status
migraphx_operation_destroy
(
migraphx_operation_t
operation
);
migraphx_status
migraphx_operation_assign_to
(
migraphx_operation_t
output
,
const_migraphx_operation_t
input
);
migraphx_status
migraphx_operation_create
(
migraphx_operation_t
*
operation
,
const
char
*
name
,
const
char
*
attributes
,
...
...
@@ -222,6 +325,9 @@ migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t op
migraphx_status
migraphx_onnx_options_destroy
(
migraphx_onnx_options_t
onnx_options
);
migraphx_status
migraphx_onnx_options_assign_to
(
migraphx_onnx_options_t
output
,
const_migraphx_onnx_options_t
input
);
migraphx_status
migraphx_onnx_options_create
(
migraphx_onnx_options_t
*
onnx_options
);
migraphx_status
migraphx_onnx_options_set_input_parameter_shape
(
...
...
@@ -236,6 +342,9 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o
migraphx_status
migraphx_file_options_destroy
(
migraphx_file_options_t
file_options
);
migraphx_status
migraphx_file_options_assign_to
(
migraphx_file_options_t
output
,
const_migraphx_file_options_t
input
);
migraphx_status
migraphx_file_options_create
(
migraphx_file_options_t
*
file_options
);
migraphx_status
migraphx_file_options_set_file_format
(
migraphx_file_options_t
file_options
,
...
...
@@ -243,6 +352,9 @@ migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t fi
migraphx_status
migraphx_compile_options_destroy
(
migraphx_compile_options_t
compile_options
);
migraphx_status
migraphx_compile_options_assign_to
(
migraphx_compile_options_t
output
,
const_migraphx_compile_options_t
input
);
migraphx_status
migraphx_compile_options_create
(
migraphx_compile_options_t
*
compile_options
);
migraphx_status
...
...
@@ -261,6 +373,9 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
migraphx_status
migraphx_tf_options_destroy
(
migraphx_tf_options_t
tf_options
);
migraphx_status
migraphx_tf_options_assign_to
(
migraphx_tf_options_t
output
,
const_migraphx_tf_options_t
input
);
migraphx_status
migraphx_tf_options_create
(
migraphx_tf_options_t
*
tf_options
);
migraphx_status
migraphx_tf_options_set_nhwc
(
migraphx_tf_options_t
tf_options
,
bool
is_nhwc
);
...
...
@@ -282,6 +397,9 @@ migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options
migraphx_status
migraphx_quantize_op_names_destroy
(
migraphx_quantize_op_names_t
quantize_op_names
);
migraphx_status
migraphx_quantize_op_names_assign_to
(
migraphx_quantize_op_names_t
output
,
const_migraphx_quantize_op_names_t
input
);
migraphx_status
migraphx_quantize_op_names_create
(
migraphx_quantize_op_names_t
*
quantize_op_names
);
migraphx_status
migraphx_quantize_op_names_add
(
migraphx_quantize_op_names_t
quantize_op_names
,
...
...
@@ -295,6 +413,10 @@ migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status
migraphx_quantize_int8_options_destroy
(
migraphx_quantize_int8_options_t
quantize_int8_options
);
migraphx_status
migraphx_quantize_int8_options_assign_to
(
migraphx_quantize_int8_options_t
output
,
const_migraphx_quantize_int8_options_t
input
);
migraphx_status
migraphx_quantize_int8_options_create
(
migraphx_quantize_int8_options_t
*
quantize_int8_options
);
...
...
@@ -309,6 +431,28 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t
target
,
migraphx_quantize_int8_options_t
options
);
migraphx_status
migraphx_context_finish
(
const_migraphx_context_t
context
);
migraphx_status
migraphx_experimental_custom_op_destroy
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
migraphx_status
migraphx_experimental_custom_op_assign_to
(
migraphx_experimental_custom_op_t
output
,
const_migraphx_experimental_custom_op_t
input
);
migraphx_status
migraphx_experimental_custom_op_create
(
migraphx_experimental_custom_op_t
*
experimental_custom_op
,
void
*
obj
,
migraphx_experimental_custom_op_copy
c
,
migraphx_experimental_custom_op_delete
d
,
const
char
*
name
);
migraphx_status
migraphx_experimental_custom_op_set_compute_shape
(
migraphx_experimental_custom_op_t
obj
,
migraphx_experimental_custom_op_compute_shape
input
);
migraphx_status
migraphx_experimental_custom_op_register
(
migraphx_experimental_custom_op_t
experimental_custom_op
);
#ifdef __cplusplus
}
#endif
...
...
src/api/include/migraphx/migraphx.hpp
View file @
e2eb6036
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <memory>
#include <exception>
...
...
@@ -13,6 +15,16 @@ namespace migraphx {
inline
namespace
api
{
// NOLINT
#endif
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
{
};
template
<
>
struct
rank
<
0
>
{
};
template
<
class
T
,
class
F
,
class
...
Ts
>
T
*
make
(
F
f
,
Ts
&&
...
xs
)
{
...
...
@@ -152,6 +164,35 @@ struct array_base
}
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wnon-template-friend"
#endif
template
<
class
T
>
struct
holder
{
// Friend injection
friend
auto
migraphx_adl_handle_lookup
(
holder
<
T
>
);
// Function left unimplemented since its only used in non-evaluated
// context
T
get
()
const
;
};
template
<
class
C
,
class
T
>
struct
handle_lookup
{
friend
auto
migraphx_adl_handle_lookup
(
holder
<
T
>
)
{
return
holder
<
C
>
{};
}
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
template
<
class
T
>
using
as_handle
=
decltype
(
migraphx_adl_handle_lookup
(
holder
<
std
::
remove_cv_t
<
std
::
remove_pointer_t
<
T
>>>
{}).
get
());
struct
own
{
};
...
...
@@ -159,8 +200,8 @@ struct borrow
{
};
template
<
class
T
,
class
D
,
D
Deleter
>
struct
handle_base
template
<
class
Derived
,
class
T
,
class
D
,
D
Deleter
,
class
A
,
A
Assigner
>
struct
handle_base
:
handle_lookup
<
Derived
,
std
::
remove_cv_t
<
T
>>
{
handle_base
()
:
m_handle
(
nullptr
)
{}
template
<
class
F
,
class
...
Ts
>
...
...
@@ -190,17 +231,158 @@ struct handle_base
m_handle
=
std
::
shared_ptr
<
U
>
{
ptr
,
[](
U
*
)
{}};
}
template
<
class
U
>
void
assign_to_handle
(
U
*
x
)
{
Assigner
(
x
,
this
->
get_handle_ptr
());
}
protected:
std
::
shared_ptr
<
T
>
m_handle
;
};
template
<
class
Base
>
struct
interface_base
:
Base
{
interface_base
()
:
Base
()
{}
protected:
template
<
class
F
>
static
migraphx_status
try_
(
F
f
)
// NOLINT
{
try
{
f
();
return
migraphx_status_success
;
}
catch
(...)
{
return
migraphx_status_unknown_error
;
}
}
template
<
class
F
,
class
T
,
class
...
Ts
>
void
make_interface
(
F
f
,
T
&
obj
,
Ts
&&
...
xs
)
{
auto
copy
=
[](
void
**
out
,
void
*
input
)
{
return
try_
([
&
]
{
T
**
y
=
reinterpret_cast
<
T
**>
(
out
);
T
*
x
=
reinterpret_cast
<
T
*>
(
input
);
assert
(
x
!=
nullptr
and
y
!=
nullptr
and
*
y
==
nullptr
);
*
y
=
new
T
(
*
x
);
// NOLINT
});
};
auto
del
=
[](
void
*
input
)
{
return
try_
([
&
]
{
T
*
x
=
reinterpret_cast
<
T
*>
(
input
);
delete
x
;
// NOLINT
});
};
this
->
make_handle
(
f
,
&
obj
,
copy
,
del
,
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
T
,
class
Setter
,
class
F
>
void
set_fp
(
Setter
setter
,
F
pf
)
{
static
F
f
=
pf
;
(
void
)
f
;
// avoid warning on gcc
call
(
setter
,
this
->
get_handle_ptr
(),
[](
auto
...
xs
)
->
migraphx_status
{
return
try_
([
&
]
{
call_cast_arg
<
T
>
(
rank
<
1
>
{},
f
,
xs
...);
});
});
}
template
<
class
T
,
class
Setter
,
class
F
>
void
set_auto_fp
(
Setter
setter
,
F
f
)
{
return
set_fp
<
T
>
(
setter
,
[
=
](
T
&
obj
,
auto
out
,
auto
...
xs
)
{
auto_invoke
(
f
,
out
,
obj
,
auto_convert_param
(
rank
<
2
>
{},
xs
)...);
});
}
struct
no_out_arg
{
};
template
<
class
T
,
class
F
,
class
X
,
class
...
Xs
,
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
static
void
call_cast_arg
(
rank
<
0
>
,
F
f
,
X
*
obj
,
Xs
...
xs
)
{
f
(
reinterpret_cast
<
T
*>
(
obj
),
no_out_arg
{},
xs
...);
}
template
<
class
T
,
class
F
,
class
R
,
class
X
,
class
...
Xs
,
class
=
std
::
enable_if_t
<
std
::
is_void
<
X
>{}
>>
static
void
call_cast_arg
(
rank
<
1
>
,
F
f
,
R
result
,
X
*
obj
,
Xs
...
xs
)
{
f
(
*
reinterpret_cast
<
T
*>
(
obj
),
result
,
xs
...);
}
template
<
class
F
,
class
T
,
class
...
Ts
>
void
auto_invoke
(
F
f
,
T
*
out
,
Ts
&&
...
xs
)
{
auto_assign
(
rank
<
2
>
{},
out
,
f
(
std
::
forward
<
Ts
>
(
xs
)...));
}
template
<
class
F
,
class
T
,
class
...
Ts
>
void
auto_invoke
(
F
f
,
no_out_arg
,
Ts
&&
...
xs
)
{
f
(
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
T
,
class
=
std
::
enable_if_t
<
std
::
is_fundamental
<
T
>{}
or
std
::
is_enum
<
T
>
{}
>>
T
auto_convert_param
(
rank
<
0
>
,
T
x
)
{
return
x
;
}
template
<
class
T
>
auto
auto_convert_param
(
rank
<
1
>
,
T
x
)
->
decltype
(
as_handle
<
T
>
{
x
})
{
return
as_handle
<
T
>
{
x
};
}
template
<
class
T
>
auto
auto_convert_param
(
rank
<
2
>
,
T
x
)
->
decltype
(
as_handle
<
T
>
{
x
,
borrow
{}})
{
return
as_handle
<
T
>
{
x
,
borrow
{}};
}
template
<
class
T
,
class
U
>
void
auto_assign
(
rank
<
0
>
,
T
*
out
,
U
x
)
{
return
*
out
=
x
;
}
template
<
class
T
,
class
U
>
auto
auto_assign
(
rank
<
1
>
,
T
*
out
,
U
x
)
->
decltype
(
x
.
assign_to_handle
(
out
))
{
x
.
assign_to_handle
(
out
);
}
};
// NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); })
template
<
class
Base
,
class
T
>
using
require_interface
=
std
::
enable_if_t
<
std
::
is_base_of
<
Base
,
T
>
{}
and
not
std
::
is_same
<
T
,
Base
>
{}
and
std
::
is_copy_constructible
<
T
>
{}
and
std
::
is_final
<
T
>
{}
>
;
#ifdef DOXYGEN
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy>
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<name, \
const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy, \
decltype(&migraphx_##name##_assign_to), \
migraphx_##name##_assign_to>
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, )
...
...
@@ -485,12 +667,116 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
};
};
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
operation
(
migraphx_operation
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
operation
(
migraphx_operation
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
{
this
->
make_handle
(
&
migraphx_operation_create
,
name
,
attributes
,
xs
...);
}
std
::
string
name
()
{
std
::
array
<
char
,
1024
>
out_name
;
call
(
&
migraphx_operation_name
,
out_name
.
data
(),
1024
,
this
->
get_handle_ptr
());
return
{
out_name
.
data
()};
}
};
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
{
instruction
(
migraphx_instruction
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
};
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
{
instructions
(
migraphx_instructions
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
instructions
(
migraphx_instructions
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
instructions
(
Ts
...
xs
)
{
std
::
array
<
const_migraphx_instruction_t
,
sizeof
...(
Ts
)
>
a
{
xs
.
get_handle_ptr
()...};
this
->
make_handle
(
&
migraphx_instructions_create
,
a
.
data
(),
a
.
size
());
}
};
struct
module
;
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
{
modules
(
migraphx_modules
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
modules
(
migraphx_modules
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
modules
(
Ts
...
xs
)
{
std
::
array
<
migraphx_module_t
,
sizeof
...(
Ts
)
>
a
=
{
xs
.
mm
...};
this
->
make_handle
(
&
migraphx_modules_create
,
a
.
data
(),
a
.
size
());
}
};
struct
module
{
migraphx_module_t
mm
;
module
(
const
migraphx_module_t
&
m
)
:
mm
(
m
)
{}
void
print
()
const
{
call
(
&
migraphx_module_print
,
mm
);
}
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
)
{
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction
,
&
op_ins
,
mm
,
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
}
instruction
add_instruction
(
const
migraphx
::
operation
&
op
,
const
migraphx
::
instructions
&
args
,
const
migraphx
::
modules
&
module_args
)
{
migraphx_instruction_t
op_ins
;
call
(
&
migraphx_module_add_instruction_with_mod_args
,
&
op_ins
,
mm
,
op
.
get_handle_ptr
(),
args
.
get_handle_ptr
(),
module_args
.
get_handle_ptr
());
return
instruction
(
op_ins
,
own
{});
}
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
{
migraphx_instruction_t
param_ins
;
call
(
&
migraphx_module_add_parameter
,
&
param_ins
,
mm
,
name
.
c_str
(),
s
.
get_handle_ptr
());
return
instruction
(
param_ins
,
own
{});
}
instruction
add_return
(
const
migraphx
::
instructions
&
args
)
{
migraphx_instruction_t
ret_ins
;
call
(
&
migraphx_module_add_return
,
&
ret_ins
,
mm
,
args
.
get_handle_ptr
());
return
instruction
(
ret_ins
,
own
{});
}
};
struct
context
{
migraphx_context_t
ctx
;
void
finish
()
const
{
call
(
&
migraphx_context_finish
,
ctx
);
}
};
struct
compile_options
:
MIGRAPHX_HANDLE_BASE
(
compile_options
)
...
...
@@ -519,7 +805,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
/// A program represents the all computation graphs to be compiled and executed
struct
program
:
MIGRAPHX_HANDLE_BASE
(
program
)
{
program
()
{}
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
(
migraphx_program
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
...
...
@@ -589,27 +875,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return
module
{
p_modu
};
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
};
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
operation
(
migraphx_operation
*
p
,
own
)
{
this
->
set_handle
(
p
,
own
{});
}
operation
(
migraphx_operation
*
p
,
borrow
)
{
this
->
set_handle
(
p
,
borrow
{});
}
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
context
experimental_get_context
()
{
this
->
make_handle
(
&
migraphx_operation_create
,
name
,
attributes
,
xs
...);
migraphx_context_t
ctx
;
call
(
&
migraphx_program_experimental_get_context
,
&
ctx
,
this
->
get_handle_ptr
());
return
context
{
ctx
};
}
std
::
string
name
(
)
module
create_module
(
const
std
::
string
&
name
)
{
std
::
array
<
char
,
1024
>
out_name
;
call
(
&
migraphx_
operation_name
,
out_name
.
data
(),
1024
,
this
->
get_handle_ptr
());
return
{
out_name
.
data
()
};
migraphx_module_t
p_modu
;
call
(
&
migraphx_
program_create_module
,
&
p_modu
,
this
->
get_handle_ptr
()
,
name
.
data
()
);
return
module
{
p_modu
};
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
};
// options for migraphx file format options
...
...
@@ -850,6 +1130,32 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
options
.
get_handle_ptr
());
}
struct
experimental_custom_op_base
{
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
shapes
inputs
)
const
=
0
;
virtual
~
experimental_custom_op_base
()
=
default
;
};
struct
experimental_custom_op
:
interface_base
<
MIGRAPHX_HANDLE_BASE
(
experimental_custom_op
)
>
{
template
<
class
T
>
experimental_custom_op
(
T
&
obj
)
{
this
->
make_interface
(
&
migraphx_experimental_custom_op_create
,
obj
,
obj
.
name
().
c_str
());
MIGRAPHX_INTERFACE_LIFT
(
T
,
experimental_custom_op
,
compute_shape
);
}
void
register_op
()
{
call
(
&
migraphx_experimental_custom_op_register
,
this
->
get_handle_ptr
());
}
};
template
<
class
T
,
class
=
require_interface
<
experimental_custom_op_base
,
T
>
>
void
register_experimental_custom_op
(
T
&
obj
)
{
experimental_custom_op
op
{
obj
};
op
.
register_op
();
}
#ifndef DOXYGEN
}
// namespace api
#endif
...
...
src/api/migraphx.py
View file @
e2eb6036
...
...
@@ -178,14 +178,55 @@ def shapes(h):
returns
=
'const migraphx::shape&'
)
@
api
.
handle
(
'migraphx_instruction'
,
'migraphx::instruction_ref'
)
def
instruction
(
h
):
pass
@
api
.
handle
(
'migraphx_instructions'
,
'std::vector<migraphx::instruction_ref>'
)
def
instructions
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'const_migraphx_instruction_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_instruction_t>'
)
@
api
.
handle
(
'migraphx_modules'
,
'std::vector<migraphx::module*>'
)
def
modules
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'migraphx_module_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_objptr_vector<migraphx::module*>'
)
@
auto_handle
(
ref
=
True
)
def
module
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
name
=
'std::string'
))
h
.
method
(
'print'
,
invoke
=
'migraphx::print_module($@)'
,
const
=
True
)
h
.
method
(
'add_instruction'
,
api
.
params
(
op
=
'migraphx::operation'
,
args
=
'std::vector<migraphx::instruction_ref>'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_instruction_with_mod_args'
,
api
.
params
(
op
=
'migraphx::operation'
,
args
=
'std::vector<migraphx::instruction_ref>'
,
module_refs
=
'std::vector<migraphx::module*>'
),
fname
=
'add_instruction'
,
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_parameter'
,
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_return'
,
api
.
params
(
args
=
'std::vector<migraphx::instruction_ref>'
),
returns
=
'migraphx::instruction_ref'
)
@
auto_handle
()
def
program
(
h
):
h
.
constructor
(
'create'
)
h
.
method
(
'get_main_module'
,
returns
=
'migraphx::module*'
)
h
.
method
(
'create_module'
,
api
.
params
(
name
=
'const char*'
),
returns
=
'migraphx::module*'
)
h
.
method
(
'compile'
,
api
.
params
(
target
=
'migraphx::target'
,
...
...
@@ -207,6 +248,10 @@ def program(h):
invoke
=
'migraphx::equal($@)'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'experimental_get_context'
,
invoke
=
'migraphx::get_context($@)'
,
const
=
True
,
returns
=
'migraphx::context'
)
@
auto_handle
()
...
...
@@ -353,3 +398,18 @@ api.add_function('migraphx_quantize_int8',
target
=
'migraphx::target'
,
options
=
'migraphx::quantize_int8_options'
),
fname
=
'migraphx::quantize_int8_wrap'
)
@
auto_handle
(
ref
=
True
)
def
context
(
h
):
h
.
method
(
'finish'
,
const
=
True
)
@
api
.
interface
(
'migraphx_experimental_custom_op'
,
'migraphx::experimental_custom_op'
)
def
experimental_custom_op
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
name
=
'const char*'
))
h
.
virtual
(
'compute_shape'
,
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
returns
=
'migraphx::shape'
)
h
.
method
(
'register'
,
invoke
=
'migraphx::register_custom_op($@)'
)
src/argument.cpp
View file @
e2eb6036
...
...
@@ -106,7 +106,11 @@ bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
const
shape
&
argument
::
get_shape
()
const
{
return
this
->
m_shape
;
}
argument
argument
::
reshape
(
const
shape
&
s
)
const
{
return
{
s
,
this
->
m_data
};
}
argument
argument
::
reshape
(
const
shape
&
s
)
const
{
assert
(
s
.
element_space
()
<=
this
->
get_shape
().
element_space
());
return
{
s
,
this
->
m_data
};
}
argument
::
data_t
argument
::
data_t
::
share
()
const
{
...
...
src/auto_contiguous.cpp
View file @
e2eb6036
...
...
@@ -10,8 +10,35 @@ inline namespace MIGRAPHX_INLINE_NS {
void
auto_contiguous
::
apply
(
module
&
p
)
const
{
std
::
string
key
=
"require_std_shape"
;
for
(
auto
ins
:
reverse_iterator_for
(
p
))
{
auto
&&
attr
=
ins
->
get_operator
().
attributes
();
if
((
attr
.
get
(
key
,
false
)))
{
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
new_args
.
begin
(),
[
&
](
auto
in
)
{
if
(
in
->
name
()
==
"contiguous"
)
{
return
in
;
}
return
p
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
});
if
(
new_args
!=
args
)
{
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
}
}
}
auto
last
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
{
// for last instruction that is NOT a return
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
continue
;
shape
s
=
ins
->
get_shape
();
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
{
...
...
src/compile_src.cpp
View file @
e2eb6036
...
...
@@ -34,7 +34,14 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
params
+=
" -o "
+
out
;
td
.
execute
(
compiler
,
params
);
if
(
not
launcher
.
empty
())
{
td
.
execute
(
launcher
,
compiler
+
" "
+
params
);
}
else
{
td
.
execute
(
compiler
,
params
);
}
auto
out_path
=
td
.
path
/
out
;
if
(
not
fs
::
exists
(
out_path
))
...
...
src/cpp_generator.cpp
View file @
e2eb6036
...
...
@@ -88,6 +88,7 @@ struct cpp_generator_impl
std
::
stringstream
fs
{};
std
::
size_t
function_count
=
0
;
std
::
function
<
std
::
string
(
std
::
string
)
>
fmap
=
nullptr
;
std
::
function
<
std
::
string
(
shape
)
>
fresult
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
point_op_map
=
{};
};
cpp_generator
::
cpp_generator
()
:
impl
(
std
::
make_unique
<
cpp_generator_impl
>
())
{}
...
...
@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
void
cpp_generator
::
fmap
(
const
std
::
function
<
std
::
string
(
std
::
string
)
>&
f
)
{
impl
->
fmap
=
f
;
}
void
cpp_generator
::
fresult
(
const
std
::
function
<
std
::
string
(
shape
)
>&
f
)
{
impl
->
fresult
=
f
;
}
void
cpp_generator
::
add_point_op
(
const
std
::
string
&
op_name
,
const
std
::
string
&
code
)
{
impl
->
point_op_map
[
op_name
]
=
code
;
...
...
@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins
->
inputs
().
end
(),
std
::
back_inserter
(
args
),
[
&
](
auto
i
)
{
return
names
.
at
(
i
);
});
return
this
->
generate_point_op
(
ins
->
get_operator
(),
args
);
auto
s
=
this
->
generate_point_op
(
ins
->
get_operator
(),
args
);
if
(
impl
->
fresult
)
return
impl
->
fresult
(
ins
->
get_shape
())
+
'('
+
s
+
')'
;
else
return
s
;
});
return
f
;
}
...
...
src/driver/alexnet.cpp
View file @
e2eb6036
...
...
@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu19
;
auto
mx19
=
mm
->
add_instruction
(
relu19
,
mx18
);
migraphx
::
op
::
pooling
pooling20
;
pooling20
.
mode
=
"max"
;
pooling20
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling20
.
padding
=
{
0
,
0
};
pooling20
.
stride
=
{
2
,
2
};
pooling20
.
lengths
=
{
3
,
3
};
...
...
@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu24
;
auto
mx24
=
mm
->
add_instruction
(
relu24
,
mx23
);
migraphx
::
op
::
pooling
pooling25
;
pooling25
.
mode
=
"max"
;
pooling25
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling25
.
padding
=
{
0
,
0
};
pooling25
.
stride
=
{
2
,
2
};
pooling25
.
lengths
=
{
3
,
3
};
...
...
@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu37
;
auto
mx37
=
mm
->
add_instruction
(
relu37
,
mx36
);
migraphx
::
op
::
pooling
pooling38
;
pooling38
.
mode
=
"max"
;
pooling38
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling38
.
padding
=
{
0
,
0
};
pooling38
.
stride
=
{
2
,
2
};
pooling38
.
lengths
=
{
3
,
3
};
...
...
src/driver/inceptionv3.cpp
View file @
e2eb6036
...
...
@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu492
;
auto
mx492
=
mm
->
add_instruction
(
relu492
,
mx491
);
migraphx
::
op
::
pooling
pooling493
;
pooling493
.
mode
=
"max"
;
pooling493
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling493
.
padding
=
{
0
,
0
};
pooling493
.
stride
=
{
2
,
2
};
pooling493
.
lengths
=
{
3
,
3
};
...
...
@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu499
;
auto
mx499
=
mm
->
add_instruction
(
relu499
,
mx498
);
migraphx
::
op
::
pooling
pooling500
;
pooling500
.
mode
=
"max"
;
pooling500
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling500
.
padding
=
{
0
,
0
};
pooling500
.
stride
=
{
2
,
2
};
pooling500
.
lengths
=
{
3
,
3
};
...
...
@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu518
;
auto
mx518
=
mm
->
add_instruction
(
relu518
,
mx517
);
migraphx
::
op
::
pooling
pooling519
;
pooling519
.
mode
=
"
average
"
;
pooling519
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling519
.
padding
=
{
1
,
1
};
pooling519
.
stride
=
{
1
,
1
};
pooling519
.
lengths
=
{
3
,
3
};
...
...
@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu541
;
auto
mx541
=
mm
->
add_instruction
(
relu541
,
mx540
);
migraphx
::
op
::
pooling
pooling542
;
pooling542
.
mode
=
"
average
"
;
pooling542
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling542
.
padding
=
{
1
,
1
};
pooling542
.
stride
=
{
1
,
1
};
pooling542
.
lengths
=
{
3
,
3
};
...
...
@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu564
;
auto
mx564
=
mm
->
add_instruction
(
relu564
,
mx563
);
migraphx
::
op
::
pooling
pooling565
;
pooling565
.
mode
=
"
average
"
;
pooling565
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling565
.
padding
=
{
1
,
1
};
pooling565
.
stride
=
{
1
,
1
};
pooling565
.
lengths
=
{
3
,
3
};
...
...
@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu581
;
auto
mx581
=
mm
->
add_instruction
(
relu581
,
mx580
);
migraphx
::
op
::
pooling
pooling582
;
pooling582
.
mode
=
"max"
;
pooling582
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling582
.
padding
=
{
0
,
0
};
pooling582
.
stride
=
{
2
,
2
};
pooling582
.
lengths
=
{
3
,
3
};
...
...
@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu610
;
auto
mx610
=
mm
->
add_instruction
(
relu610
,
mx609
);
migraphx
::
op
::
pooling
pooling611
;
pooling611
.
mode
=
"
average
"
;
pooling611
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling611
.
padding
=
{
1
,
1
};
pooling611
.
stride
=
{
1
,
1
};
pooling611
.
lengths
=
{
3
,
3
};
...
...
@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu642
;
auto
mx642
=
mm
->
add_instruction
(
relu642
,
mx641
);
migraphx
::
op
::
pooling
pooling643
;
pooling643
.
mode
=
"
average
"
;
pooling643
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling643
.
padding
=
{
1
,
1
};
pooling643
.
stride
=
{
1
,
1
};
pooling643
.
lengths
=
{
3
,
3
};
...
...
@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu674
;
auto
mx674
=
mm
->
add_instruction
(
relu674
,
mx673
);
migraphx
::
op
::
pooling
pooling675
;
pooling675
.
mode
=
"
average
"
;
pooling675
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling675
.
padding
=
{
1
,
1
};
pooling675
.
stride
=
{
1
,
1
};
pooling675
.
lengths
=
{
3
,
3
};
...
...
@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu706
;
auto
mx706
=
mm
->
add_instruction
(
relu706
,
mx705
);
migraphx
::
op
::
pooling
pooling707
;
pooling707
.
mode
=
"
average
"
;
pooling707
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling707
.
padding
=
{
1
,
1
};
pooling707
.
stride
=
{
1
,
1
};
pooling707
.
lengths
=
{
3
,
3
};
...
...
@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx
::
op
::
relu
relu729
;
auto
mx729
=
mm
->
add_instruction
(
relu729
,
mx728
);
migraphx
::
op
::
pooling
pooling730
;
pooling730
.
mode
=
"max"
;
pooling730
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling730
.
padding
=
{
0
,
0
};
pooling730
.
stride
=
{
2
,
2
};
pooling730
.
lengths
=
{
3
,
3
};
...
...
@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat757
.
axis
=
1
;
auto
mx757
=
mm
->
add_instruction
(
concat757
,
mx753
,
mx756
);
migraphx
::
op
::
pooling
pooling758
;
pooling758
.
mode
=
"
average
"
;
pooling758
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling758
.
padding
=
{
1
,
1
};
pooling758
.
stride
=
{
1
,
1
};
pooling758
.
lengths
=
{
3
,
3
};
...
...
@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat788
.
axis
=
1
;
auto
mx788
=
mm
->
add_instruction
(
concat788
,
mx784
,
mx787
);
migraphx
::
op
::
pooling
pooling789
;
pooling789
.
mode
=
"
average
"
;
pooling789
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling789
.
padding
=
{
1
,
1
};
pooling789
.
stride
=
{
1
,
1
};
pooling789
.
lengths
=
{
3
,
3
};
...
...
@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat793
.
axis
=
1
;
auto
mx793
=
mm
->
add_instruction
(
concat793
,
mx765
,
mx775
,
mx788
,
mx792
);
migraphx
::
op
::
pooling
pooling794
;
pooling794
.
mode
=
"
average
"
;
pooling794
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling794
.
padding
=
{
0
,
0
};
pooling794
.
stride
=
{
8
,
8
};
pooling794
.
lengths
=
{
8
,
8
};
...
...
src/driver/main.cpp
View file @
e2eb6036
...
...
@@ -505,8 +505,10 @@ struct roctx : command<roctx>
struct
op
:
command
<
op
>
{
bool
show_ops
=
false
;
std
::
string
op_name
{};
void
parse
(
argument_parser
&
ap
)
{
ap
(
op_name
,
{},
ap
.
metavar
(
"<MIGraphX operator name>"
));
ap
(
show_ops
,
{
"--list"
,
"-l"
},
ap
.
help
(
"List all the operators of MIGraphX"
),
...
...
@@ -519,6 +521,12 @@ struct op : command<op>
for
(
const
auto
&
name
:
get_operators
())
std
::
cout
<<
name
<<
std
::
endl
;
}
else
{
auto
op
=
load_op
(
op_name
);
std
::
cout
<<
op_name
<<
": "
<<
std
::
endl
;
std
::
cout
<<
to_pretty_json_string
(
op
.
to_value
())
<<
std
::
endl
;
}
}
};
...
...
src/driver/perf.cpp
View file @
e2eb6036
...
...
@@ -87,6 +87,6 @@ target get_target(bool gpu)
void
compile_program
(
program
&
p
,
bool
gpu
)
{
p
.
compile
(
get_target
(
gpu
));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace
MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace migraphx
src/driver/resnet50.cpp
View file @
e2eb6036
...
...
@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu269
;
auto
mx269
=
mm
->
add_instruction
(
relu269
,
mx268
);
migraphx
::
op
::
pooling
pooling270
;
pooling270
.
mode
=
"max"
;
pooling270
.
mode
=
migraphx
::
op
::
pooling_mode
::
max
;
pooling270
.
padding
=
{
1
,
1
};
pooling270
.
stride
=
{
2
,
2
};
pooling270
.
lengths
=
{
3
,
3
};
...
...
@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
op
::
relu
relu438
;
auto
mx438
=
mm
->
add_instruction
(
relu438
,
mx437
);
migraphx
::
op
::
pooling
pooling439
;
pooling439
.
mode
=
"
average
"
;
pooling439
.
mode
=
migraphx
::
op
::
pooling_mode
::
average
;
pooling439
.
padding
=
{
0
,
0
};
pooling439
.
stride
=
{
1
,
1
};
pooling439
.
lengths
=
{
7
,
7
};
...
...
src/eliminate_common_subexpression.cpp
View file @
e2eb6036
...
...
@@ -32,7 +32,12 @@ void cse_range(module& p, Range&& r)
continue
;
p
.
replace_instruction
(
ins
,
eq
);
processed_ins
.
emplace
(
ins
);
auto
outputs
=
eq
->
outputs
();
std
::
vector
<
instruction_ref
>
outputs
;
std
::
copy_if
(
eq
->
outputs
().
begin
(),
eq
->
outputs
().
end
(),
std
::
back_inserter
(
outputs
),
[
&
](
auto
x
)
{
return
p
.
has_instruction
(
x
);
});
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
});
...
...
src/eliminate_contiguous.cpp
View file @
e2eb6036
...
...
@@ -78,15 +78,16 @@ void eliminate_contiguous::apply(module& p) const
continue
;
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
auto
mod_args
=
ins
->
module_inputs
();
for
(
auto
arg
:
ins
->
inputs
())
{
if
(
arg
->
name
()
==
op_name
)
{
auto
new_args
=
args
;
auto
prev
=
arg
->
inputs
().
front
();
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
ins
->
module_inputs
()
))
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
{
instruction
::
replace_argument
(
ins
,
arg
,
prev
);
}
...
...
src/eliminate_data_type.cpp
View file @
e2eb6036
...
...
@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
void
eliminate_data_type
::
apply
(
module
&
m
)
const
{
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
"convert"
,
"get_tuple_elem"
,
"if"
,
"loop"
,
"roialign"
};
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
"convert"
,
"get_tuple_elem"
,
"if"
,
"loop"
,
"roialign"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_none"
};
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()[
0
]
==
'@'
)
...
...
Prev
1
2
3
4
5
6
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment