Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9dc0b779
Unverified
Commit
9dc0b779
authored
Sep 12, 2022
by
Paul Fultz II
Committed by
GitHub
Sep 12, 2022
Browse files
Merge branch 'develop' into jit-concat
parents
30cde6df
d78bcdfb
Changes
146
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
59 additions
and
53 deletions
+59
-53
CMakeLists.txt
CMakeLists.txt
+1
-1
cppcheck.rules
cppcheck.rules
+2
-2
examples/migraphx/cpp_parse_load_save/parse_load_save.cpp
examples/migraphx/cpp_parse_load_save/parse_load_save.cpp
+2
-2
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
...phx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
+1
-1
examples/vision/cpp_mnist/mnist_inference.cpp
examples/vision/cpp_mnist/mnist_inference.cpp
+6
-6
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+4
-4
src/apply_alpha_beta.cpp
src/apply_alpha_beta.cpp
+1
-1
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+1
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+1
-1
src/file_buffer.cpp
src/file_buffer.cpp
+1
-1
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+2
-2
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+16
-15
src/include/migraphx/concat_opt.hpp
src/include/migraphx/concat_opt.hpp
+2
-2
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+2
-2
src/include/migraphx/iterator.hpp
src/include/migraphx/iterator.hpp
+2
-2
src/include/migraphx/marker.hpp
src/include/migraphx/marker.hpp
+2
-2
src/include/migraphx/match/gelu_erf.hpp
src/include/migraphx/match/gelu_erf.hpp
+5
-5
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+6
-2
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+1
-1
No files found.
CMakeLists.txt
View file @
9dc0b779
...
@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES)
...
@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES)
include
(
ROCMSetupVersion
)
include
(
ROCMSetupVersion
)
rocm_setup_version
(
VERSION 2.
3
)
rocm_setup_version
(
VERSION 2.
4
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
...
...
cppcheck.rules
View file @
9dc0b779
...
@@ -107,7 +107,7 @@
...
@@ -107,7 +107,7 @@
<summary>
Use make_shared or make_unique instead of new
</summary>
<summary>
Use make_shared or make_unique instead of new
</summary>
</message>
</message>
</rule>
</rule>
<!--
<rule>
<rule>
<tokenlist>
raw
</tokenlist>
<tokenlist>
raw
</tokenlist>
<pattern>
<![CDATA[ \|\| ]]>
</pattern>
<pattern>
<![CDATA[ \|\| ]]>
</pattern>
<message>
<message>
...
@@ -124,7 +124,7 @@
...
@@ -124,7 +124,7 @@
<severity>
style
</severity>
<severity>
style
</severity>
<summary>
Use 'not' instead of !
</summary>
<summary>
Use 'not' instead of !
</summary>
</message>
</message>
</rule>
-->
</rule>
<!-- <rule>
<!-- <rule>
<tokenlist>raw</tokenlist>
<tokenlist>raw</tokenlist>
<pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern>
<pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern>
...
...
examples/migraphx/cpp_parse_load_save/parse_load_save.cpp
View file @
9dc0b779
...
@@ -53,8 +53,8 @@ int main(int argc, char** argv)
...
@@ -53,8 +53,8 @@ int main(int argc, char** argv)
migraphx
::
program
p
;
migraphx
::
program
p
;
if
(
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--parse"
)
||
if
(
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--parse"
)
or
!
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--load"
))
not
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--load"
))
{
{
std
::
cout
<<
"Parsing ONNX File"
<<
std
::
endl
;
std
::
cout
<<
"Parsing ONNX File"
<<
std
::
endl
;
migraphx
::
onnx_options
options
;
migraphx
::
onnx_options
options
;
...
...
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
View file @
9dc0b779
...
@@ -72,7 +72,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
...
@@ -72,7 +72,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
{
{
throw
std
::
runtime_error
(
"sscal_custom_op must have 2 input arguments"
);
throw
std
::
runtime_error
(
"sscal_custom_op must have 2 input arguments"
);
}
}
if
(
inputs
[
0
].
lengths
().
size
()
!=
1
||
inputs
[
0
].
lengths
()[
0
]
!=
1
)
if
(
inputs
[
0
].
lengths
().
size
()
!=
1
or
inputs
[
0
].
lengths
()[
0
]
!=
1
)
{
{
throw
std
::
runtime_error
(
"first input argument to sscal_custom_op must be a scalar"
);
throw
std
::
runtime_error
(
"first input argument to sscal_custom_op must be a scalar"
);
}
}
...
...
examples/vision/cpp_mnist/mnist_inference.cpp
View file @
9dc0b779
...
@@ -51,16 +51,16 @@ int main(int argc, char** argv)
...
@@ -51,16 +51,16 @@ int main(int argc, char** argv)
char
**
begin
=
argv
+
1
;
char
**
begin
=
argv
+
1
;
char
**
end
=
argv
+
argc
;
char
**
end
=
argv
+
argc
;
const
bool
CPU
=
(
std
::
find
(
begin
,
end
,
std
::
string
(
"-c"
))
!=
end
)
||
const
bool
CPU
=
(
std
::
find
(
begin
,
end
,
std
::
string
(
"-c"
))
!=
end
)
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--cpu"
))
!=
end
;
std
::
find
(
begin
,
end
,
std
::
string
(
"--cpu"
))
!=
end
;
const
bool
GPU
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-g"
))
!=
end
||
const
bool
GPU
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-g"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--gpu"
))
!=
end
;
std
::
find
(
begin
,
end
,
std
::
string
(
"--gpu"
))
!=
end
;
const
bool
FP16
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-f"
))
!=
end
||
const
bool
FP16
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-f"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--fp16"
))
!=
end
;
std
::
find
(
begin
,
end
,
std
::
string
(
"--fp16"
))
!=
end
;
const
bool
INT8
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-i"
))
!=
end
||
const
bool
INT8
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-i"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--int8"
))
!=
end
;
std
::
find
(
begin
,
end
,
std
::
string
(
"--int8"
))
!=
end
;
const
bool
CALIB
=
std
::
find
(
begin
,
end
,
std
::
string
(
"--cal"
))
!=
end
;
const
bool
CALIB
=
std
::
find
(
begin
,
end
,
std
::
string
(
"--cal"
))
!=
end
;
const
bool
PRINT
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-p"
))
!=
end
||
const
bool
PRINT
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-p"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--print"
))
!=
end
;
std
::
find
(
begin
,
end
,
std
::
string
(
"--print"
))
!=
end
;
migraphx
::
program
prog
;
migraphx
::
program
prog
;
...
@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit)
...
@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit)
const
int
HEIGHT
=
28
;
const
int
HEIGHT
=
28
;
const
int
WIDTH
=
28
;
const
int
WIDTH
=
28
;
if
(
!
file
.
is_open
())
if
(
not
file
.
is_open
())
{
{
return
;
return
;
}
}
...
...
src/CMakeLists.txt
View file @
9dc0b779
...
@@ -82,6 +82,7 @@ add_library(migraphx
...
@@ -82,6 +82,7 @@ add_library(migraphx
simplify_qdq.cpp
simplify_qdq.cpp
sqlite.cpp
sqlite.cpp
rewrite_batchnorm.cpp
rewrite_batchnorm.cpp
rewrite_gelu.cpp
rewrite_pooling.cpp
rewrite_pooling.cpp
rewrite_quantization.cpp
rewrite_quantization.cpp
rewrite_rnn.cpp
rewrite_rnn.cpp
...
...
src/api/include/migraphx/migraphx.hpp
View file @
9dc0b779
...
@@ -588,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -588,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
pout
;
return
pout
;
}
}
friend
bool
operator
!=
(
const
shape
&
px
,
const
shape
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
shape
&
px
,
const
shape
&
py
)
{
return
not
(
px
==
py
);
}
};
};
/**
/**
...
@@ -647,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -647,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return
pout
;
return
pout
;
}
}
friend
bool
operator
!=
(
const
argument
&
px
,
const
argument
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
argument
&
px
,
const
argument
&
py
)
{
return
not
(
px
==
py
);
}
};
};
/// A target for compilation
/// A target for compilation
...
@@ -684,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -684,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
std
::
vector
<
const
char
*>
names
()
const
std
::
vector
<
const
char
*>
names
()
const
{
{
std
::
vector
<
const
char
*>
result
(
this
->
size
());
std
::
vector
<
const
char
*>
result
(
this
->
size
());
if
(
!
result
.
empty
())
if
(
not
result
.
empty
())
{
{
call
(
&
migraphx_program_parameter_shapes_names
,
result
.
data
(),
this
->
get_handle_ptr
());
call
(
&
migraphx_program_parameter_shapes_names
,
result
.
data
(),
this
->
get_handle_ptr
());
}
}
...
@@ -1015,7 +1015,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -1015,7 +1015,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return
module
{
p_modu
,
this
->
share_handle
()};
return
module
{
p_modu
,
this
->
share_handle
()};
}
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
not
(
px
==
py
);
}
};
};
// options for migraphx file format options
// options for migraphx file format options
...
...
src/apply_alpha_beta.cpp
View file @
9dc0b779
...
@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
...
@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto
a
=
args
[
0
];
auto
a
=
args
[
0
];
auto
b
=
args
[
1
];
auto
b
=
args
[
1
];
auto
input_type
=
a
->
get_shape
().
type
();
auto
input_type
=
a
->
get_shape
().
type
();
if
(
!
float_equal
(
alpha
.
at
<
float
>
(
0
),
1.0
))
if
(
not
float_equal
(
alpha
.
at
<
float
>
(
0
),
1.0
))
{
{
auto
alpha_literal
=
m
.
add_literal
(
alpha
);
auto
alpha_literal
=
m
.
add_literal
(
alpha
);
a
=
insert_common_op
(
m
,
pos
,
migraphx
::
make_op
(
"mul"
),
{
alpha_literal
,
a
});
a
=
insert_common_op
(
m
,
pos
,
migraphx
::
make_op
(
"mul"
),
{
alpha_literal
,
a
});
...
...
src/eliminate_concat.cpp
View file @
9dc0b779
...
@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
...
@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
std
::
size_t
axis_index
=
tune_axis
(
lens
.
size
(),
concat_op
.
axis
,
concat_op
.
name
());
std
::
size_t
axis_index
=
tune_axis
(
lens
.
size
(),
concat_op
.
axis
,
concat_op
.
name
());
if
(
axis_index
==
0
||
if
(
axis_index
==
0
or
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
axis_index
,
[](
auto
x
)
{
return
x
==
1
;
}))
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
axis_index
,
[](
auto
x
)
{
return
x
==
1
;
}))
{
{
// Last input should be an allocation
// Last input should be an allocation
...
...
src/eliminate_contiguous.cpp
View file @
9dc0b779
...
@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
});
});
if
(
!
try_compute_shape
(
output
,
input_shapes
,
mods
))
if
(
not
try_compute_shape
(
output
,
input_shapes
,
mods
))
{
{
return
false
;
return
false
;
}
}
...
...
src/file_buffer.cpp
View file @
9dc0b779
...
@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
...
@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
is
.
seekg
(
0
,
std
::
ios
::
beg
);
is
.
seekg
(
0
,
std
::
ios
::
beg
);
T
buffer
(
size
,
0
);
T
buffer
(
size
,
0
);
if
(
!
is
.
read
(
&
buffer
[
0
],
size
))
if
(
not
is
.
read
(
&
buffer
[
0
],
size
))
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
return
buffer
;
return
buffer
;
}
}
...
...
src/include/migraphx/allocation_model.hpp
View file @
9dc0b779
...
@@ -205,7 +205,7 @@ struct allocation_model
...
@@ -205,7 +205,7 @@ struct allocation_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -267,7 +267,7 @@ struct allocation_model
...
@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/check_shapes.hpp
View file @
9dc0b779
...
@@ -101,7 +101,7 @@ struct check_shapes
...
@@ -101,7 +101,7 @@ struct check_shapes
const
check_shapes
&
nelements
(
std
::
size_t
n
)
const
const
check_shapes
&
nelements
(
std
::
size_t
n
)
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes must have only "
+
std
::
to_string
(
n
)
+
" elements"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes must have only "
+
std
::
to_string
(
n
)
+
" elements"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -164,7 +164,7 @@ struct check_shapes
...
@@ -164,7 +164,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_shape
()
const
const
check_shapes
&
same_shape
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -174,7 +174,7 @@ struct check_shapes
...
@@ -174,7 +174,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_type
()
const
const
check_shapes
&
same_type
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Types do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Types do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -184,10 +184,10 @@ struct check_shapes
...
@@ -184,10 +184,10 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_dims
()
const
const
check_shapes
&
same_dims
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Dimensions do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Dimensions do not match"
);
if
(
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
if
(
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Min dynamic dimensions do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Min dynamic dimensions do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -197,7 +197,7 @@ struct check_shapes
...
@@ -197,7 +197,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_ndims
()
const
const
check_shapes
&
same_ndims
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -207,7 +207,7 @@ struct check_shapes
...
@@ -207,7 +207,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
standard
()
const
const
check_shapes
&
standard
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not in standard layout"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not in standard layout"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -217,7 +217,7 @@ struct check_shapes
...
@@ -217,7 +217,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
standard_or_scalar
()
const
const
check_shapes
&
standard_or_scalar
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -227,7 +227,7 @@ struct check_shapes
...
@@ -227,7 +227,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
packed
()
const
const
check_shapes
&
packed
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -237,7 +237,7 @@ struct check_shapes
...
@@ -237,7 +237,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
packed_or_broadcasted
()
const
const
check_shapes
&
packed_or_broadcasted
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed nor broadcasted"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed nor broadcasted"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -247,7 +247,7 @@ struct check_shapes
...
@@ -247,7 +247,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
tuple_type
()
const
const
check_shapes
&
tuple_type
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not tuple!"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not tuple!"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -257,7 +257,7 @@ struct check_shapes
...
@@ -257,7 +257,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
not_transposed
()
const
const
check_shapes
&
not_transposed
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are transposed"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are transposed"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -267,7 +267,7 @@ struct check_shapes
...
@@ -267,7 +267,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
not_broadcasted
()
const
const
check_shapes
&
not_broadcasted
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are broadcasted"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are broadcasted"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -278,7 +278,7 @@ struct check_shapes
...
@@ -278,7 +278,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -288,7 +288,8 @@ struct check_shapes
...
@@ -288,7 +288,8 @@ struct check_shapes
*/
*/
const
check_shapes
&
batch_not_transposed
()
const
const
check_shapes
&
batch_not_transposed
()
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
if
(
not
this
->
all_of
(
[
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
MIGRAPHX_THROW
(
prefix
()
+
"Batch size is transposed"
);
MIGRAPHX_THROW
(
prefix
()
+
"Batch size is transposed"
);
return
*
this
;
return
*
this
;
}
}
...
...
src/include/migraphx/concat_opt.hpp
View file @
9dc0b779
...
@@ -183,7 +183,7 @@ struct concat_optimization
...
@@ -183,7 +183,7 @@ struct concat_optimization
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -233,7 +233,7 @@ struct concat_optimization
...
@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/context.hpp
View file @
9dc0b779
...
@@ -246,7 +246,7 @@ struct context
...
@@ -246,7 +246,7 @@ struct context
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -306,7 +306,7 @@ struct context
...
@@ -306,7 +306,7 @@ struct context
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/iterator.hpp
View file @
9dc0b779
...
@@ -31,9 +31,9 @@ namespace migraphx {
...
@@ -31,9 +31,9 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
,
class
EndIterator
>
template
<
class
Iterator
,
class
EndIterator
>
auto
is_end
(
rank
<
2
>
,
Iterator
it
,
EndIterator
)
->
decltype
(
!
it
.
_M_dereferenceable
())
auto
is_end
(
rank
<
2
>
,
Iterator
it
,
EndIterator
)
->
decltype
(
not
it
.
_M_dereferenceable
())
{
{
return
!
it
.
_M_dereferenceable
();
return
not
it
.
_M_dereferenceable
();
}
}
template
<
class
Iterator
,
class
EndIterator
>
template
<
class
Iterator
,
class
EndIterator
>
...
...
src/include/migraphx/marker.hpp
View file @
9dc0b779
...
@@ -181,7 +181,7 @@ struct marker
...
@@ -181,7 +181,7 @@ struct marker
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -233,7 +233,7 @@ struct marker
...
@@ -233,7 +233,7 @@ struct marker
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/match/gelu_erf.hpp
View file @
9dc0b779
...
@@ -38,11 +38,11 @@ struct gelu_erf_matcher
...
@@ -38,11 +38,11 @@ struct gelu_erf_matcher
F
f
;
F
f
;
auto
erf_fn
()
const
auto
erf_fn
()
const
{
{
return
f
(
"erf"
)(
auto
mul_1_sqrt_2
=
f
(
"mul"
)(
either_arg
(
0
,
1
)(
none_of
(
has_value
(
M_SQRT1_2
,
1e-3
)).
bind
(
"x"
),
used_once
(),
has_value
(
M_SQRT1_2
,
1e-3
)));
arg
(
0
)(
used_once
(),
auto
div_sqrt_2
=
f
(
"mul"
)(
either_arg
(
0
,
1
)
(
none_of
(
has_value
(
M_SQRT
1_
2
,
1e-3
)).
bind
(
"x"
),
f
(
"div"
)(
args
(
none_of
(
has_value
(
M_SQRT2
,
1e-3
)).
bind
(
"x"
),
has_value
(
M_SQRT2
,
1e-3
)));
has_value
(
M_SQRT1_2
,
1e-3
))
)));
return
f
(
"erf"
)(
used_once
(),
arg
(
0
)(
used_once
(),
any_of
(
mul_1_sqrt_2
,
div_sqrt_2
)));
}
}
auto
add_erf
()
const
auto
add_erf
()
const
...
...
src/include/migraphx/matcher.hpp
View file @
9dc0b779
...
@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
...
@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return
nullopt
;
return
nullopt
;
}
}
MIGRAPHX_PRED_MATCHER
(
broadcast
,
instruction_ref
ins
)
{
return
contains
({
"broadcast"
,
"multibroadcast"
},
ins
->
name
());
}
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
skip
(
Ms
...
ms
)
auto
skip
(
Ms
...
ms
)
{
{
...
@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
...
@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
pointwise
(
Ms
...
ms
)
auto
pointwise
(
Ms
...
ms
)
{
{
return
match
::
has_attribute
(
"pointwise"
)(
match
::
any_of
(
match
::
nargs
(
1
),
match
::
nargs
(
2
)),
return
match
::
has_attribute
(
"pointwise"
)(
ms
...);
ms
...);
}
}
}
// namespace match
}
// namespace match
...
...
src/include/migraphx/module.hpp
View file @
9dc0b779
...
@@ -219,7 +219,7 @@ struct module
...
@@ -219,7 +219,7 @@ struct module
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
module
&
m
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
module
&
m
);
friend
bool
operator
==
(
const
module
&
x
,
const
module
&
y
);
friend
bool
operator
==
(
const
module
&
x
,
const
module
&
y
);
friend
bool
operator
!=
(
const
module
&
x
,
const
module
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
module
&
x
,
const
module
&
y
)
{
return
not
(
x
==
y
);
}
private:
private:
void
assign
(
const
module
&
m
);
void
assign
(
const
module
&
m
);
...
...
Prev
1
2
3
4
5
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment