Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
13d14c66
Commit
13d14c66
authored
Oct 24, 2023
by
Brian Pickrell
Browse files
Merge branch 'develop' into dyn_resize_gather
parents
f4e7d9d9
d1abf06f
Changes
420
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
266 additions
and
46 deletions
+266
-46
src/api/api.cpp
src/api/api.cpp
+8
-0
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+1
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+1
-1
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+0
-1
src/compile_src.cpp
src/compile_src.cpp
+1
-1
src/cpp_generator.cpp
src/cpp_generator.cpp
+2
-2
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+9
-1
src/driver/argument_parser.hpp
src/driver/argument_parser.hpp
+7
-0
src/driver/main.cpp
src/driver/main.cpp
+35
-9
src/driver/verify.cpp
src/driver/verify.cpp
+74
-14
src/driver/verify.hpp
src/driver/verify.hpp
+10
-3
src/dynamic_loader.cpp
src/dynamic_loader.cpp
+58
-0
src/include/migraphx/argument.hpp
src/include/migraphx/argument.hpp
+1
-1
src/include/migraphx/auto_register.hpp
src/include/migraphx/auto_register.hpp
+3
-4
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+30
-4
src/include/migraphx/compile_src.hpp
src/include/migraphx/compile_src.hpp
+12
-2
src/include/migraphx/config.hpp
src/include/migraphx/config.hpp
+1
-0
src/include/migraphx/dynamic_loader.hpp
src/include/migraphx/dynamic_loader.hpp
+2
-0
src/include/migraphx/filesystem.hpp
src/include/migraphx/filesystem.hpp
+11
-0
src/include/migraphx/float_equal.hpp
src/include/migraphx/float_equal.hpp
+0
-3
No files found.
src/api/api.cpp
View file @
13d14c66
...
...
@@ -38,26 +38,32 @@
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <array>
#include <algorithm>
#include <cstdarg>
namespace
migraphx
{
#ifdef MIGRAPHX_BUILD_TESTING
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
disable_exception_catch
=
b
;
}
#endif
template
<
class
F
>
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
{
#ifdef MIGRAPHX_BUILD_TESTING
if
(
disable_exception_catch
)
{
f
();
}
else
{
#endif
try
{
f
();
...
...
@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{
return
migraphx_status_unknown_error
;
}
#ifdef MIGRAPHX_BUILD_TESTING
}
#endif
return
migraphx_status_success
;
}
...
...
src/api/include/migraphx/migraphx.h
View file @
13d14c66
...
...
@@ -26,6 +26,7 @@
#include <stdlib.h>
#include <stdbool.h>
#include <stdint.h>
#include <migraphx/api/export.h>
...
...
src/api/include/migraphx/migraphx.hpp
View file @
13d14c66
...
...
@@ -66,7 +66,7 @@ template <class PrivateMigraphTypeNameProbe>
std
::
string
compute_type_name
()
{
std
::
string
name
;
#ifdef
_MSC_VER
#if
def
ined(
_MSC_VER
) && !defined(__clang__)
name
=
typeid
(
PrivateMigraphTypeNameProbe
).
name
();
name
=
name
.
substr
(
7
);
#else
...
...
src/auto_contiguous.cpp
View file @
13d14c66
...
...
@@ -25,7 +25,6 @@
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
namespace
migraphx
{
...
...
src/compile_src.cpp
View file @
13d14c66
...
...
@@ -46,7 +46,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
fs
::
path
full_path
=
td
.
path
/
src
.
path
;
fs
::
path
parent_path
=
full_path
.
parent_path
();
fs
::
create_directories
(
parent_path
);
write_buffer
(
full_path
.
string
(),
src
.
content
.
first
,
src
.
len
());
write_buffer
(
full_path
.
string
(),
src
.
content
.
data
(),
src
.
content
.
size
());
if
(
src
.
path
.
extension
().
string
()
==
".cpp"
)
{
params
+=
" "
+
src
.
path
.
filename
().
string
();
...
...
src/cpp_generator.cpp
View file @
13d14c66
...
...
@@ -213,13 +213,13 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
assert
(
v
.
size
()
==
1
);
auto
x
=
v
.
front
();
if
(
std
::
isinf
(
x
))
if
(
std
::
isinf
(
static_cast
<
double
>
(
x
)
))
{
string_literal
=
"__builtin_huge_val()"
;
if
(
x
<
0
)
string_literal
=
"-__builtin_huge_val()"
;
}
else
if
(
std
::
isnan
(
x
))
else
if
(
std
::
isnan
(
static_cast
<
double
>
(
x
)
))
string_literal
=
"__builtin_nan()"
;
else
string_literal
=
ins
->
get_literal
().
to_string
();
...
...
src/driver/CMakeLists.txt
View file @
13d14c66
...
...
@@ -45,7 +45,15 @@ if(NOT WIN32)
endif
()
rocm_clang_tidy_check
(
driver
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py
)
file
(
STRINGS
"
${
CMAKE_SOURCE_DIR
}
/test/onnx/.onnxrt-commit"
String_output
)
target_compile_definitions
(
driver PUBLIC MIGRAPHX_ORT_SHA1=
"
${
String_output
}
"
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
target_link_libraries
(
driver migraphx_py
)
target_compile_definitions
(
driver PRIVATE MIGRAPHX_ENABLE_PYTHON
)
endif
()
rocm_install_targets
(
TARGETS driver
...
...
src/driver/argument_parser.hpp
View file @
13d14c66
...
...
@@ -187,6 +187,13 @@ struct value_parser
}
};
// version for std::optional object
template
<
class
T
>
struct
value_parser
<
std
::
optional
<
T
>>
{
static
T
apply
(
const
std
::
string
&
x
)
{
return
value_parser
<
T
>::
apply
(
x
);
}
};
struct
argument_parser
{
struct
argument
...
...
src/driver/main.cpp
View file @
13d14c66
...
...
@@ -32,7 +32,9 @@
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#ifdef MIGRAPHX_ENABLE_PYTHON
#include <migraphx/py.hpp>
#endif
#include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp>
...
...
@@ -281,10 +283,12 @@ struct loader
options
.
format
=
"json"
;
p
=
migraphx
::
load
(
file
,
options
);
}
#ifdef MIGRAPHX_ENABLE_PYTHON
else
if
(
file_type
==
"py"
)
{
p
=
migraphx
::
load_py
(
file
);
}
#endif
else
if
(
file_type
==
"migraphx"
)
{
p
=
migraphx
::
load
(
file
);
...
...
@@ -475,13 +479,15 @@ struct compiler
{
if
(
is_offload_copy_set
(
p
)
and
not
co
.
offload_copy
)
{
std
::
cout
<<
"MIGraphX program was likely compiled with offload_copy set, Try "
"passing "
"`--enable-offload-copy` if program run fails.
\n
"
;
std
::
cout
<<
"[WARNING]: MIGraphX program was likely compiled with offload_copy "
"set, Try "
"passing "
"`--enable-offload-copy` if program run fails.
\n
"
;
}
else
if
(
co
.
offload_copy
)
{
std
::
cout
<<
"MIGraphX program was likely compiled without "
std
::
cout
<<
"
[WARNING]:
MIGraphX program was likely compiled without "
"offload_copy set, Try "
"removing "
"`--enable-offload-copy` flag if passed to driver, if program run "
...
...
@@ -534,13 +540,17 @@ struct params : command<params>
struct
verify
:
command
<
verify
>
{
compiler
c
;
double
tolerance
=
80
;
std
::
optional
<
double
>
rms_tol
;
std
::
optional
<
double
>
atol
;
std
::
optional
<
double
>
rtol
;
bool
per_instruction
=
false
;
bool
reduce
=
false
;
void
parse
(
argument_parser
&
ap
)
{
c
.
parse
(
ap
);
ap
(
tolerance
,
{
"--tolerance"
},
ap
.
help
(
"Tolerance for errors"
));
ap
(
rms_tol
,
{
"--rms-tol"
},
ap
.
help
(
"Tolerance for the RMS error"
));
ap
(
atol
,
{
"--atol"
},
ap
.
help
(
"Tolerance for the elementwise absolute difference"
));
ap
(
rtol
,
{
"--rtol"
},
ap
.
help
(
"Tolerance for the elementwise relative difference"
));
ap
(
per_instruction
,
{
"-i"
,
"--per-instruction"
},
ap
.
help
(
"Verify each instruction"
),
...
...
@@ -559,21 +569,30 @@ struct verify : command<verify>
auto
quantize
=
precision
::
fp32
;
if
(
c
.
to_fp16
)
{
quantize
=
precision
::
fp16
;
}
if
(
c
.
to_int8
)
{
quantize
=
precision
::
int8
;
}
auto
tols
=
get_tolerances
(
p
,
quantize
,
rms_tol
,
atol
,
rtol
);
std
::
cout
<<
"rms_tol: "
<<
tols
.
rms_tol
<<
std
::
endl
;
std
::
cout
<<
"atol: "
<<
tols
.
atol
<<
std
::
endl
;
std
::
cout
<<
"rtol: "
<<
tols
.
rtol
<<
std
::
endl
;
if
(
per_instruction
)
{
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tol
erance
);
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tol
s
);
}
else
if
(
reduce
)
{
verify_reduced_program
(
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
erance
);
verify_reduced_program
(
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
s
);
}
else
{
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
erance
);
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
s
);
}
}
};
...
...
@@ -802,6 +821,13 @@ int main(int argc, const char* argv[])
auto
&&
m
=
get_commands
();
auto
cmd
=
args
.
front
();
if
(
cmd
==
"ort-sha"
)
{
std
::
cout
<<
MIGRAPHX_ORT_SHA1
<<
std
::
endl
;
return
0
;
}
if
(
m
.
count
(
cmd
)
>
0
)
{
m
.
at
(
cmd
)(
argv
[
0
],
{
args
.
begin
()
+
1
,
args
.
end
()});
...
...
src/driver/verify.cpp
View file @
13d14c66
...
...
@@ -30,11 +30,48 @@
#include <migraphx/instruction.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults.
* Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the
* model.
*/
verify
::
tolerance
get_tolerances
(
const
program
&
p
,
precision
quantize
,
std
::
optional
<
double
>
rms_tol
,
std
::
optional
<
double
>
atol
,
std
::
optional
<
double
>
rtol
)
{
bool
has_fp16
=
any_of
(
p
.
get_modules
(),
[](
auto
&&
m
)
{
return
any_of
(
*
m
,
[](
auto
&&
ins
)
{
return
(
ins
.
get_shape
().
type
()
==
shape
::
half_type
);
});
});
migraphx
::
verify
::
tolerance
result
{};
if
(
has_fp16
or
quantize
==
precision
::
fp16
)
{
result
.
rms_tol
=
8e-2
;
result
.
atol
=
4e-2
;
result
.
rtol
=
4e-2
;
}
if
(
rms_tol
)
{
result
.
rms_tol
=
*
rms_tol
;
}
if
(
atol
)
{
result
.
atol
=
*
atol
;
}
if
(
rtol
)
{
result
.
rtol
=
*
rtol
;
}
return
result
;
}
std
::
vector
<
argument
>
run_ref
(
program
p
,
const
parameter_map
&
inputs
)
{
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
...
...
@@ -76,15 +113,25 @@ void verify_program(const std::string& name,
compile_options
options
,
precision
quantize
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
auto
x
=
run_ref
(
p
,
inputs
);
auto
y
=
run_target
(
p
,
t
,
options
,
quantize
,
inputs
);
auto
ref_outs
=
run_ref
(
p
,
inputs
);
auto
target_outs
=
run_target
(
p
,
t
,
options
,
quantize
,
inputs
);
std
::
size_t
output_num
=
x
.
size
();
std
::
size_t
output_num
=
ref_outs
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
output_num
;
++
i
)
{
verify_args
(
name
,
x
[
i
],
y
[
i
],
tolerance
);
if
(
ref_outs
[
i
].
get_shape
().
type
()
!=
target_outs
[
i
].
get_shape
().
type
()
or
ref_outs
[
i
].
get_shape
().
lens
()
!=
target_outs
[
i
].
get_shape
().
lens
())
{
std
::
cout
<<
"FAILED: "
<<
name
<<
std
::
endl
;
std
::
cout
<<
"Shape mismatch {"
<<
ref_outs
[
i
].
get_shape
()
<<
"} != {"
<<
target_outs
[
i
].
get_shape
()
<<
"}"
<<
std
::
endl
;
}
else
{
verify_args
(
name
,
target_outs
[
i
],
verify
::
expected
{
ref_outs
[
i
]},
tols
);
}
}
}
...
...
@@ -92,7 +139,7 @@ void verify_instructions(const program& prog,
const
target
&
t
,
compile_options
options
,
precision
quantize
,
double
tolerance
)
verify
::
tolerance
tols
)
{
const
auto
*
mm_prog
=
prog
.
get_main_module
();
for
(
auto
&&
ins
:
(
*
mm_prog
))
...
...
@@ -123,8 +170,7 @@ void verify_instructions(const program& prog,
{
std
::
cout
<<
"Verify: "
<<
ins
.
name
()
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
verify_program
(
ins
.
name
(),
p
,
t
,
options
,
quantize
,
create_param_map
(
p
,
false
),
tolerance
);
verify_program
(
ins
.
name
(),
p
,
t
,
options
,
quantize
,
create_param_map
(
p
,
false
),
tols
);
}
catch
(...)
{
...
...
@@ -140,14 +186,22 @@ void verify_reduced(program p,
compile_options
options
,
precision
quantize
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
auto
*
mm
=
p
.
get_main_module
();
auto
last
=
std
::
prev
(
mm
->
end
(),
n
+
1
);
auto
last
=
std
::
prev
(
mm
->
end
(),
n
);
mm
->
remove_instructions
(
last
,
mm
->
end
());
std
::
cout
<<
"Verify: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
try
{
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tols
);
}
catch
(
const
std
::
exception
&
e
)
{
std
::
cout
<<
"FAILED: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
"Exception: "
<<
e
.
what
()
<<
std
::
endl
;
}
}
void
verify_reduced_program
(
const
program
&
p
,
...
...
@@ -155,14 +209,20 @@ void verify_reduced_program(const program& p,
compile_options
options
,
precision
quantize
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
const
auto
*
mm
=
p
.
get_main_module
();
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
std
::
cout
<<
"Verify steps: "
<<
n
<<
std
::
endl
;
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
n
;
i
++
)
{
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
auto
last
=
std
::
prev
(
mm
->
end
(),
i
+
1
);
if
(
contains
({
"@literal"
,
"@param"
},
last
->
name
()))
{
std
::
cout
<<
"Skip: "
<<
i
<<
std
::
endl
;
continue
;
}
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tols
);
}
}
...
...
src/driver/verify.hpp
View file @
13d14c66
...
...
@@ -26,29 +26,36 @@
#include "precision.hpp"
#include <migraphx/program.hpp>
#include <migraphx/verify.hpp>
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
verify
::
tolerance
get_tolerances
(
const
program
&
p
,
precision
quantize
,
std
::
optional
<
double
>
rms_tol
,
std
::
optional
<
double
>
atol
,
std
::
optional
<
double
>
rtol
);
void
verify_program
(
const
std
::
string
&
name
,
const
program
&
p
,
const
target
&
t
,
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
const
parameter_map
&
inputs
=
{},
double
tolerance
=
100
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
void
verify_instructions
(
const
program
&
prog
,
const
target
&
t
,
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
double
tolerance
=
80
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
void
verify_reduced_program
(
const
program
&
p
,
const
target
&
t
,
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
const
parameter_map
&
inputs
=
{},
double
tolerance
=
80
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
...
...
src/dynamic_loader.cpp
View file @
13d14c66
...
...
@@ -27,11 +27,20 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
#include <utility>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#else
#include <dlfcn.h>
#endif
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#ifndef _WIN32
void
check_load_error
(
bool
flush
=
false
)
{
char
*
error_msg
=
dlerror
();
...
...
@@ -81,6 +90,48 @@ fs::path dynamic_loader::path(void* address)
return
p
;
}
#else
struct
dynamic_loader_impl
{
dynamic_loader_impl
()
=
default
;
dynamic_loader_impl
(
const
fs
::
path
&
p
,
tmp_dir
t
=
{})
:
handle
{
LoadLibrary
(
p
.
string
().
c_str
())},
temp
{
std
::
move
(
t
)}
{
if
(
handle
==
nullptr
)
{
MIGRAPHX_THROW
(
"Error loading DLL: "
+
p
.
string
()
+
" ("
+
std
::
to_string
(
GetLastError
())
+
")"
);
}
}
dynamic_loader_impl
(
const
dynamic_loader_impl
&
)
=
delete
;
dynamic_loader_impl
&
operator
=
(
const
dynamic_loader_impl
&
)
=
delete
;
dynamic_loader_impl
(
dynamic_loader_impl
&&
)
=
default
;
~
dynamic_loader_impl
()
{
if
(
handle
!=
nullptr
)
{
FreeLibrary
(
handle
);
}
}
static
std
::
shared_ptr
<
dynamic_loader_impl
>
from_buffer
(
const
char
*
image
,
std
::
size_t
size
)
{
auto
t
=
tmp_dir
{
"migx-dynload"
};
auto
f
=
t
.
path
/
"tmp.dll"
;
write_buffer
(
f
.
string
(),
image
,
size
);
return
std
::
make_shared
<
dynamic_loader_impl
>
(
f
,
std
::
move
(
t
));
}
HMODULE
handle
=
nullptr
;
tmp_dir
temp
;
};
#endif
optional
<
dynamic_loader
>
dynamic_loader
::
try_load
(
const
fs
::
path
&
p
)
{
try
...
...
@@ -109,12 +160,19 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std
::
shared_ptr
<
void
>
dynamic_loader
::
get_symbol
(
const
std
::
string
&
name
)
const
{
#ifndef _WIN32
// flush any previous error messages
check_load_error
(
true
);
void
*
symbol
=
dlsym
(
impl
->
handle
.
get
(),
name
.
c_str
());
if
(
symbol
==
nullptr
)
check_load_error
();
return
{
impl
,
symbol
};
#else
FARPROC
addr
=
GetProcAddress
(
impl
->
handle
,
name
.
c_str
());
if
(
addr
==
nullptr
)
MIGRAPHX_THROW
(
"Symbol not found: "
+
name
+
" ("
+
std
::
to_string
(
GetLastError
())
+
")"
);
return
{
impl
,
reinterpret_cast
<
void
*>
(
addr
)};
#endif
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/argument.hpp
View file @
13d14c66
...
...
@@ -46,7 +46,7 @@ struct MIGRAPHX_EXPORT argument : raw_data<argument>
{
argument
()
=
default
;
argument
(
const
shape
&
s
);
explicit
argument
(
const
shape
&
s
);
template
<
class
F
,
MIGRAPHX_REQUIRES
(
std
::
is_pointer
<
decltype
(
std
::
declval
<
F
>()())
>
{})
>
argument
(
shape
s
,
F
d
)
...
...
src/include/migraphx/auto_register.hpp
View file @
13d14c66
...
...
@@ -62,10 +62,9 @@ const int auto_register<Action, T>::static_register = auto_register_action<Actio
#define MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x) migraphx_auto_register_##x
#define MIGRAPHX_AUTO_REGISTER_NAME(x) MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x)
// NOLINTNEXTLINE
#define MIGRAPHX_AUTO_REGISTER(...) \
void MIGRAPHX_AUTO_REGISTER_NAME(__LINE__)(migraphx::auto_register<__VA_ARGS__> x = \
migraphx::auto_register<__VA_ARGS__>{}) \
__attribute__((unused));
#define MIGRAPHX_AUTO_REGISTER(...) \
[[maybe_unused]] void MIGRAPHX_AUTO_REGISTER_NAME(__LINE__)( \
migraphx::auto_register<__VA_ARGS__> x = migraphx::auto_register<__VA_ARGS__>{});
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/check_shapes.hpp
View file @
13d14c66
...
...
@@ -70,13 +70,19 @@ struct check_shapes
check_dynamic
();
}
template
<
class
Op
>
template
<
class
Op
,
MIGRAPHX_REQUIRES
(
not
std
::
is_convertible
<
Op
,
std
::
string
>{})
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
n
),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
void
check_dynamic
()
const
{
if
(
not
dynamic_allowed
and
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
...
...
@@ -147,7 +153,7 @@ struct check_shapes
{
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
!=
n
)
if
(
begin
->
ndim
()
!=
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
}
return
*
this
;
...
...
@@ -162,7 +168,7 @@ struct check_shapes
{
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
>
n
)
if
(
begin
->
ndim
()
>
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at most "
+
std
::
to_string
(
n
)
+
" dimensions"
);
}
...
...
@@ -178,7 +184,7 @@ struct check_shapes
{
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
<
n
)
if
(
begin
->
ndim
()
<
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at least "
+
std
::
to_string
(
n
)
+
" dimensions"
);
}
...
...
@@ -228,6 +234,16 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes have the same layout.
*/
const
check_shapes
&
same_layout
()
const
{
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
find_permutation
(
s
);
}))
MIGRAPHX_THROW
(
prefix
()
+
"Layouts do not match"
);
return
*
this
;
}
/*!
* Check all shapes are standard.
*/
...
...
@@ -238,6 +254,16 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes are scalar.
*/
const
check_shapes
&
scalar
()
const
{
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar"
);
return
*
this
;
}
/*!
* Check all shapes are standard or scalar.
*/
...
...
src/include/migraphx/compile_src.hpp
View file @
13d14c66
...
...
@@ -37,8 +37,18 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
src_file
{
fs
::
path
path
;
std
::
pair
<
const
char
*
,
const
char
*>
content
;
std
::
size_t
len
()
const
{
return
content
.
second
-
content
.
first
;
}
std
::
string_view
content
;
src_file
()
=
default
;
src_file
(
fs
::
path
file_path
,
std
::
string_view
file_content
)
:
path
{
std
::
move
(
file_path
)},
content
{
file_content
}
{
}
explicit
src_file
(
const
std
::
pair
<
std
::
string_view
,
std
::
string_view
>&
pair
)
:
path
{
pair
.
first
},
content
{
pair
.
second
}
{
}
};
struct
MIGRAPHX_EXPORT
src_compiler
...
...
src/include/migraphx/config.hpp
View file @
13d14c66
...
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_CONFIG_HPP
#include <migraphx/export.h>
#include <ciso646>
#if !defined(MIGRAPHX_USE_CLANG_TIDY) && !defined(DOXYGEN)
...
...
src/include/migraphx/dynamic_loader.hpp
View file @
13d14c66
...
...
@@ -38,12 +38,14 @@ struct dynamic_loader_impl;
struct
MIGRAPHX_EXPORT
dynamic_loader
{
#ifndef _WIN32
template
<
class
T
>
static
fs
::
path
path
(
T
*
address
)
{
return
path
(
reinterpret_cast
<
void
*>
(
address
));
}
static
fs
::
path
path
(
void
*
address
);
#endif
static
optional
<
dynamic_loader
>
try_load
(
const
fs
::
path
&
p
);
...
...
src/include/migraphx/filesystem.hpp
View file @
13d14c66
...
...
@@ -29,6 +29,17 @@
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(_WIN32)
#if _MSC_VER >= 1920
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
#elif _MSC_VER >= 1900
#define MIGRAPHX_HAS_FILESYSTEM 0
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#else
#define MIGRAPHX_HAS_FILESYSTEM 0
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
#endif
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1
...
...
src/include/migraphx/float_equal.hpp
View file @
13d14c66
...
...
@@ -27,9 +27,6 @@
#include <algorithm>
#include <cmath>
#include <numeric>
#ifdef _MSC_VER
#include <iso646.h>
#endif
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
...
...
Prev
1
2
3
4
5
6
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment