Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f7838bc8
Commit
f7838bc8
authored
Sep 12, 2022
by
turneram
Browse files
Merge remote-tracking branch 'origin/develop' into ck-elementwise
parents
fea58a7b
d78bcdfb
Changes
146
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
78 additions
and
46 deletions
+78
-46
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+1
-1
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+1
-1
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+3
-2
src/include/migraphx/op/fmod.hpp
src/include/migraphx/op/fmod.hpp
+0
-9
src/include/migraphx/op/gather.hpp
src/include/migraphx/op/gather.hpp
+1
-1
src/include/migraphx/op/mod.hpp
src/include/migraphx/op/mod.hpp
+0
-9
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+1
-1
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+3
-2
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+2
-2
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+1
-1
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+3
-3
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+2
-2
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+1
-1
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+4
-4
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+1
-1
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+1
-1
src/include/migraphx/rewrite_gelu.hpp
src/include/migraphx/rewrite_gelu.hpp
+48
-0
src/include/migraphx/schedule_model.hpp
src/include/migraphx/schedule_model.hpp
+2
-2
src/include/migraphx/stream_model.hpp
src/include/migraphx/stream_model.hpp
+2
-2
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+1
-1
No files found.
src/include/migraphx/op/broadcast.hpp
View file @
f7838bc8
...
@@ -70,7 +70,7 @@ struct broadcast
...
@@ -70,7 +70,7 @@ struct broadcast
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than input ndims"
);
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than input ndims"
);
}
}
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
if
(
not
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
}
...
...
src/include/migraphx/op/concat.hpp
View file @
f7838bc8
...
@@ -86,7 +86,7 @@ struct concat
...
@@ -86,7 +86,7 @@ struct concat
{
{
if
(
l
!=
axis
)
if
(
l
!=
axis
)
{
{
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
}))
}))
{
{
...
...
src/include/migraphx/op/dot.hpp
View file @
f7838bc8
...
@@ -43,13 +43,14 @@ struct dot
...
@@ -43,13 +43,14 @@ struct dot
const
shape
&
b
=
inputs
.
at
(
1
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dims operands"
);
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dims operands"
);
}
}
// only handle the case that the batch size of a and b are the same
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
...
...
src/include/migraphx/op/fmod.hpp
View file @
f7838bc8
...
@@ -24,17 +24,8 @@
...
@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/gather.hpp
View file @
f7838bc8
...
@@ -65,7 +65,7 @@ struct gather
...
@@ -65,7 +65,7 @@ struct gather
auto
lens
=
inputs
[
0
].
lens
();
auto
lens
=
inputs
[
0
].
lens
();
auto
type
=
inputs
[
0
].
type
();
auto
type
=
inputs
[
0
].
type
();
lens
.
erase
(
lens
.
begin
()
+
axis
);
lens
.
erase
(
lens
.
begin
()
+
axis
);
if
(
!
inputs
[
1
].
scalar
())
if
(
not
inputs
[
1
].
scalar
())
{
{
auto
ind_lens
=
inputs
[
1
].
lens
();
auto
ind_lens
=
inputs
[
1
].
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
...
...
src/include/migraphx/op/mod.hpp
View file @
f7838bc8
...
@@ -24,17 +24,8 @@
...
@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
f7838bc8
...
@@ -266,7 +266,7 @@ struct nonmaxsuppression
...
@@ -266,7 +266,7 @@ struct nonmaxsuppression
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
selected_boxes_inside_class
.
clear
();
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
// Get the next box with top score, filter by iou_threshold
while
(
!
boxes_heap
.
empty
()
&&
while
(
not
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
{
// Check with existing selected boxes for this class, remove box if it
// Check with existing selected boxes for this class, remove box if it
...
...
src/include/migraphx/op/quant_dot.hpp
View file @
f7838bc8
...
@@ -49,13 +49,14 @@ struct quant_dot
...
@@ -49,13 +49,14 @@ struct quant_dot
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
}
}
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
{
MIGRAPHX_THROW
(
"QUANT_DOT: dot only accept 2 or more dims operands"
);
MIGRAPHX_THROW
(
"QUANT_DOT: dot only accept 2 or more dims operands"
);
}
}
// only handle the case that the batch size of a and b are the same
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
{
MIGRAPHX_THROW
(
"QUANT_DOT: batch size of A and B mismatch: {"
+
MIGRAPHX_THROW
(
"QUANT_DOT: batch size of A and B mismatch: {"
+
...
...
src/include/migraphx/op/slice.hpp
View file @
f7838bc8
...
@@ -78,7 +78,7 @@ struct slice
...
@@ -78,7 +78,7 @@ struct slice
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
auto
offset
=
0
;
auto
offset
=
0
;
if
(
!
axes
.
empty
())
if
(
not
axes
.
empty
())
{
{
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
{
...
@@ -109,7 +109,7 @@ struct slice
...
@@ -109,7 +109,7 @@ struct slice
MIGRAPHX_THROW
(
"SLICE: input axis "
+
to_string_range
(
axes
)
+
" out of range"
);
MIGRAPHX_THROW
(
"SLICE: input axis "
+
to_string_range
(
axes
)
+
" out of range"
);
}
}
if
(
starts
.
size
()
!=
axes
.
size
()
||
axes
.
size
()
!=
ends
.
size
())
if
(
starts
.
size
()
!=
axes
.
size
()
or
axes
.
size
()
!=
ends
.
size
())
{
{
MIGRAPHX_THROW
(
"SLICE: inconsistent sizes"
);
MIGRAPHX_THROW
(
"SLICE: inconsistent sizes"
);
}
}
...
...
src/include/migraphx/op/transpose.hpp
View file @
f7838bc8
...
@@ -59,7 +59,7 @@ struct transpose
...
@@ -59,7 +59,7 @@ struct transpose
}
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
if
(
not
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
{
{
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
}
}
...
...
src/include/migraphx/operation.hpp
View file @
f7838bc8
...
@@ -1066,7 +1066,7 @@ struct operation
...
@@ -1066,7 +1066,7 @@ struct operation
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -1237,7 +1237,7 @@ struct operation
...
@@ -1237,7 +1237,7 @@ struct operation
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
@@ -1276,7 +1276,7 @@ inline const ValueType& any_cast(const operation& x)
...
@@ -1276,7 +1276,7 @@ inline const ValueType& any_cast(const operation& x)
}
}
#endif
#endif
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
not
(
x
==
y
);
}
inline
value
inline
value
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
...
...
src/include/migraphx/pass.hpp
View file @
f7838bc8
...
@@ -238,7 +238,7 @@ struct pass
...
@@ -238,7 +238,7 @@ struct pass
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -292,7 +292,7 @@ struct pass
...
@@ -292,7 +292,7 @@ struct pass
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/program.hpp
View file @
f7838bc8
...
@@ -124,7 +124,7 @@ struct program
...
@@ -124,7 +124,7 @@ struct program
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
not
(
x
==
y
);
}
// module related api
// module related api
module
*
create_module
(
const
std
::
string
&
name
);
module
*
create_module
(
const
std
::
string
&
name
);
...
...
src/include/migraphx/raw_data.hpp
View file @
f7838bc8
...
@@ -147,7 +147,7 @@ struct raw_data : raw_data_base
...
@@ -147,7 +147,7 @@ struct raw_data : raw_data_base
template
<
class
T
>
template
<
class
T
>
bool
matches
()
const
bool
matches
()
const
{
{
return
is_data_ptr
<
T
>
{}
||
return
is_data_ptr
<
T
>
{}
or
self
->
get_shape
().
type
()
==
migraphx
::
shape
::
get_type
<
get_data_type
<
T
>>
{};
self
->
get_shape
().
type
()
==
migraphx
::
shape
::
get_type
<
get_data_type
<
T
>>
{};
}
}
...
@@ -232,7 +232,7 @@ auto visit_all(T&& x, Ts&&... xs)
...
@@ -232,7 +232,7 @@ auto visit_all(T&& x, Ts&&... xs)
{
{
auto
&&
s
=
x
.
get_shape
();
auto
&&
s
=
x
.
get_shape
();
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
...
vs
)
{
detail
::
visit_all_pack
(
s
,
vs
...)(
x
,
xs
...);
};
return
[
&
](
auto
...
vs
)
{
detail
::
visit_all_pack
(
s
,
vs
...)(
x
,
xs
...);
};
}
}
...
@@ -241,7 +241,7 @@ template <class T>
...
@@ -241,7 +241,7 @@ template <class T>
auto
visit_all
(
const
std
::
vector
<
T
>&
x
)
auto
visit_all
(
const
std
::
vector
<
T
>&
x
)
{
{
auto
&&
s
=
x
.
front
().
get_shape
();
auto
&&
s
=
x
.
front
().
get_shape
();
if
(
!
std
::
all_of
(
if
(
not
std
::
all_of
(
x
.
begin
(),
x
.
end
(),
[
&
](
const
T
&
y
)
{
return
y
.
get_shape
().
type
()
==
s
.
type
();
}))
x
.
begin
(),
x
.
end
(),
[
&
](
const
T
&
y
)
{
return
y
.
get_shape
().
type
()
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
v
)
{
return
[
&
](
auto
v
)
{
...
@@ -281,7 +281,7 @@ template <class T,
...
@@ -281,7 +281,7 @@ template <class T,
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
{
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/reflect.hpp
View file @
f7838bc8
...
@@ -129,7 +129,7 @@ template <class T>
...
@@ -129,7 +129,7 @@ template <class T>
struct
reflect_equality
struct
reflect_equality
{
{
friend
bool
operator
==
(
const
T
&
x
,
const
T
&
y
)
{
return
reflect_tie
(
x
)
==
reflect_tie
(
y
);
}
friend
bool
operator
==
(
const
T
&
x
,
const
T
&
y
)
{
return
reflect_tie
(
x
)
==
reflect_tie
(
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
not
(
x
==
y
);
}
};
};
template
<
class
T
>
template
<
class
T
>
...
...
src/include/migraphx/requires.hpp
View file @
f7838bc8
...
@@ -31,7 +31,7 @@ namespace migraphx {
...
@@ -31,7 +31,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
bool
...
Bs
>
template
<
bool
...
Bs
>
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
||
true
)...
>>
// NOLINT
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
or
true
)...
>>
// NOLINT
{
{
};
};
...
...
src/include/migraphx/rewrite_gelu.hpp
0 → 100644
View file @
f7838bc8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
/**
* Rewrite gelu standard formula as the sigmoid approximation formula
*/
struct
rewrite_gelu
{
std
::
string
name
()
const
{
return
"rewrite_gelu"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/schedule_model.hpp
View file @
f7838bc8
...
@@ -208,7 +208,7 @@ struct schedule_model
...
@@ -208,7 +208,7 @@ struct schedule_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -274,7 +274,7 @@ struct schedule_model
...
@@ -274,7 +274,7 @@ struct schedule_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/stream_model.hpp
View file @
f7838bc8
...
@@ -216,7 +216,7 @@ struct stream_model
...
@@ -216,7 +216,7 @@ struct stream_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -288,7 +288,7 @@ struct stream_model
...
@@ -288,7 +288,7 @@ struct stream_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/streamutils.hpp
View file @
f7838bc8
...
@@ -41,7 +41,7 @@ struct stream_range_container
...
@@ -41,7 +41,7 @@ struct stream_range_container
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
stream_range_container
&
sr
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
stream_range_container
&
sr
)
{
{
assert
(
sr
.
r
!=
nullptr
);
assert
(
sr
.
r
!=
nullptr
);
if
(
!
sr
.
r
->
empty
())
if
(
not
sr
.
r
->
empty
())
{
{
os
<<
sr
.
r
->
front
();
os
<<
sr
.
r
->
front
();
std
::
for_each
(
std
::
for_each
(
...
...
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment