Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
baac1dab
Commit
baac1dab
authored
May 24, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/develop' into ck-host-lib
parents
830dff7a
77042e30
Changes
299
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1227 additions
and
163 deletions
+1227
-163
src/CMakeLists.txt
src/CMakeLists.txt
+12
-1
src/api/CMakeLists.txt
src/api/CMakeLists.txt
+5
-2
src/api/api.cpp
src/api/api.cpp
+302
-2
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+83
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+135
-1
src/api/migraphx.py
src/api/migraphx.py
+59
-43
src/common.cpp
src/common.cpp
+18
-7
src/compile_src.cpp
src/compile_src.cpp
+0
-3
src/cpp_generator.cpp
src/cpp_generator.cpp
+37
-8
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+1
-0
src/driver/argument_parser.hpp
src/driver/argument_parser.hpp
+15
-7
src/driver/main.cpp
src/driver/main.cpp
+150
-55
src/driver/perf.cpp
src/driver/perf.cpp
+8
-21
src/driver/perf.hpp
src/driver/perf.hpp
+4
-3
src/driver/verify.cpp
src/driver/verify.cpp
+2
-2
src/dynamic_loader.cpp
src/dynamic_loader.cpp
+34
-4
src/eliminate_allocation.cpp
src/eliminate_allocation.cpp
+0
-3
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+1
-0
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+11
-1
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+350
-0
No files found.
src/CMakeLists.txt
View file @
baac1dab
...
...
@@ -50,6 +50,7 @@ add_library(migraphx
env.cpp
file_buffer.cpp
fuse_pointwise.cpp
fuse_reduce.cpp
generate.cpp
inline_module.cpp
insert_pad.cpp
...
...
@@ -73,6 +74,7 @@ add_library(migraphx
process.cpp
program.cpp
propagate_constant.cpp
promote_literals.cpp
quantization.cpp
quantize_fp16.cpp
quantize_int8.cpp
...
...
@@ -91,6 +93,7 @@ add_library(migraphx
shape.cpp
simplify_algebra.cpp
simplify_reshapes.cpp
split_single_dyn_dim.cpp
tmp_dir.cpp
value.cpp
verify_args.cpp
...
...
@@ -199,6 +202,7 @@ register_migraphx_ops(
scatternd_add
scatternd_mul
scatternd_none
select_module
sigmoid
sign
sinh
...
...
@@ -250,7 +254,10 @@ find_package(PkgConfig)
pkg_check_modules
(
SQLITE3 REQUIRED IMPORTED_TARGET sqlite3
)
target_link_libraries
(
migraphx PRIVATE PkgConfig::SQLITE3
)
find_package
(
msgpack REQUIRED
)
find_package
(
msgpackc-cxx QUIET
)
if
(
NOT msgpackc-cxx_FOUND
)
find_package
(
msgpack REQUIRED
)
endif
()
target_link_libraries
(
migraphx PRIVATE msgpackc-cxx
)
# Make this available to the tests
target_link_libraries
(
migraphx INTERFACE $<BUILD_INTERFACE:msgpackc-cxx>
)
...
...
@@ -288,6 +295,10 @@ if(HAVE_HALF_EXPR)
target_compile_definitions
(
migraphx PUBLIC -DHAS_HALF_V1
)
endif
()
if
(
BUILD_DEV
)
target_compile_definitions
(
migraphx PUBLIC -DBUILD_DEV
)
endif
()
rocm_export_targets
(
TARGETS migraphx::migraphx_c
NAMESPACE migraphx::
...
...
src/api/CMakeLists.txt
View file @
baac1dab
...
...
@@ -26,10 +26,13 @@ add_library(migraphx_c
api.cpp
)
set_target_properties
(
migraphx_c PROPERTIES EXPORT_NAME c
)
rocm_set_soversion
(
migraphx_c 3.0
)
# migraphx_c is stable API interface library. SO version of this should be
# bumped when binary compatibility is broken.
rocm_set_soversion
(
migraphx_c 3.0
)
rocm_clang_tidy_check
(
migraphx_c
)
target_link_libraries
(
migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx
migraphx_all_targets
)
target_link_libraries
(
migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx
)
rocm_install_targets
(
TARGETS migraphx_c
...
...
src/api/api.cpp
View file @
baac1dab
...
...
@@ -24,6 +24,7 @@
#include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h>
#include <migraphx/rank.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
...
...
@@ -32,7 +33,6 @@
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
...
...
@@ -134,6 +134,11 @@ void set_offload_copy(compile_options& options, bool value) { options.offload_co
void
set_fast_math
(
compile_options
&
options
,
bool
value
)
{
options
.
fast_math
=
value
;
}
void
set_exhaustive_tune_flag
(
compile_options
&
options
,
bool
value
)
{
options
.
exhaustive_tune
=
value
;
}
void
set_file_format
(
file_options
&
options
,
const
char
*
format
)
{
options
.
format
=
format
;
}
void
set_default_dim_value
(
onnx_options
&
options
,
size_t
value
)
...
...
@@ -141,6 +146,11 @@ void set_default_dim_value(onnx_options& options, size_t value)
options
.
default_dim_value
=
value
;
}
void
set_default_dyn_dim_value
(
onnx_options
&
options
,
const
shape
::
dynamic_dimension
&
dd
)
{
options
.
default_dyn_dim_value
=
dd
;
}
void
set_default_loop_iterations
(
onnx_options
&
options
,
int64_t
value
)
{
options
.
max_loop_iterations
=
value
;
...
...
@@ -157,6 +167,13 @@ void set_input_parameter_shape(onnx_options& options,
options
.
map_input_dims
[
std
::
string
(
name
)]
=
std
::
move
(
dims
);
}
void
set_dyn_input_parameter_shape
(
onnx_options
&
options
,
const
char
*
name
,
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
)
{
options
.
map_dyn_input_dims
[
std
::
string
(
name
)]
=
std
::
move
(
dyn_dims
);
}
void
set_input_parameter_shape
(
tf_options
&
options
,
const
char
*
name
,
std
::
vector
<
std
::
size_t
>
dims
)
{
options
.
map_input_dims
[
std
::
string
(
name
)]
=
std
::
move
(
dims
);
...
...
@@ -183,6 +200,12 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>&
return
result
;
}
template
<
class
T
>
std
::
set
<
T
>
make_set
(
const
T
*
x
,
std
::
size_t
n
)
{
return
{
x
,
x
+
n
};
}
void
quantize_fp16_with_op_names
(
program
&
prog
,
std
::
vector
<
std
::
string
>&
names
)
{
if
(
names
.
empty
())
...
...
@@ -342,7 +365,10 @@ const Target* object_cast(const U* x)
template
<
class
T
,
class
...
Ts
,
class
Target
=
std
::
remove_pointer_t
<
T
>
>
Target
*
allocate
(
Ts
&&
...
xs
)
{
return
new
Target
(
std
::
forward
<
Ts
>
(
xs
)...);
// NOLINT
if
constexpr
(
std
::
is_aggregate
<
Target
>
{})
return
new
Target
{
std
::
forward
<
Ts
>
(
xs
)...};
// NOLINT
else
return
new
Target
(
std
::
forward
<
Ts
>
(
xs
)...);
// NOLINT
}
template
<
class
T
>
...
...
@@ -405,6 +431,39 @@ struct manage_generic_ptr
D
deleter
=
nullptr
;
};
extern
"C"
struct
migraphx_optimals
;
struct
migraphx_optimals
{
template
<
class
...
Ts
>
migraphx_optimals
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
std
::
set
<
size_t
>
object
;
};
extern
"C"
struct
migraphx_dynamic_dimension
;
struct
migraphx_dynamic_dimension
{
template
<
class
...
Ts
>
migraphx_dynamic_dimension
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
migraphx
::
shape
::
dynamic_dimension
object
;
};
extern
"C"
struct
migraphx_dynamic_dimensions
;
struct
migraphx_dynamic_dimensions
{
template
<
class
...
Ts
>
migraphx_dynamic_dimensions
(
Ts
&&
...
xs
)
:
object
(
std
::
forward
<
Ts
>
(
xs
)...)
// NOLINT(readability-redundant-member-init)
{
}
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
object
;
};
extern
"C"
struct
migraphx_shape
;
struct
migraphx_shape
{
...
...
@@ -732,6 +791,152 @@ struct migraphx_experimental_custom_op
}
};
extern
"C"
migraphx_status
migraphx_optimals_destroy
(
migraphx_optimals_t
optimals
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
optimals
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_optimals_assign_to
(
migraphx_optimals_t
output
,
const_migraphx_optimals_t
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
output
=
*
input
;
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_optimals_create
(
migraphx_optimals_t
*
optimals
,
const
size_t
*
ptr
,
size_t
size
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
optimals
=
object_cast
<
migraphx_optimals_t
>
(
allocate
<
std
::
set
<
size_t
>>
(
migraphx
::
make_set
<
size_t
>
((
ptr
),
(
size
))));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_destroy
(
migraphx_dynamic_dimension_t
dynamic_dimension
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
dynamic_dimension
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_assign_to
(
migraphx_dynamic_dimension_t
output
,
const_migraphx_dynamic_dimension_t
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
output
=
*
input
;
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_create_min_max
(
migraphx_dynamic_dimension_t
*
dynamic_dimension
,
size_t
min
,
size_t
max
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
dynamic_dimension
=
object_cast
<
migraphx_dynamic_dimension_t
>
(
allocate
<
migraphx
::
shape
::
dynamic_dimension
>
((
min
),
(
max
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals
(
migraphx_dynamic_dimension_t
*
dynamic_dimension
,
size_t
min
,
size_t
max
,
migraphx_optimals_t
optimals
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
optimals
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter optimals: Null pointer"
);
*
dynamic_dimension
=
object_cast
<
migraphx_dynamic_dimension_t
>
(
allocate
<
migraphx
::
shape
::
dynamic_dimension
>
((
min
),
(
max
),
(
optimals
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_is_fixed
(
bool
*
out
,
const_migraphx_dynamic_dimension_t
dynamic_dimension
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
dynamic_dimension
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dynamic_dimension: Null pointer"
);
*
out
=
(
dynamic_dimension
->
object
).
is_fixed
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimension_equal
(
bool
*
out
,
const_migraphx_dynamic_dimension_t
dynamic_dimension
,
const_migraphx_dynamic_dimension_t
x
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
dynamic_dimension
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dynamic_dimension: Null pointer"
);
if
(
x
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter x: Null pointer"
);
*
out
=
migraphx
::
equal
((
dynamic_dimension
->
object
),
(
x
->
object
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_destroy
(
migraphx_dynamic_dimensions_t
dynamic_dimensions
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
dynamic_dimensions
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_assign_to
(
migraphx_dynamic_dimensions_t
output
,
const_migraphx_dynamic_dimensions_t
input
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
output
=
*
input
;
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_create
(
migraphx_dynamic_dimensions_t
*
dynamic_dimensions
,
const_migraphx_dynamic_dimension_t
*
ptr
,
size_t
size
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
*
dynamic_dimensions
=
object_cast
<
migraphx_dynamic_dimensions_t
>
(
allocate
<
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(
migraphx
::
to_obj_vector
<
const_migraphx_dynamic_dimension_t
>
((
ptr
),
(
size
))));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_size
(
size_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
dynamic_dimensions
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dynamic_dimensions: Null pointer"
);
*
out
=
(
dynamic_dimensions
->
object
).
size
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_dynamic_dimensions_get
(
const_migraphx_dynamic_dimension_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
,
size_t
idx
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
dynamic_dimensions
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dynamic_dimensions: Null pointer"
);
*
out
=
object_cast
<
const_migraphx_dynamic_dimension_t
>
(
&
((
dynamic_dimensions
->
object
).
at
((
idx
))));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
destroy
((
shape
));
});
...
...
@@ -790,6 +995,19 @@ extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_create_dynamic
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
migraphx_dynamic_dimensions_t
dims
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
dims
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dims: Null pointer"
);
*
shape
=
object_cast
<
migraphx_shape_t
>
(
allocate
<
migraphx
::
shape
>
((
migraphx
::
to_shape_type
(
type
)),
(
dims
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_lengths
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
)
{
...
...
@@ -820,6 +1038,17 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_dyn_dims
(
migraphx_dynamic_dimensions_t
*
out
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
allocate
<
migraphx_dynamic_dimensions_t
>
((
shape
->
object
).
dyn_dims
());
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_type
(
migraphx_shape_datatype_t
*
out
,
const_migraphx_shape_t
shape
)
{
...
...
@@ -853,6 +1082,16 @@ extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shap
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_ndim
(
size_t
*
out
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
(
shape
->
object
).
ndim
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_equal
(
bool
*
out
,
const_migraphx_shape_t
shape
,
const_migraphx_shape_t
x
)
{
...
...
@@ -876,6 +1115,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_dynamic
(
bool
*
out
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
(
shape
->
object
).
dynamic
();
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
const_migraphx_shape_t
shape
,
size_t
i
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
...
...
@@ -911,6 +1160,17 @@ migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t s
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_argument_create_empty
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
argument
=
object_cast
<
migraphx_argument_t
>
(
allocate
<
migraphx
::
argument
>
((
shape
->
object
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_argument_shape
(
const_migraphx_shape_t
*
out
,
const_migraphx_argument_t
argument
)
{
...
...
@@ -1586,6 +1846,19 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_onnx_options_set_dyn_input_parameter_shape
(
migraphx_onnx_options_t
onnx_options
,
const
char
*
name
,
migraphx_dynamic_dimensions_t
dims
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
onnx_options
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter onnx_options: Null pointer"
);
if
(
dims
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dims: Null pointer"
);
migraphx
::
set_dyn_input_parameter_shape
((
onnx_options
->
object
),
(
name
),
(
dims
->
object
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_onnx_options_set_default_dim_value
(
migraphx_onnx_options_t
onnx_options
,
size_t
value
)
{
...
...
@@ -1597,6 +1870,20 @@ migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_onnx_options_set_default_dyn_dim_value
(
migraphx_onnx_options_t
onnx_options
,
const_migraphx_dynamic_dimension_t
dd
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
onnx_options
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter onnx_options: Null pointer"
);
if
(
dd
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter dd: Null pointer"
);
migraphx
::
set_default_dyn_dim_value
((
onnx_options
->
object
),
(
dd
->
object
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_onnx_options_set_default_loop_iterations
(
migraphx_onnx_options_t
onnx_options
,
int64_t
value
)
...
...
@@ -1690,6 +1977,19 @@ migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_option
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_compile_options_set_exhaustive_tune_flag
(
migraphx_compile_options_t
compile_options
,
bool
value
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
compile_options
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter compile_options: Null pointer"
);
migraphx
::
set_exhaustive_tune_flag
((
compile_options
->
object
),
(
value
));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_parse_onnx
(
migraphx_program_t
*
out
,
const
char
*
name
,
migraphx_onnx_options_t
options
)
{
...
...
src/api/include/migraphx/migraphx.h
View file @
baac1dab
...
...
@@ -66,6 +66,15 @@ typedef enum
}
migraphx_shape_datatype_t
;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
typedef
struct
migraphx_optimals
*
migraphx_optimals_t
;
typedef
const
struct
migraphx_optimals
*
const_migraphx_optimals_t
;
typedef
struct
migraphx_dynamic_dimension
*
migraphx_dynamic_dimension_t
;
typedef
const
struct
migraphx_dynamic_dimension
*
const_migraphx_dynamic_dimension_t
;
typedef
struct
migraphx_dynamic_dimensions
*
migraphx_dynamic_dimensions_t
;
typedef
const
struct
migraphx_dynamic_dimensions
*
const_migraphx_dynamic_dimensions_t
;
typedef
struct
migraphx_shape
*
migraphx_shape_t
;
typedef
const
struct
migraphx_shape
*
const_migraphx_shape_t
;
...
...
@@ -157,6 +166,55 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void
typedef
migraphx_status
(
*
migraphx_experimental_custom_op_delete
)(
void
*
input
);
migraphx_status
migraphx_optimals_destroy
(
migraphx_optimals_t
optimals
);
migraphx_status
migraphx_optimals_assign_to
(
migraphx_optimals_t
output
,
const_migraphx_optimals_t
input
);
migraphx_status
migraphx_optimals_create
(
migraphx_optimals_t
*
optimals
,
const
size_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_dynamic_dimension_destroy
(
migraphx_dynamic_dimension_t
dynamic_dimension
);
migraphx_status
migraphx_dynamic_dimension_assign_to
(
migraphx_dynamic_dimension_t
output
,
const_migraphx_dynamic_dimension_t
input
);
migraphx_status
migraphx_dynamic_dimension_create_min_max
(
migraphx_dynamic_dimension_t
*
dynamic_dimension
,
size_t
min
,
size_t
max
);
migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals
(
migraphx_dynamic_dimension_t
*
dynamic_dimension
,
size_t
min
,
size_t
max
,
migraphx_optimals_t
optimals
);
migraphx_status
migraphx_dynamic_dimension_is_fixed
(
bool
*
out
,
const_migraphx_dynamic_dimension_t
dynamic_dimension
);
migraphx_status
migraphx_dynamic_dimension_equal
(
bool
*
out
,
const_migraphx_dynamic_dimension_t
dynamic_dimension
,
const_migraphx_dynamic_dimension_t
x
);
migraphx_status
migraphx_dynamic_dimensions_destroy
(
migraphx_dynamic_dimensions_t
dynamic_dimensions
);
migraphx_status
migraphx_dynamic_dimensions_assign_to
(
migraphx_dynamic_dimensions_t
output
,
const_migraphx_dynamic_dimensions_t
input
);
migraphx_status
migraphx_dynamic_dimensions_create
(
migraphx_dynamic_dimensions_t
*
dynamic_dimensions
,
const_migraphx_dynamic_dimension_t
*
ptr
,
size_t
size
);
migraphx_status
migraphx_dynamic_dimensions_size
(
size_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
);
migraphx_status
migraphx_dynamic_dimensions_get
(
const_migraphx_dynamic_dimension_t
*
out
,
migraphx_dynamic_dimensions_t
dynamic_dimensions
,
size_t
idx
);
migraphx_status
migraphx_shape_destroy
(
migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_assign_to
(
migraphx_shape_t
output
,
const_migraphx_shape_t
input
);
...
...
@@ -176,23 +234,34 @@ migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
migraphx_status
migraphx_shape_create_scalar
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
);
migraphx_status
migraphx_shape_create_dynamic
(
migraphx_shape_t
*
shape
,
migraphx_shape_datatype_t
type
,
migraphx_dynamic_dimensions_t
dims
);
migraphx_status
migraphx_shape_lengths
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_strides
(
const
size_t
**
out
,
size_t
*
out_size
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_dyn_dims
(
migraphx_dynamic_dimensions_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_type
(
migraphx_shape_datatype_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_elements
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_bytes
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_ndim
(
size_t
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_equal
(
bool
*
out
,
const_migraphx_shape_t
shape
,
const_migraphx_shape_t
x
);
migraphx_status
migraphx_shape_standard
(
bool
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_dynamic
(
bool
*
out
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_shape_index
(
size_t
*
out
,
const_migraphx_shape_t
shape
,
size_t
i
);
migraphx_status
migraphx_argument_destroy
(
migraphx_argument_t
argument
);
...
...
@@ -203,6 +272,9 @@ migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
migraphx_status
migraphx_argument_create
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
,
void
*
buffer
);
migraphx_status
migraphx_argument_create_empty
(
migraphx_argument_t
*
argument
,
const_migraphx_shape_t
shape
);
migraphx_status
migraphx_argument_shape
(
const_migraphx_shape_t
*
out
,
const_migraphx_argument_t
argument
);
...
...
@@ -397,9 +469,16 @@ migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_optio
migraphx_status
migraphx_onnx_options_set_input_parameter_shape
(
migraphx_onnx_options_t
onnx_options
,
const
char
*
name
,
size_t
*
dims
,
size_t
dims_size
);
migraphx_status
migraphx_onnx_options_set_dyn_input_parameter_shape
(
migraphx_onnx_options_t
onnx_options
,
const
char
*
name
,
migraphx_dynamic_dimensions_t
dims
);
migraphx_status
migraphx_onnx_options_set_default_dim_value
(
migraphx_onnx_options_t
onnx_options
,
size_t
value
);
migraphx_status
migraphx_onnx_options_set_default_dyn_dim_value
(
migraphx_onnx_options_t
onnx_options
,
const_migraphx_dynamic_dimension_t
dd
);
migraphx_status
migraphx_onnx_options_set_default_loop_iterations
(
migraphx_onnx_options_t
onnx_options
,
int64_t
value
);
...
...
@@ -427,6 +506,10 @@ migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_opt
migraphx_status
migraphx_compile_options_set_fast_math
(
migraphx_compile_options_t
compile_options
,
bool
value
);
migraphx_status
migraphx_compile_options_set_exhaustive_tune_flag
(
migraphx_compile_options_t
compile_options
,
bool
value
);
migraphx_status
migraphx_parse_onnx
(
migraphx_program_t
*
out
,
const
char
*
name
,
migraphx_onnx_options_t
options
);
...
...
src/api/include/migraphx/migraphx.hpp
View file @
baac1dab
...
...
@@ -571,10 +571,90 @@ using require_interface =
// NOLINTNEXTLINE
#define MIGRAPHX_CONST_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, const)
/**
* Container to hold optimal dynamic dimension values.
*/
struct
optimals
:
MIGRAPHX_HANDLE_BASE
(
optimals
)
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
optimals
)
optimals
(
std
::
initializer_list
<
size_t
>
init_list
)
{
this
->
make_handle
(
&
migraphx_optimals_create
,
init_list
.
begin
(),
init_list
.
size
());
}
};
/**
* @brief Dynamic dimension object.
* @details minimum, maximum, and optimal dimensions
*/
struct
dynamic_dimension
:
MIGRAPHX_CONST_HANDLE_BASE
(
dynamic_dimension
)
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
dynamic_dimension
)
dynamic_dimension
(
size_t
min
,
size_t
max
)
{
this
->
make_handle
(
&
migraphx_dynamic_dimension_create_min_max
,
min
,
max
);
}
dynamic_dimension
(
size_t
min
,
size_t
max
,
const
optimals
&
opts
)
{
this
->
make_handle
(
&
migraphx_dynamic_dimension_create_min_max_optimals
,
min
,
max
,
opts
.
get_handle_ptr
());
}
bool
is_fixed
()
const
{
bool
result
=
false
;
call
(
&
migraphx_dynamic_dimension_is_fixed
,
&
result
,
this
->
get_handle_ptr
());
return
result
;
}
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
)
{
bool
pout
;
call
(
&
migraphx_dynamic_dimension_equal
,
&
pout
,
x
.
get_handle_ptr
(),
y
.
get_handle_ptr
());
return
pout
;
}
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
)
{
return
not
(
x
==
y
);
}
};
/**
* Container to hold dynamic_dimension objects.
*/
struct
dynamic_dimensions
:
MIGRAPHX_HANDLE_BASE
(
dynamic_dimensions
)
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
dynamic_dimensions
)
template
<
class
...
Ts
>
dynamic_dimensions
(
Ts
...
xs
)
{
std
::
array
<
const_migraphx_dynamic_dimension_t
,
sizeof
...(
Ts
)
>
a
{
xs
.
get_handle_ptr
()...};
this
->
make_handle
(
&
migraphx_dynamic_dimensions_create
,
a
.
data
(),
a
.
size
());
}
size_t
size
()
const
{
size_t
pout
;
call
(
&
migraphx_dynamic_dimensions_size
,
&
pout
,
this
->
get_handle_ptr
());
return
pout
;
}
dynamic_dimension
operator
[](
size_t
pidx
)
const
{
const_migraphx_dynamic_dimension_t
pout
;
call
(
&
migraphx_dynamic_dimensions_get
,
&
pout
,
this
->
get_handle_ptr
(),
pidx
);
return
{
pout
,
this
->
share_handle
()};
}
};
/**
* @brief Describe shape of tensor
* @details A shape consists of a data type, lengths of multi-dimension tensor, and strides
*
*/
struct
shape
:
MIGRAPHX_CONST_HANDLE_BASE
(
shape
)
{
...
...
@@ -598,6 +678,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
this
->
make_handle
(
&
migraphx_shape_create
,
type
,
plengths
.
data
(),
plengths
.
size
());
}
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape
(
migraphx_shape_datatype_t
t
,
std
::
initializer_list
<
std
::
size_t
>
d
)
:
shape
::
shape
(
t
,
std
::
vector
<
std
::
size_t
>
{
d
.
begin
(),
d
.
end
()})
{
}
shape
(
migraphx_shape_datatype_t
type
,
std
::
vector
<
size_t
>
plengths
,
std
::
vector
<
size_t
>
pstrides
)
...
...
@@ -610,6 +697,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
pstrides
.
size
());
}
shape
(
migraphx_shape_datatype_t
type
,
const
dynamic_dimensions
&
dyn_dims
)
{
this
->
make_handle
(
&
migraphx_shape_create_dynamic
,
type
,
dyn_dims
.
get_handle_ptr
());
}
std
::
vector
<
size_t
>
lengths
()
const
{
const
size_t
*
pout
;
...
...
@@ -626,6 +718,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
{
pout
,
pout
+
pout_size
};
}
/// Get the dynamic dimensions of the shape
dynamic_dimensions
dyn_dims
()
const
{
migraphx_dynamic_dimensions_t
pout
;
call
(
&
migraphx_shape_dyn_dims
,
&
pout
,
this
->
get_handle_ptr
());
return
{
pout
,
own
{}};
}
migraphx_shape_datatype_t
type
()
const
{
migraphx_shape_datatype_t
pout
;
...
...
@@ -654,6 +754,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
result
;
}
/// Is the shape dynamic
bool
dynamic
()
const
{
bool
result
=
false
;
call
(
&
migraphx_shape_dynamic
,
&
result
,
this
->
get_handle_ptr
());
return
result
;
}
// map element index to space index
size_t
index
(
size_t
i
)
const
{
...
...
@@ -687,6 +795,11 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
shape
pshape
)
{
this
->
make_handle
(
&
migraphx_argument_create_empty
,
pshape
.
get_handle_ptr
());
}
argument
(
shape
pshape
,
void
*
pbuffer
)
{
this
->
make_handle
(
&
migraphx_argument_create
,
pshape
.
get_handle_ptr
(),
pbuffer
);
...
...
@@ -1015,6 +1128,12 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{
call
(
&
migraphx_compile_options_set_fast_math
,
this
->
get_handle_ptr
(),
value
);
}
/// Set or un-set exhaustive search to find fastest kernel
void
set_exhaustive_tune_flag
(
bool
value
=
true
)
{
call
(
&
migraphx_compile_options_set_exhaustive_tune_flag
,
this
->
get_handle_ptr
(),
value
);
}
};
/// A program represents the all computation graphs to be compiled and executed
...
...
@@ -1176,12 +1295,27 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
dim
.
size
());
}
void
set_dyn_input_parameter_shape
(
const
std
::
string
&
name
,
const
dynamic_dimensions
&
dyn_dims
)
{
call
(
&
migraphx_onnx_options_set_dyn_input_parameter_shape
,
this
->
get_handle_ptr
(),
name
.
c_str
(),
dyn_dims
.
get_handle_ptr
());
}
/// When there is a dimension parameter, then use this default value
void
set_default_dim_value
(
unsigned
int
value
)
{
call
(
&
migraphx_onnx_options_set_default_dim_value
,
this
->
get_handle_ptr
(),
value
);
}
void
set_default_dyn_dim_value
(
const
dynamic_dimension
&
dd
)
{
call
(
&
migraphx_onnx_options_set_default_dyn_dim_value
,
this
->
get_handle_ptr
(),
dd
.
get_handle_ptr
());
}
/// Set default max iteration number for the loop operator
void
set_default_loop_iterations
(
int64_t
value
)
{
...
...
src/api/migraphx.py
View file @
baac1dab
...
...
@@ -45,56 +45,48 @@ def shape_type_wrap(p):
p
.
read
=
'migraphx::to_shape_type(${name})'
@
api
.
cwrap
(
'migraphx::compile_options'
)
def
compile_options_type_wrap
(
p
):
if
p
.
returns
:
p
.
add_param
(
'migraphx_compile_options *'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
write
=
[
'*${name} = migraphx::to_compile_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_compile_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})'
@
api
.
cwrap
(
'migraphx::file_options'
)
def
file_options_type_wrap
(
p
):
if
p
.
returns
:
p
.
add_param
(
'migraphx_file_options *'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
write
=
[
'*${name} = migraphx::to_file_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_file_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*${name})'
def
auto_handle
(
*
args
,
**
kwargs
):
def
with_handle
(
f
):
return
api
.
handle
(
'migraphx_'
+
f
.
__name__
,
'migraphx::'
+
f
.
__name__
,
*
args
,
**
kwargs
)(
f
)
return
with_handle
@
api
.
cwrap
(
'migraphx::onnx_options'
)
def
onnx_options_type_wrap
(
p
):
if
p
.
returns
:
p
.
add_param
(
'migraphx_onnx_options *'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
write
=
[
'*${name} = migraphx::to_onnx_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_onnx_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})'
@
api
.
handle
(
'migraphx_optimals'
,
'std::set<size_t>'
)
def
optimals
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'const size_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::make_set<size_t>'
)
@
api
.
cwrap
(
'migraphx::tf_options'
)
def
tf_options_type_wrap
(
p
):
if
p
.
returns
:
p
.
add_param
(
'migraphx_tf_options *'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
write
=
[
'*${name} = migraphx::to_tf_options(${result})'
]
else
:
p
.
add_param
(
'migraphx_tf_options *'
)
p
.
read
=
'${name} == nullptr ? migraphx::tf_options{} : migraphx::to_tf_options(*${name})'
@
api
.
handle
(
'migraphx_dynamic_dimension'
,
'migraphx::shape::dynamic_dimension'
)
def
dynamic_dimension
(
h
):
h
.
constructor
(
'create_min_max'
,
api
.
params
(
min
=
'size_t'
,
max
=
'size_t'
))
h
.
constructor
(
'create_min_max_optimals'
,
api
.
params
(
min
=
'size_t'
,
max
=
'size_t'
,
optimals
=
'std::set<size_t>'
))
h
.
method
(
'is_fixed'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::shape::dynamic_dimension&'
),
invoke
=
'migraphx::equal($@)'
,
returns
=
'bool'
,
const
=
True
)
def
auto_handle
(
*
args
,
**
kwargs
):
def
with_handle
(
f
):
return
api
.
handle
(
'migraphx_'
+
f
.
__name__
,
'migraphx::'
+
f
.
__name__
,
*
args
,
**
kwargs
)(
f
)
return
with_handle
@
api
.
handle
(
'migraphx_dynamic_dimensions'
,
'std::vector<migraphx::shape::dynamic_dimension>'
)
def
dynamic_dimensions
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
ptr
=
'const_migraphx_dynamic_dimension_t*'
,
size
=
'size_t'
),
fname
=
'migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>'
)
h
.
method
(
'size'
,
returns
=
'size_t'
)
h
.
method
(
'get'
,
api
.
params
(
idx
=
'size_t'
),
fname
=
'at'
,
cpp_name
=
'operator[]'
,
returns
=
'const migraphx::shape::dynamic_dimension&'
)
@
auto_handle
()
...
...
@@ -109,20 +101,29 @@ def shape(h):
lengths
=
'std::vector<size_t>'
,
strides
=
'std::vector<size_t>'
))
h
.
constructor
(
'create_scalar'
,
api
.
params
(
type
=
'migraphx::shape::type_t'
))
h
.
constructor
(
'create_dynamic'
,
api
.
params
(
type
=
'migraphx::shape::type_t'
,
dims
=
'std::vector<migraphx::shape::dynamic_dimension>'
))
h
.
method
(
'lengths'
,
fname
=
'lens'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
h
.
method
(
'strides'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
h
.
method
(
'dyn_dims'
,
returns
=
'std::vector<migraphx::shape::dynamic_dimension>'
,
const
=
True
)
h
.
method
(
'type'
,
returns
=
'migraphx::shape::type_t'
,
const
=
True
)
h
.
method
(
'elements'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'bytes'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'ndim'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::shape&'
),
invoke
=
'migraphx::equal($@)'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'standard'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'dynamic'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'index'
,
api
.
params
(
i
=
'size_t'
),
returns
=
'size_t'
,
const
=
True
)
...
...
@@ -130,6 +131,7 @@ def shape(h):
def
argument
(
h
):
h
.
constructor
(
'create'
,
api
.
params
(
shape
=
'const migraphx::shape&'
,
buffer
=
'void*'
))
h
.
constructor
(
'create_empty'
,
api
.
params
(
shape
=
'const migraphx::shape&'
))
h
.
method
(
'shape'
,
fname
=
'get_shape'
,
cpp_name
=
'get_shape'
,
...
...
@@ -325,11 +327,22 @@ def onnx_options(h):
api
.
params
(
name
=
'const char*'
,
dims
=
'std::vector<size_t>'
),
invoke
=
'migraphx::set_input_parameter_shape($@)'
,
)
h
.
method
(
'set_dyn_input_parameter_shape'
,
api
.
params
(
name
=
'const char*'
,
dims
=
'std::vector<migraphx::shape::dynamic_dimension>'
),
invoke
=
'migraphx::set_dyn_input_parameter_shape($@)'
,
)
h
.
method
(
'set_default_dim_value'
,
api
.
params
(
value
=
'size_t'
),
invoke
=
'migraphx::set_default_dim_value($@)'
,
)
h
.
method
(
'set_default_dyn_dim_value'
,
api
.
params
(
dd
=
'const migraphx::shape::dynamic_dimension&'
),
invoke
=
'migraphx::set_default_dyn_dim_value($@)'
,
)
h
.
method
(
'set_default_loop_iterations'
,
api
.
params
(
value
=
'int64_t'
),
...
...
@@ -354,6 +367,9 @@ def compile_options(h):
h
.
method
(
'set_fast_math'
,
api
.
params
(
value
=
'bool'
),
invoke
=
'migraphx::set_fast_math($@)'
)
h
.
method
(
'set_exhaustive_tune_flag'
,
api
.
params
(
value
=
'bool'
),
invoke
=
'migraphx::set_exhaustive_tune_flag($@)'
)
api
.
add_function
(
'migraphx_parse_onnx'
,
...
...
src/common.cpp
View file @
baac1dab
...
...
@@ -89,8 +89,8 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
}
else
if
(
a
==
1
or
b
==
1
)
{
// setting opt
to 0
, may need to be changed
return
shape
::
dynamic_dimension
{
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
)
,
0
};
// setting opt
imals to empty
, may need to be changed
return
shape
::
dynamic_dimension
{
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
)};
}
else
{
...
...
@@ -148,10 +148,8 @@ shape common_shape(const std::vector<shape>& shapes)
return
{
compute_common_types
(
shapes
),
compute_common_lens
(
shapes
)};
}
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
inputs
)
std
::
vector
<
instruction_ref
>
insert_common_args
(
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
)
{
if
(
std
::
any_of
(
inputs
.
cbegin
(),
inputs
.
cend
(),
[](
auto
input
)
{
return
input
->
get_shape
().
dynamic
();
}))
...
...
@@ -210,7 +208,20 @@ instruction_ref insert_common_op(module& m,
return
input
;
});
}
return
m
.
insert_instruction
(
ins
,
op
,
inputs
);
return
inputs
;
}
std
::
vector
<
instruction_ref
>
add_common_args
(
module
&
m
,
std
::
vector
<
instruction_ref
>
inputs
)
{
return
insert_common_args
(
m
,
m
.
end
(),
std
::
move
(
inputs
));
}
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
inputs
)
{
return
m
.
insert_instruction
(
ins
,
op
,
insert_common_args
(
m
,
ins
,
std
::
move
(
inputs
)));
}
instruction_ref
add_common_op
(
module
&
m
,
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
inputs
)
...
...
src/compile_src.cpp
View file @
baac1dab
...
...
@@ -70,9 +70,6 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
if
(
not
fs
::
exists
(
out_path
))
MIGRAPHX_THROW
(
"Output file missing: "
+
out
);
if
(
process
)
out_path
=
process
(
out_path
);
return
read_buffer
(
out_path
.
string
());
}
...
...
src/cpp_generator.cpp
View file @
baac1dab
...
...
@@ -106,6 +106,18 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return
*
this
;
}
cpp_generator
::
function
&
cpp_generator
::
function
::
unused_param
(
const
std
::
string
&
pname
)
{
body
.
insert
(
0
,
"(void)"
+
pname
+
";
\n
"
);
return
*
this
;
}
cpp_generator
::
function
&
cpp_generator
::
function
::
add_generic_param
(
const
std
::
string
&
pname
)
{
params
.
push_back
({
pname
,
"T"
+
pname
});
tparams
.
push_back
(
"class T"
+
pname
);
return
*
this
;
}
struct
cpp_generator_impl
{
std
::
stringstream
fs
{};
...
...
@@ -167,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
else
if
(
with_char
(
::
isdigit
)(
key
[
0
]))
{
auto
i
=
std
::
stoul
(
key
);
if
(
i
>=
args
.
size
())
MIGRAPHX_THROW
(
"Invalid argument index: "
+
key
);
return
args
.
at
(
i
);
}
else
if
(
v
.
contains
(
key
))
...
...
@@ -182,7 +196,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
std
::
string
cpp_generator
::
str
()
const
{
return
impl
->
fs
.
str
();
}
cpp_generator
::
function
cpp_generator
::
generate_module
(
const
module
&
m
)
cpp_generator
::
function
cpp_generator
::
generate_module
(
const
module
&
m
,
const
generate_module_callback
&
g
)
{
function
f
;
auto
name
=
transform_string
(
m
.
name
(),
[](
char
c
)
{
...
...
@@ -195,13 +210,7 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
if
(
ins
->
name
()
==
"@literal"
)
return
shape
::
cpp_type
(
ins
->
get_shape
().
type
())
+
"("
+
ins
->
get_literal
().
to_string
()
+
")"
;
std
::
vector
<
std
::
string
>
args
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
args
),
[
&
](
auto
i
)
{
return
names
.
at
(
i
);
});
auto
s
=
this
->
generate_point_op
(
ins
->
get_operator
(),
args
);
auto
s
=
g
(
ins
,
names
);
if
(
impl
->
fresult
)
return
impl
->
fresult
(
ins
->
get_shape
())
+
'('
+
s
+
')'
;
else
...
...
@@ -210,6 +219,24 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
return
f
;
}
std
::
vector
<
std
::
string
>
cpp_generator
::
to_args
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const
std
::
unordered_map
<
instruction_ref
,
std
::
string
>&
names
)
{
std
::
vector
<
std
::
string
>
args
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
args
),
[
&
](
auto
i
)
{
return
names
.
at
(
i
);
});
return
args
;
}
cpp_generator
::
function
cpp_generator
::
generate_module
(
const
module
&
m
)
{
return
this
->
generate_module
(
m
,
[
&
](
auto
ins
,
const
auto
&
names
)
{
return
this
->
generate_point_op
(
ins
->
get_operator
(),
to_args
(
ins
->
inputs
(),
names
));
});
}
std
::
string
cpp_generator
::
create_function
(
const
cpp_generator
::
function
&
f
)
{
impl
->
function_count
++
;
...
...
@@ -218,6 +245,8 @@ std::string cpp_generator::create_function(const cpp_generator::function& f)
std
::
string
name
=
f
.
name
.
empty
()
?
"f"
+
std
::
to_string
(
impl
->
function_count
)
:
f
.
name
;
impl
->
fs
<<
join_strings
(
f
.
attributes
,
" "
)
<<
" "
<<
f
.
return_type
<<
" "
<<
name
;
char
delim
=
'('
;
if
(
f
.
params
.
empty
())
impl
->
fs
<<
delim
;
for
(
auto
&&
p
:
f
.
params
)
{
impl
->
fs
<<
delim
<<
p
.
type
<<
" "
<<
p
.
name
;
...
...
src/driver/CMakeLists.txt
View file @
baac1dab
...
...
@@ -42,6 +42,7 @@ add_custom_command(
)
set_directory_properties
(
PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES
${
CMAKE_RUNTIME_OUTPUT_DIRECTORY
}
/driver
)
rocm_clang_tidy_check
(
driver
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf
)
rocm_install_targets
(
...
...
src/driver/argument_parser.hpp
View file @
baac1dab
...
...
@@ -148,13 +148,21 @@ struct value_parser
template
<
MIGRAPHX_REQUIRES
(
not
std
::
is_enum
<
T
>{}
and
not
is_multi_value
<
T
>
{})
>
static
T
apply
(
const
std
::
string
&
x
)
{
T
result
;
std
::
stringstream
ss
;
ss
.
str
(
x
);
ss
>>
result
;
if
(
ss
.
fail
())
throw
std
::
runtime_error
(
"Failed to parse '"
+
x
+
"' as "
+
type_name
<
T
>::
apply
());
return
result
;
// handle whitespace in string
if
constexpr
(
std
::
is_same
<
T
,
std
::
string
>
{})
{
return
x
;
}
else
{
T
result
;
std
::
stringstream
ss
;
ss
.
str
(
x
);
ss
>>
result
;
if
(
ss
.
fail
())
throw
std
::
runtime_error
(
"Failed to parse '"
+
x
+
"' as "
+
type_name
<
T
>::
apply
());
return
result
;
}
}
template
<
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{}
and
not
is_multi_value
<
T
>
{})
>
...
...
src/driver/main.cpp
View file @
baac1dab
...
...
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify.hpp"
#include "argument_parser.hpp"
#include "command.hpp"
...
...
@@ -32,6 +33,7 @@
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/json.hpp>
#include <migraphx/version.h>
...
...
@@ -67,7 +69,9 @@ struct loader
bool
brief
=
false
;
std
::
string
output_type
;
std
::
string
output
;
std
::
string
default_dyn_dim
;
std
::
vector
<
std
::
string
>
param_dims
;
std
::
vector
<
std
::
string
>
dyn_param_dims
;
std
::
vector
<
std
::
string
>
output_names
;
void
parse
(
argument_parser
&
ap
)
...
...
@@ -82,7 +86,11 @@ struct loader
ap
(
file_type
,
{
"--tf"
},
ap
.
help
(
"Load as tensorflow"
),
ap
.
set_value
(
"tf"
));
ap
(
file_type
,
{
"--migraphx"
},
ap
.
help
(
"Load as MIGraphX"
),
ap
.
set_value
(
"migraphx"
));
ap
(
file_type
,
{
"--migraphx-json"
},
ap
.
help
(
"Load as MIGraphX JSON"
),
ap
.
set_value
(
"json"
));
ap
(
batch
,
{
"--batch"
},
ap
.
help
(
"Set batch size for model"
));
ap
(
batch
,
{
"--batch"
},
ap
.
help
(
"For a static model, sets default_dim_value size (commonly batch size). For a "
"dynamic batch model, sets the batch "
"size at runtime."
));
ap
(
is_nhwc
,
{
"--nhwc"
},
ap
.
help
(
"Treat tensorflow format as nhwc"
),
ap
.
set_value
(
true
));
ap
(
skip_unknown_operators
,
{
"--skip-unknown-operators"
},
...
...
@@ -95,7 +103,16 @@ struct loader
ap
.
help
(
"Dim of a parameter (format:
\"
@name d1 d2 dn
\"
)"
),
ap
.
append
(),
ap
.
nargs
(
2
));
ap
(
dyn_param_dims
,
{
"--dyn-input-dim"
},
ap
.
help
(
"Dynamic dimensions of a parameter (format:
\"
@name_1
\"
\"
[{min:x, max:y, "
"optimals:[o1,o2,...]}, dim2,dim3, ...]
\"
,
\"
@name_2
\"
, ... You can supply a "
"single integer value for a dimension to specify it as fixed."
),
ap
.
append
(),
ap
.
nargs
(
2
));
ap
(
default_dyn_dim
,
{
"--default-dyn-dim"
},
ap
.
help
(
"Default dynamic dimension (format:
\"
{min:x, max:y, optimals:[o1,o2]}
\"
)."
));
ap
(
output_names
,
{
"--output-names"
},
ap
.
help
(
"Names of node output (format:
\"
name_1 name_2 name_n
\"
)"
),
...
...
@@ -146,6 +163,40 @@ struct loader
return
map_input_dims
;
}
static
auto
parse_dyn_dims_json
(
const
std
::
string
&
dd_json
)
{
// expecting a json string like "[{min:1,max:64,optimals:[1,2,4,8]},3,224,224]"
auto
v
=
from_json_string
(
convert_to_json
(
dd_json
));
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dyn_dims
;
std
::
transform
(
v
.
begin
(),
v
.
end
(),
std
::
back_inserter
(
dyn_dims
),
[
&
](
auto
x
)
{
if
(
x
.
is_object
())
return
from_value
<
migraphx
::
shape
::
dynamic_dimension
>
(
x
);
auto
d
=
x
.
template
to
<
std
::
size_t
>();
return
migraphx
::
shape
::
dynamic_dimension
{
d
,
d
};
});
return
dyn_dims
;
}
static
auto
parse_dyn_dims_map
(
const
std
::
vector
<
std
::
string
>&
param_dyn_dims
)
{
// expecting vector of strings formatted like
// {"@param_name_0", "dd_json_0", "@param_name_1", "dd_json_1", ...}
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
;
std
::
string
name
=
""
;
for
(
auto
&&
x
:
param_dyn_dims
)
{
if
(
x
[
0
]
==
'@'
)
{
name
=
x
.
substr
(
1
);
}
else
{
map_dyn_input_dims
[
name
]
=
parse_dyn_dims_json
(
x
);
}
}
return
map_dyn_input_dims
;
}
static
auto
parse_output_names
(
const
std
::
vector
<
std
::
string
>&
output_names_info
)
{
std
::
vector
<
std
::
string
>
output_node_names
;
...
...
@@ -157,13 +208,44 @@ struct loader
return
output_node_names
;
}
tf_options
get_tf_options
()
const
{
auto
map_input_dims
=
parse_param_dims
(
param_dims
);
auto
output_node_names
=
parse_output_names
(
output_names
);
tf_options
options
;
options
.
is_nhwc
=
is_nhwc
;
options
.
batch_size
=
batch
;
options
.
map_input_dims
=
map_input_dims
;
options
.
output_node_names
=
output_node_names
;
return
options
;
}
onnx_options
get_onnx_options
()
const
{
auto
map_input_dims
=
parse_param_dims
(
param_dims
);
auto
map_dyn_input_dims
=
parse_dyn_dims_map
(
dyn_param_dims
);
onnx_options
options
;
if
(
default_dyn_dim
.
empty
())
{
options
.
default_dim_value
=
batch
;
}
else
{
auto
v
=
from_json_string
(
convert_to_json
(
default_dyn_dim
));
options
.
default_dyn_dim_value
=
from_value
<
migraphx
::
shape
::
dynamic_dimension
>
(
v
);
}
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
print_program_on_error
=
true
;
options
.
map_input_dims
=
map_input_dims
;
options
.
map_dyn_input_dims
=
map_dyn_input_dims
;
return
options
;
}
program
load
()
{
program
p
;
if
(
model
.
empty
())
{
auto
map_input_dims
=
parse_param_dims
(
param_dims
);
auto
output_node_names
=
parse_output_names
(
output_names
);
if
(
file_type
.
empty
())
{
if
(
ends_with
(
file
,
".onnx"
))
...
...
@@ -178,16 +260,11 @@ struct loader
std
::
cout
<<
"Reading: "
<<
file
<<
std
::
endl
;
if
(
file_type
==
"onnx"
)
{
onnx_options
options
;
options
.
default_dim_value
=
batch
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
print_program_on_error
=
true
;
options
.
map_input_dims
=
map_input_dims
;
p
=
parse_onnx
(
file
,
options
);
p
=
parse_onnx
(
file
,
get_onnx_options
());
}
else
if
(
file_type
==
"tf"
)
{
p
=
parse_tf
(
file
,
tf_options
{
is_nhwc
,
batch
,
map_input_dims
,
output_node_names
}
);
p
=
parse_tf
(
file
,
get_
tf_options
()
);
}
else
if
(
file_type
==
"json"
)
{
...
...
@@ -288,14 +365,21 @@ struct program_params
ap
(
fill1
,
{
"--fill1"
},
ap
.
help
(
"Fill parameter with 1s"
),
ap
.
append
(),
ap
.
nargs
(
2
));
}
auto
generate
(
const
program
&
p
,
const
target
&
t
,
bool
offload
)
auto
generate
(
const
program
&
p
,
const
target
&
t
,
bool
offload
,
unsigned
batch
)
{
parameter_map
m
;
auto
param_shapes
=
p
.
get_parameter_shapes
();
std
::
unordered_map
<
std
::
string
,
shape
>
static_param_shapes
;
std
::
transform
(
param_shapes
.
cbegin
(),
param_shapes
.
cend
(),
std
::
inserter
(
static_param_shapes
,
static_param_shapes
.
end
()),
[
&
](
const
auto
&
x
)
{
return
std
::
make_pair
(
x
.
first
,
x
.
second
.
to_static
(
batch
));
});
for
(
auto
&&
s
:
fill0
)
m
[
s
]
=
fill_argument
(
p
.
get
_param
eter
_shape
(
s
),
0
);
m
[
s
]
=
fill_argument
(
static
_param_shape
s
.
at
(
s
),
0
);
for
(
auto
&&
s
:
fill1
)
m
[
s
]
=
fill_argument
(
p
.
get
_param
eter
_shape
(
s
),
1
);
fill_param_map
(
m
,
p
,
t
,
offload
);
m
[
s
]
=
fill_argument
(
static
_param_shape
s
.
at
(
s
),
1
);
fill_param_map
(
m
,
static_param_shapes
,
t
,
offload
);
return
m
;
}
};
...
...
@@ -304,8 +388,12 @@ struct compiler_target
{
#ifdef HAVE_GPU
std
::
string
target_name
=
"gpu"
;
#el
se
#el
if defined(HAVE_CPU)
std
::
string
target_name
=
"cpu"
;
#elif defined(HAVE_FPGA)
std
::
string
target_name
=
"fpga"
;
#else
std
::
string
target_name
=
"ref"
;
#endif
void
parse
(
argument_parser
&
ap
)
...
...
@@ -326,8 +414,7 @@ struct compiler
loader
l
;
program_params
parameters
;
compiler_target
ct
;
bool
offload_copy
=
false
;
bool
fast_math
=
true
;
compile_options
co
;
precision
quantize
=
precision
::
fp32
;
std
::
vector
<
std
::
string
>
fill0
;
...
...
@@ -337,19 +424,26 @@ struct compiler
l
.
parse
(
ap
);
parameters
.
parse
(
ap
);
ct
.
parse
(
ap
);
ap
(
offload_copy
,
ap
(
co
.
offload_copy
,
{
"--enable-offload-copy"
},
ap
.
help
(
"Enable implicit offload copying"
),
ap
.
set_value
(
true
));
ap
(
fast_math
,
ap
(
co
.
fast_math
,
{
"--disable-fast-math"
},
ap
.
help
(
"Disable fast math optimization"
),
ap
.
set_value
(
false
));
ap
(
co
.
exhaustive_tune
,
{
"--exhaustive-tune"
},
ap
.
help
(
"Exhastively search for best tuning parameters for kernels"
),
ap
.
set_value
(
true
));
ap
(
quantize
,
{
"--fp16"
},
ap
.
help
(
"Quantize for fp16"
),
ap
.
set_value
(
precision
::
fp16
));
ap
(
quantize
,
{
"--int8"
},
ap
.
help
(
"Quantize for int8"
),
ap
.
set_value
(
precision
::
int8
));
}
auto
params
(
const
program
&
p
)
{
return
parameters
.
generate
(
p
,
ct
.
get_target
(),
offload_copy
);
}
auto
params
(
const
program
&
p
)
{
return
parameters
.
generate
(
p
,
ct
.
get_target
(),
co
.
offload_copy
,
l
.
batch
);
}
program
compile
()
{
...
...
@@ -366,10 +460,7 @@ struct compiler
{
quantize_int8
(
p
,
t
,
{
params
(
p
)});
}
compile_options
options
;
options
.
offload_copy
=
offload_copy
;
options
.
fast_math
=
fast_math
;
p
.
compile
(
t
,
options
);
p
.
compile
(
t
,
co
);
l
.
save
(
p
);
return
p
;
}
...
...
@@ -402,60 +493,41 @@ struct params : command<params>
struct
verify
:
command
<
verify
>
{
loader
l
;
program_params
parameters
;
compiler_target
ct
;
compiler
c
;
double
tolerance
=
80
;
bool
per_instruction
=
false
;
bool
reduce
=
false
;
bool
offload_copy
=
false
;
bool
fast_math
=
true
;
precision
quantize
=
precision
::
fp32
;
void
parse
(
argument_parser
&
ap
)
{
l
.
parse
(
ap
);
parameters
.
parse
(
ap
);
ct
.
parse
(
ap
);
ap
(
offload_copy
,
{
"--enable-offload-copy"
},
ap
.
help
(
"Enable implicit offload copying"
),
ap
.
set_value
(
true
));
ap
(
fast_math
,
{
"--disable-fast-math"
},
ap
.
help
(
"Disable fast math optimization"
),
ap
.
set_value
(
false
));
c
.
parse
(
ap
);
ap
(
tolerance
,
{
"--tolerance"
},
ap
.
help
(
"Tolerance for errors"
));
ap
(
per_instruction
,
{
"-i"
,
"--per-instruction"
},
ap
.
help
(
"Verify each instruction"
),
ap
.
set_value
(
true
));
ap
(
reduce
,
{
"-r"
,
"--reduce"
},
ap
.
help
(
"Reduce program and verify"
),
ap
.
set_value
(
true
));
ap
(
quantize
,
{
"--fp16"
},
ap
.
help
(
"Quantize for fp16"
),
ap
.
set_value
(
precision
::
fp16
));
}
void
run
()
{
auto
p
=
l
.
load
();
l
.
save
(
p
);
auto
p
=
c
.
l
.
load
();
c
.
l
.
save
(
p
);
std
::
cout
<<
p
<<
std
::
endl
;
compile_options
options
;
options
.
offload_copy
=
offload_copy
;
options
.
fast_math
=
fast_math
;
auto
t
=
ct
.
get_target
();
auto
m
=
parameters
.
generate
(
p
,
t
,
true
);
auto
t
=
c
.
ct
.
get_target
();
auto
m
=
c
.
parameters
.
generate
(
p
,
t
,
true
,
c
.
l
.
batch
);
if
(
per_instruction
)
{
verify_instructions
(
p
,
t
,
options
,
quantize
,
tolerance
);
verify_instructions
(
p
,
t
,
c
.
co
,
c
.
quantize
,
tolerance
);
}
else
if
(
reduce
)
{
verify_reduced_program
(
p
,
t
,
options
,
quantize
,
m
,
tolerance
);
verify_reduced_program
(
p
,
t
,
c
.
co
,
c
.
quantize
,
m
,
tolerance
);
}
else
{
verify_program
(
l
.
file
,
p
,
t
,
options
,
quantize
,
m
,
tolerance
);
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
c
.
quantize
,
m
,
tolerance
);
}
}
};
...
...
@@ -466,7 +538,8 @@ struct version : command<version>
void
run
()
const
{
std
::
cout
<<
"MIGraphX Version: "
<<
MIGRAPHX_VERSION_MAJOR
<<
"."
<<
MIGRAPHX_VERSION_MINOR
<<
std
::
endl
;
<<
"."
<<
MIGRAPHX_VERSION_PATCH
<<
"."
<<
MIGRAPHX_STRINGIZE
(
MIGRAPHX_VERSION_TWEAK
)
<<
std
::
endl
;
}
};
...
...
@@ -584,6 +657,26 @@ struct onnx : command<onnx>
}
};
struct
tf
:
command
<
tf
>
{
bool
show_ops
=
false
;
void
parse
(
argument_parser
&
ap
)
{
ap
(
show_ops
,
{
"--list"
,
"-l"
},
ap
.
help
(
"List all tf operators supported by MIGraphX"
),
ap
.
set_value
(
true
));
}
void
run
()
const
{
if
(
show_ops
)
{
for
(
const
auto
&
name
:
get_tf_operators
())
std
::
cout
<<
name
<<
std
::
endl
;
}
}
};
struct
main_command
{
static
std
::
string
get_command_help
(
const
std
::
string
&
title
=
colorize
(
color
::
fg_yellow
,
...
...
@@ -603,7 +696,9 @@ struct main_command
void
parse
(
argument_parser
&
ap
)
{
std
::
string
version_str
=
"MIGraphX Version: "
+
std
::
to_string
(
MIGRAPHX_VERSION_MAJOR
)
+
"."
+
std
::
to_string
(
MIGRAPHX_VERSION_MINOR
);
"."
+
std
::
to_string
(
MIGRAPHX_VERSION_MINOR
)
+
"."
+
std
::
to_string
(
MIGRAPHX_VERSION_PATCH
)
+
"."
+
MIGRAPHX_STRINGIZE
(
MIGRAPHX_VERSION_TWEAK
);
ap
(
wrong_commands
,
{},
ap
.
metavar
(
"<command>"
),
ap
.
append
());
ap
(
nullptr
,
{
"-h"
,
"--help"
},
ap
.
help
(
"Show help"
),
ap
.
show_help
(
get_command_help
()));
ap
(
nullptr
,
...
...
src/driver/perf.cpp
View file @
baac1dab
...
...
@@ -39,36 +39,25 @@ auto get_hash(const T& x)
return
std
::
hash
<
T
>
{}(
x
);
}
parameter_map
fill_param_map
(
parameter_map
&
m
,
const
program
&
p
,
const
target
&
t
,
bool
offload
)
parameter_map
fill_param_map
(
parameter_map
&
m
,
const
std
::
unordered_map
<
std
::
string
,
shape
>&
param_shapes
,
const
target
&
t
,
bool
offload
)
{
for
(
auto
&&
x
:
p
.
get_parameter
_shapes
()
)
for
(
auto
&&
x
:
p
aram
_shapes
)
{
argument
&
arg
=
m
[
x
.
first
];
if
(
arg
.
empty
())
{
assert
(
not
x
.
second
.
dynamic
());
arg
=
generate_argument
(
x
.
second
,
get_hash
(
x
.
first
));
}
if
(
not
offload
)
arg
=
t
.
copy_to
(
arg
);
}
return
m
;
}
parameter_map
fill_param_map
(
parameter_map
&
m
,
const
program
&
p
,
bool
gpu
)
{
for
(
auto
&&
x
:
p
.
get_parameter_shapes
())
{
argument
&
arg
=
m
[
x
.
first
];
if
(
arg
.
empty
())
arg
=
generate_argument
(
x
.
second
,
get_hash
(
x
.
first
));
#ifdef HAVE_GPU
if
(
gpu
)
arg
=
gpu
::
to_gpu
(
arg
);
#else
(
void
)
gpu
;
#endif
}
return
m
;
}
parameter_map
create_param_map
(
const
program
&
p
,
const
target
&
t
,
bool
offload
)
{
parameter_map
m
;
...
...
@@ -108,8 +97,6 @@ target get_target(bool gpu)
return
make_target
(
"cpu"
);
}
void
compile_program
(
program
&
p
,
bool
gpu
)
{
p
.
compile
(
get_target
(
gpu
));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace migraphx
src/driver/perf.hpp
View file @
baac1dab
...
...
@@ -30,14 +30,15 @@ namespace migraphx {
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
parameter_map
fill_param_map
(
parameter_map
&
m
,
const
program
&
p
,
const
target
&
t
,
bool
offload
=
false
);
parameter_map
fill_param_map
(
parameter_map
&
m
,
const
std
::
unordered_map
<
std
::
string
,
shape
>&
param_shapes
,
const
target
&
t
,
bool
offload
=
false
);
parameter_map
create_param_map
(
const
program
&
p
,
const
target
&
t
,
bool
offload
=
false
);
parameter_map
fill_param_map
(
parameter_map
&
m
,
const
program
&
p
,
bool
gpu
);
parameter_map
create_param_map
(
const
program
&
p
,
bool
gpu
=
true
);
target
get_target
(
bool
gpu
);
void
compile_program
(
program
&
p
,
bool
gpu
=
true
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
...
...
src/driver/verify.cpp
View file @
baac1dab
...
...
@@ -24,7 +24,7 @@
#include "verify.hpp"
#include "perf.hpp"
#include <migraphx/re
f/
target.hpp>
#include <migraphx/re
gister_
target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp>
...
...
@@ -37,7 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS {
std
::
vector
<
argument
>
run_ref
(
program
p
,
const
parameter_map
&
inputs
)
{
p
.
compile
(
ref
::
target
{}
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
)
);
auto
out
=
p
.
eval
(
inputs
);
std
::
cout
<<
p
<<
std
::
endl
;
return
out
;
...
...
src/dynamic_loader.cpp
View file @
baac1dab
...
...
@@ -21,25 +21,44 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/manage_ptr.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
#include <utility>
#include <dlfcn.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
check_load_error
(
bool
flush
=
false
)
{
char
*
error_msg
=
dlerror
();
if
(
not
flush
and
error_msg
!=
nullptr
)
MIGRAPHX_THROW
(
"Dynamic loading or symbol lookup failed with "
+
std
::
string
(
error_msg
));
}
struct
dynamic_loader_impl
{
dynamic_loader_impl
()
=
default
;
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
dynamic_loader_impl
(
const
fs
::
path
&
p
,
std
::
shared_ptr
<
tmp_dir
>
t
=
nullptr
)
:
handle
(
dlopen
(
p
.
string
().
c_str
(),
RTLD_LAZY
),
&
dlclose
),
temp
(
std
::
move
(
t
))
:
handle
(
dlopen
(
p
.
string
().
c_str
(),
RTLD_LAZY
),
manage_deleter
<
decltype
(
&
dlclose
),
&
dlclose
>
{}),
temp
(
std
::
move
(
t
))
{
check_load_error
();
}
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
static
std
::
shared_ptr
<
dynamic_loader_impl
>
from_buffer
(
const
char
*
image
,
std
::
size_t
size
)
{
auto
t
=
std
::
make_shared
<
tmp_dir
>
(
"dloader"
);
...
...
@@ -52,6 +71,16 @@ struct dynamic_loader_impl
std
::
shared_ptr
<
tmp_dir
>
temp
=
nullptr
;
};
fs
::
path
dynamic_loader
::
path
(
void
*
address
)
{
fs
::
path
p
;
Dl_info
info
;
// Find the location of .so
if
(
dladdr
(
address
,
&
info
)
!=
0
)
p
=
info
.
dli_fname
;
return
p
;
}
dynamic_loader
::
dynamic_loader
(
const
fs
::
path
&
p
)
:
impl
(
std
::
make_shared
<
dynamic_loader_impl
>
(
p
))
{
}
...
...
@@ -68,10 +97,11 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std
::
shared_ptr
<
void
>
dynamic_loader
::
get_symbol
(
const
std
::
string
&
name
)
const
{
dlerror
();
// flush any previous error messages
check_load_error
(
true
);
void
*
symbol
=
dlsym
(
impl
->
handle
.
get
(),
name
.
c_str
());
if
(
symbol
==
nullptr
)
MIGRAPHX_THROW
(
"Symbol not found: "
+
name
);
check_load_error
(
);
return
{
impl
,
symbol
};
}
...
...
src/eliminate_allocation.cpp
View file @
baac1dab
...
...
@@ -28,11 +28,8 @@
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/eliminate_data_type.cpp
View file @
baac1dab
...
...
@@ -38,6 +38,7 @@ void eliminate_data_type::apply(module& m) const
"if"
,
"loop"
,
"roialign"
,
"nonmaxsuppression"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_none"
};
...
...
src/fuse_pointwise.cpp
View file @
baac1dab
...
...
@@ -31,6 +31,8 @@
#include <migraphx/ranges.hpp>
#include <iterator>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -67,13 +69,13 @@ static void create_pointwise_modules(module_pass_manager& mpm)
continue
;
if
(
ins
->
get_operator
().
name
()
==
"layout"
)
continue
;
assert
(
ins
->
get_operator
().
attributes
().
contains
(
"point_op"
));
auto
*
pm
=
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":pointwise"
+
std
::
to_string
(
n
++
));
pm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
;
std
::
vector
<
instruction_ref
>
pointwise_inputs
;
std
::
size_t
i
=
0
;
for
(
auto
input
:
ins
->
inputs
())
{
if
(
contains
(
param_map
,
input
))
...
...
@@ -92,6 +94,10 @@ static void create_pointwise_modules(module_pass_manager& mpm)
}
}
// Don't create pointwise module if no inputs are detected
if
(
pointwise_inputs
.
empty
())
continue
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
...
...
@@ -188,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
{
create_pointwise_modules
(
mpm
);
mpm
.
run_pass
(
dead_code_elimination
{});
if
(
enabled
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
{}))
{
return
;
}
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
...
...
src/fuse_reduce.cpp
0 → 100644
View file @
baac1dab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp>
#include <iterator>
#include <map>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
fused_reduce
{
std
::
vector
<
std
::
int64_t
>
axes
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
{
if
(
mods
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"should have one submodule."
);
auto
*
sm
=
mods
.
front
();
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Only one output supported"
);
auto
names
=
sm
->
get_parameter_names
();
check_shapes
{
inputs
,
*
this
}.
has
(
names
.
size
()).
same_ndims
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
auto
shapes
=
sm
->
get_parameter_shapes
();
// Check dimension matches for each input
if
(
not
equal
(
names
,
inputs
,
[
&
](
const
auto
&
name
,
const
auto
&
input
)
{
return
shapes
.
at
(
name
).
lens
()
==
input
.
lens
();
}))
MIGRAPHX_THROW
(
"Dimenstion does not match the submodule."
);
const
auto
&
s
=
inputs
.
at
(
0
);
auto
lens
=
s
.
lens
();
if
(
lens
!=
sm
->
get_output_shapes
().
front
().
lens
())
{
for
(
const
auto
&
axis
:
axes
)
{
lens
[
axis
]
=
1
;
}
}
return
shape
::
from_permutation
(
sm
->
get_output_shapes
().
front
().
type
(),
lens
,
find_permutation
(
inputs
));
}
std
::
string
name
()
const
{
return
"fused_reduce"
;
}
};
MIGRAPHX_REGISTER_OP
(
fused_reduce
);
static
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
get_ins_param_map
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const_module_ref
sm
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
result
;
auto
names
=
sm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
assert
(
names
.
size
()
==
inputs
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
inputs
.
begin
(),
std
::
inserter
(
result
,
result
.
end
()),
[
&
](
const
auto
&
name
,
auto
input
)
{
return
std
::
make_pair
(
input
,
sm
->
get_parameter
(
name
));
});
return
result
;
}
static
void
insert_params
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
auto
n
=
sm
->
get_parameter_shapes
().
size
();
for
(
auto
input
:
ins
->
inputs
())
{
if
(
contains
(
map_ins
,
input
))
continue
;
auto
s
=
shape
{
input
->
get_shape
().
type
(),
input
->
get_shape
().
lens
()};
map_ins
[
input
]
=
sm
->
add_parameter
(
"x"
+
std
::
to_string
(
n
++
),
s
);
}
}
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
insert_params
(
sm
,
ins
,
map_ins
);
return
sm
->
add_instructions
({
ins
},
map_ins
);
}
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
return
insert_ins_in_submodule
(
sm
,
ins
,
map_ins
);
}
static
auto
insert_module_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
insert_params
(
sm
,
ins
,
map_ins
);
auto
*
m
=
ins
->
module_inputs
().
front
();
auto
param_map
=
get_ins_param_map
(
ins
->
inputs
(),
m
);
for
(
auto
&&
[
input
,
param
]
:
param_map
)
{
map_ins
[
param
]
=
map_ins
.
at
(
input
);
}
return
sm
->
add_instructions
(
m
,
map_ins
);
}
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
const
module
&
parent
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
std
::
vector
<
instruction_ref
>
result
;
std
::
map
<
std
::
string
,
instruction_ref
>
names
;
for
(
auto
&&
[
input
,
param
]
:
map_ins
)
{
if
(
not
sm
->
has_instruction
(
param
))
continue
;
if
(
param
->
name
()
!=
"@param"
)
continue
;
if
(
not
parent
.
has_instruction
(
input
))
continue
;
auto
v
=
param
->
get_operator
().
to_value
();
auto
name
=
v
.
at
(
"parameter"
).
to
<
std
::
string
>
();
names
[
name
]
=
input
;
}
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
result
),
[](
const
auto
&
p
)
{
return
p
.
second
;
});
assert
(
result
.
size
()
==
sm
->
get_parameter_shapes
().
size
());
return
result
;
}
static
void
create_reduce_modules
(
module_pass_manager
&
mpm
)
{
std
::
size_t
n
=
0
;
for
(
auto
ins
:
iterator_for
(
mpm
.
get_module
()))
{
if
(
not
ins
->
get_operator
().
attributes
().
get
(
"reduce"
,
false
))
continue
;
if
(
ins
->
inputs
().
size
()
!=
1
)
continue
;
auto
*
rm
=
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":"
+
ins
->
name
()
+
std
::
to_string
(
n
++
));
rm
->
set_bypass
();
rm
->
add_return
(
insert_ins_in_submodule
(
rm
,
ins
));
auto
v
=
ins
->
get_operator
().
to_value
();
mpm
.
get_module
().
replace_instruction
(
ins
,
make_op
(
"fused_reduce"
,
{{
"axes"
,
v
[
"axes"
]}}),
ins
->
inputs
(),
{
rm
});
}
}
template
<
class
...
Ms
>
static
auto
match_broadcast
(
Ms
...
ms
)
{
return
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"multibroadcast"
)(
match
::
arg
(
0
)(
ms
...),
match
::
used_once
()).
bind
(
"broadcast"
));
}
template
<
class
...
Ms
>
static
auto
any_input
(
Ms
...
ms
)
{
return
match
::
any_of
[
match
::
inputs
()](
match
::
any
(
ms
...).
bind
(
"input"
));
}
static
auto
match_broadcastable_input
(
const
std
::
string
&
op
,
const
std
::
string
&
name
)
{
auto
match_op
=
match
::
name
(
op
)(
match
::
used_once
()).
bind
(
name
);
auto
match_op_input
=
any_input
(
match_op
,
match
::
used_once
());
auto
broadcast_match_op_input
=
any_input
(
match_broadcast
(
match_op
),
match
::
used_once
());
return
match
::
any_of
(
match_op_input
,
broadcast_match_op_input
);
}
namespace
{
struct
find_pointwise_reduce
{
auto
matcher
()
const
{
return
match
::
name
(
"fused_reduce"
)(
match_broadcastable_input
(
"pointwise"
,
"pointwise"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
reduce
=
r
.
result
;
auto
input
=
r
.
instructions
[
"pointwise"
];
const
auto
*
pm
=
input
->
module_inputs
().
front
();
const
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
pm
->
name
()
+
":"
+
old_rm
->
name
());
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Insert pointwise
auto
rins
=
insert_ins_in_submodule
(
rm
,
input
,
map_ins
).
front
();
map_ins
[
input
]
=
rins
;
if
(
contains
(
r
.
instructions
,
"broadcast"
))
{
auto
broadcast
=
r
.
instructions
[
"broadcast"
];
map_ins
[
broadcast
]
=
insert_ins_in_submodule
(
rm
,
broadcast
,
map_ins
).
front
();
}
// Insert fused_reduce
rm
->
add_return
(
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
));
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
struct
find_reduce_pointwise
{
auto
matcher
()
const
{
return
match
::
name
(
"pointwise"
)(
match_broadcastable_input
(
"fused_reduce"
,
"reduce"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
pw
=
r
.
result
;
auto
reduce
=
r
.
instructions
[
"reduce"
];
auto
input
=
r
.
instructions
[
"input"
];
const
auto
*
pm
=
pw
->
module_inputs
().
front
();
const
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
old_rm
->
name
()
+
":"
+
pm
->
name
());
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Copy module instructions
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
if
(
contains
(
r
.
instructions
,
"broadcast"
))
{
auto
broadcast
=
r
.
instructions
[
"broadcast"
];
map_ins
[
broadcast
->
inputs
().
front
()]
=
rm
->
get_returns
().
front
();
auto
bout
=
insert_ins_in_submodule
(
rm
,
broadcast
,
map_ins
);
map_ins
[
input
]
=
bout
.
front
();
}
else
{
map_ins
[
input
]
=
rm
->
get_returns
().
front
();
}
auto
out
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
);
rm
->
replace_return
(
out
);
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
struct
find_reduce_reduce
{
auto
matcher
()
const
{
return
match
::
name
(
"fused_reduce"
)(
match_broadcastable_input
(
"fused_reduce"
,
"reduce"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
reduce1
=
r
.
result
;
auto
reduce2
=
r
.
instructions
[
"reduce"
];
auto
input
=
r
.
instructions
[
"input"
];
if
(
reduce1
->
get_operator
()
!=
reduce2
->
get_operator
())
return
;
const
auto
*
rm1
=
reduce1
->
module_inputs
().
front
();
const
auto
*
rm2
=
reduce2
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
rm1
->
name
()
+
":"
+
rm2
->
name
());
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Copy reduce1 instructions
insert_module_in_submodule
(
rm
,
reduce2
,
map_ins
);
if
(
contains
(
r
.
instructions
,
"broadcast"
))
{
auto
broadcast
=
r
.
instructions
[
"broadcast"
];
map_ins
[
broadcast
->
inputs
().
front
()]
=
rm
->
get_returns
().
front
();
auto
bout
=
insert_ins_in_submodule
(
rm
,
broadcast
,
map_ins
);
map_ins
[
input
]
=
bout
.
front
();
}
else
{
map_ins
[
input
]
=
rm
->
get_returns
().
front
();
}
auto
out
=
insert_module_in_submodule
(
rm
,
reduce1
,
map_ins
);
rm
->
replace_return
(
out
);
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce1
,
reduce1
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
}
// namespace
void
fuse_reduce
::
apply
(
module_pass_manager
&
mpm
)
const
{
create_reduce_modules
(
mpm
);
mpm
.
run_pass
(
dead_code_elimination
{});
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
match
::
find_matches
(
mpm
,
find_reduce_pointwise
{},
find_pointwise_reduce
{},
find_reduce_reduce
{});
mpm
.
run_pass
(
dead_code_elimination
{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
Prev
1
2
3
4
5
6
7
8
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment