Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8d32c6b8
Commit
8d32c6b8
authored
Oct 17, 2023
by
Paul
Browse files
Merge branch 'develop' into blas_tuning
parents
23cb7917
f25606f9
Changes
386
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
522 additions
and
45 deletions
+522
-45
src/CMakeLists.txt
src/CMakeLists.txt
+10
-1
src/api/CMakeLists.txt
src/api/CMakeLists.txt
+4
-0
src/api/api.cpp
src/api/api.cpp
+8
-0
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+1
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+1
-1
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+0
-1
src/common_dims.cpp
src/common_dims.cpp
+156
-0
src/compile_src.cpp
src/compile_src.cpp
+1
-1
src/cpp_generator.cpp
src/cpp_generator.cpp
+2
-2
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+9
-1
src/driver/main.cpp
src/driver/main.cpp
+62
-9
src/driver/verify.cpp
src/driver/verify.cpp
+38
-14
src/driver/verify.hpp
src/driver/verify.hpp
+4
-3
src/dynamic_loader.cpp
src/dynamic_loader.cpp
+58
-0
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+20
-1
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+53
-0
src/include/migraphx/auto_register.hpp
src/include/migraphx/auto_register.hpp
+3
-4
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+31
-5
src/include/migraphx/common_dims.hpp
src/include/migraphx/common_dims.hpp
+49
-0
src/include/migraphx/compile_src.hpp
src/include/migraphx/compile_src.hpp
+12
-2
No files found.
src/CMakeLists.txt
View file @
8d32c6b8
#####################################################################################
#####################################################################################
# The MIT License (MIT)
# The MIT License (MIT)
#
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# of this software and associated documentation files (the "Software"), to deal
...
@@ -29,6 +29,7 @@ include(ROCMPackageConfigHelpers)
...
@@ -29,6 +29,7 @@ include(ROCMPackageConfigHelpers)
include
(
RegisterOp
)
include
(
RegisterOp
)
include
(
CheckCXXLinkerFlag
)
include
(
CheckCXXLinkerFlag
)
add_library
(
migraphx
add_library
(
migraphx
adjust_allocation.cpp
adjust_allocation.cpp
analyze_streams.cpp
analyze_streams.cpp
...
@@ -36,6 +37,7 @@ add_library(migraphx
...
@@ -36,6 +37,7 @@ add_library(migraphx
argument.cpp
argument.cpp
auto_contiguous.cpp
auto_contiguous.cpp
common.cpp
common.cpp
common_dims.cpp
compile_src.cpp
compile_src.cpp
convert_to_json.cpp
convert_to_json.cpp
cpp_generator.cpp
cpp_generator.cpp
...
@@ -94,6 +96,7 @@ add_library(migraphx
...
@@ -94,6 +96,7 @@ add_library(migraphx
serialize.cpp
serialize.cpp
shape.cpp
shape.cpp
simplify_algebra.cpp
simplify_algebra.cpp
simplify_dyn_ops.cpp
simplify_reshapes.cpp
simplify_reshapes.cpp
split_single_dyn_dim.cpp
split_single_dyn_dim.cpp
target.cpp
target.cpp
...
@@ -140,6 +143,7 @@ register_migraphx_ops(
...
@@ -140,6 +143,7 @@ register_migraphx_ops(
equal
equal
erf
erf
exp
exp
fill
flatten
flatten
floor
floor
fmod
fmod
...
@@ -183,6 +187,8 @@ register_migraphx_ops(
...
@@ -183,6 +187,8 @@ register_migraphx_ops(
quant_convolution
quant_convolution
quant_dot
quant_dot
quantizelinear
quantizelinear
random_uniform
random_seed
recip
recip
reduce_max
reduce_max
reduce_mean
reduce_mean
...
@@ -191,6 +197,7 @@ register_migraphx_ops(
...
@@ -191,6 +197,7 @@ register_migraphx_ops(
reduce_sum
reduce_sum
relu
relu
reshape
reshape
reshape_lazy
reverse
reverse
rnn
rnn
rnn_last_cell_output
rnn_last_cell_output
...
@@ -275,7 +282,9 @@ add_subdirectory(driver)
...
@@ -275,7 +282,9 @@ add_subdirectory(driver)
add_subdirectory
(
onnx
)
add_subdirectory
(
onnx
)
add_subdirectory
(
tf
)
add_subdirectory
(
tf
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
add_subdirectory
(
py
)
add_subdirectory
(
py
)
endif
()
add_subdirectory
(
targets/ref
)
add_subdirectory
(
targets/ref
)
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_ref
)
target_link_libraries
(
migraphx_all_targets INTERFACE migraphx_ref
)
if
(
MIGRAPHX_ENABLE_CPU
)
if
(
MIGRAPHX_ENABLE_CPU
)
...
...
src/api/CMakeLists.txt
View file @
8d32c6b8
...
@@ -32,6 +32,10 @@ migraphx_generate_export_header(migraphx_c DIRECTORY migraphx/api)
...
@@ -32,6 +32,10 @@ migraphx_generate_export_header(migraphx_c DIRECTORY migraphx/api)
# bumped when binary compatibility is broken.
# bumped when binary compatibility is broken.
rocm_set_soversion
(
migraphx_c 3.0
)
rocm_set_soversion
(
migraphx_c 3.0
)
if
(
BUILD_TESTING
)
target_compile_definitions
(
migraphx_c PRIVATE MIGRAPHX_BUILD_TESTING
)
endif
()
rocm_clang_tidy_check
(
migraphx_c
)
rocm_clang_tidy_check
(
migraphx_c
)
target_link_libraries
(
migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx
)
target_link_libraries
(
migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx
)
...
...
src/api/api.cpp
View file @
8d32c6b8
...
@@ -38,26 +38,32 @@
...
@@ -38,26 +38,32 @@
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <array>
#include <algorithm>
#include <algorithm>
#include <cstdarg>
#include <cstdarg>
namespace
migraphx
{
namespace
migraphx
{
#ifdef MIGRAPHX_BUILD_TESTING
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
static
thread_local
bool
disable_exception_catch
=
false
;
// NOLINT
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
extern
"C"
MIGRAPHX_C_EXPORT
void
migraphx_test_private_disable_exception_catch
(
bool
b
)
{
{
disable_exception_catch
=
b
;
disable_exception_catch
=
b
;
}
}
#endif
template
<
class
F
>
template
<
class
F
>
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
migraphx_status
try_
(
F
f
,
bool
output
=
true
)
// NOLINT
{
{
#ifdef MIGRAPHX_BUILD_TESTING
if
(
disable_exception_catch
)
if
(
disable_exception_catch
)
{
{
f
();
f
();
}
}
else
else
{
{
#endif
try
try
{
{
f
();
f
();
...
@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
...
@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{
{
return
migraphx_status_unknown_error
;
return
migraphx_status_unknown_error
;
}
}
#ifdef MIGRAPHX_BUILD_TESTING
}
}
#endif
return
migraphx_status_success
;
return
migraphx_status_success
;
}
}
...
...
src/api/include/migraphx/migraphx.h
View file @
8d32c6b8
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <stdlib.h>
#include <stdlib.h>
#include <stdbool.h>
#include <stdbool.h>
#include <stdint.h>
#include <migraphx/api/export.h>
#include <migraphx/api/export.h>
...
...
src/api/include/migraphx/migraphx.hpp
View file @
8d32c6b8
...
@@ -66,7 +66,7 @@ template <class PrivateMigraphTypeNameProbe>
...
@@ -66,7 +66,7 @@ template <class PrivateMigraphTypeNameProbe>
std
::
string
compute_type_name
()
std
::
string
compute_type_name
()
{
{
std
::
string
name
;
std
::
string
name
;
#ifdef
_MSC_VER
#if
def
ined(
_MSC_VER
) && !defined(__clang__)
name
=
typeid
(
PrivateMigraphTypeNameProbe
).
name
();
name
=
typeid
(
PrivateMigraphTypeNameProbe
).
name
();
name
=
name
.
substr
(
7
);
name
=
name
.
substr
(
7
);
#else
#else
...
...
src/auto_contiguous.cpp
View file @
8d32c6b8
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/common_dims.cpp
0 → 100644
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
>
static
auto
compute_end_dim
(
Iterator
start
,
Iterator
last
,
std
::
size_t
dim
)
{
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>
dim
;
});
if
(
x
<
dim
)
return
start
;
return
it
;
}
template
<
class
Range
>
static
auto
elements
(
const
Range
&
r
)
{
return
std
::
accumulate
(
r
.
begin
(),
r
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
}
struct
common_dim_state
{
common_dim_state
(
const
std
::
vector
<
std
::
size_t
>&
pdims
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>&
paxes_map
)
:
dims
(
&
pdims
),
axes_map
(
&
paxes_map
),
it
(
dims
->
begin
())
{
}
const
std
::
vector
<
std
::
size_t
>*
dims
=
nullptr
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>*
axes_map
=
nullptr
;
std
::
vector
<
std
::
size_t
>::
const_iterator
it
{};
std
::
size_t
rem
=
1
;
std
::
size_t
get
()
const
{
return
*
it
/
rem
;
}
bool
is_end
()
const
{
return
it
==
dims
->
end
();
}
void
next
(
std
::
size_t
i
=
1
)
{
it
+=
i
;
}
auto
dims_for
(
std
::
size_t
d
)
const
{
auto
dim_end
=
compute_end_dim
(
it
,
dims
->
end
(),
d
);
return
range
(
it
,
dim_end
);
}
void
add_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
MIGRAPHX_TIDY_CONST
{
auto
axes
=
compute_axes
(
naxes
,
start
);
axes_map
->
push_back
(
std
::
move
(
axes
));
}
void
add_multi_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
MIGRAPHX_TIDY_CONST
{
auto
axes
=
compute_axes
(
naxes
,
start
);
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
*
axes_map
),
[
&
](
auto
axis
)
->
std
::
vector
<
std
::
size_t
>
{
return
{
axis
};
});
}
std
::
vector
<
std
::
size_t
>
compute_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
const
{
if
(
rem
!=
1
)
{
assert
(
start
>
0
);
naxes
++
;
start
--
;
}
std
::
vector
<
std
::
size_t
>
axes
(
naxes
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
start
);
return
axes
;
}
};
static
bool
compute_common_dim
(
std
::
vector
<
std
::
size_t
>&
cd_dims
,
common_dim_state
&
state1
,
common_dim_state
&
state2
)
{
assert
(
state1
.
get
()
<=
state2
.
get
());
auto
d2
=
state2
.
get
();
auto
dims
=
state1
.
dims_for
(
d2
);
auto
n
=
elements
(
dims
);
auto
naxes
=
distance
(
dims
);
if
(
naxes
==
0
)
return
false
;
// If not divisible then we can't compute a common dim
if
((
d2
%
n
)
!=
0
)
return
false
;
auto
rem
=
d2
/
n
;
state1
.
add_multi_axes
(
naxes
,
cd_dims
.
size
());
state2
.
add_axes
(
rem
==
1
?
naxes
:
naxes
+
1
,
cd_dims
.
size
());
state1
.
rem
=
rem
;
state2
.
rem
=
1
;
cd_dims
.
insert
(
cd_dims
.
end
(),
dims
.
begin
(),
dims
.
end
());
if
(
state1
.
rem
!=
1
)
cd_dims
.
push_back
(
state1
.
rem
);
state1
.
next
(
distance
(
dims
));
state2
.
next
();
return
true
;
}
common_dims
common_dims
::
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
)
{
assert
(
elements
(
dims1
)
>
0
);
assert
(
elements
(
dims1
)
==
elements
(
dims2
));
common_dims
cd
;
common_dim_state
state1
{
dims1
,
cd
.
axes_map1
};
common_dim_state
state2
{
dims2
,
cd
.
axes_map2
};
while
(
not
state1
.
is_end
()
and
not
state2
.
is_end
())
{
auto
d1
=
state1
.
get
();
auto
d2
=
state2
.
get
();
if
(
d1
<=
d2
)
{
if
(
not
compute_common_dim
(
cd
.
dims
,
state1
,
state2
))
return
{};
}
else
// if(d1 > d2)
{
if
(
not
compute_common_dim
(
cd
.
dims
,
state2
,
state1
))
return
{};
}
}
assert
(
elements
(
dims1
)
==
elements
(
cd
.
dims
));
return
cd
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/compile_src.cpp
View file @
8d32c6b8
...
@@ -46,7 +46,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
...
@@ -46,7 +46,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
fs
::
path
full_path
=
td
.
path
/
src
.
path
;
fs
::
path
full_path
=
td
.
path
/
src
.
path
;
fs
::
path
parent_path
=
full_path
.
parent_path
();
fs
::
path
parent_path
=
full_path
.
parent_path
();
fs
::
create_directories
(
parent_path
);
fs
::
create_directories
(
parent_path
);
write_buffer
(
full_path
.
string
(),
src
.
content
.
first
,
src
.
len
());
write_buffer
(
full_path
.
string
(),
src
.
content
.
data
(),
src
.
content
.
size
());
if
(
src
.
path
.
extension
().
string
()
==
".cpp"
)
if
(
src
.
path
.
extension
().
string
()
==
".cpp"
)
{
{
params
+=
" "
+
src
.
path
.
filename
().
string
();
params
+=
" "
+
src
.
path
.
filename
().
string
();
...
...
src/cpp_generator.cpp
View file @
8d32c6b8
...
@@ -213,13 +213,13 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
...
@@ -213,13 +213,13 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
assert
(
v
.
size
()
==
1
);
assert
(
v
.
size
()
==
1
);
auto
x
=
v
.
front
();
auto
x
=
v
.
front
();
if
(
std
::
isinf
(
x
))
if
(
std
::
isinf
(
static_cast
<
double
>
(
x
)
))
{
{
string_literal
=
"__builtin_huge_val()"
;
string_literal
=
"__builtin_huge_val()"
;
if
(
x
<
0
)
if
(
x
<
0
)
string_literal
=
"-__builtin_huge_val()"
;
string_literal
=
"-__builtin_huge_val()"
;
}
}
else
if
(
std
::
isnan
(
x
))
else
if
(
std
::
isnan
(
static_cast
<
double
>
(
x
)
))
string_literal
=
"__builtin_nan()"
;
string_literal
=
"__builtin_nan()"
;
else
else
string_literal
=
ins
->
get_literal
().
to_string
();
string_literal
=
ins
->
get_literal
().
to_string
();
...
...
src/driver/CMakeLists.txt
View file @
8d32c6b8
...
@@ -45,7 +45,15 @@ if(NOT WIN32)
...
@@ -45,7 +45,15 @@ if(NOT WIN32)
endif
()
endif
()
rocm_clang_tidy_check
(
driver
)
rocm_clang_tidy_check
(
driver
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py
)
file
(
STRINGS
"
${
CMAKE_SOURCE_DIR
}
/test/onnx/.onnxrt-commit"
String_output
)
target_compile_definitions
(
driver PUBLIC MIGRAPHX_ORT_SHA1=
"
${
String_output
}
"
)
target_link_libraries
(
driver migraphx_all_targets migraphx_onnx migraphx_tf
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
target_link_libraries
(
driver migraphx_py
)
target_compile_definitions
(
driver PRIVATE MIGRAPHX_ENABLE_PYTHON
)
endif
()
rocm_install_targets
(
rocm_install_targets
(
TARGETS driver
TARGETS driver
...
...
src/driver/main.cpp
View file @
8d32c6b8
...
@@ -32,7 +32,9 @@
...
@@ -32,7 +32,9 @@
#include <migraphx/tf.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
#ifdef MIGRAPHX_ENABLE_PYTHON
#include <migraphx/py.hpp>
#include <migraphx/py.hpp>
#endif
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/load_save.hpp>
...
@@ -281,10 +283,12 @@ struct loader
...
@@ -281,10 +283,12 @@ struct loader
options
.
format
=
"json"
;
options
.
format
=
"json"
;
p
=
migraphx
::
load
(
file
,
options
);
p
=
migraphx
::
load
(
file
,
options
);
}
}
#ifdef MIGRAPHX_ENABLE_PYTHON
else
if
(
file_type
==
"py"
)
else
if
(
file_type
==
"py"
)
{
{
p
=
migraphx
::
load_py
(
file
);
p
=
migraphx
::
load_py
(
file
);
}
}
#endif
else
if
(
file_type
==
"migraphx"
)
else
if
(
file_type
==
"migraphx"
)
{
{
p
=
migraphx
::
load
(
file
);
p
=
migraphx
::
load
(
file
);
...
@@ -475,13 +479,15 @@ struct compiler
...
@@ -475,13 +479,15 @@ struct compiler
{
{
if
(
is_offload_copy_set
(
p
)
and
not
co
.
offload_copy
)
if
(
is_offload_copy_set
(
p
)
and
not
co
.
offload_copy
)
{
{
std
::
cout
<<
"MIGraphX program was likely compiled with offload_copy set, Try "
std
::
cout
<<
"[WARNING]: MIGraphX program was likely compiled with offload_copy "
"set, Try "
"passing "
"passing "
"`--enable-offload-copy` if program run fails.
\n
"
;
"`--enable-offload-copy` if program run fails.
\n
"
;
}
}
else
if
(
co
.
offload_copy
)
else
if
(
co
.
offload_copy
)
{
{
std
::
cout
<<
"MIGraphX program was likely compiled without "
std
::
cout
<<
"
[WARNING]:
MIGraphX program was likely compiled without "
"offload_copy set, Try "
"offload_copy set, Try "
"removing "
"removing "
"`--enable-offload-copy` flag if passed to driver, if program run "
"`--enable-offload-copy` flag if passed to driver, if program run "
...
@@ -534,13 +540,22 @@ struct params : command<params>
...
@@ -534,13 +540,22 @@ struct params : command<params>
struct
verify
:
command
<
verify
>
struct
verify
:
command
<
verify
>
{
{
compiler
c
;
compiler
c
;
double
tolerance
=
80
;
// Set to -1. as nonsense initial value
double
rms_tol
=
-
1.0
;
double
atol
=
-
1.0
;
double
rtol
=
-
1.0
;
bool
per_instruction
=
false
;
bool
per_instruction
=
false
;
bool
reduce
=
false
;
bool
reduce
=
false
;
void
parse
(
argument_parser
&
ap
)
void
parse
(
argument_parser
&
ap
)
{
{
c
.
parse
(
ap
);
c
.
parse
(
ap
);
ap
(
tolerance
,
{
"--tolerance"
},
ap
.
help
(
"Tolerance for errors"
));
ap
(
rms_tol
,
{
"--rms-tol"
},
ap
.
help
(
"Tolerance for the RMS error (Default: 0.001)"
));
ap
(
atol
,
{
"--atol"
},
ap
.
help
(
"Tolerance for the elementwise absolute difference (Default: 0.001)"
));
ap
(
rtol
,
{
"--rtol"
},
ap
.
help
(
"Tolerance for the elementwise relative difference (Default: 0.001)"
));
ap
(
per_instruction
,
ap
(
per_instruction
,
{
"-i"
,
"--per-instruction"
},
{
"-i"
,
"--per-instruction"
},
ap
.
help
(
"Verify each instruction"
),
ap
.
help
(
"Verify each instruction"
),
...
@@ -557,23 +572,54 @@ struct verify : command<verify>
...
@@ -557,23 +572,54 @@ struct verify : command<verify>
auto
t
=
c
.
ct
.
get_target
();
auto
t
=
c
.
ct
.
get_target
();
auto
m
=
c
.
parameters
.
generate
(
p
,
t
,
true
,
c
.
l
.
batch
);
auto
m
=
c
.
parameters
.
generate
(
p
,
t
,
true
,
c
.
l
.
batch
);
// TODO remove this and make the driver able to figure out datatype most used in the model
// then set the tolerances appropriately. Need to check here because c.to_fp16 only set
// after argument_parser.parse() is run. This code is complicated because there's not a
// good way to change the default tolerances after reading `--fp16` but before reading
// `--rms-tol`, `--atol`, and `--rtol`.
migraphx
::
verify
::
tolerance
tols
{};
if
(
c
.
to_fp16
)
{
tols
=
migraphx
::
verify
::
tolerance
{
8e-2
,
4e-2
,
4e-2
};
}
if
(
not
float_equal
(
this
->
rms_tol
,
-
1.0
))
{
tols
.
rms_tol
=
this
->
rms_tol
;
}
if
(
not
float_equal
(
this
->
atol
,
-
1.0
))
{
tols
.
atol
=
this
->
atol
;
}
if
(
not
float_equal
(
this
->
rtol
,
-
1.0
))
{
tols
.
rtol
=
this
->
rtol
;
}
std
::
cout
<<
"rms_tol: "
<<
tols
.
rms_tol
<<
std
::
endl
;
std
::
cout
<<
"atol: "
<<
tols
.
atol
<<
std
::
endl
;
std
::
cout
<<
"rtol: "
<<
tols
.
rtol
<<
std
::
endl
;
auto
quantize
=
precision
::
fp32
;
auto
quantize
=
precision
::
fp32
;
if
(
c
.
to_fp16
)
if
(
c
.
to_fp16
)
{
quantize
=
precision
::
fp16
;
quantize
=
precision
::
fp16
;
}
if
(
c
.
to_int8
)
if
(
c
.
to_int8
)
{
quantize
=
precision
::
int8
;
quantize
=
precision
::
int8
;
}
if
(
per_instruction
)
if
(
per_instruction
)
{
{
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tol
erance
);
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tol
s
);
}
}
else
if
(
reduce
)
else
if
(
reduce
)
{
{
verify_reduced_program
(
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
erance
);
verify_reduced_program
(
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
s
);
}
}
else
else
{
{
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
erance
);
verify_program
(
c
.
l
.
file
,
p
,
t
,
c
.
co
,
quantize
,
m
,
tol
s
);
}
}
}
}
};
};
...
@@ -802,6 +848,13 @@ int main(int argc, const char* argv[])
...
@@ -802,6 +848,13 @@ int main(int argc, const char* argv[])
auto
&&
m
=
get_commands
();
auto
&&
m
=
get_commands
();
auto
cmd
=
args
.
front
();
auto
cmd
=
args
.
front
();
if
(
cmd
==
"ort-sha"
)
{
std
::
cout
<<
MIGRAPHX_ORT_SHA1
<<
std
::
endl
;
return
0
;
}
if
(
m
.
count
(
cmd
)
>
0
)
if
(
m
.
count
(
cmd
)
>
0
)
{
{
m
.
at
(
cmd
)(
argv
[
0
],
{
args
.
begin
()
+
1
,
args
.
end
()});
m
.
at
(
cmd
)(
argv
[
0
],
{
args
.
begin
()
+
1
,
args
.
end
()});
...
...
src/driver/verify.cpp
View file @
8d32c6b8
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
namespace
driver
{
namespace
driver
{
...
@@ -76,15 +77,25 @@ void verify_program(const std::string& name,
...
@@ -76,15 +77,25 @@ void verify_program(const std::string& name,
compile_options
options
,
compile_options
options
,
precision
quantize
,
precision
quantize
,
const
parameter_map
&
inputs
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
{
auto
x
=
run_ref
(
p
,
inputs
);
auto
ref_outs
=
run_ref
(
p
,
inputs
);
auto
y
=
run_target
(
p
,
t
,
options
,
quantize
,
inputs
);
auto
target_outs
=
run_target
(
p
,
t
,
options
,
quantize
,
inputs
);
std
::
size_t
output_num
=
x
.
size
();
std
::
size_t
output_num
=
ref_outs
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
output_num
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
output_num
;
++
i
)
{
{
verify_args
(
name
,
x
[
i
],
y
[
i
],
tolerance
);
if
(
ref_outs
[
i
].
get_shape
().
type
()
!=
target_outs
[
i
].
get_shape
().
type
()
or
ref_outs
[
i
].
get_shape
().
lens
()
!=
target_outs
[
i
].
get_shape
().
lens
())
{
std
::
cout
<<
"FAILED: "
<<
name
<<
std
::
endl
;
std
::
cout
<<
"Shape mismatch {"
<<
ref_outs
[
i
].
get_shape
()
<<
"} != {"
<<
target_outs
[
i
].
get_shape
()
<<
"}"
<<
std
::
endl
;
}
else
{
verify_args
(
name
,
target_outs
[
i
],
verify
::
expected
{
ref_outs
[
i
]},
tols
);
}
}
}
}
}
...
@@ -92,7 +103,7 @@ void verify_instructions(const program& prog,
...
@@ -92,7 +103,7 @@ void verify_instructions(const program& prog,
const
target
&
t
,
const
target
&
t
,
compile_options
options
,
compile_options
options
,
precision
quantize
,
precision
quantize
,
double
tolerance
)
verify
::
tolerance
tols
)
{
{
const
auto
*
mm_prog
=
prog
.
get_main_module
();
const
auto
*
mm_prog
=
prog
.
get_main_module
();
for
(
auto
&&
ins
:
(
*
mm_prog
))
for
(
auto
&&
ins
:
(
*
mm_prog
))
...
@@ -123,8 +134,7 @@ void verify_instructions(const program& prog,
...
@@ -123,8 +134,7 @@ void verify_instructions(const program& prog,
{
{
std
::
cout
<<
"Verify: "
<<
ins
.
name
()
<<
std
::
endl
;
std
::
cout
<<
"Verify: "
<<
ins
.
name
()
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
verify_program
(
verify_program
(
ins
.
name
(),
p
,
t
,
options
,
quantize
,
create_param_map
(
p
,
false
),
tols
);
ins
.
name
(),
p
,
t
,
options
,
quantize
,
create_param_map
(
p
,
false
),
tolerance
);
}
}
catch
(...)
catch
(...)
{
{
...
@@ -140,14 +150,22 @@ void verify_reduced(program p,
...
@@ -140,14 +150,22 @@ void verify_reduced(program p,
compile_options
options
,
compile_options
options
,
precision
quantize
,
precision
quantize
,
const
parameter_map
&
inputs
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
{
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
last
=
std
::
prev
(
mm
->
end
(),
n
+
1
);
auto
last
=
std
::
prev
(
mm
->
end
(),
n
);
mm
->
remove_instructions
(
last
,
mm
->
end
());
mm
->
remove_instructions
(
last
,
mm
->
end
());
std
::
cout
<<
"Verify: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
"Verify: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
try
{
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tols
);
}
catch
(
const
std
::
exception
&
e
)
{
std
::
cout
<<
"FAILED: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
"Exception: "
<<
e
.
what
()
<<
std
::
endl
;
}
}
}
void
verify_reduced_program
(
const
program
&
p
,
void
verify_reduced_program
(
const
program
&
p
,
...
@@ -155,14 +173,20 @@ void verify_reduced_program(const program& p,
...
@@ -155,14 +173,20 @@ void verify_reduced_program(const program& p,
compile_options
options
,
compile_options
options
,
precision
quantize
,
precision
quantize
,
const
parameter_map
&
inputs
,
const
parameter_map
&
inputs
,
double
tolerance
)
verify
::
tolerance
tols
)
{
{
const
auto
*
mm
=
p
.
get_main_module
();
const
auto
*
mm
=
p
.
get_main_module
();
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
std
::
cout
<<
"Verify steps: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
"Verify steps: "
<<
n
<<
std
::
endl
;
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
n
;
i
++
)
{
{
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
auto
last
=
std
::
prev
(
mm
->
end
(),
i
+
1
);
if
(
contains
({
"@literal"
,
"@param"
},
last
->
name
()))
{
std
::
cout
<<
"Skip: "
<<
i
<<
std
::
endl
;
continue
;
}
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tols
);
}
}
}
}
...
...
src/driver/verify.hpp
View file @
8d32c6b8
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include "precision.hpp"
#include "precision.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/verify.hpp>
namespace
migraphx
{
namespace
migraphx
{
namespace
driver
{
namespace
driver
{
...
@@ -37,18 +38,18 @@ void verify_program(const std::string& name,
...
@@ -37,18 +38,18 @@ void verify_program(const std::string& name,
compile_options
options
=
compile_options
{},
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
precision
quantize
=
precision
::
fp32
,
const
parameter_map
&
inputs
=
{},
const
parameter_map
&
inputs
=
{},
double
tolerance
=
100
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
void
verify_instructions
(
const
program
&
prog
,
void
verify_instructions
(
const
program
&
prog
,
const
target
&
t
,
const
target
&
t
,
compile_options
options
=
compile_options
{},
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
precision
quantize
=
precision
::
fp32
,
double
tolerance
=
80
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
void
verify_reduced_program
(
const
program
&
p
,
void
verify_reduced_program
(
const
program
&
p
,
const
target
&
t
,
const
target
&
t
,
compile_options
options
=
compile_options
{},
compile_options
options
=
compile_options
{},
precision
quantize
=
precision
::
fp32
,
precision
quantize
=
precision
::
fp32
,
const
parameter_map
&
inputs
=
{},
const
parameter_map
&
inputs
=
{},
double
tolerance
=
80
);
verify
::
tolerance
tols
=
verify
::
tolerance
{}
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace driver
...
...
src/dynamic_loader.cpp
View file @
8d32c6b8
...
@@ -27,11 +27,20 @@
...
@@ -27,11 +27,20 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
#include <migraphx/tmp_dir.hpp>
#include <utility>
#include <utility>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#else
#include <dlfcn.h>
#include <dlfcn.h>
#endif
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#ifndef _WIN32
void
check_load_error
(
bool
flush
=
false
)
void
check_load_error
(
bool
flush
=
false
)
{
{
char
*
error_msg
=
dlerror
();
char
*
error_msg
=
dlerror
();
...
@@ -81,6 +90,48 @@ fs::path dynamic_loader::path(void* address)
...
@@ -81,6 +90,48 @@ fs::path dynamic_loader::path(void* address)
return
p
;
return
p
;
}
}
#else
struct
dynamic_loader_impl
{
dynamic_loader_impl
()
=
default
;
dynamic_loader_impl
(
const
fs
::
path
&
p
,
tmp_dir
t
=
{})
:
handle
{
LoadLibrary
(
p
.
string
().
c_str
())},
temp
{
std
::
move
(
t
)}
{
if
(
handle
==
nullptr
)
{
MIGRAPHX_THROW
(
"Error loading DLL: "
+
p
.
string
()
+
" ("
+
std
::
to_string
(
GetLastError
())
+
")"
);
}
}
dynamic_loader_impl
(
const
dynamic_loader_impl
&
)
=
delete
;
dynamic_loader_impl
&
operator
=
(
const
dynamic_loader_impl
&
)
=
delete
;
dynamic_loader_impl
(
dynamic_loader_impl
&&
)
=
default
;
~
dynamic_loader_impl
()
{
if
(
handle
!=
nullptr
)
{
FreeLibrary
(
handle
);
}
}
static
std
::
shared_ptr
<
dynamic_loader_impl
>
from_buffer
(
const
char
*
image
,
std
::
size_t
size
)
{
auto
t
=
tmp_dir
{
"migx-dynload"
};
auto
f
=
t
.
path
/
"tmp.dll"
;
write_buffer
(
f
.
string
(),
image
,
size
);
return
std
::
make_shared
<
dynamic_loader_impl
>
(
f
,
std
::
move
(
t
));
}
HMODULE
handle
=
nullptr
;
tmp_dir
temp
;
};
#endif
optional
<
dynamic_loader
>
dynamic_loader
::
try_load
(
const
fs
::
path
&
p
)
optional
<
dynamic_loader
>
dynamic_loader
::
try_load
(
const
fs
::
path
&
p
)
{
{
try
try
...
@@ -109,12 +160,19 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
...
@@ -109,12 +160,19 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std
::
shared_ptr
<
void
>
dynamic_loader
::
get_symbol
(
const
std
::
string
&
name
)
const
std
::
shared_ptr
<
void
>
dynamic_loader
::
get_symbol
(
const
std
::
string
&
name
)
const
{
{
#ifndef _WIN32
// flush any previous error messages
// flush any previous error messages
check_load_error
(
true
);
check_load_error
(
true
);
void
*
symbol
=
dlsym
(
impl
->
handle
.
get
(),
name
.
c_str
());
void
*
symbol
=
dlsym
(
impl
->
handle
.
get
(),
name
.
c_str
());
if
(
symbol
==
nullptr
)
if
(
symbol
==
nullptr
)
check_load_error
();
check_load_error
();
return
{
impl
,
symbol
};
return
{
impl
,
symbol
};
#else
FARPROC
addr
=
GetProcAddress
(
impl
->
handle
,
name
.
c_str
());
if
(
addr
==
nullptr
)
MIGRAPHX_THROW
(
"Symbol not found: "
+
name
+
" ("
+
std
::
to_string
(
GetLastError
())
+
")"
);
return
{
impl
,
reinterpret_cast
<
void
*>
(
addr
)};
#endif
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/eliminate_contiguous.cpp
View file @
8d32c6b8
...
@@ -35,6 +35,8 @@
...
@@ -35,6 +35,8 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
)
static
bool
try_compute_shape
(
instruction_ref
ins
,
static
bool
try_compute_shape
(
instruction_ref
ins
,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
std
::
vector
<
module_ref
>&
mods
)
...
@@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins,
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
});
});
if
(
not
try_compute_shape
(
output
,
input_shapes
,
mods
))
if
(
not
try_compute_shape
(
output
,
input_shapes
,
output
->
module_inputs
()
))
{
{
return
false
;
return
false
;
}
}
}
}
}
}
catch
(
const
std
::
exception
&
e
)
{
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"Exception: "
<<
e
.
what
()
<<
std
::
endl
;
}
return
false
;
}
catch
(...)
catch
(...)
{
{
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"Unknown exception"
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
...
@@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
...
@@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
{
{
if
(
arg
->
name
()
!=
op_name
)
if
(
arg
->
name
()
!=
op_name
)
continue
;
continue
;
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"eliminate_contiguous: "
;
m
.
debug_print
(
ins
);
}
auto
prev
=
arg
->
inputs
().
front
();
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
...
...
src/fuse_pointwise.cpp
View file @
8d32c6b8
...
@@ -24,11 +24,14 @@
...
@@ -24,11 +24,14 @@
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common_dims.hpp>
#include <iterator>
#include <iterator>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
...
@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
...
@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
}
}
return
changed
;
return
changed
;
}
}
namespace
{
struct
find_pointwise_reshape_pointwise
{
auto
matcher
()
const
{
auto
reshape
=
match
::
name
(
"reshape"
,
"squeeze"
,
"unsqueeze"
,
"flatten"
)(
match
::
used_once
());
auto
skip_contiguous
=
[](
auto
...
ms
)
{
return
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"contiguous"
)(
match
::
used_once
()))(
ms
...));
};
auto
pointwise
=
match
::
name
(
"pointwise"
)(
match
::
used_once
());
auto
reshape_pointwise
=
reshape
(
skip_contiguous
(
pointwise
.
bind
(
"x"
))).
bind
(
"reshape"
);
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
reshape_pointwise
));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
reshape_ins
=
r
.
instructions
[
"reshape"
];
auto
cd
=
common_dims
::
compute
(
ins
->
get_shape
().
lens
(),
x_ins
->
get_shape
().
lens
());
if
(
cd
.
dims
.
empty
())
return
;
auto
reshape_input
=
[
&
](
const
auto
&
ins_to_insert
)
{
return
[
&
](
auto
input
)
{
auto
c
=
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"contiguous"
),
input
);
return
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"reshape"
,
{{
"dims"
,
cd
.
dims
}}),
c
);
};
};
auto
x_inputs
=
x_ins
->
inputs
();
std
::
transform
(
x_inputs
.
begin
(),
x_inputs
.
end
(),
x_inputs
.
begin
(),
reshape_input
(
x_ins
));
auto
new_x_ins
=
m
.
insert_instruction
(
x_ins
,
x_ins
->
get_operator
(),
x_inputs
,
x_ins
->
module_inputs
());
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
==
reshape_ins
)
return
new_x_ins
;
return
reshape_input
(
ins
)(
input
);
});
auto
pw
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
ins
->
module_inputs
());
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
ins
->
get_shape
().
lens
()}}),
pw
);
}
};
}
// namespace
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
...
@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
...
@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
}
}
for
(
int
i
=
0
;
i
<
8
;
i
++
)
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
{
match
::
find_matches
(
mpm
.
get_module
(),
find_pointwise_reshape_pointwise
{});
mpm
.
run_pass
(
simplify_reshapes
{
1
});
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
break
;
break
;
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
...
...
src/include/migraphx/auto_register.hpp
View file @
8d32c6b8
...
@@ -63,9 +63,8 @@ const int auto_register<Action, T>::static_register = auto_register_action<Actio
...
@@ -63,9 +63,8 @@ const int auto_register<Action, T>::static_register = auto_register_action<Actio
#define MIGRAPHX_AUTO_REGISTER_NAME(x) MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x)
#define MIGRAPHX_AUTO_REGISTER_NAME(x) MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x)
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_AUTO_REGISTER(...) \
#define MIGRAPHX_AUTO_REGISTER(...) \
void MIGRAPHX_AUTO_REGISTER_NAME(__LINE__)(migraphx::auto_register<__VA_ARGS__> x = \
[[maybe_unused]] void MIGRAPHX_AUTO_REGISTER_NAME(__LINE__)( \
migraphx::auto_register<__VA_ARGS__>{}) \
migraphx::auto_register<__VA_ARGS__> x = migraphx::auto_register<__VA_ARGS__>{});
__attribute__((unused));
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/check_shapes.hpp
View file @
8d32c6b8
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -70,13 +70,19 @@ struct check_shapes
...
@@ -70,13 +70,19 @@ struct check_shapes
check_dynamic
();
check_dynamic
();
}
}
template
<
class
Op
>
template
<
class
Op
,
MIGRAPHX_REQUIRES
(
not
std
::
is_convertible
<
Op
,
std
::
string
>{})
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
}
}
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
n
),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
void
check_dynamic
()
const
void
check_dynamic
()
const
{
{
if
(
not
dynamic_allowed
and
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
if
(
not
dynamic_allowed
and
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
...
@@ -147,7 +153,7 @@ struct check_shapes
...
@@ -147,7 +153,7 @@ struct check_shapes
{
{
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
!=
n
)
if
(
begin
->
ndim
()
!=
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
MIGRAPHX_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
}
}
return
*
this
;
return
*
this
;
...
@@ -162,7 +168,7 @@ struct check_shapes
...
@@ -162,7 +168,7 @@ struct check_shapes
{
{
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
>
n
)
if
(
begin
->
ndim
()
>
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at most "
+
std
::
to_string
(
n
)
+
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at most "
+
std
::
to_string
(
n
)
+
" dimensions"
);
" dimensions"
);
}
}
...
@@ -178,7 +184,7 @@ struct check_shapes
...
@@ -178,7 +184,7 @@ struct check_shapes
{
{
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
<
n
)
if
(
begin
->
ndim
()
<
n
)
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at least "
+
std
::
to_string
(
n
)
+
MIGRAPHX_THROW
(
prefix
()
+
"Shape must have at least "
+
std
::
to_string
(
n
)
+
" dimensions"
);
" dimensions"
);
}
}
...
@@ -228,6 +234,16 @@ struct check_shapes
...
@@ -228,6 +234,16 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes have the same layout.
*/
const
check_shapes
&
same_layout
()
const
{
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
find_permutation
(
s
);
}))
MIGRAPHX_THROW
(
prefix
()
+
"Layouts do not match"
);
return
*
this
;
}
/*!
/*!
* Check all shapes are standard.
* Check all shapes are standard.
*/
*/
...
@@ -238,6 +254,16 @@ struct check_shapes
...
@@ -238,6 +254,16 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are scalar.
*/
const
check_shapes
&
scalar
()
const
{
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar"
);
return
*
this
;
}
/*!
/*!
* Check all shapes are standard or scalar.
* Check all shapes are standard or scalar.
*/
*/
...
...
src/include/migraphx/common_dims.hpp
0 → 100644
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/// This will compute a higher dimensional space that will preserve the axes
/// for both sets of dimensions. Two axes_maps are provided for each of the
/// dims that will map the axis to the axes that are used by the result of
/// common_dims.
struct
MIGRAPHX_EXPORT
common_dims
{
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
);
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
src/include/migraphx/compile_src.hpp
View file @
8d32c6b8
...
@@ -37,8 +37,18 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -37,8 +37,18 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
src_file
struct
src_file
{
{
fs
::
path
path
;
fs
::
path
path
;
std
::
pair
<
const
char
*
,
const
char
*>
content
;
std
::
string_view
content
;
std
::
size_t
len
()
const
{
return
content
.
second
-
content
.
first
;
}
src_file
()
=
default
;
src_file
(
fs
::
path
file_path
,
std
::
string_view
file_content
)
:
path
{
std
::
move
(
file_path
)},
content
{
file_content
}
{
}
explicit
src_file
(
const
std
::
pair
<
std
::
string_view
,
std
::
string_view
>&
pair
)
:
path
{
pair
.
first
},
content
{
pair
.
second
}
{
}
};
};
struct
MIGRAPHX_EXPORT
src_compiler
struct
MIGRAPHX_EXPORT
src_compiler
...
...
Prev
1
2
3
4
5
6
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment