Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
36bb977b
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "8930d23d80b250b46f822473cf6d4e9e3af8c4de"
Commit
36bb977b
authored
Aug 07, 2023
by
Brian Pickrell
Browse files
Merge branch 'develop' into rand_uniform
parents
d626a09e
c65ab678
Changes
52
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
492 additions
and
330 deletions
+492
-330
CMakeLists.txt
CMakeLists.txt
+1
-0
Jenkinsfile
Jenkinsfile
+1
-1
cmake/Embed.cmake
cmake/Embed.cmake
+32
-16
cmake/PythonModules.cmake
cmake/PythonModules.cmake
+10
-0
docs/driver/read.rst
docs/driver/read.rst
+4
-0
src/CMakeLists.txt
src/CMakeLists.txt
+0
-2
src/api/CMakeLists.txt
src/api/CMakeLists.txt
+1
-0
src/api/api.cpp
src/api/api.cpp
+1
-1
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+297
-265
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+1
-1
src/driver/argument_parser.hpp
src/driver/argument_parser.hpp
+14
-3
src/driver/main.cpp
src/driver/main.cpp
+22
-9
src/dynamic_loader.cpp
src/dynamic_loader.cpp
+13
-1
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+37
-0
src/include/migraphx/builtin.hpp
src/include/migraphx/builtin.hpp
+11
-1
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+27
-22
src/include/migraphx/dynamic_loader.hpp
src/include/migraphx/dynamic_loader.hpp
+4
-0
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+10
-0
src/include/migraphx/permutation.hpp
src/include/migraphx/permutation.hpp
+4
-0
src/instruction.cpp
src/instruction.cpp
+2
-8
No files found.
CMakeLists.txt
View file @
36bb977b
...
@@ -51,6 +51,7 @@ project(migraphx LANGUAGES C CXX)
...
@@ -51,6 +51,7 @@ project(migraphx LANGUAGES C CXX)
include
(
CTest
)
include
(
CTest
)
find_package
(
ROCM REQUIRED
)
find_package
(
ROCM REQUIRED
)
find_package
(
Threads REQUIRED
)
find_path
(
HALF_INCLUDE_DIR half.hpp PATH_SUFFIXES half
)
find_path
(
HALF_INCLUDE_DIR half.hpp PATH_SUFFIXES half
)
if
(
NOT HALF_INCLUDE_DIR
)
if
(
NOT HALF_INCLUDE_DIR
)
...
...
Jenkinsfile
View file @
36bb977b
...
@@ -26,7 +26,7 @@ def rocmtestnode(Map conf) {
...
@@ -26,7 +26,7 @@ def rocmtestnode(Map conf) {
rm -rf build
rm -rf build
mkdir build
mkdir build
cd build
cd build
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On ${flags} ..
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On
-DCMAKE_EXECUTE_PROCESS_COMMAND_ECHO=STDOUT
${flags} ..
git diff
git diff
git diff-index --quiet HEAD || (echo "Git repo is not clean after running cmake." && exit 1)
git diff-index --quiet HEAD || (echo "Git repo is not clean after running cmake." && exit 1)
make -j\$(nproc) generate VERBOSE=1
make -j\$(nproc) generate VERBOSE=1
...
...
cmake/Embed.cmake
View file @
36bb977b
...
@@ -24,11 +24,7 @@
...
@@ -24,11 +24,7 @@
find_program
(
EMBED_LD ld
)
find_program
(
EMBED_LD ld
)
find_program
(
EMBED_OBJCOPY objcopy
)
find_program
(
EMBED_OBJCOPY objcopy
)
if
(
LINUX
)
option
(
EMBED_USE_LD
"Use ld to embed data files"
OFF
)
option
(
EMBED_USE_LD
"Use ld to embed data files"
ON
)
else
()
option
(
EMBED_USE_LD
"Use ld to embed data files"
OFF
)
endif
()
function
(
wrap_string
)
function
(
wrap_string
)
set
(
options
)
set
(
options
)
...
@@ -60,8 +56,8 @@ endfunction()
...
@@ -60,8 +56,8 @@ endfunction()
function
(
generate_embed_source EMBED_NAME
)
function
(
generate_embed_source EMBED_NAME
)
set
(
options
)
set
(
options
)
set
(
oneValueArgs SRC HEADER
)
set
(
oneValueArgs SRC HEADER
RELATIVE
)
set
(
multiValueArgs OBJECTS SYMBOLS
)
set
(
multiValueArgs OBJECTS SYMBOLS
FILES
)
cmake_parse_arguments
(
PARSE
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
cmake_parse_arguments
(
PARSE
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
...
@@ -78,6 +74,8 @@ function(generate_embed_source EMBED_NAME)
...
@@ -78,6 +74,8 @@ function(generate_embed_source EMBED_NAME)
foreach
(
idx RANGE
${
LEN
}
)
foreach
(
idx RANGE
${
LEN
}
)
list
(
GET PARSE_SYMBOLS
${
idx
}
SYMBOL
)
list
(
GET PARSE_SYMBOLS
${
idx
}
SYMBOL
)
list
(
GET PARSE_OBJECTS
${
idx
}
OBJECT
)
list
(
GET PARSE_OBJECTS
${
idx
}
OBJECT
)
list
(
GET PARSE_FILES
${
idx
}
FILE
)
set
(
START_SYMBOL
"_binary_
${
SYMBOL
}
_start"
)
set
(
START_SYMBOL
"_binary_
${
SYMBOL
}
_start"
)
set
(
END_SYMBOL
"_binary_
${
SYMBOL
}
_end"
)
set
(
END_SYMBOL
"_binary_
${
SYMBOL
}
_end"
)
if
(
EMBED_USE_LD
)
if
(
EMBED_USE_LD
)
...
@@ -92,9 +90,11 @@ function(generate_embed_source EMBED_NAME)
...
@@ -92,9 +90,11 @@ function(generate_embed_source EMBED_NAME)
"
)
"
)
endif
()
endif
()
# TODO: Should use NAME_WLE
if
(
PARSE_RELATIVE
)
get_filename_component
(
BASE_NAME
"
${
OBJECT
}
"
NAME
)
file
(
RELATIVE_PATH BASE_NAME
${
PARSE_RELATIVE
}
"
${
FILE
}
"
)
string
(
REGEX REPLACE
".[A-Za-z0-9_]+$"
""
BASE_NAME
${
BASE_NAME
}
)
else
()
get_filename_component
(
BASE_NAME
"
${
FILE
}
"
NAME
)
endif
()
string
(
APPEND INIT_KERNELS
"
string
(
APPEND INIT_KERNELS
"
{
\"
${
BASE_NAME
}
\"
, {
${
START_SYMBOL
}
,
${
END_SYMBOL
}
} },
{
\"
${
BASE_NAME
}
\"
, {
${
START_SYMBOL
}
,
${
END_SYMBOL
}
} },
...
@@ -162,6 +162,11 @@ function(embed_file OUTPUT_FILE OUTPUT_SYMBOL FILE)
...
@@ -162,6 +162,11 @@ function(embed_file OUTPUT_FILE OUTPUT_SYMBOL FILE)
endfunction
()
endfunction
()
function
(
add_embed_library EMBED_NAME
)
function
(
add_embed_library EMBED_NAME
)
set
(
options
)
set
(
oneValueArgs RELATIVE
)
set
(
multiValueArgs
)
cmake_parse_arguments
(
PARSE
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
file
(
MAKE_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/embed
)
file
(
MAKE_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/embed
)
file
(
MAKE_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/embed/
${
EMBED_NAME
}
)
file
(
MAKE_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/embed/
${
EMBED_NAME
}
)
set
(
EMBED_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/embed/
${
EMBED_NAME
}
)
set
(
EMBED_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/embed/
${
EMBED_NAME
}
)
...
@@ -171,15 +176,26 @@ function(add_embed_library EMBED_NAME)
...
@@ -171,15 +176,26 @@ function(add_embed_library EMBED_NAME)
set
(
OUTPUT_FILES
)
set
(
OUTPUT_FILES
)
set
(
SYMBOLS
)
set
(
SYMBOLS
)
message
(
STATUS
"Embedding files"
)
message
(
STATUS
"Embedding files"
)
foreach
(
FILE
${
ARGN
}
)
foreach
(
FILE
${
PARSE_UNPARSED_ARGUMENTS
}
)
embed_file
(
OUTPUT_FILE OUTPUT_SYMBOL
${
FILE
}
)
embed_file
(
OUTPUT_FILE OUTPUT_SYMBOL
${
FILE
}
)
list
(
APPEND OUTPUT_FILES
${
OUTPUT_FILE
}
)
list
(
APPEND OUTPUT_FILES
${
OUTPUT_FILE
}
)
list
(
APPEND SYMBOLS
${
OUTPUT_SYMBOL
}
)
list
(
APPEND SYMBOLS
${
OUTPUT_SYMBOL
}
)
endforeach
()
endforeach
()
message
(
STATUS
"Generating embedding library
${
EMBED_NAME
}
"
)
message
(
STATUS
"Generating embedding library
${
EMBED_NAME
}
"
)
generate_embed_source
(
${
EMBED_NAME
}
SRC
${
SRC_FILE
}
HEADER
${
HEADER_FILE
}
OBJECTS
${
OUTPUT_FILES
}
SYMBOLS
${
SYMBOLS
}
)
generate_embed_source
(
${
EMBED_NAME
}
SRC
${
SRC_FILE
}
HEADER
${
HEADER_FILE
}
OBJECTS
${
OUTPUT_FILES
}
SYMBOLS
${
SYMBOLS
}
RELATIVE
${
PARSE_RELATIVE
}
FILES
${
PARSE_UNPARSED_ARGUMENTS
}
)
add_library
(
${
EMBED_NAME
}
STATIC
${
OUTPUT_FILES
}
"
${
SRC_FILE
}
"
)
target_include_directories
(
${
EMBED_NAME
}
PUBLIC
"
${
EMBED_DIR
}
/include"
)
set
(
INTERNAL_EMBED_LIB embed_lib_
${
EMBED_NAME
}
)
target_compile_options
(
${
EMBED_NAME
}
PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations
)
add_library
(
${
INTERNAL_EMBED_LIB
}
OBJECT
"
${
SRC_FILE
}
"
)
set_target_properties
(
${
EMBED_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE On
)
target_include_directories
(
${
INTERNAL_EMBED_LIB
}
PRIVATE
"
${
EMBED_DIR
}
/include"
)
target_compile_options
(
${
INTERNAL_EMBED_LIB
}
PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations
)
set_target_properties
(
${
INTERNAL_EMBED_LIB
}
PROPERTIES POSITION_INDEPENDENT_CODE On
)
add_library
(
${
EMBED_NAME
}
INTERFACE
)
if
(
EMBED_USE_LD
)
target_sources
(
${
EMBED_NAME
}
INTERFACE
${
OUTPUT_FILES
}
)
else
()
target_sources
(
${
INTERNAL_EMBED_LIB
}
PRIVATE
${
OUTPUT_FILES
}
)
endif
()
target_sources
(
${
EMBED_NAME
}
INTERFACE $<TARGET_OBJECTS:
${
INTERNAL_EMBED_LIB
}
>
)
target_include_directories
(
${
EMBED_NAME
}
INTERFACE
"
${
EMBED_DIR
}
/include"
)
endfunction
()
endfunction
()
cmake/PythonModules.cmake
View file @
36bb977b
...
@@ -38,12 +38,22 @@ macro(find_python version)
...
@@ -38,12 +38,22 @@ macro(find_python version)
find_program
(
PYTHON_CONFIG_
${
version
}
python
${
version
}
-config
)
find_program
(
PYTHON_CONFIG_
${
version
}
python
${
version
}
-config
)
if
(
EXISTS
${
PYTHON_CONFIG_
${
version
}}
)
if
(
EXISTS
${
PYTHON_CONFIG_
${
version
}}
)
py_exec
(
COMMAND
${
PYTHON_CONFIG_
${
version
}}
--includes OUTPUT_VARIABLE _python_include_args
)
py_exec
(
COMMAND
${
PYTHON_CONFIG_
${
version
}}
--includes OUTPUT_VARIABLE _python_include_args
)
execute_process
(
COMMAND
${
PYTHON_CONFIG_
${
version
}}
--ldflags --embed OUTPUT_VARIABLE _python_ldflags_args RESULT_VARIABLE _python_ldflags_result
)
if
(
NOT _python_ldflags_result EQUAL 0
)
py_exec
(
COMMAND
${
PYTHON_CONFIG_
${
version
}}
--ldflags OUTPUT_VARIABLE _python_ldflags_args
)
endif
()
separate_arguments
(
_python_includes UNIX_COMMAND
"
${
_python_include_args
}
"
)
separate_arguments
(
_python_includes UNIX_COMMAND
"
${
_python_include_args
}
"
)
separate_arguments
(
_python_ldflags UNIX_COMMAND
"
${
_python_ldflags_args
}
"
)
string
(
REPLACE
"-I"
""
_python_includes
"
${
_python_includes
}
"
)
string
(
REPLACE
"-I"
""
_python_includes
"
${
_python_includes
}
"
)
add_library
(
python
${
version
}
::headers INTERFACE IMPORTED GLOBAL
)
add_library
(
python
${
version
}
::headers INTERFACE IMPORTED GLOBAL
)
set_target_properties
(
python
${
version
}
::headers PROPERTIES
set_target_properties
(
python
${
version
}
::headers PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES
"
${
_python_includes
}
"
INTERFACE_INCLUDE_DIRECTORIES
"
${
_python_includes
}
"
)
)
add_library
(
python
${
version
}
::runtime INTERFACE IMPORTED GLOBAL
)
set_target_properties
(
python
${
version
}
::runtime PROPERTIES
INTERFACE_LINK_OPTIONS
"
${
_python_ldflags
}
"
INTERFACE_LINK_LIBRARIES python
${
version
}
::headers
)
py_exec
(
COMMAND
${
PYTHON_CONFIG_
${
version
}}
--prefix OUTPUT_VARIABLE _python_prefix
)
py_exec
(
COMMAND
${
PYTHON_CONFIG_
${
version
}}
--prefix OUTPUT_VARIABLE _python_prefix
)
string
(
STRIP
"
${
_python_prefix
}
"
_python_prefix
)
string
(
STRIP
"
${
_python_prefix
}
"
_python_prefix
)
set
(
PYTHON_
${
version
}
_EXECUTABLE
"
${
_python_prefix
}
/bin/python
${
version
}
"
CACHE PATH
""
)
set
(
PYTHON_
${
version
}
_EXECUTABLE
"
${
_python_prefix
}
/bin/python
${
version
}
"
CACHE PATH
""
)
...
...
docs/driver/read.rst
View file @
36bb977b
...
@@ -82,6 +82,10 @@ Print out program in text format.
...
@@ -82,6 +82,10 @@ Print out program in text format.
Print out program in binary format.
Print out program in binary format.
.. option:: --py
Print out program using python API.
.. option:: --output, -o [std::string]
.. option:: --output, -o [std::string]
Output to file.
Output to file.
...
...
src/CMakeLists.txt
View file @
36bb977b
...
@@ -249,8 +249,6 @@ endif()
...
@@ -249,8 +249,6 @@ endif()
target_link_libraries
(
migraphx PRIVATE -ldl
)
target_link_libraries
(
migraphx PRIVATE -ldl
)
target_include_directories
(
migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
find_package
(
Threads
)
target_link_libraries
(
migraphx PUBLIC Threads::Threads
)
target_link_libraries
(
migraphx PUBLIC Threads::Threads
)
find_package
(
nlohmann_json 3.8.0 REQUIRED
)
find_package
(
nlohmann_json 3.8.0 REQUIRED
)
...
...
src/api/CMakeLists.txt
View file @
36bb977b
...
@@ -26,6 +26,7 @@ add_library(migraphx_c
...
@@ -26,6 +26,7 @@ add_library(migraphx_c
api.cpp
api.cpp
)
)
set_target_properties
(
migraphx_c PROPERTIES EXPORT_NAME c
)
set_target_properties
(
migraphx_c PROPERTIES EXPORT_NAME c
)
migraphx_generate_export_header
(
migraphx_c DIRECTORY migraphx/api
)
# migraphx_c is stable API interface library. SO version of this should be
# migraphx_c is stable API interface library. SO version of this should be
# bumped when binary compatibility is broken.
# bumped when binary compatibility is broken.
...
...
src/api/api.cpp
View file @
36bb977b
...
@@ -44,7 +44,7 @@ namespace migraphx {
...
@@ -44,7 +44,7 @@ namespace migraphx {
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
{
disable_exception_catch
=
b
;
disable_exception_catch
=
b
;
}
}
...
...
src/api/include/migraphx/migraphx.h
View file @
36bb977b
This diff is collapsed.
Click to expand it.
src/driver/CMakeLists.txt
View file @
36bb977b
...
@@ -45,7 +45,7 @@ if(NOT WIN32)
...
@@ -45,7 +45,7 @@ if(NOT WIN32)
endif
()
endif
()
rocm_clang_tidy_check
(
driver
)
rocm_clang_tidy_check
(
driver
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf
migraphx_py
)
rocm_install_targets
(
rocm_install_targets
(
TARGETS driver
TARGETS driver
...
...
src/driver/argument_parser.hpp
View file @
36bb977b
...
@@ -342,7 +342,19 @@ struct argument_parser
...
@@ -342,7 +342,19 @@ struct argument_parser
if
(
params
.
empty
())
if
(
params
.
empty
())
throw
std
::
runtime_error
(
"No argument passed."
);
throw
std
::
runtime_error
(
"No argument passed."
);
if
(
not
fs
::
exists
(
params
.
back
()))
if
(
not
fs
::
exists
(
params
.
back
()))
throw
std
::
runtime_error
(
"Path does not exists: "
+
params
.
back
());
throw
std
::
runtime_error
(
"Path does not exist: "
+
params
.
back
());
});
}
MIGRAPHX_DRIVER_STATIC
auto
matches
(
const
std
::
unordered_set
<
std
::
string
>&
names
)
{
return
validate
([
=
](
auto
&
,
auto
&
,
auto
&
params
)
{
for
(
const
auto
&
p
:
params
)
{
if
(
names
.
count
(
p
)
==
0
)
throw
std
::
runtime_error
(
"Invalid argument: "
+
p
+
". Valid arguments are {"
+
to_string_range
(
names
)
+
"}"
);
}
});
});
}
}
...
@@ -570,8 +582,7 @@ struct argument_parser
...
@@ -570,8 +582,7 @@ struct argument_parser
continue
;
continue
;
if
(
flag
[
0
]
!=
'-'
)
if
(
flag
[
0
]
!=
'-'
)
continue
;
continue
;
auto
d
=
std
::
ptrdiff_t
d
=
levenshtein_distance
(
flag
,
input
);
levenshtein_distance
(
flag
.
begin
(),
flag
.
end
(),
input
.
begin
(),
input
.
end
());
if
(
d
<
result
.
distance
)
if
(
d
<
result
.
distance
)
result
=
result_t
{
&
arg
,
flag
,
input
,
d
};
result
=
result_t
{
&
arg
,
flag
,
input
,
d
};
}
}
...
...
src/driver/main.cpp
View file @
36bb977b
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
#include <migraphx/tf.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/py.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/load_save.hpp>
...
@@ -81,6 +82,7 @@ struct loader
...
@@ -81,6 +82,7 @@ struct loader
{
"--model"
},
{
"--model"
},
ap
.
help
(
"Load model"
),
ap
.
help
(
"Load model"
),
ap
.
type
(
"resnet50|inceptionv3|alexnet"
),
ap
.
type
(
"resnet50|inceptionv3|alexnet"
),
ap
.
matches
({
"resnet50"
,
"inceptionv3"
,
"alexnet"
}),
ap
.
group
(
"input"
));
ap
.
group
(
"input"
));
ap
(
file_type
,
{
"--onnx"
},
ap
.
help
(
"Load as onnx"
),
ap
.
set_value
(
"onnx"
));
ap
(
file_type
,
{
"--onnx"
},
ap
.
help
(
"Load as onnx"
),
ap
.
set_value
(
"onnx"
));
ap
(
file_type
,
{
"--tf"
},
ap
.
help
(
"Load as tensorflow"
),
ap
.
set_value
(
"tf"
));
ap
(
file_type
,
{
"--tf"
},
ap
.
help
(
"Load as tensorflow"
),
ap
.
set_value
(
"tf"
));
...
@@ -241,6 +243,20 @@ struct loader
...
@@ -241,6 +243,20 @@ struct loader
return
options
;
return
options
;
}
}
static
std
::
string
get_file_type
(
const
std
::
string
&
file
)
{
if
(
ends_with
(
file
,
".onnx"
))
return
"onnx"
;
else
if
(
ends_with
(
file
,
".pb"
))
return
"tf"
;
else
if
(
ends_with
(
file
,
".json"
))
return
"json"
;
else
if
(
ends_with
(
file
,
".py"
))
return
"py"
;
else
return
"migraphx"
;
}
program
load
()
program
load
()
{
{
program
p
;
program
p
;
...
@@ -248,14 +264,7 @@ struct loader
...
@@ -248,14 +264,7 @@ struct loader
{
{
if
(
file_type
.
empty
())
if
(
file_type
.
empty
())
{
{
if
(
ends_with
(
file
,
".onnx"
))
file_type
=
get_file_type
(
file
);
file_type
=
"onnx"
;
else
if
(
ends_with
(
file
,
".pb"
))
file_type
=
"tf"
;
else
if
(
ends_with
(
file
,
".json"
))
file_type
=
"json"
;
else
file_type
=
"migraphx"
;
}
}
std
::
cout
<<
"Reading: "
<<
file
<<
std
::
endl
;
std
::
cout
<<
"Reading: "
<<
file
<<
std
::
endl
;
if
(
file_type
==
"onnx"
)
if
(
file_type
==
"onnx"
)
...
@@ -272,6 +281,10 @@ struct loader
...
@@ -272,6 +281,10 @@ struct loader
options
.
format
=
"json"
;
options
.
format
=
"json"
;
p
=
migraphx
::
load
(
file
,
options
);
p
=
migraphx
::
load
(
file
,
options
);
}
}
else
if
(
file_type
==
"py"
)
{
p
=
migraphx
::
load_py
(
file
);
}
else
if
(
file_type
==
"migraphx"
)
else
if
(
file_type
==
"migraphx"
)
{
{
p
=
migraphx
::
load
(
file
);
p
=
migraphx
::
load
(
file
);
...
@@ -757,7 +770,7 @@ struct main_command
...
@@ -757,7 +770,7 @@ struct main_command
{
{
std
::
cout
<<
"'"
<<
color
::
fg_yellow
<<
wrong_commands
.
front
()
<<
color
::
reset
std
::
cout
<<
"'"
<<
color
::
fg_yellow
<<
wrong_commands
.
front
()
<<
color
::
reset
<<
"' is not a valid command."
<<
std
::
endl
;
<<
"' is not a valid command."
<<
std
::
endl
;
std
::
cout
<<
get_command_help
(
"Available commands:"
)
<<
std
::
endl
;
std
::
cout
<<
get_command_help
(
"Available commands:"
);
}
}
else
else
{
{
...
...
src/dynamic_loader.cpp
View file @
36bb977b
...
@@ -48,7 +48,7 @@ struct dynamic_loader_impl
...
@@ -48,7 +48,7 @@ struct dynamic_loader_impl
#pragma GCC diagnostic ignored "-Wignored-attributes"
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
#endif
dynamic_loader_impl
(
const
fs
::
path
&
p
,
std
::
shared_ptr
<
tmp_dir
>
t
=
nullptr
)
dynamic_loader_impl
(
const
fs
::
path
&
p
,
std
::
shared_ptr
<
tmp_dir
>
t
=
nullptr
)
:
handle
(
dlopen
(
p
.
string
().
c_str
(),
RTLD_
LAZY
),
:
handle
(
dlopen
(
p
.
string
().
c_str
(),
RTLD_
GLOBAL
|
RTLD_NOW
),
manage_deleter
<
decltype
(
&
dlclose
),
&
dlclose
>
{}),
manage_deleter
<
decltype
(
&
dlclose
),
&
dlclose
>
{}),
temp
(
std
::
move
(
t
))
temp
(
std
::
move
(
t
))
{
{
...
@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
...
@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
return
p
;
return
p
;
}
}
optional
<
dynamic_loader
>
dynamic_loader
::
try_load
(
const
fs
::
path
&
p
)
{
try
{
return
dynamic_loader
{
p
};
}
catch
(
const
std
::
exception
&
)
{
return
nullopt
;
}
}
dynamic_loader
::
dynamic_loader
(
const
fs
::
path
&
p
)
:
impl
(
std
::
make_shared
<
dynamic_loader_impl
>
(
p
))
dynamic_loader
::
dynamic_loader
(
const
fs
::
path
&
p
)
:
impl
(
std
::
make_shared
<
dynamic_loader_impl
>
(
p
))
{
{
}
}
...
...
src/include/migraphx/algorithm.hpp
View file @
36bb977b
...
@@ -90,6 +90,43 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
...
@@ -90,6 +90,43 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
return
std
::
ptrdiff_t
{
1
}
+
std
::
min
({
x1
,
x2
,
x3
});
return
std
::
ptrdiff_t
{
1
}
+
std
::
min
({
x1
,
x2
,
x3
});
}
}
inline
size_t
levenshtein_distance
(
const
std
::
string
&
s1
,
const
std
::
string
&
s2
)
{
const
size_t
l1
=
s1
.
length
();
const
size_t
l2
=
s2
.
length
();
if
(
l1
<
l2
)
levenshtein_distance
(
s2
,
s1
);
std
::
vector
<
size_t
>
d
(
l2
+
1
);
for
(
size_t
j
=
1
;
j
<=
l2
;
j
++
)
d
[
j
]
=
j
;
for
(
size_t
i
=
1
;
i
<=
l1
;
i
++
)
{
size_t
prev_cost
=
d
[
0
];
d
[
0
]
=
i
;
for
(
size_t
j
=
1
;
j
<=
l2
;
j
++
)
{
if
(
s1
[
i
-
1
]
==
s2
[
j
-
1
])
{
d
[
j
]
=
prev_cost
;
}
else
{
size_t
cost_insert_or_delete
=
std
::
min
(
d
[
j
-
1
],
d
[
j
]);
size_t
cost_substitute
=
prev_cost
;
prev_cost
=
d
[
j
];
d
[
j
]
=
std
::
min
(
cost_substitute
,
cost_insert_or_delete
)
+
1
;
}
}
}
return
d
[
l2
];
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/builtin.hpp
View file @
36bb977b
...
@@ -90,7 +90,17 @@ struct param
...
@@ -90,7 +90,17 @@ struct param
struct
returns
struct
returns
{
{
std
::
string
name
()
const
{
return
"@return"
;
}
std
::
string
name
()
const
{
return
"@return"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
)
const
{
return
{};
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
arg
)
const
{
if
(
arg
.
empty
())
return
{};
else
if
(
arg
.
size
()
==
1
)
return
arg
[
0
];
else
return
arg
;
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
{
MIGRAPHX_THROW
(
"builtin"
);
MIGRAPHX_THROW
(
"builtin"
);
...
...
src/include/migraphx/check_shapes.hpp
View file @
36bb977b
...
@@ -34,21 +34,37 @@
...
@@ -34,21 +34,37 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Check that deduced type is incrementable, dereferencable, and comparable
template
<
class
,
class
=
void
>
struct
is_iterator
{
};
template
<
class
T
>
struct
is_iterator
<
T
,
std
::
void_t
<
decltype
(
++
std
::
declval
<
T
&>
()),
decltype
(
*
std
::
declval
<
T
&>
()),
decltype
(
std
::
declval
<
T
&>
()
==
std
::
declval
<
T
&>
())
>>
:
std
::
true_type
{
};
template
<
class
Iterator
>
struct
check_shapes
struct
check_shapes
{
{
const
shape
*
begin
;
static_assert
(
is_iterator
<
Iterator
>
{},
"CHECK_SHAPES: Deduced type must be an iterator"
);
const
shape
*
end
;
Iterator
begin
;
Iterator
end
;
std
::
string
name
;
std
::
string
name
;
bool
dynamic_allowed
;
bool
dynamic_allowed
;
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
}
}
template
<
class
Op
>
template
<
class
Op
>
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
...
@@ -56,7 +72,7 @@ struct check_shapes
...
@@ -56,7 +72,7 @@ struct check_shapes
template
<
class
Op
>
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
}
}
...
@@ -81,8 +97,6 @@ struct check_shapes
...
@@ -81,8 +97,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
0
;
return
0
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
end
-
begin
;
return
end
-
begin
;
}
}
...
@@ -131,8 +145,6 @@ struct check_shapes
...
@@ -131,8 +145,6 @@ struct check_shapes
*/
*/
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
!=
n
)
if
(
begin
->
max_lens
().
size
()
!=
n
)
...
@@ -148,8 +160,6 @@ struct check_shapes
...
@@ -148,8 +160,6 @@ struct check_shapes
*/
*/
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
>
n
)
if
(
begin
->
max_lens
().
size
()
>
n
)
...
@@ -166,8 +176,6 @@ struct check_shapes
...
@@ -166,8 +176,6 @@ struct check_shapes
*/
*/
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
<
n
)
if
(
begin
->
max_lens
().
size
()
<
n
)
...
@@ -330,8 +338,6 @@ struct check_shapes
...
@@ -330,8 +338,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
true
;
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
auto
&&
key
=
f
(
*
begin
);
auto
&&
key
=
f
(
*
begin
);
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
}
}
...
@@ -341,8 +347,6 @@ struct check_shapes
...
@@ -341,8 +347,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
true
;
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
all_of
(
begin
,
end
,
p
);
return
std
::
all_of
(
begin
,
end
,
p
);
}
}
...
@@ -351,17 +355,13 @@ struct check_shapes
...
@@ -351,17 +355,13 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
false
;
return
false
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
any_of
(
begin
,
end
,
p
);
return
std
::
any_of
(
begin
,
end
,
p
);
}
}
const
shape
*
get
(
long
i
)
const
Iterator
get
(
long
i
)
const
{
{
if
(
i
>=
size
())
if
(
i
>=
size
())
MIGRAPHX_THROW
(
prefix
()
+
"Accessing shape out of bounds"
);
MIGRAPHX_THROW
(
prefix
()
+
"Accessing shape out of bounds"
);
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
i
<
0
)
if
(
i
<
0
)
return
end
-
i
;
return
end
-
i
;
return
begin
+
i
;
return
begin
+
i
;
...
@@ -394,6 +394,11 @@ struct check_shapes
...
@@ -394,6 +394,11 @@ struct check_shapes
}
}
};
};
// Deduction guide for std::vector constructor
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
,
const
Op
&
,
bool
d
=
false
)
->
check_shapes
<
std
::
vector
<
shape
>::
const_iterator
>
;
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/dynamic_loader.hpp
View file @
36bb977b
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/optional.hpp>
#include <functional>
#include <functional>
#include <memory>
#include <memory>
#include <vector>
#include <vector>
...
@@ -43,6 +44,9 @@ struct MIGRAPHX_EXPORT dynamic_loader
...
@@ -43,6 +44,9 @@ struct MIGRAPHX_EXPORT dynamic_loader
return
path
(
reinterpret_cast
<
void
*>
(
address
));
return
path
(
reinterpret_cast
<
void
*>
(
address
));
}
}
static
fs
::
path
path
(
void
*
address
);
static
fs
::
path
path
(
void
*
address
);
static
optional
<
dynamic_loader
>
try_load
(
const
fs
::
path
&
p
);
dynamic_loader
()
=
default
;
dynamic_loader
()
=
default
;
dynamic_loader
(
const
fs
::
path
&
p
);
dynamic_loader
(
const
fs
::
path
&
p
);
...
...
src/include/migraphx/module.hpp
View file @
36bb977b
...
@@ -222,7 +222,17 @@ struct MIGRAPHX_EXPORT module
...
@@ -222,7 +222,17 @@ struct MIGRAPHX_EXPORT module
void
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
;
void
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
;
std
::
vector
<
module_ref
>
get_sub_modules
(
bool
shallow
=
false
)
const
;
std
::
vector
<
module_ref
>
get_sub_modules
(
bool
shallow
=
false
)
const
;
/* sorts the module in topological order aka reverse-post order (RPO) DFS order
it takes last instruction or @return as the root and walks back the graph and moves inputs
of the each instruction such that it appears before the instruction itself.
*/
module
&
sort
();
module
&
sort
();
/* Any instruction "X" can have module arguments and those modules inside them can use any other
* instruction "Y" from predecessor modules of the instruction "X". Such instruction "Y" inside
* module args are not listed as input instructions to "X". But those instructions "Y" must be
* evaluted before the instruction "X" can. Therefore such "Y" instructions are considered
* implicit dependency to "X".
*/
ins_dep_map
calc_implicit_deps
()
const
;
ins_dep_map
calc_implicit_deps
()
const
;
MIGRAPHX_EXPORT
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
module
&
m
);
MIGRAPHX_EXPORT
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
module
&
m
);
...
...
src/include/migraphx/permutation.hpp
View file @
36bb977b
...
@@ -66,6 +66,10 @@ MIGRAPHX_EXPORT std::vector<int64_t> invert_permutation(const std::vector<int64_
...
@@ -66,6 +66,10 @@ MIGRAPHX_EXPORT std::vector<int64_t> invert_permutation(const std::vector<int64_
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
/// Normalize the shapes so the order of dimensions will be in the order it is
/// in memory as much as possible.
MIGRAPHX_EXPORT
std
::
vector
<
shape
>
normalize_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/instruction.cpp
View file @
36bb977b
...
@@ -64,10 +64,7 @@ void instruction::replace(const shape& r)
...
@@ -64,10 +64,7 @@ void instruction::replace(const shape& r)
result
=
r
;
result
=
r
;
for
(
auto
&&
ins
:
output
)
for
(
auto
&&
ins
:
output
)
{
{
if
(
ins
->
name
()
==
"@return"
)
assert
(
ins
->
name
()
==
"@return"
or
ins
->
name
().
front
()
!=
'@'
);
continue
;
assert
(
ins
->
name
().
front
()
!=
'@'
);
ins
->
recompute_shape
();
ins
->
recompute_shape
();
}
}
}
}
...
@@ -122,10 +119,6 @@ bool instruction::valid() const
...
@@ -122,10 +119,6 @@ bool instruction::valid() const
{
{
computed
=
result
;
computed
=
result
;
}
}
else
if
(
op
.
name
()
==
"@return"
)
{
computed
=
{};
}
else
else
{
{
try
try
...
@@ -145,6 +138,7 @@ bool instruction::valid() const
...
@@ -145,6 +138,7 @@ bool instruction::valid() const
}
}
shape
instruction
::
get_shape
()
const
{
return
result
;
}
shape
instruction
::
get_shape
()
const
{
return
result
;
}
const
literal
&
instruction
::
get_literal
()
const
const
literal
&
instruction
::
get_literal
()
const
{
{
assert
(
op
.
name
()
==
"@literal"
);
assert
(
op
.
name
()
==
"@literal"
);
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment