Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5bafb637
Unverified
Commit
5bafb637
authored
Oct 10, 2023
by
Chris Austen
Committed by
GitHub
Oct 10, 2023
Browse files
Merge branch 'develop' into windows_cxx_compilation
parents
761e977a
c58e7d89
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
332 additions
and
107 deletions
+332
-107
CMakeLists.txt
CMakeLists.txt
+6
-2
src/CMakeLists.txt
src/CMakeLists.txt
+2
-0
src/api/CMakeLists.txt
src/api/CMakeLists.txt
+4
-0
src/api/api.cpp
src/api/api.cpp
+7
-0
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+6
-1
src/driver/main.cpp
src/driver/main.cpp
+4
-0
src/include/migraphx/instruction_ref.hpp
src/include/migraphx/instruction_ref.hpp
+31
-1
src/include/migraphx/op/random_uniform.hpp
src/include/migraphx/op/random_uniform.hpp
+20
-5
src/onnx/parse_depthtospace.cpp
src/onnx/parse_depthtospace.cpp
+1
-2
src/onnx/parse_reshape.cpp
src/onnx/parse_reshape.cpp
+1
-2
src/onnx/parse_spacetodepth.cpp
src/onnx/parse_spacetodepth.cpp
+1
-2
src/process.cpp
src/process.cpp
+167
-1
src/py/CMakeLists.txt
src/py/CMakeLists.txt
+14
-17
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
+15
-2
src/tf/parse_reshape.cpp
src/tf/parse_reshape.cpp
+1
-2
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+8
-17
tools/api/api.cpp
tools/api/api.cpp
+7
-0
tools/check_stamped.py
tools/check_stamped.py
+37
-53
No files found.
CMakeLists.txt
View file @
5bafb637
...
...
@@ -53,6 +53,12 @@ include(CTest)
find_package
(
ROCM REQUIRED
)
find_package
(
Threads REQUIRED
)
if
(
WIN32
)
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
OFF
)
else
()
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
ON
)
endif
()
find_path
(
HALF_INCLUDE_DIR half.hpp PATH_SUFFIXES half
)
if
(
NOT HALF_INCLUDE_DIR
)
message
(
FATAL_ERROR
"Could not find half.hpp - Please check that the install path of half.hpp has been added to CMAKE_PREFIX_PATH"
)
...
...
@@ -261,8 +267,6 @@ rocm_enable_cppcheck(
MIGRAPHX_USE_CLANG_TIDY
)
enable_testing
()
include
(
ROCMCreatePackage
)
include
(
ROCMTest
)
...
...
src/CMakeLists.txt
View file @
5bafb637
...
...
@@ -282,7 +282,9 @@ add_subdirectory(driver)
add_subdirectory
(
onnx
)
add_subdirectory
(
tf
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
add_subdirectory
(
py
)
endif
()
add_subdirectory
(
targets/ref
)
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_ref
)
if
(
MIGRAPHX_ENABLE_CPU
)
...
...
src/api/CMakeLists.txt
View file @
5bafb637
...
...
@@ -32,6 +32,10 @@ migraphx_generate_export_header(migraphx_c DIRECTORY migraphx/api)
# bumped when binary compatibility is broken.
rocm_set_soversion
(
migraphx_c 3.0
)
if
(
BUILD_TESTING
)
target_compile_definitions
(
migraphx_c PRIVATE MIGRAPHX_BUILD_TESTING
)
endif
()
rocm_clang_tidy_check
(
migraphx_c
)
target_link_libraries
(
migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx
)
...
...
src/api/api.cpp
View file @
5bafb637
...
...
@@ -41,24 +41,29 @@
#include <array>
#include <algorithm>
#include <cstdarg>
namespace
migraphx
{
#ifdef MIGRAPHX_BUILD_TESTING
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
disable_exception_catch
=
b
;
}
#endif
template
<
class
F
>
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
{
#ifdef MIGRAPHX_BUILD_TESTING
if
(
disable_exception_catch
)
{
f
();
}
else
{
#endif
try
{
f
();
...
...
@@ -82,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{
return
migraphx_status_unknown_error
;
}
#ifdef MIGRAPHX_BUILD_TESTING
}
#endif
return
migraphx_status_success
;
}
...
...
src/driver/CMakeLists.txt
View file @
5bafb637
...
...
@@ -48,7 +48,12 @@ rocm_clang_tidy_check(driver)
file
(
STRINGS
"
${
CMAKE_SOURCE_DIR
}
/test/onnx/.onnxrt-commit"
String_output
)
target_compile_definitions
(
driver PUBLIC MIGRAPHX_ORT_SHA1=
"
${
String_output
}
"
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
target_link_libraries
(
driver migraphx_py
)
target_compile_definitions
(
driver PRIVATE MIGRAPHX_ENABLE_PYTHON
)
endif
()
rocm_install_targets
(
TARGETS driver
...
...
src/driver/main.cpp
View file @
5bafb637
...
...
@@ -32,7 +32,9 @@
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#ifdef MIGRAPHX_ENABLE_PYTHON
#include <migraphx/py.hpp>
#endif
#include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp>
...
...
@@ -281,10 +283,12 @@ struct loader
options
.
format
=
"json"
;
p
=
migraphx
::
load
(
file
,
options
);
}
#ifdef MIGRAPHX_ENABLE_PYTHON
else
if
(
file_type
==
"py"
)
{
p
=
migraphx
::
load_py
(
file
);
}
#endif
else
if
(
file_type
==
"migraphx"
)
{
p
=
migraphx
::
load
(
file
);
...
...
src/include/migraphx/instruction_ref.hpp
View file @
5bafb637
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -27,12 +27,42 @@
#include <list>
#include <functional>
#include <migraphx/config.hpp>
#include <migraphx/requires.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
instruction
;
#if defined(_WIN32) && !defined(NDEBUG)
struct
instruction_ref
:
std
::
list
<
instruction
>::
iterator
{
using
instruction_iter
=
std
::
list
<
instruction
>::
iterator
;
using
instruction_const_iter
=
std
::
list
<
instruction
>::
const_iterator
;
instruction_ref
()
=
default
;
instruction_ref
(
const
instruction_iter
&
other
)
:
instruction_iter
(
other
)
{}
template
<
class
T
,
class
U
,
MIGRAPHX_REQUIRES
(
std
::
is_same
<
T
,
instruction_ref
>{}
or
std
::
is_same
<
U
,
instruction_ref
>
{})
>
friend
bool
operator
==
(
const
T
&
x
,
const
U
&
y
)
{
return
x
.
_Unwrapped
().
_Ptr
==
y
.
_Unwrapped
().
_Ptr
;
}
template
<
class
T
,
class
U
,
MIGRAPHX_REQUIRES
(
std
::
is_same
<
T
,
instruction_ref
>{}
or
std
::
is_same
<
U
,
instruction_ref
>
{})
>
friend
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
{
return
!
(
x
==
y
);
}
};
#else
using
instruction_ref
=
std
::
list
<
instruction
>::
iterator
;
#endif
MIGRAPHX_EXPORT
migraphx
::
instruction
*
as_address
(
const
instruction_ref
&
ins
)
noexcept
;
...
...
src/include/migraphx/op/random_uniform.hpp
View file @
5bafb637
...
...
@@ -77,11 +77,26 @@ struct random_uniform
using
type
=
typename
decltype
(
output
)
::
value_type
;
if
constexpr
(
std
::
is_integral
<
type
>
{})
{
// default range for all integer types is
// (0, std::uniform_int_distribution<type>::max()).
// Todo: enable different ranges
std
::
uniform_int_distribution
<
type
>
dis
;
std
::
generate
(
output
.
begin
(),
output
.
end
(),
[
&
]
{
return
dis
(
gen
);
});
#ifdef _MSC_VER
// According to the C++ specification, the effect is undefined if the result type
// for the generator is not one of short, int, long, long long, unsigned short,
// unsigned int, unsigned long, or unsigned long long. See
// https://en.cppreference.com/w/cpp/numeric/random/uniform_int_distribution.
if
constexpr
(
sizeof
(
type
)
==
1
)
{
std
::
uniform_int_distribution
<
int
>
dis
{
std
::
numeric_limits
<
type
>::
min
(),
std
::
numeric_limits
<
type
>::
max
()};
std
::
generate
(
output
.
begin
(),
output
.
end
(),
[
&
]
{
return
dis
(
gen
);
});
}
else
#endif
{
// default range for all integer types is
// (0, std::uniform_int_distribution<type>::max()).
// Todo: enable different ranges
std
::
uniform_int_distribution
<
type
>
dis
;
std
::
generate
(
output
.
begin
(),
output
.
end
(),
[
&
]
{
return
dis
(
gen
);
});
}
}
else
{
...
...
src/onnx/parse_depthtospace.cpp
View file @
5bafb637
...
...
@@ -87,8 +87,7 @@ struct parse_depthtospace : op_parser<parse_depthtospace>
auto
temp1
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
lens1
}}),
args
[
0
]);
auto
temp2
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
temp1
);
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
lens2
}}),
info
.
make_contiguous
(
temp2
));
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
lens2
}}),
temp2
);
}
};
...
...
src/onnx/parse_reshape.cpp
View file @
5bafb637
...
...
@@ -53,8 +53,7 @@ struct parse_reshape : op_parser<parse_reshape>
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
}
auto
cont
=
info
.
add_instruction
(
make_op
(
"contiguous"
),
args
[
0
]);
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
cont
);
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
args
[
0
]);
}
};
...
...
src/onnx/parse_spacetodepth.cpp
View file @
5bafb637
...
...
@@ -73,8 +73,7 @@ struct parse_spacetodepth : op_parser<parse_spacetodepth>
std
::
vector
<
int64_t
>
perm
=
{
0
,
3
,
5
,
1
,
2
,
4
};
auto
temp1
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
trans_lens
}}),
args
[
0
]);
auto
temp2
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
temp1
);
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
res_lens
}}),
info
.
make_contiguous
(
temp2
));
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
res_lens
}}),
temp2
);
}
};
...
...
src/process.cpp
View file @
5bafb637
...
...
@@ -26,13 +26,23 @@
#include <migraphx/env.hpp>
#include <functional>
#include <iostream>
#include <optional>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#else
#include <unistd.h>
#endif
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_CMD_EXECUTE
)
#ifndef _WIN32
std
::
function
<
void
(
const
char
*
)
>
redirect_to
(
std
::
ostream
&
os
)
{
return
[
&
](
const
char
*
x
)
{
os
<<
x
;
};
...
...
@@ -74,6 +84,155 @@ int exec(const std::string& cmd, std::function<void(process::writer)> std_in)
});
}
#else
constexpr
std
::
size_t
MIGRAPHX_PROCESS_BUFSIZE
=
4096
;
class
pipe
{
public:
explicit
pipe
(
bool
inherit_handle
=
true
)
{
SECURITY_ATTRIBUTES
attrs
;
attrs
.
nLength
=
sizeof
(
SECURITY_ATTRIBUTES
);
attrs
.
bInheritHandle
=
inherit_handle
?
TRUE
:
FALSE
;
attrs
.
lpSecurityDescriptor
=
nullptr
;
if
(
CreatePipe
(
&
m_read
,
&
m_write
,
&
attrs
,
0
)
==
FALSE
)
throw
GetLastError
();
if
(
SetHandleInformation
(
&
m_read
,
HANDLE_FLAG_INHERIT
,
0
)
==
FALSE
)
throw
GetLastError
();
}
pipe
(
const
pipe
&
)
=
delete
;
pipe
&
operator
=
(
const
pipe
&
)
=
delete
;
pipe
(
pipe
&&
)
=
default
;
~
pipe
()
{
CloseHandle
(
m_read
);
m_read
=
nullptr
;
CloseHandle
(
m_write
);
m_write
=
nullptr
;
}
std
::
optional
<
std
::
pair
<
bool
,
DWORD
>>
read
(
LPVOID
buffer
,
DWORD
length
)
const
{
DWORD
bytes_read
;
if
(
ReadFile
(
m_read
,
buffer
,
length
,
&
bytes_read
,
nullptr
)
==
FALSE
)
{
DWORD
error
{
GetLastError
()};
if
(
error
!=
ERROR_MORE_DATA
)
{
return
std
::
nullopt
;
}
return
{{
true
,
bytes_read
}};
}
return
{{
false
,
bytes_read
}};
}
HANDLE
get_read_handle
()
const
{
return
m_read
;
}
bool
write
(
LPCVOID
buffer
,
DWORD
length
)
const
{
DWORD
bytes_written
;
return
WriteFile
(
m_write
,
buffer
,
length
,
&
bytes_written
,
nullptr
)
==
TRUE
;
}
HANDLE
get_write_handle
()
const
{
return
m_write
;
}
private:
HANDLE
m_write
=
nullptr
,
m_read
=
nullptr
;
};
template
<
typename
F
>
int
exec
(
const
std
::
string
&
cmd
,
F
f
)
{
try
{
if
(
enabled
(
MIGRAPHX_TRACE_CMD_EXECUTE
{}))
std
::
cout
<<
cmd
<<
std
::
endl
;
STARTUPINFO
info
;
PROCESS_INFORMATION
process_info
;
pipe
in
{},
out
{};
ZeroMemory
(
&
info
,
sizeof
(
STARTUPINFO
));
info
.
cb
=
sizeof
(
STARTUPINFO
);
info
.
hStdError
=
out
.
get_write_handle
();
info
.
hStdOutput
=
out
.
get_write_handle
();
info
.
hStdInput
=
in
.
get_read_handle
();
info
.
dwFlags
|=
STARTF_USESTDHANDLES
;
ZeroMemory
(
&
process_info
,
sizeof
(
process_info
));
if
(
CreateProcess
(
nullptr
,
const_cast
<
LPSTR
>
(
cmd
.
c_str
()),
nullptr
,
nullptr
,
TRUE
,
0
,
nullptr
,
nullptr
,
&
info
,
&
process_info
)
==
FALSE
)
{
return
GetLastError
();
}
f
(
in
,
out
);
WaitForSingleObject
(
process_info
.
hProcess
,
INFINITE
);
DWORD
status
{};
GetExitCodeProcess
(
process_info
.
hProcess
,
&
status
);
CloseHandle
(
process_info
.
hProcess
);
CloseHandle
(
process_info
.
hThread
);
return
static_cast
<
int
>
(
status
);
}
// cppcheck-suppress catchExceptionByValue
catch
(
DWORD
last_error
)
{
return
last_error
;
}
}
int
exec
(
const
std
::
string
&
cmd
)
{
TCHAR
buffer
[
MIGRAPHX_PROCESS_BUFSIZE
];
HANDLE
std_out
{
GetStdHandle
(
STD_OUTPUT_HANDLE
)};
return
(
std_out
==
nullptr
or
std_out
==
INVALID_HANDLE_VALUE
)
?
GetLastError
()
:
exec
(
cmd
,
[
&
](
const
pipe
&
,
const
pipe
&
out
)
{
for
(;;)
{
if
(
auto
result
=
out
.
read
(
buffer
,
MIGRAPHX_PROCESS_BUFSIZE
))
{
auto
[
more_data
,
bytes_read
]
=
*
result
;
if
(
not
more_data
or
bytes_read
==
0
)
break
;
DWORD
written
;
if
(
WriteFile
(
std_out
,
buffer
,
bytes_read
,
&
written
,
nullptr
)
==
FALSE
)
break
;
}
}
});
}
int
exec
(
const
std
::
string
&
cmd
,
std
::
function
<
void
(
process
::
writer
)
>
std_in
)
{
return
exec
(
cmd
,
[
&
](
const
pipe
&
in
,
const
pipe
&
)
{
std_in
([
&
](
const
char
*
buffer
,
std
::
size_t
n
)
{
in
.
write
(
buffer
,
n
);
});
});
}
#endif
struct
process_impl
{
std
::
string
command
{};
...
...
@@ -119,7 +278,14 @@ process& process::cwd(const fs::path& p)
return
*
this
;
}
void
process
::
exec
()
{
impl
->
check_exec
(
impl
->
get_command
(),
redirect_to
(
std
::
cout
));
}
void
process
::
exec
()
{
#ifndef _WIN32
impl
->
check_exec
(
impl
->
get_command
(),
redirect_to
(
std
::
cout
));
#else
impl
->
check_exec
(
impl
->
get_command
());
#endif
}
void
process
::
write
(
std
::
function
<
void
(
process
::
writer
)
>
pipe_in
)
{
...
...
src/py/CMakeLists.txt
View file @
5bafb637
...
...
@@ -22,27 +22,24 @@
# THE SOFTWARE.
#####################################################################################
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
ON
)
add_library
(
migraphx_py py_loader.cpp
)
migraphx_generate_export_header
(
migraphx_py
)
target_include_directories
(
migraphx_py PRIVATE include
)
target_link_libraries
(
migraphx_py PUBLIC migraphx
)
rocm_install_targets
(
TARGETS migraphx_py INCLUDE include
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
include
(
PythonModules
)
include
(
PythonModules
)
foreach
(
PYTHON_VERSION
${
PYTHON_VERSIONS
}
)
py_add_module
(
migraphx_pybind_
${
PYTHON_VERSION
}
migraphx_py.cpp PYTHON_VERSION
${
PYTHON_VERSION
}
PYTHON_MODULE migraphx
)
target_link_libraries
(
migraphx_pybind_
${
PYTHON_VERSION
}
PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets
)
rocm_install_targets
(
TARGETS migraphx_pybind_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_pybind_
${
PYTHON_VERSION
}
)
add_library
(
migraphx_py_
${
PYTHON_VERSION
}
py.cpp
)
target_include_directories
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE include
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PUBLIC migraphx
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE pybind11::pybind11 python
${
PYTHON_VERSION
}
::runtime
)
rocm_install_targets
(
TARGETS migraphx_py_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_py_
${
PYTHON_VERSION
}
)
endforeach
()
endif
()
foreach
(
PYTHON_VERSION
${
PYTHON_VERSIONS
}
)
py_add_module
(
migraphx_pybind_
${
PYTHON_VERSION
}
migraphx_py.cpp PYTHON_VERSION
${
PYTHON_VERSION
}
PYTHON_MODULE migraphx
)
target_link_libraries
(
migraphx_pybind_
${
PYTHON_VERSION
}
PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets
)
rocm_install_targets
(
TARGETS migraphx_pybind_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_pybind_
${
PYTHON_VERSION
}
)
add_library
(
migraphx_py_
${
PYTHON_VERSION
}
py.cpp
)
target_include_directories
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE include
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PUBLIC migraphx
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE pybind11::pybind11 python
${
PYTHON_VERSION
}
::runtime
)
rocm_install_targets
(
TARGETS migraphx_py_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_py_
${
PYTHON_VERSION
}
)
endforeach
()
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
View file @
5bafb637
...
...
@@ -91,6 +91,19 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
}
};
template
<
class
F
>
struct
execute_wrapper
{
F
f
;
argument
operator
()(
context
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
f
(
args
);
}
};
template
<
class
F
>
execute_wrapper
<
F
>
make_execute_wrapper
(
F
f
)
{
return
{
std
::
move
(
f
)};
}
template
<
class
Derived
,
class
Primitive
>
struct
dnnl_op
:
auto_register_op
<
Derived
>
{
...
...
@@ -308,7 +321,7 @@ struct dnnl_op : auto_register_op<Derived>
#ifndef NDEBUG
auto
prim_attr
=
get_primitive_attr
(
md
);
#endif
execute
=
[
=
](
context
&
,
const
std
::
vector
<
argument
>&
args
)
{
execute
=
make_execute_wrapper
([
=
](
const
std
::
vector
<
argument
>&
args
)
{
#ifndef NDEBUG
// Check that the memory descriptors have not changed
auto
debug_args
=
args
;
...
...
@@ -379,7 +392,7 @@ struct dnnl_op : auto_register_op<Derived>
m
[
arg_lookup
[
i
]]
=
to_dnnl_memory
(
md
.
at
(
arg_lookup
[
i
]),
args
[
i
]);
prim
.
execute
(
get_dnnl_context
().
stream
,
m
);
return
args
.
back
();
};
}
)
;
}
std
::
vector
<
shape
>
trim_post_op_inputs
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
...
...
src/tf/parse_reshape.cpp
View file @
5bafb637
...
...
@@ -45,8 +45,7 @@ struct parse_reshape : op_parser<parse_reshape>
auto
s
=
args
[
1
]
->
eval
();
std
::
vector
<
int64_t
>
dims
;
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
info
.
make_contiguous
(
args
[
0
]));
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
args
[
0
]);
}
};
...
...
test/onnx/onnx_test.cpp
View file @
5bafb637
...
...
@@ -1772,8 +1772,7 @@ TEST_CASE(depthtospace_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), l0);
auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), tmp1);
auto
tmp3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
tmp2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
2
,
10
,
10
}}}),
tmp3
);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp2);
auto prog = optimize_onnx("depthtospace_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -1787,8 +1786,7 @@ TEST_CASE(depthtospace_crd_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), l0);
auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 4, 2, 5, 3}}}), tmp1);
auto
tmp3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
tmp2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
2
,
10
,
10
}}}),
tmp3
);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp2);
auto prog = optimize_onnx("depthtospace_crd_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -1802,8 +1800,7 @@ TEST_CASE(depthtospace_simple_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 2, 2, 2, 3}}}), l0);
auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), tmp1);
auto
tmp3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
tmp2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
1
,
2
,
4
,
6
}}}),
tmp3
);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 4, 6}}}), tmp2);
auto prog = optimize_onnx("depthtospace_simple_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -1817,8 +1814,7 @@ TEST_CASE(spacetodepth_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 5, 2, 5, 2}}}), l0);
auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 5, 1, 2, 4}}}), tmp1);
auto
tmp3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
tmp2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
8
,
5
,
5
}}}),
tmp3
);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 8, 5, 5}}}), tmp2);
auto prog = optimize_onnx("spacetodepth_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -1832,8 +1828,7 @@ TEST_CASE(spacetodepth_simple_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 2, 2, 3, 2}}}), l0);
auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 5, 1, 2, 4}}}), tmp1);
auto
tmp3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
tmp2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
1
,
8
,
2
,
3
}}}),
tmp3
);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 8, 2, 3}}}), tmp2);
auto prog = optimize_onnx("spacetodepth_simple_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -5491,12 +5486,9 @@ TEST_CASE(reshape_test)
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims;
auto
c0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
mm
->
add_instruction
(
op
,
c0
);
auto
c1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
mm
->
add_instruction
(
op
,
c1
);
mm->add_instruction(op, l0);
mm->add_instruction(op, l0);
auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -5509,8 +5501,7 @@ TEST_CASE(reshape_non_standard_test)
auto x = mm->add_parameter("x", s);
auto tran_x =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), x);
auto
cont_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
tran_x
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
,
3
,
2
}}}),
cont_x
);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, 3, 2}}}), tran_x);
auto prog = optimize_onnx("reshape_non_standard_test.onnx");
EXPECT(p == prog);
...
...
tools/api/api.cpp
View file @
5bafb637
...
...
@@ -41,24 +41,29 @@
#include <array>
#include <algorithm>
#include <cstdarg>
namespace
migraphx
{
#ifdef MIGRAPHX_BUILD_TESTING
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
disable_exception_catch
=
b
;
}
#endif
template
<
class
F
>
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
{
#ifdef MIGRAPHX_BUILD_TESTING
if
(
disable_exception_catch
)
{
f
();
}
else
{
#endif
try
{
f
();
...
...
@@ -82,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{
return
migraphx_status_unknown_error
;
}
#ifdef MIGRAPHX_BUILD_TESTING
}
#endif
return
migraphx_status_success
;
}
...
...
tools/check_stamped.py
View file @
5bafb637
...
...
@@ -2,7 +2,7 @@
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
...
...
@@ -27,11 +27,11 @@ import sys
debug
=
False
# The filetypes we want to check for that are stamped
# LICENSE is included here as it SHOULD have a li
s
cen
c
e in it otherwise flag it as unstamped
# LICENSE is included here as it SHOULD have a licen
s
e in it otherwise flag it as unstamped
supported_file_types
=
(
".cpp"
,
".hpp"
,
".h"
,
".ipynb"
,
".py"
,
".txt"
,
".sh"
,
".bsh"
,
"LICENSE"
,
".cmake"
)
#add general stuff we shouldn't stamp and any exceptions here
#
add general stuff we shouldn't stamp and any exceptions here
unsupported_file_types
=
[
".onnx"
,
".pb"
,
".rst"
,
".jpg"
,
".jpeg"
,
".proto"
,
".md"
,
".clang"
,
".weight"
,
".ini"
,
".json"
,
".docker"
,
".git"
,
".rules"
,
".yml"
...
...
@@ -40,105 +40,89 @@ unsupported_file_types = [
specificIgnores
=
(
"digits.txt"
,
"Dockerfile"
,
"Jenkinsfile"
,
""
)
def
hasKeySequence
(
inputfile
,
key_message
):
result
=
False
def
hasKeySequence
(
inputfile
:
str
,
key_message
:
str
)
->
bool
:
if
key_message
in
inputfile
:
re
sult
=
True
return
result
re
turn
True
return
False
#Simple just open and write stuff to each file with the license stamp
def
openAndCheckFile
(
filename
):
result
=
False
#open save old contents and append things here
if
debug
is
True
:
print
(
"Open"
,
filename
,
end
=
''
)
# Simple just open and write stuff to each file with the license stamp
def
needStampCheck
(
filename
:
str
)
->
bool
:
# open save old contents and append things here
if
debug
:
print
(
"Open"
,
filename
,
end
=
' '
)
try
:
file
=
open
(
filename
,
'r'
)
except
OSError
as
e
:
if
debug
is
True
:
print
(
str
(
e
)
+
"....Open Error: Skipping file "
)
if
debug
:
print
(
str
(
e
)
+
"....Open Error: Skipping file "
)
file
.
close
()
return
return
False
else
:
with
file
as
contents
:
try
:
save
=
contents
.
read
()
hasAmdLic
=
hasKeySequence
(
save
,
"Advanced Micro Devices, Inc. All rights reserved"
)
#Check if we have a licence stamp already
if
hasAmdLic
is
True
:
if
debug
is
True
:
print
(
"....Already Stamped: Skipping file "
)
# Check if we have a license stamp already
if
hasKeySequence
(
save
,
"Advanced Micro Devices, Inc. All rights reserved"
):
if
debug
:
print
(
"....Already Stamped: Skipping file "
)
contents
.
close
()
re
sult
=
Tru
e
re
turn
Fals
e
except
UnicodeDecodeError
as
eu
:
if
debug
is
True
:
print
(
str
(
eu
)
+
"...Skipping binary file "
)
if
debug
:
print
(
f
"
{
str
(
eu
)
}
...Skipping binary file "
)
contents
.
close
()
re
sult
=
Tru
e
re
turn
Fals
e
return
result
return
True
# Deterine if filename is desired in the fileTuple past in
def
check_filename
(
filename
,
fileTuple
):
supported
=
False
for
key
in
fileTuple
:
if
key
in
filename
:
supported
=
True
break
return
supported
# Check if any element in fileTuple is in filename
def
check_filename
(
filename
:
str
,
fileTuple
:
tuple
or
list
)
->
bool
:
if
any
([
x
in
filename
for
x
in
fileTuple
]):
return
True
return
False
def
main
():
def
main
()
->
None
:
unsupported_file_types
.
extend
(
specificIgnores
)
#Get a list of all the tracked files in our git repo
#
Get a list of all the tracked files in our git repo
proc
=
subprocess
.
run
(
"git ls-files --exclude-standard"
,
shell
=
True
,
stdout
=
subprocess
.
PIPE
)
fileList
=
proc
.
stdout
.
decode
().
split
(
'
\n
'
)
if
debug
is
True
:
print
(
"Target file list:
\n
"
+
str
(
fileList
))
if
debug
:
print
(
"Target file list:
\n
"
+
str
(
fileList
))
unsupportedFiles
=
[]
unstampedFiles
=
[]
unknownFiles
=
[]
for
file
in
fileList
:
supported
=
check_filename
(
file
,
supported_file_types
)
if
supported
is
True
:
isStamped
=
openAndCheckFile
(
file
)
if
isStamped
is
False
:
if
check_filename
(
file
,
supported_file_types
):
if
needStampCheck
(
file
):
unstampedFiles
.
append
(
file
)
elif
check_filename
(
file
,
unsupported_file_types
):
unsupportedFiles
.
append
(
file
)
else
:
unsupported
=
check_filename
(
file
,
unsupported_file_types
)
if
unsupported
is
True
:
unsupportedFiles
.
append
(
file
)
else
:
unknownFiles
.
append
(
file
)
unknownFiles
.
append
(
file
)
#Do a bunch of checks based on our file lists
#
Do a bunch of checks based on our file lists
if
len
(
unstampedFiles
)
>
0
:
print
(
"Error: The following "
+
str
(
len
(
unstampedFiles
))
+
print
(
"
\n
Error: The following "
+
str
(
len
(
unstampedFiles
))
+
" files are currently without a license:"
)
print
(
str
(
unstampedFiles
))
sys
.
exit
(
1
)
if
len
(
unknownFiles
)
>
0
:
print
(
"Error: The following "
+
str
(
len
(
unknownFiles
))
+
print
(
"
\n
Error: The following "
+
str
(
len
(
unknownFiles
))
+
" files not handled:"
)
print
(
str
(
unknownFiles
))
sys
.
exit
(
2
)
sys
.
exit
(
0
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment