Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3a4d36cf
Commit
3a4d36cf
authored
Sep 30, 2022
by
charlie
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test
parents
6bec381f
e19f78ae
Changes
384
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
136 additions
and
101 deletions
+136
-101
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+3
-2
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+2
-2
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+1
-1
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+3
-3
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+2
-2
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+1
-1
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+4
-4
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+1
-1
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+1
-1
src/include/migraphx/rewrite_gelu.hpp
src/include/migraphx/rewrite_gelu.hpp
+14
-10
src/include/migraphx/schedule_model.hpp
src/include/migraphx/schedule_model.hpp
+2
-2
src/include/migraphx/stream_model.hpp
src/include/migraphx/stream_model.hpp
+2
-2
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+17
-9
src/include/migraphx/supported_segments.hpp
src/include/migraphx/supported_segments.hpp
+11
-9
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+35
-32
src/include/migraphx/target_assignments.hpp
src/include/migraphx/target_assignments.hpp
+14
-3
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+10
-10
src/include/migraphx/tune_axis.hpp
src/include/migraphx/tune_axis.hpp
+1
-1
src/include/migraphx/value.hpp
src/include/migraphx/value.hpp
+6
-0
src/instruction.cpp
src/instruction.cpp
+6
-6
No files found.
src/include/migraphx/op/quant_dot.hpp
View file @
3a4d36cf
...
@@ -49,13 +49,14 @@ struct quant_dot
...
@@ -49,13 +49,14 @@ struct quant_dot
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
}
}
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
{
MIGRAPHX_THROW
(
"QUANT_DOT: dot only accept 2 or more dims operands"
);
MIGRAPHX_THROW
(
"QUANT_DOT: dot only accept 2 or more dims operands"
);
}
}
// only handle the case that the batch size of a and b are the same
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
{
MIGRAPHX_THROW
(
"QUANT_DOT: batch size of A and B mismatch: {"
+
MIGRAPHX_THROW
(
"QUANT_DOT: batch size of A and B mismatch: {"
+
...
...
src/include/migraphx/op/slice.hpp
View file @
3a4d36cf
...
@@ -78,7 +78,7 @@ struct slice
...
@@ -78,7 +78,7 @@ struct slice
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
auto
offset
=
0
;
auto
offset
=
0
;
if
(
!
axes
.
empty
())
if
(
not
axes
.
empty
())
{
{
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
{
...
@@ -109,7 +109,7 @@ struct slice
...
@@ -109,7 +109,7 @@ struct slice
MIGRAPHX_THROW
(
"SLICE: input axis "
+
to_string_range
(
axes
)
+
" out of range"
);
MIGRAPHX_THROW
(
"SLICE: input axis "
+
to_string_range
(
axes
)
+
" out of range"
);
}
}
if
(
starts
.
size
()
!=
axes
.
size
()
||
axes
.
size
()
!=
ends
.
size
())
if
(
starts
.
size
()
!=
axes
.
size
()
or
axes
.
size
()
!=
ends
.
size
())
{
{
MIGRAPHX_THROW
(
"SLICE: inconsistent sizes"
);
MIGRAPHX_THROW
(
"SLICE: inconsistent sizes"
);
}
}
...
...
src/include/migraphx/op/transpose.hpp
View file @
3a4d36cf
...
@@ -59,7 +59,7 @@ struct transpose
...
@@ -59,7 +59,7 @@ struct transpose
}
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
if
(
not
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
{
{
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
}
}
...
...
src/include/migraphx/operation.hpp
View file @
3a4d36cf
...
@@ -1066,7 +1066,7 @@ struct operation
...
@@ -1066,7 +1066,7 @@ struct operation
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -1237,7 +1237,7 @@ struct operation
...
@@ -1237,7 +1237,7 @@ struct operation
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
@@ -1276,7 +1276,7 @@ inline const ValueType& any_cast(const operation& x)
...
@@ -1276,7 +1276,7 @@ inline const ValueType& any_cast(const operation& x)
}
}
#endif
#endif
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
not
(
x
==
y
);
}
inline
value
inline
value
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
...
...
src/include/migraphx/pass.hpp
View file @
3a4d36cf
...
@@ -238,7 +238,7 @@ struct pass
...
@@ -238,7 +238,7 @@ struct pass
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -292,7 +292,7 @@ struct pass
...
@@ -292,7 +292,7 @@ struct pass
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/program.hpp
View file @
3a4d36cf
...
@@ -124,7 +124,7 @@ struct program
...
@@ -124,7 +124,7 @@ struct program
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
not
(
x
==
y
);
}
// module related api
// module related api
module
*
create_module
(
const
std
::
string
&
name
);
module
*
create_module
(
const
std
::
string
&
name
);
...
...
src/include/migraphx/raw_data.hpp
View file @
3a4d36cf
...
@@ -147,7 +147,7 @@ struct raw_data : raw_data_base
...
@@ -147,7 +147,7 @@ struct raw_data : raw_data_base
template
<
class
T
>
template
<
class
T
>
bool
matches
()
const
bool
matches
()
const
{
{
return
is_data_ptr
<
T
>
{}
||
return
is_data_ptr
<
T
>
{}
or
self
->
get_shape
().
type
()
==
migraphx
::
shape
::
get_type
<
get_data_type
<
T
>>
{};
self
->
get_shape
().
type
()
==
migraphx
::
shape
::
get_type
<
get_data_type
<
T
>>
{};
}
}
...
@@ -232,7 +232,7 @@ auto visit_all(T&& x, Ts&&... xs)
...
@@ -232,7 +232,7 @@ auto visit_all(T&& x, Ts&&... xs)
{
{
auto
&&
s
=
x
.
get_shape
();
auto
&&
s
=
x
.
get_shape
();
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
...
vs
)
{
detail
::
visit_all_pack
(
s
,
vs
...)(
x
,
xs
...);
};
return
[
&
](
auto
...
vs
)
{
detail
::
visit_all_pack
(
s
,
vs
...)(
x
,
xs
...);
};
}
}
...
@@ -241,7 +241,7 @@ template <class T>
...
@@ -241,7 +241,7 @@ template <class T>
auto
visit_all
(
const
std
::
vector
<
T
>&
x
)
auto
visit_all
(
const
std
::
vector
<
T
>&
x
)
{
{
auto
&&
s
=
x
.
front
().
get_shape
();
auto
&&
s
=
x
.
front
().
get_shape
();
if
(
!
std
::
all_of
(
if
(
not
std
::
all_of
(
x
.
begin
(),
x
.
end
(),
[
&
](
const
T
&
y
)
{
return
y
.
get_shape
().
type
()
==
s
.
type
();
}))
x
.
begin
(),
x
.
end
(),
[
&
](
const
T
&
y
)
{
return
y
.
get_shape
().
type
()
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
v
)
{
return
[
&
](
auto
v
)
{
...
@@ -281,7 +281,7 @@ template <class T,
...
@@ -281,7 +281,7 @@ template <class T,
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
{
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/reflect.hpp
View file @
3a4d36cf
...
@@ -129,7 +129,7 @@ template <class T>
...
@@ -129,7 +129,7 @@ template <class T>
struct
reflect_equality
struct
reflect_equality
{
{
friend
bool
operator
==
(
const
T
&
x
,
const
T
&
y
)
{
return
reflect_tie
(
x
)
==
reflect_tie
(
y
);
}
friend
bool
operator
==
(
const
T
&
x
,
const
T
&
y
)
{
return
reflect_tie
(
x
)
==
reflect_tie
(
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
not
(
x
==
y
);
}
};
};
template
<
class
T
>
template
<
class
T
>
...
...
src/include/migraphx/requires.hpp
View file @
3a4d36cf
...
@@ -31,7 +31,7 @@ namespace migraphx {
...
@@ -31,7 +31,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
bool
...
Bs
>
template
<
bool
...
Bs
>
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
||
true
)...
>>
// NOLINT
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
or
true
)...
>>
// NOLINT
{
{
};
};
...
...
src/
targets/gpu/
include/migraphx/
gpu/device/convert
.hpp
→
src/include/migraphx/
rewrite_gelu
.hpp
View file @
3a4d36cf
...
@@ -21,23 +21,27 @@
...
@@ -21,23 +21,27 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_CONVERT_HPP
#include <string>
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_CONVERT_HPP
#include <migraphx/instruction_ref.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
convert
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
);
struct
module
;
/**
* Rewrite gelu standard formula as the sigmoid approximation formula
*/
struct
rewrite_gelu
{
std
::
string
name
()
const
{
return
"rewrite_gelu"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/schedule_model.hpp
View file @
3a4d36cf
...
@@ -208,7 +208,7 @@ struct schedule_model
...
@@ -208,7 +208,7 @@ struct schedule_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -274,7 +274,7 @@ struct schedule_model
...
@@ -274,7 +274,7 @@ struct schedule_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/stream_model.hpp
View file @
3a4d36cf
...
@@ -216,7 +216,7 @@ struct stream_model
...
@@ -216,7 +216,7 @@ struct stream_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -288,7 +288,7 @@ struct stream_model
...
@@ -288,7 +288,7 @@ struct stream_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/streamutils.hpp
View file @
3a4d36cf
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <algorithm>
#include <algorithm>
#include <migraphx/rank.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -41,7 +42,7 @@ struct stream_range_container
...
@@ -41,7 +42,7 @@ struct stream_range_container
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
stream_range_container
&
sr
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
stream_range_container
&
sr
)
{
{
assert
(
sr
.
r
!=
nullptr
);
assert
(
sr
.
r
!=
nullptr
);
if
(
!
sr
.
r
->
empty
())
if
(
not
sr
.
r
->
empty
())
{
{
os
<<
sr
.
r
->
front
();
os
<<
sr
.
r
->
front
();
std
::
for_each
(
std
::
for_each
(
...
@@ -59,28 +60,35 @@ inline stream_range_container<Range> stream_range(const Range& r)
...
@@ -59,28 +60,35 @@ inline stream_range_container<Range> stream_range(const Range& r)
namespace
detail
{
namespace
detail
{
inline
void
stream_write_value_impl
(
rank
<
2
>
,
std
::
ostream
&
os
,
const
std
::
string
&
x
)
{
os
<<
x
;
}
template
<
class
T
>
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
T
&
x
)
->
decltype
(
os
<<
x
,
void
())
{
os
<<
x
;
}
template
<
class
Range
>
template
<
class
T
>
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
void
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
r
)
->
decltype
(
r
.
begin
(),
r
.
end
(),
void
())
{
{
os
<<
"{"
;
os
<<
"{"
;
os
<<
stream_range
(
r
);
os
<<
stream_range
(
r
);
os
<<
"}"
;
os
<<
"}"
;
}
}
template
<
class
T
>
template
<
class
Range
>
void
stream_write_value_impl
(
rank
<
0
>
,
std
::
ostream
&
os
,
const
T
&
x
)
auto
stream_write_value_impl
(
rank
<
0
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
->
decltype
(
r
.
begin
(),
r
.
end
(),
void
())
{
{
os
<<
x
;
os
<<
"{"
;
os
<<
stream_range
(
r
);
os
<<
"}"
;
}
}
}
// namespace detail
}
// namespace detail
template
<
class
T
>
template
<
class
T
>
void
stream_write_value
(
std
::
ostream
&
os
,
const
T
&
x
)
void
stream_write_value
(
std
::
ostream
&
os
,
const
T
&
x
)
{
{
detail
::
stream_write_value_impl
(
rank
<
2
>
{},
os
,
x
);
detail
::
stream_write_value_impl
(
rank
<
1
>
{},
os
,
x
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/
targets/gpu/
include/migraphx/
gpu/greater
.hpp
→
src/include/migraphx/
supported_segments
.hpp
View file @
3a4d36cf
...
@@ -21,22 +21,24 @@
...
@@ -21,22 +21,24 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_
RTGLIB_GREATER
_HPP
#ifndef MIGRAPHX_GUARD_
MIGRAPHX_SUPPORTED_SEGMENTS
_HPP
#define MIGRAPHX_GUARD_
RTGLIB_GREATER
_HPP
#define MIGRAPHX_GUARD_
MIGRAPHX_SUPPORTED_SEGMENTS
_HPP
#include <migraphx/gpu/oper.hpp>
#include <unordered_set>
#include <migraphx/gpu/device/greater.hpp>
#include <migraphx/instruction_ref.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
hip_greater
:
binary_device
<
hip_greater
,
device
::
greater
>
struct
supported_segment
{
{
std
::
unordered_set
<
instruction_ref
>
instructions
;
float
metric
;
};
};
}
// namespace gpu
using
supported_segments
=
std
::
vector
<
supported_segment
>
;
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
#endif
src/include/migraphx/target.hpp
View file @
3a4d36cf
...
@@ -37,8 +37,10 @@
...
@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -64,12 +66,12 @@ struct target
...
@@ -64,12 +66,12 @@ struct target
*/
*/
context
get_context
()
const
;
context
get_context
()
const
;
/**
/**
* @brief
Check how well an
instruction
is
supported on a target
with the given metric
* @brief
Get the ranges of
instruction
s that are
supported on a target
* @param
ins Instruction
to check
if it's
supported
* @param
module Module
to check
for
supported
instructions
* @param metric Used to define how the
return value
should be
interpret
ed
* @param metric Used to define how the
quality of the support
should be
measur
ed
* @return
T
he
value based on the chosen metric. Negative numbers mean unsupported
* @return
t
he
supported segments of the graph
*/
*/
float
is_supported
(
T
&
,
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
;
supported_segments
target_
is_supported
(
T
&
,
co
nst
_module
_ref
mod
,
support_metric
m
etric
)
const
;
/**
/**
* @brief copy an argument to the current target.
* @brief copy an argument to the current target.
*
*
...
@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
...
@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
}
}
template
<
class
T
>
template
<
class
T
>
float
target_
is
_supported
(
T
&
,
i
nst
ruction
_ref
,
support_metric
)
supported_segments
target_
find
_supported
(
T
&
,
co
nst
_module
_ref
,
support_metric
)
{
{
return
0
;
return
{}
;
}
}
#ifdef TYPE_ERASED_DECLARATION
#ifdef TYPE_ERASED_DECLARATION
...
@@ -132,7 +134,7 @@ struct target
...
@@ -132,7 +134,7 @@ struct target
//
//
context
get_context
()
const
;
context
get_context
()
const
;
// (optional)
// (optional)
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
;
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
;
// (optional)
// (optional)
argument
copy_to
(
const
argument
&
input
)
const
;
argument
copy_to
(
const
argument
&
input
)
const
;
// (optional)
// (optional)
...
@@ -224,10 +226,10 @@ struct target
...
@@ -224,10 +226,10 @@ struct target
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
}
}
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
is
_supported
(
ins
,
m
);
return
(
*
this
).
private_detail_te_get_handle
().
find
_supported
(
mod
,
m
);
}
}
argument
copy_to
(
const
argument
&
input
)
const
argument
copy_to
(
const
argument
&
input
)
const
...
@@ -261,33 +263,33 @@ struct target
...
@@ -261,33 +263,33 @@ struct target
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
,
virtual
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
,
const
compile_options
&
options
)
const
=
0
;
const
compile_options
&
options
)
const
=
0
;
virtual
context
get_context
()
const
=
0
;
virtual
context
get_context
()
const
=
0
;
virtual
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
=
0
;
virtual
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
=
0
;
virtual
argument
copy_to
(
const
argument
&
input
)
const
=
0
;
virtual
argument
copy_to
(
const
argument
&
input
)
const
=
0
;
virtual
argument
copy_from
(
const
argument
&
input
)
const
=
0
;
virtual
argument
copy_from
(
const
argument
&
input
)
const
=
0
;
virtual
argument
allocate
(
const
shape
&
s
)
const
=
0
;
virtual
argument
allocate
(
const
shape
&
s
)
const
=
0
;
};
};
template
<
class
T
>
template
<
class
T
>
static
auto
private_detail_te_default_
is
_supported
(
char
,
static
auto
private_detail_te_default_
find
_supported
(
char
,
T
&&
private_detail_te_self
,
T
&&
private_detail_te_self
,
instruction
_ref
ins
,
const_module
_ref
mod
,
support_metric
m
)
support_metric
m
)
->
decltype
(
private_detail_te_self
.
is
_supported
(
ins
,
m
))
->
decltype
(
private_detail_te_self
.
find
_supported
(
mod
,
m
))
{
{
return
private_detail_te_self
.
is
_supported
(
ins
,
m
);
return
private_detail_te_self
.
find
_supported
(
mod
,
m
);
}
}
template
<
class
T
>
template
<
class
T
>
static
float
private_detail_te_default_
is
_supported
(
float
,
static
supported_segments
private_detail_te_default_
find
_supported
(
float
,
T
&&
private_detail_te_self
,
T
&&
private_detail_te_self
,
instruction
_ref
ins
,
const_module
_ref
mod
,
support_metric
m
)
support_metric
m
)
{
{
return
target_
is
_supported
(
private_detail_te_self
,
ins
,
m
);
return
target_
find
_supported
(
private_detail_te_self
,
mod
,
m
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -349,7 +351,7 @@ struct target
...
@@ -349,7 +351,7 @@ struct target
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -372,10 +374,11 @@ struct target
...
@@ -372,10 +374,11 @@ struct target
context
get_context
()
const
override
{
return
private_detail_te_value
.
get_context
();
}
context
get_context
()
const
override
{
return
private_detail_te_value
.
get_context
();
}
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
override
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
override
{
{
return
private_detail_te_default_is_supported
(
char
(
0
),
private_detail_te_value
,
ins
,
m
);
return
private_detail_te_default_find_supported
(
char
(
0
),
private_detail_te_value
,
mod
,
m
);
}
}
argument
copy_to
(
const
argument
&
input
)
const
override
argument
copy_to
(
const
argument
&
input
)
const
override
...
@@ -423,7 +426,7 @@ struct target
...
@@ -423,7 +426,7 @@ struct target
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/target_assignments.hpp
View file @
3a4d36cf
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <unordered_map>
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
...
@@ -33,10 +34,20 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -33,10 +34,20 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
target_assignments
struct
target_assignments
{
{
void
add_assignment
(
instruction_ref
ins
,
const
std
::
string
&
target
);
using
iterator
=
std
::
unordered_map
<
instruction_ref
,
std
::
string
>::
const_iterator
;
using
value_type
=
std
::
pair
<
instruction_ref
,
std
::
string
>
;
auto
begin
()
const
{
return
assignments
.
cbegin
();
}
auto
size
()
const
{
return
assignments
.
size
();
}
auto
end
()
const
{
return
assignments
.
cend
();
}
auto
&
at
(
instruction_ref
ins
)
const
{
return
assignments
.
at
(
ins
);
}
auto
insert
(
iterator
it
,
const
std
::
pair
<
instruction_ref
,
std
::
string
>&
assignment
)
{
return
assignments
.
insert
(
it
,
assignment
);
}
auto
find
(
instruction_ref
ins
)
const
{
return
assignments
.
find
(
ins
);
}
auto
begin
()
const
{
return
assignments
.
begin
();
}
auto
end
()
const
{
return
assignments
.
end
();
}
private:
private:
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
assignments
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
assignments
;
...
...
src/include/migraphx/tensor_view.hpp
View file @
3a4d36cf
...
@@ -67,7 +67,7 @@ struct tensor_view
...
@@ -67,7 +67,7 @@ struct tensor_view
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
bool
empty
()
const
{
return
m_data
==
nullptr
||
m_shape
.
lens
().
empty
();
}
bool
empty
()
const
{
return
m_data
==
nullptr
or
m_shape
.
lens
().
empty
();
}
std
::
size_t
size
()
const
{
return
m_shape
.
elements
();
}
std
::
size_t
size
()
const
{
return
m_shape
.
elements
();
}
...
@@ -109,37 +109,37 @@ struct tensor_view
...
@@ -109,37 +109,37 @@ struct tensor_view
T
&
operator
[](
std
::
size_t
i
)
T
&
operator
[](
std
::
size_t
i
)
{
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
assert
(
not
this
->
empty
()
&&
i
<
this
->
size
());
return
m_data
[
m_shape
.
index
(
i
)];
return
m_data
[
m_shape
.
index
(
i
)];
}
}
const
T
&
operator
[](
std
::
size_t
i
)
const
const
T
&
operator
[](
std
::
size_t
i
)
const
{
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
assert
(
not
this
->
empty
()
&&
i
<
this
->
size
());
return
m_data
[
m_shape
.
index
(
i
)];
return
m_data
[
m_shape
.
index
(
i
)];
}
}
T
&
front
()
T
&
front
()
{
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
0
];
return
m_data
[
0
];
}
}
const
T
&
front
()
const
const
T
&
front
()
const
{
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
0
];
return
m_data
[
0
];
}
}
T
&
back
()
T
&
back
()
{
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
}
}
const
T
&
back
()
const
const
T
&
back
()
const
{
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
}
}
...
@@ -159,7 +159,7 @@ struct tensor_view
...
@@ -159,7 +159,7 @@ struct tensor_view
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
{
if
(
!
x
.
empty
())
if
(
not
x
.
empty
())
{
{
os
<<
as_number
(
x
.
front
());
os
<<
as_number
(
x
.
front
());
for
(
std
::
size_t
i
=
1
;
i
<
x
.
m_shape
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
x
.
m_shape
.
elements
();
i
++
)
...
@@ -182,7 +182,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
...
@@ -182,7 +182,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
get_shape
().
elements
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
x
.
get_shape
().
elements
();
i
++
)
{
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
if
(
not
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
return
false
;
}
}
return
true
;
return
true
;
...
@@ -193,7 +193,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
...
@@ -193,7 +193,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
U
>&
y
)
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
U
>&
y
)
{
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
template
<
class
T
>
template
<
class
T
>
...
...
src/include/migraphx/tune_axis.hpp
View file @
3a4d36cf
...
@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS {
inline
int
tune_axis
(
const
int
n_dim
,
const
int
axis
,
const
std
::
string
&
op_name
=
"OPERATOR"
)
inline
int
tune_axis
(
const
int
n_dim
,
const
int
axis
,
const
std
::
string
&
op_name
=
"OPERATOR"
)
{
{
if
(
axis
>=
n_dim
||
std
::
abs
(
axis
)
>
n_dim
)
if
(
axis
>=
n_dim
or
std
::
abs
(
axis
)
>
n_dim
)
{
{
MIGRAPHX_THROW
(
to_upper
(
op_name
)
+
": axis is out of range."
);
MIGRAPHX_THROW
(
to_upper
(
op_name
)
+
": axis is out of range."
);
}
}
...
...
src/include/migraphx/value.hpp
View file @
3a4d36cf
...
@@ -184,6 +184,12 @@ struct value
...
@@ -184,6 +184,12 @@ struct value
{
{
}
}
explicit
binary
(
std
::
size_t
s
)
:
base
(
s
)
{}
explicit
binary
(
std
::
size_t
s
)
:
base
(
s
)
{}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
binary
&
obj
)
{
os
<<
"{binary_object: "
<<
obj
.
size
()
<<
"}"
;
return
os
;
}
};
};
value
()
=
default
;
value
()
=
default
;
...
...
src/instruction.cpp
View file @
3a4d36cf
...
@@ -176,13 +176,13 @@ bool operator==(const instruction& x, const instruction& y)
...
@@ -176,13 +176,13 @@ bool operator==(const instruction& x, const instruction& y)
return
true
;
return
true
;
}
}
bool
operator
!=
(
const
instruction
&
x
,
const
instruction
&
y
)
{
return
!
(
x
==
y
);
}
bool
operator
!=
(
const
instruction
&
x
,
const
instruction
&
y
)
{
return
not
(
x
==
y
);
}
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
!
(
i
==
ref
);
}
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
not
(
i
==
ref
);
}
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
not
(
i
==
ref
);
}
void
instruction
::
add_output
(
instruction_ref
ins
)
void
instruction
::
add_output
(
instruction_ref
ins
)
{
{
...
@@ -361,7 +361,7 @@ void instruction::print(std::ostream& os,
...
@@ -361,7 +361,7 @@ void instruction::print(std::ostream& os,
os
<<
"{"
<<
ins
->
get_literal
()
<<
"}"
;
os
<<
"{"
<<
ins
->
get_literal
()
<<
"}"
;
}
}
if
(
!
ins
->
inputs
().
empty
())
if
(
not
ins
->
inputs
().
empty
())
{
{
char
delim
=
'('
;
char
delim
=
'('
;
for
(
auto
&&
arg
:
ins
->
inputs
())
for
(
auto
&&
arg
:
ins
->
inputs
())
...
@@ -374,7 +374,7 @@ void instruction::print(std::ostream& os,
...
@@ -374,7 +374,7 @@ void instruction::print(std::ostream& os,
}
}
// print module inputs
// print module inputs
if
(
!
ins
->
module_inputs
().
empty
())
if
(
not
ins
->
module_inputs
().
empty
())
{
{
std
::
string
delim
=
", ["
;
std
::
string
delim
=
", ["
;
for
(
auto
&&
mod_arg
:
ins
->
module_inputs
())
for
(
auto
&&
mod_arg
:
ins
->
module_inputs
())
...
@@ -446,7 +446,7 @@ operation instruction::normalized_operator() const
...
@@ -446,7 +446,7 @@ operation instruction::normalized_operator() const
if
(
this
->
need_normalization
())
if
(
this
->
need_normalization
())
{
{
auto
s
=
this
->
inputs
().
front
()
->
get_shape
();
auto
s
=
this
->
inputs
().
front
()
->
get_shape
();
if
(
!
normalize_attributes
(
o
,
s
.
max_lens
()))
if
(
not
normalize_attributes
(
o
,
s
.
max_lens
()))
return
this
->
get_operator
();
return
this
->
get_operator
();
}
}
return
o
;
return
o
;
...
...
Prev
1
2
3
4
5
6
7
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment