Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e5f47154
Unverified
Commit
e5f47154
authored
Sep 09, 2022
by
Umang Yadav
Committed by
GitHub
Sep 09, 2022
Browse files
Merge branch 'develop' into workspace_size
parents
4a3afd0f
d78bcdfb
Changes
124
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
52 additions
and
50 deletions
+52
-50
CMakeLists.txt
CMakeLists.txt
+1
-1
cppcheck.rules
cppcheck.rules
+2
-2
examples/migraphx/cpp_parse_load_save/parse_load_save.cpp
examples/migraphx/cpp_parse_load_save/parse_load_save.cpp
+2
-2
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
...phx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
+1
-1
examples/vision/cpp_mnist/mnist_inference.cpp
examples/vision/cpp_mnist/mnist_inference.cpp
+6
-6
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+4
-4
src/apply_alpha_beta.cpp
src/apply_alpha_beta.cpp
+1
-1
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+1
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+1
-1
src/file_buffer.cpp
src/file_buffer.cpp
+1
-1
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+2
-2
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+16
-15
src/include/migraphx/concat_opt.hpp
src/include/migraphx/concat_opt.hpp
+2
-2
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+2
-2
src/include/migraphx/iterator.hpp
src/include/migraphx/iterator.hpp
+2
-2
src/include/migraphx/marker.hpp
src/include/migraphx/marker.hpp
+2
-2
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+1
-1
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+1
-1
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+1
-1
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+3
-2
No files found.
CMakeLists.txt
View file @
e5f47154
...
...
@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES)
include
(
ROCMSetupVersion
)
rocm_setup_version
(
VERSION 2.
3
)
rocm_setup_version
(
VERSION 2.
4
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
...
...
cppcheck.rules
View file @
e5f47154
...
...
@@ -107,7 +107,7 @@
<summary>
Use make_shared or make_unique instead of new
</summary>
</message>
</rule>
<!--
<rule>
<rule>
<tokenlist>
raw
</tokenlist>
<pattern>
<![CDATA[ \|\| ]]>
</pattern>
<message>
...
...
@@ -124,7 +124,7 @@
<severity>
style
</severity>
<summary>
Use 'not' instead of !
</summary>
</message>
</rule>
-->
</rule>
<!-- <rule>
<tokenlist>raw</tokenlist>
<pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern>
...
...
examples/migraphx/cpp_parse_load_save/parse_load_save.cpp
View file @
e5f47154
...
...
@@ -53,8 +53,8 @@ int main(int argc, char** argv)
migraphx
::
program
p
;
if
(
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--parse"
)
||
!
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--load"
))
if
(
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--parse"
)
or
not
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--load"
))
{
std
::
cout
<<
"Parsing ONNX File"
<<
std
::
endl
;
migraphx
::
onnx_options
options
;
...
...
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
View file @
e5f47154
...
...
@@ -72,7 +72,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
{
throw
std
::
runtime_error
(
"sscal_custom_op must have 2 input arguments"
);
}
if
(
inputs
[
0
].
lengths
().
size
()
!=
1
||
inputs
[
0
].
lengths
()[
0
]
!=
1
)
if
(
inputs
[
0
].
lengths
().
size
()
!=
1
or
inputs
[
0
].
lengths
()[
0
]
!=
1
)
{
throw
std
::
runtime_error
(
"first input argument to sscal_custom_op must be a scalar"
);
}
...
...
examples/vision/cpp_mnist/mnist_inference.cpp
View file @
e5f47154
...
...
@@ -51,16 +51,16 @@ int main(int argc, char** argv)
char
**
begin
=
argv
+
1
;
char
**
end
=
argv
+
argc
;
const
bool
CPU
=
(
std
::
find
(
begin
,
end
,
std
::
string
(
"-c"
))
!=
end
)
||
const
bool
CPU
=
(
std
::
find
(
begin
,
end
,
std
::
string
(
"-c"
))
!=
end
)
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--cpu"
))
!=
end
;
const
bool
GPU
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-g"
))
!=
end
||
const
bool
GPU
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-g"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--gpu"
))
!=
end
;
const
bool
FP16
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-f"
))
!=
end
||
const
bool
FP16
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-f"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--fp16"
))
!=
end
;
const
bool
INT8
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-i"
))
!=
end
||
const
bool
INT8
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-i"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--int8"
))
!=
end
;
const
bool
CALIB
=
std
::
find
(
begin
,
end
,
std
::
string
(
"--cal"
))
!=
end
;
const
bool
PRINT
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-p"
))
!=
end
||
const
bool
PRINT
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-p"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--print"
))
!=
end
;
migraphx
::
program
prog
;
...
...
@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit)
const
int
HEIGHT
=
28
;
const
int
WIDTH
=
28
;
if
(
!
file
.
is_open
())
if
(
not
file
.
is_open
())
{
return
;
}
...
...
src/api/include/migraphx/migraphx.hpp
View file @
e5f47154
...
...
@@ -588,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
pout
;
}
friend
bool
operator
!=
(
const
shape
&
px
,
const
shape
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
shape
&
px
,
const
shape
&
py
)
{
return
not
(
px
==
py
);
}
};
/**
...
...
@@ -647,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return
pout
;
}
friend
bool
operator
!=
(
const
argument
&
px
,
const
argument
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
argument
&
px
,
const
argument
&
py
)
{
return
not
(
px
==
py
);
}
};
/// A target for compilation
...
...
@@ -684,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
std
::
vector
<
const
char
*>
names
()
const
{
std
::
vector
<
const
char
*>
result
(
this
->
size
());
if
(
!
result
.
empty
())
if
(
not
result
.
empty
())
{
call
(
&
migraphx_program_parameter_shapes_names
,
result
.
data
(),
this
->
get_handle_ptr
());
}
...
...
@@ -1015,7 +1015,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return
module
{
p_modu
,
this
->
share_handle
()};
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
not
(
px
==
py
);
}
};
// options for migraphx file format options
...
...
src/apply_alpha_beta.cpp
View file @
e5f47154
...
...
@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto
a
=
args
[
0
];
auto
b
=
args
[
1
];
auto
input_type
=
a
->
get_shape
().
type
();
if
(
!
float_equal
(
alpha
.
at
<
float
>
(
0
),
1.0
))
if
(
not
float_equal
(
alpha
.
at
<
float
>
(
0
),
1.0
))
{
auto
alpha_literal
=
m
.
add_literal
(
alpha
);
a
=
insert_common_op
(
m
,
pos
,
migraphx
::
make_op
(
"mul"
),
{
alpha_literal
,
a
});
...
...
src/eliminate_concat.cpp
View file @
e5f47154
...
...
@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
std
::
size_t
axis_index
=
tune_axis
(
lens
.
size
(),
concat_op
.
axis
,
concat_op
.
name
());
if
(
axis_index
==
0
||
if
(
axis_index
==
0
or
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
axis_index
,
[](
auto
x
)
{
return
x
==
1
;
}))
{
// Last input should be an allocation
...
...
src/eliminate_contiguous.cpp
View file @
e5f47154
...
...
@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
});
if
(
!
try_compute_shape
(
output
,
input_shapes
,
mods
))
if
(
not
try_compute_shape
(
output
,
input_shapes
,
mods
))
{
return
false
;
}
...
...
src/file_buffer.cpp
View file @
e5f47154
...
...
@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
is
.
seekg
(
0
,
std
::
ios
::
beg
);
T
buffer
(
size
,
0
);
if
(
!
is
.
read
(
&
buffer
[
0
],
size
))
if
(
not
is
.
read
(
&
buffer
[
0
],
size
))
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
return
buffer
;
}
...
...
src/include/migraphx/allocation_model.hpp
View file @
e5f47154
...
...
@@ -205,7 +205,7 @@ struct allocation_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/check_shapes.hpp
View file @
e5f47154
...
...
@@ -101,7 +101,7 @@ struct check_shapes
const
check_shapes
&
nelements
(
std
::
size_t
n
)
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes must have only "
+
std
::
to_string
(
n
)
+
" elements"
);
return
*
this
;
}
...
...
@@ -164,7 +164,7 @@ struct check_shapes
*/
const
check_shapes
&
same_shape
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes do not match"
);
return
*
this
;
}
...
...
@@ -174,7 +174,7 @@ struct check_shapes
*/
const
check_shapes
&
same_type
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Types do not match"
);
return
*
this
;
}
...
...
@@ -184,10 +184,10 @@ struct check_shapes
*/
const
check_shapes
&
same_dims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Dimensions do not match"
);
if
(
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Min dynamic dimensions do not match"
);
return
*
this
;
}
...
...
@@ -197,7 +197,7 @@ struct check_shapes
*/
const
check_shapes
&
same_ndims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
}
...
...
@@ -207,7 +207,7 @@ struct check_shapes
*/
const
check_shapes
&
standard
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not in standard layout"
);
return
*
this
;
}
...
...
@@ -217,7 +217,7 @@ struct check_shapes
*/
const
check_shapes
&
standard_or_scalar
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
return
*
this
;
}
...
...
@@ -227,7 +227,7 @@ struct check_shapes
*/
const
check_shapes
&
packed
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed"
);
return
*
this
;
}
...
...
@@ -237,7 +237,7 @@ struct check_shapes
*/
const
check_shapes
&
packed_or_broadcasted
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed nor broadcasted"
);
return
*
this
;
}
...
...
@@ -247,7 +247,7 @@ struct check_shapes
*/
const
check_shapes
&
tuple_type
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not tuple!"
);
return
*
this
;
}
...
...
@@ -257,7 +257,7 @@ struct check_shapes
*/
const
check_shapes
&
not_transposed
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are transposed"
);
return
*
this
;
}
...
...
@@ -267,7 +267,7 @@ struct check_shapes
*/
const
check_shapes
&
not_broadcasted
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are broadcasted"
);
return
*
this
;
}
...
...
@@ -278,7 +278,7 @@ struct check_shapes
*/
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
return
*
this
;
}
...
...
@@ -288,7 +288,8 @@ struct check_shapes
*/
const
check_shapes
&
batch_not_transposed
()
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
if
(
not
this
->
all_of
(
[
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
MIGRAPHX_THROW
(
prefix
()
+
"Batch size is transposed"
);
return
*
this
;
}
...
...
src/include/migraphx/concat_opt.hpp
View file @
e5f47154
...
...
@@ -183,7 +183,7 @@ struct concat_optimization
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/context.hpp
View file @
e5f47154
...
...
@@ -246,7 +246,7 @@ struct context
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -306,7 +306,7 @@ struct context
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/iterator.hpp
View file @
e5f47154
...
...
@@ -31,9 +31,9 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
,
class
EndIterator
>
auto
is_end
(
rank
<
2
>
,
Iterator
it
,
EndIterator
)
->
decltype
(
!
it
.
_M_dereferenceable
())
auto
is_end
(
rank
<
2
>
,
Iterator
it
,
EndIterator
)
->
decltype
(
not
it
.
_M_dereferenceable
())
{
return
!
it
.
_M_dereferenceable
();
return
not
it
.
_M_dereferenceable
();
}
template
<
class
Iterator
,
class
EndIterator
>
...
...
src/include/migraphx/marker.hpp
View file @
e5f47154
...
...
@@ -181,7 +181,7 @@ struct marker
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -233,7 +233,7 @@ struct marker
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/module.hpp
View file @
e5f47154
...
...
@@ -219,7 +219,7 @@ struct module
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
module
&
m
);
friend
bool
operator
==
(
const
module
&
x
,
const
module
&
y
);
friend
bool
operator
!=
(
const
module
&
x
,
const
module
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
module
&
x
,
const
module
&
y
)
{
return
not
(
x
==
y
);
}
private:
void
assign
(
const
module
&
m
);
...
...
src/include/migraphx/op/broadcast.hpp
View file @
e5f47154
...
...
@@ -70,7 +70,7 @@ struct broadcast
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than input ndims"
);
}
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
if
(
not
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
...
...
src/include/migraphx/op/concat.hpp
View file @
e5f47154
...
...
@@ -86,7 +86,7 @@ struct concat
{
if
(
l
!=
axis
)
{
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
}))
{
...
...
src/include/migraphx/op/dot.hpp
View file @
e5f47154
...
...
@@ -43,13 +43,14 @@ struct dot
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dims operands"
);
}
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
...
...
Prev
1
2
3
4
5
…
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment