Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
124f7d55
Commit
124f7d55
authored
Oct 12, 2023
by
Umang Yadav
Browse files
Merge branch 'develop' into resnet50_partition
parents
350bbea2
34b68ee4
Changes
288
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
182 additions
and
42 deletions
+182
-42
src/api/CMakeLists.txt
src/api/CMakeLists.txt
+4
-0
src/api/api.cpp
src/api/api.cpp
+8
-0
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+1
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+1
-1
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+0
-1
src/compile_src.cpp
src/compile_src.cpp
+1
-1
src/cpp_generator.cpp
src/cpp_generator.cpp
+2
-2
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+6
-1
src/driver/main.cpp
src/driver/main.cpp
+15
-5
src/driver/verify.cpp
src/driver/verify.cpp
+38
-14
src/driver/verify.hpp
src/driver/verify.hpp
+4
-3
src/dynamic_loader.cpp
src/dynamic_loader.cpp
+58
-0
src/include/migraphx/auto_register.hpp
src/include/migraphx/auto_register.hpp
+3
-4
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+13
-3
src/include/migraphx/compile_src.hpp
src/include/migraphx/compile_src.hpp
+12
-2
src/include/migraphx/config.hpp
src/include/migraphx/config.hpp
+1
-0
src/include/migraphx/dynamic_loader.hpp
src/include/migraphx/dynamic_loader.hpp
+2
-0
src/include/migraphx/filesystem.hpp
src/include/migraphx/filesystem.hpp
+11
-0
src/include/migraphx/float_equal.hpp
src/include/migraphx/float_equal.hpp
+0
-3
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+2
-2
No files found.
src/api/CMakeLists.txt
View file @
124f7d55
...
@@ -32,6 +32,10 @@ migraphx_generate_export_header(migraphx_c DIRECTORY migraphx/api)
...
@@ -32,6 +32,10 @@ migraphx_generate_export_header(migraphx_c DIRECTORY migraphx/api)
# bumped when binary compatibility is broken.
# bumped when binary compatibility is broken.
rocm_set_soversion
(
migraphx_c 3.0
)
rocm_set_soversion
(
migraphx_c 3.0
)
if
(
BUILD_TESTING
)
target_compile_definitions
(
migraphx_c PRIVATE MIGRAPHX_BUILD_TESTING
)
endif
()
rocm_clang_tidy_check
(
migraphx_c
)
rocm_clang_tidy_check
(
migraphx_c
)
target_link_libraries
(
migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx
)
target_link_libraries
(
migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx
)
...
...
src/api/api.cpp
View file @
124f7d55
...
@@ -38,26 +38,32 @@
...
@@ -38,26 +38,32 @@
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <array>
#include <algorithm>
#include <algorithm>
#include <cstdarg>
#include <cstdarg>
namespace
migraphx
{
namespace
migraphx
{
#ifdef MIGRAPHX_BUILD_TESTING
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
{
disable_exception_catch
=
b
;
disable_exception_catch
=
b
;
}
}
#endif
template
<
class
F
>
template
<
class
F
>
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
{
{
#ifdef MIGRAPHX_BUILD_TESTING
if
(
disable_exception_catch
)
if
(
disable_exception_catch
)
{
{
f
();
f
();
}
}
else
else
{
{
#endif
try
try
{
{
f
();
f
();
...
@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
...
@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{
{
return
migraphx_status_unknown_error
;
return
migraphx_status_unknown_error
;
}
}
#ifdef MIGRAPHX_BUILD_TESTING
}
}
#endif
return
migraphx_status_success
;
return
migraphx_status_success
;
}
}
...
...
src/api/include/migraphx/migraphx.h
View file @
124f7d55
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <stdlib.h>
#include <stdlib.h>
#include <stdbool.h>
#include <stdbool.h>
#include <stdint.h>
#include <migraphx/api/export.h>
#include <migraphx/api/export.h>
...
...
src/api/include/migraphx/migraphx.hpp
View file @
124f7d55
...
@@ -66,7 +66,7 @@ template <class PrivateMigraphTypeNameProbe>
...
@@ -66,7 +66,7 @@ template <class PrivateMigraphTypeNameProbe>
std
::
string
compute_type_name
()
std
::
string
compute_type_name
()
{
{
std
::
string
name
;
std
::
string
name
;
#ifdef
_MSC_VER
#if
def
ined(
_MSC_VER
) && !defined(__clang__)
name
=
typeid
(
PrivateMigraphTypeNameProbe
).
name
();
name
=
typeid
(
PrivateMigraphTypeNameProbe
).
name
();
name
=
name
.
substr
(
7
);
name
=
name
.
substr
(
7
);
#else
#else
...
...
src/auto_contiguous.cpp
View file @
124f7d55
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/compile_src.cpp
View file @
124f7d55
...
@@ -46,7 +46,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
...
@@ -46,7 +46,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
fs
::
path
full_path
=
td
.
path
/
src
.
path
;
fs
::
path
full_path
=
td
.
path
/
src
.
path
;
fs
::
path
parent_path
=
full_path
.
parent_path
();
fs
::
path
parent_path
=
full_path
.
parent_path
();
fs
::
create_directories
(
parent_path
);
fs
::
create_directories
(
parent_path
);
write_buffer
(
full_path
.
string
(),
src
.
content
.
first
,
src
.
len
());
write_buffer
(
full_path
.
string
(),
src
.
content
.
data
(),
src
.
content
.
size
());
if
(
src
.
path
.
extension
().
string
()
==
".cpp"
)
if
(
src
.
path
.
extension
().
string
()
==
".cpp"
)
{
{
params
+=
" "
+
src
.
path
.
filename
().
string
();
params
+=
" "
+
src
.
path
.
filename
().
string
();
...
...
src/cpp_generator.cpp
View file @
124f7d55
...
@@ -213,13 +213,13 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
...
@@ -213,13 +213,13 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
assert
(
v
.
size
()
==
1
);
assert
(
v
.
size
()
==
1
);
auto
x
=
v
.
front
();
auto
x
=
v
.
front
();
if
(
std
::
isinf
(
x
))
if
(
std
::
isinf
(
static_cast
<
double
>
(
x
)
))
{
{
string_literal
=
"__builtin_huge_val()"
;
string_literal
=
"__builtin_huge_val()"
;
if
(
x
<
0
)
if
(
x
<
0
)
string_literal
=
"-__builtin_huge_val()"
;
string_literal
=
"-__builtin_huge_val()"
;
}
}
else
if
(
std
::
isnan
(
x
))
else
if
(
std
::
isnan
(
static_cast
<
double
>
(
x
)
))
string_literal
=
"__builtin_nan()"
;
string_literal
=
"__builtin_nan()"
;
else
else
string_literal
=
ins
->
get_literal
().
to_string
();
string_literal
=
ins
->
get_literal
().
to_string
();
...
...
src/driver/CMakeLists.txt
View file @
124f7d55
...
@@ -48,7 +48,12 @@ rocm_clang_tidy_check(driver)
...
@@ -48,7 +48,12 @@ rocm_clang_tidy_check(driver)
file
(
STRINGS
"
${
CMAKE_SOURCE_DIR
}
/test/onnx/.onnxrt-commit"
String_output
)
file
(
STRINGS
"
${
CMAKE_SOURCE_DIR
}
/test/onnx/.onnxrt-commit"
String_output
)
target_compile_definitions
(
driver PUBLIC MIGRAPHX_ORT_SHA1=
"
${
String_output
}
"
)
target_compile_definitions
(
driver PUBLIC MIGRAPHX_ORT_SHA1=
"
${
String_output
}
"
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
target_link_libraries
(
driver migraphx_py
)
target_compile_definitions
(
driver PRIVATE MIGRAPHX_ENABLE_PYTHON
)
endif
()
rocm_install_targets
(
rocm_install_targets
(
TARGETS driver
TARGETS driver
...
...
src/driver/main.cpp
View file @
124f7d55
...
@@ -32,7 +32,9 @@
...
@@ -32,7 +32,9 @@
#include <migraphx/tf.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
#ifdef MIGRAPHX_ENABLE_PYTHON
#include <migraphx/py.hpp>
#include <migraphx/py.hpp>
#endif
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/load_save.hpp>
...
@@ -281,10 +283,12 @@ struct loader
...
@@ -281,10 +283,12 @@ struct loader
options
.
format
=
"json"
;
options
.
format
=
"json"
;
p
=
migraphx
::
load
(
file
,
options
);
p
=
migraphx
::
load
(
file
,
options
);
}
}
#ifdef MIGRAPHX_ENABLE_PYTHON
else
if
(
file_type
==
"py"
)
else
if
(
file_type
==
"py"
)
{
{
p
=
migraphx
::
load_py
(
file
);
p
=
migraphx
::
load_py
(
file
);
}
}
#endif
else
if
(
file_type
==
"migraphx"
)
else
if
(
file_type
==
"migraphx"
)
{
{
p
=
migraphx
::
load
(
file
);
p
=
migraphx
::
load
(
file
);
...
@@ -536,13 +540,19 @@ struct params : command<params>
...
@@ -536,13 +540,19 @@ struct params : command<params>
struct
verify
:
command
<
verify
>
struct
verify
:
command
<
verify
>
{
{
compiler
c
;
compiler
c
;
double
tolerance
=
80
;
migraphx
::
verify
::
tolerance
tols
;
bool
per_instruction
=
false
;
bool
per_instruction
=
false
;
bool
reduce
=
false
;
bool
reduce
=
false
;
void
parse
(
argument_parser
&
ap
)
void
parse
(
argument_parser
&
ap
)
{
{
c
.
parse
(
ap
);
c
.
parse
(
ap
);
ap
(
tolerance
,
{
"--tolerance"
},
ap
.
help
(
"Tolerance for errors"
));
ap
(
tols
.
rms_tol
,
{
"--rms-tol"
},
ap
.
help
(
"Tolerance for the RMS error (Default: 0.001)"
));
ap
(
tols
.
atol
,
{
"--atol"
},
ap
.
help
(
"Tolerance for the elementwise absolute difference (Default: 0.001)"
));
ap
(
tols
.
rtol
,
{
"--rtol"
},
ap
.
help
(
"Tolerance for the elementwise relative difference (Default: 0.001)"
));
ap
(
per_instruction
,
ap
(
per_instruction
,
{
"-i"
,
"--per-instruction"
},
{
"-i"
,
"--per-instruction"
},
ap
.
help
(
"Verify each instruction"
),
ap
.
help
(
"Verify each instruction"
),
...
@@ -567,15 +577,15 @@ struct verify : command<verify>
...
@@ -567,15 +577,15 @@ struct verify : command<verify>
if
(
per_instruction
)
if
(
per_instruction
)
{
{
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tol
erance
);
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tol
s
);
}
}
else
if
(
reduce
)
else
if
(
reduce
)
{
{
verify_reduced_program
(
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
erance
);
verify_reduced_program
(
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
s
);
}
}
else
else
{
{
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
erance
);
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
s
);
}
}
}
}
};
};
...
...
src/driver/verify.cpp
View file @
124f7d55
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
namespace
driver
{
namespace
driver
{
...
@@ -76,15 +77,25 @@ void verify_program(const std::string& name,
...
@@ -76,15 +77,25 @@ void verify_program(const std::string& name,
compile_options
options
,
compile_options
options
,
precision
quantize
,
precision
quantize
,
const
parameter_map
&
inputs
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
{
auto
x
=
run_ref
(
p
,
inputs
);
auto
ref_outs
=
run_ref
(
p
,
inputs
);
auto
y
=
run_target
(
p
,
t
,
options
,
quantize
,
inputs
);
auto
target_outs
=
run_target
(
p
,
t
,
options
,
quantize
,
inputs
);
std
::
size_t
output_num
=
x
.
size
();
std
::
size_t
output_num
=
ref_outs
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
output_num
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
output_num
;
++
i
)
{
{
verify_args
(
name
,
x
[
i
],
y
[
i
],
tolerance
);
if
(
ref_outs
[
i
].
get_shape
().
type
()
!=
target_outs
[
i
].
get_shape
().
type
()
or
ref_outs
[
i
].
get_shape
().
lens
()
!=
target_outs
[
i
].
get_shape
().
lens
())
{
std
::
cout
<<
"FAILED: "
<<
name
<<
std
::
endl
;
std
::
cout
<<
"Shape mismatch {"
<<
ref_outs
[
i
].
get_shape
()
<<
"} != {"
<<
target_outs
[
i
].
get_shape
()
<<
"}"
<<
std
::
endl
;
}
else
{
verify_args
(
name
,
target_outs
[
i
],
verify
::
expected
{
ref_outs
[
i
]},
tols
);
}
}
}
}
}
...
@@ -92,7 +103,7 @@ void verify_instructions(const program& prog,
...
@@ -92,7 +103,7 @@ void verify_instructions(const program& prog,
const
target
&
t
,
const
target
&
t
,
compile_options
options
,
compile_options
options
,
precision
quantize
,
precision
quantize
,
double
tolerance
)
verify
::
tolerance
tols
)
{
{
const
auto
*
mm_prog
=
prog
.
get_main_module
();
const
auto
*
mm_prog
=
prog
.
get_main_module
();
for
(
auto
&&
ins
:
(
*
mm_prog
))
for
(
auto
&&
ins
:
(
*
mm_prog
))
...
@@ -123,8 +134,7 @@ void verify_instructions(const program& prog,
...
@@ -123,8 +134,7 @@ void verify_instructions(const program& prog,
{
{
std
::
cout
<<
"Verify: "
<<
ins
.
name
()
<<
std
::
endl
;
std
::
cout
<<
"Verify: "
<<
ins
.
name
()
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
verify_program
(
verify_program
(
ins
.
name
(),
p
,
t
,
options
,
quantize
,
create_param_map
(
p
,
false
),
tols
);
ins
.
name
(),
p
,
t
,
options
,
quantize
,
create_param_map
(
p
,
false
),
tolerance
);
}
}
catch
(...)
catch
(...)
{
{
...
@@ -140,14 +150,22 @@ void verify_reduced(program p,
...
@@ -140,14 +150,22 @@ void verify_reduced(program p,
compile_options
options
,
compile_options
options
,
precision
quantize
,
precision
quantize
,
const
parameter_map
&
inputs
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
{
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
last
=
std
::
prev
(
mm
->
end
(),
n
+
1
);
auto
last
=
std
::
prev
(
mm
->
end
(),
n
);
mm
->
remove_instructions
(
last
,
mm
->
end
());
mm
->
remove_instructions
(
last
,
mm
->
end
());
std
::
cout
<<
"Verify: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
"Verify: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
try
{
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tols
);
}
catch
(
const
std
::
exception
&
e
)
{
std
::
cout
<<
"FAILED: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
"Exception: "
<<
e
.
what
()
<<
std
::
endl
;
}
}
}
void
verify_reduced_program
(
const
program
&
p
,
void
verify_reduced_program
(
const
program
&
p
,
...
@@ -155,14 +173,20 @@ void verify_reduced_program(const program& p,
...
@@ -155,14 +173,20 @@ void verify_reduced_program(const program& p,
compile_options
options
,
compile_options
options
,
precision
quantize
,
precision
quantize
,
const
parameter_map
&
inputs
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
{
const
auto
*
mm
=
p
.
get_main_module
();
const
auto
*
mm
=
p
.
get_main_module
();
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
std
::
cout
<<
"Verify steps: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
"Verify steps: "
<<
n
<<
std
::
endl
;
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
n
;
i
++
)
{
{
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
auto
last
=
std
::
prev
(
mm
->
end
(),
i
+
1
);
if
(
contains
({
"@literal"
,
"@param"
},
last
->
name
()))
{
std
::
cout
<<
"Skip: "
<<
i
<<
std
::
endl
;
continue
;
}
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tols
);
}
}
}
}
...
...
src/driver/verify.hpp
View file @
124f7d55
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include "precision.hpp"
#include "precision.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/verify.hpp>
namespace
migraphx
{
namespace
migraphx
{
namespace
driver
{
namespace
driver
{
...
@@ -37,18 +38,18 @@ void verify_program(const std::string& name,
...
@@ -37,18 +38,18 @@ void verify_program(const std::string& name,
compile_options
options
=
compile_options
{},
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
precision
quantize
=
precision
::
fp32
,
const
parameter_map
&
inputs
=
{},
const
parameter_map
&
inputs
=
{},
double
tolerance
=
100
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
void
verify_instructions
(
const
program
&
prog
,
void
verify_instructions
(
const
program
&
prog
,
const
target
&
t
,
const
target
&
t
,
compile_options
options
=
compile_options
{},
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
precision
quantize
=
precision
::
fp32
,
double
tolerance
=
80
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
void
verify_reduced_program
(
const
program
&
p
,
void
verify_reduced_program
(
const
program
&
p
,
const
target
&
t
,
const
target
&
t
,
compile_options
options
=
compile_options
{},
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
precision
quantize
=
precision
::
fp32
,
const
parameter_map
&
inputs
=
{},
const
parameter_map
&
inputs
=
{},
double
tolerance
=
80
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace driver
...
...
src/dynamic_loader.cpp
View file @
124f7d55
...
@@ -27,11 +27,20 @@
...
@@ -27,11 +27,20 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
#include <migraphx/tmp_dir.hpp>
#include <utility>
#include <utility>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#else
#include <dlfcn.h>
#include <dlfcn.h>
#endif
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#ifndef _WIN32
void
check_load_error
(
bool
flush
=
false
)
void
check_load_error
(
bool
flush
=
false
)
{
{
char
*
error_msg
=
dlerror
();
char
*
error_msg
=
dlerror
();
...
@@ -81,6 +90,48 @@ fs::path dynamic_loader::path(void* address)
...
@@ -81,6 +90,48 @@ fs::path dynamic_loader::path(void* address)
return
p
;
return
p
;
}
}
#else
struct
dynamic_loader_impl
{
dynamic_loader_impl
()
=
default
;
dynamic_loader_impl
(
const
fs
::
path
&
p
,
tmp_dir
t
=
{})
:
handle
{
LoadLibrary
(
p
.
string
().
c_str
())},
temp
{
std
::
move
(
t
)}
{
if
(
handle
==
nullptr
)
{
MIGRAPHX_THROW
(
"Error loading DLL: "
+
p
.
string
()
+
" ("
+
std
::
to_string
(
GetLastError
())
+
")"
);
}
}
dynamic_loader_impl
(
const
dynamic_loader_impl
&
)
=
delete
;
dynamic_loader_impl
&
operator
=
(
const
dynamic_loader_impl
&
)
=
delete
;
dynamic_loader_impl
(
dynamic_loader_impl
&&
)
=
default
;
~
dynamic_loader_impl
()
{
if
(
handle
!=
nullptr
)
{
FreeLibrary
(
handle
);
}
}
static
std
::
shared_ptr
<
dynamic_loader_impl
>
from_buffer
(
const
char
*
image
,
std
::
size_t
size
)
{
auto
t
=
tmp_dir
{
"migx-dynload"
};
auto
f
=
t
.
path
/
"tmp.dll"
;
write_buffer
(
f
.
string
(),
image
,
size
);
return
std
::
make_shared
<
dynamic_loader_impl
>
(
f
,
std
::
move
(
t
));
}
HMODULE
handle
=
nullptr
;
tmp_dir
temp
;
};
#endif
optional
<
dynamic_loader
>
dynamic_loader
::
try_load
(
const
fs
::
path
&
p
)
optional
<
dynamic_loader
>
dynamic_loader
::
try_load
(
const
fs
::
path
&
p
)
{
{
try
try
...
@@ -109,12 +160,19 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
...
@@ -109,12 +160,19 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std
::
shared_ptr
<
void
>
dynamic_loader
::
get_symbol
(
const
std
::
string
&
name
)
const
std
::
shared_ptr
<
void
>
dynamic_loader
::
get_symbol
(
const
std
::
string
&
name
)
const
{
{
#ifndef _WIN32
// flush any previous error messages
// flush any previous error messages
check_load_error
(
true
);
check_load_error
(
true
);
void
*
symbol
=
dlsym
(
impl
->
handle
.
get
(),
name
.
c_str
());
void
*
symbol
=
dlsym
(
impl
->
handle
.
get
(),
name
.
c_str
());
if
(
symbol
==
nullptr
)
if
(
symbol
==
nullptr
)
check_load_error
();
check_load_error
();
return
{
impl
,
symbol
};
return
{
impl
,
symbol
};
#else
FARPROC
addr
=
GetProcAddress
(
impl
->
handle
,
name
.
c_str
());
if
(
addr
==
nullptr
)
MIGRAPHX_THROW
(
"Symbol not found: "
+
name
+
" ("
+
std
::
to_string
(
GetLastError
())
+
")"
);
return
{
impl
,
reinterpret_cast
<
void
*>
(
addr
)};
#endif
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/auto_register.hpp
View file @
124f7d55
...
@@ -62,10 +62,9 @@ const int auto_register<Action, T>::static_register = auto_register_action<Actio
...
@@ -62,10 +62,9 @@ const int auto_register<Action, T>::static_register = auto_register_action<Actio
#define MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x) migraphx_auto_register_##x
#define MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x) migraphx_auto_register_##x
#define MIGRAPHX_AUTO_REGISTER_NAME(x) MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x)
#define MIGRAPHX_AUTO_REGISTER_NAME(x) MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x)
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_AUTO_REGISTER(...) \
#define MIGRAPHX_AUTO_REGISTER(...) \
void MIGRAPHX_AUTO_REGISTER_NAME(__LINE__)(migraphx::auto_register<__VA_ARGS__> x = \
[[maybe_unused]] void MIGRAPHX_AUTO_REGISTER_NAME(__LINE__)( \
migraphx::auto_register<__VA_ARGS__>{}) \
migraphx::auto_register<__VA_ARGS__> x = migraphx::auto_register<__VA_ARGS__>{});
__attribute__((unused));
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/check_shapes.hpp
View file @
124f7d55
...
@@ -153,7 +153,7 @@ struct check_shapes
...
@@ -153,7 +153,7 @@ struct check_shapes
{
{
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
!=
n
)
if
(
begin
->
ndim
()
!=
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
MIGRAPHX_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
}
}
return
*
this
;
return
*
this
;
...
@@ -168,7 +168,7 @@ struct check_shapes
...
@@ -168,7 +168,7 @@ struct check_shapes
{
{
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
>
n
)
if
(
begin
->
ndim
()
>
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at most "
+
std
::
to_string
(
n
)
+
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at most "
+
std
::
to_string
(
n
)
+
" dimensions"
);
" dimensions"
);
}
}
...
@@ -184,7 +184,7 @@ struct check_shapes
...
@@ -184,7 +184,7 @@ struct check_shapes
{
{
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
<
n
)
if
(
begin
->
ndim
()
<
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at least "
+
std
::
to_string
(
n
)
+
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at least "
+
std
::
to_string
(
n
)
+
" dimensions"
);
" dimensions"
);
}
}
...
@@ -254,6 +254,16 @@ struct check_shapes
...
@@ -254,6 +254,16 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are scalar.
*/
const
check_shapes
&
scalar
()
const
{
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar"
);
return
*
this
;
}
/*!
/*!
* Check all shapes are standard or scalar.
* Check all shapes are standard or scalar.
*/
*/
...
...
src/include/migraphx/compile_src.hpp
View file @
124f7d55
...
@@ -37,8 +37,18 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -37,8 +37,18 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
src_file
struct
src_file
{
{
fs
::
path
path
;
fs
::
path
path
;
std
::
pair
<
const
char
*
,
const
char
*>
content
;
std
::
string_view
content
;
std
::
size_t
len
()
const
{
return
content
.
second
-
content
.
first
;
}
src_file
()
=
default
;
src_file
(
fs
::
path
file_path
,
std
::
string_view
file_content
)
:
path
{
std
::
move
(
file_path
)},
content
{
file_content
}
{
}
explicit
src_file
(
const
std
::
pair
<
std
::
string_view
,
std
::
string_view
>&
pair
)
:
path
{
pair
.
first
},
content
{
pair
.
second
}
{
}
};
};
struct
MIGRAPHX_EXPORT
src_compiler
struct
MIGRAPHX_EXPORT
src_compiler
...
...
src/include/migraphx/config.hpp
View file @
124f7d55
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_CONFIG_HPP
#define MIGRAPHX_GUARD_CONFIG_HPP
#include <migraphx/export.h>
#include <migraphx/export.h>
#include <ciso646>
#if !defined(MIGRAPHX_USE_CLANG_TIDY) && !defined(DOXYGEN)
#if !defined(MIGRAPHX_USE_CLANG_TIDY) && !defined(DOXYGEN)
...
...
src/include/migraphx/dynamic_loader.hpp
View file @
124f7d55
...
@@ -38,12 +38,14 @@ struct dynamic_loader_impl;
...
@@ -38,12 +38,14 @@ struct dynamic_loader_impl;
struct
MIGRAPHX_EXPORT
dynamic_loader
struct
MIGRAPHX_EXPORT
dynamic_loader
{
{
#ifndef _WIN32
template
<
class
T
>
template
<
class
T
>
static
fs
::
path
path
(
T
*
address
)
static
fs
::
path
path
(
T
*
address
)
{
{
return
path
(
reinterpret_cast
<
void
*>
(
address
));
return
path
(
reinterpret_cast
<
void
*>
(
address
));
}
}
static
fs
::
path
path
(
void
*
address
);
static
fs
::
path
path
(
void
*
address
);
#endif
static
optional
<
dynamic_loader
>
try_load
(
const
fs
::
path
&
p
);
static
optional
<
dynamic_loader
>
try_load
(
const
fs
::
path
&
p
);
...
...
src/include/migraphx/filesystem.hpp
View file @
124f7d55
...
@@ -29,6 +29,17 @@
...
@@ -29,6 +29,17 @@
#if defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(_WIN32)
#if _MSC_VER >= 1920
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
#elif _MSC_VER >= 1900
#define MIGRAPHX_HAS_FILESYSTEM 0
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#else
#define MIGRAPHX_HAS_FILESYSTEM 0
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
#endif
#elif defined(__has_include)
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM 1
...
...
src/include/migraphx/float_equal.hpp
View file @
124f7d55
...
@@ -27,9 +27,6 @@
...
@@ -27,9 +27,6 @@
#include <algorithm>
#include <algorithm>
#include <cmath>
#include <cmath>
#include <numeric>
#include <numeric>
#ifdef _MSC_VER
#include <iso646.h>
#endif
#include <migraphx/requires.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
...
src/include/migraphx/generate.hpp
View file @
124f7d55
...
@@ -48,7 +48,7 @@ constexpr T normalize(unsigned long z)
...
@@ -48,7 +48,7 @@ constexpr T normalize(unsigned long z)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
constexpr
T
normalize
(
unsigned
long
z
)
{
{
const
auto
max
=
1UL
<<
(
sizeof
(
T
)
*
5
);
const
auto
max
=
1UL
L
<<
(
sizeof
(
T
)
*
5
);
const
auto
half_max
=
max
/
2
;
const
auto
half_max
=
max
/
2
;
return
half_max
-
(
z
%
max
);
return
half_max
-
(
z
%
max
);
}
}
...
@@ -58,7 +58,7 @@ template <class T,
...
@@ -58,7 +58,7 @@ template <class T,
not
std
::
is_same
<
T
,
bool
>
{})
>
not
std
::
is_same
<
T
,
bool
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
constexpr
T
normalize
(
unsigned
long
z
)
{
{
const
auto
max
=
1UL
<<
(
sizeof
(
T
)
*
5
);
const
auto
max
=
1UL
L
<<
(
sizeof
(
T
)
*
5
);
return
z
%
max
;
return
z
%
max
;
}
}
...
...
Prev
1
2
3
4
5
6
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment