Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
06fb0905
Commit
06fb0905
authored
Jul 05, 2018
by
Scott Thornton
Browse files
Added MNIST test for cpu target
parents
0a59f103
cff16121
Changes
69
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
341 additions
and
202 deletions
+341
-202
src/include/migraph/instruction_ref.hpp
src/include/migraph/instruction_ref.hpp
+13
-0
src/include/migraph/literal.hpp
src/include/migraph/literal.hpp
+8
-8
src/include/migraph/manage_ptr.hpp
src/include/migraph/manage_ptr.hpp
+6
-5
src/include/migraph/onnx.hpp
src/include/migraph/onnx.hpp
+12
-0
src/include/migraph/operation.hpp
src/include/migraph/operation.hpp
+33
-16
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+101
-43
src/include/migraph/program.hpp
src/include/migraph/program.hpp
+12
-10
src/include/migraph/ranges.hpp
src/include/migraph/ranges.hpp
+4
-4
src/include/migraph/raw_data.hpp
src/include/migraph/raw_data.hpp
+40
-14
src/include/migraph/requires.hpp
src/include/migraph/requires.hpp
+26
-0
src/include/migraph/shape.hpp
src/include/migraph/shape.hpp
+18
-21
src/include/migraph/shape_for_each.hpp
src/include/migraph/shape_for_each.hpp
+6
-6
src/include/migraph/streamutils.hpp
src/include/migraph/streamutils.hpp
+4
-4
src/include/migraph/stringutils.hpp
src/include/migraph/stringutils.hpp
+4
-4
src/include/migraph/target.hpp
src/include/migraph/target.hpp
+30
-4
src/include/migraph/tensor_view.hpp
src/include/migraph/tensor_view.hpp
+11
-11
src/include/rtg/fallthrough.hpp
src/include/rtg/fallthrough.hpp
+0
-14
src/include/rtg/onnx.hpp
src/include/rtg/onnx.hpp
+0
-12
src/include/rtg/requires.hpp
src/include/rtg/requires.hpp
+0
-21
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+13
-5
No files found.
src/include/
rtg
/instruction_ref.hpp
→
src/include/
migraph
/instruction_ref.hpp
View file @
06fb0905
#ifndef
RTG
_GUARD_INSTRUCTION_REF_HPP
#define
RTG
_GUARD_INSTRUCTION_REF_HPP
#ifndef
MIGRAPH
_GUARD_INSTRUCTION_REF_HPP
#define
MIGRAPH
_GUARD_INSTRUCTION_REF_HPP
#include <list>
namespace
rtg
{
namespace
migraph
{
struct
instruction
;
using
instruction_ref
=
std
::
list
<
instruction
>::
iterator
;
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/literal.hpp
→
src/include/
migraph
/literal.hpp
View file @
06fb0905
#ifndef
RTG_GUARD_RTG
LIB_LITERAL_HPP
#define
RTG_GUARD_RTG
LIB_LITERAL_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
LIB_LITERAL_HPP
#define
MIGRAPH_GUARD_MIGRAPH
LIB_LITERAL_HPP
#include <
rtg
/shape.hpp>
#include <
rtg
/argument.hpp>
#include <
rtg
/tensor_view.hpp>
#include <
rtg
/raw_data.hpp>
#include <
migraph
/shape.hpp>
#include <
migraph
/argument.hpp>
#include <
migraph
/tensor_view.hpp>
#include <
migraph
/raw_data.hpp>
namespace
rtg
{
namespace
migraph
{
/**
* @brief Represents a raw literal
...
...
@@ -68,6 +68,6 @@ struct literal : raw_data<literal>
shape
m_shape
;
};
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/manage_ptr.hpp
→
src/include/
migraph
/manage_ptr.hpp
View file @
06fb0905
#ifndef
RTG_GUARD_RTG
_MANAGE_PTR_HPP
#define
RTG_GUARD_RTG
_MANAGE_PTR_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
_MANAGE_PTR_HPP
#define
MIGRAPH_GUARD_MIGRAPH
_MANAGE_PTR_HPP
#include <memory>
#include <type_traits>
namespace
rtg
{
namespace
migraph
{
template
<
class
F
,
F
f
>
// NOLINT
struct
manage_deleter
...
...
@@ -49,8 +49,9 @@ shared<T> share(T p)
return
shared
<
T
>
{
std
::
move
(
p
)};
}
}
// namespace
rtg
}
// namespace
migraph
#define RTG_MANAGE_PTR(T, F) rtg::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#define MIGRAPH_MANAGE_PTR(T, F) \
migraph::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#endif
src/include/migraph/onnx.hpp
0 → 100644
View file @
06fb0905
#ifndef GUARD_MIGRAPHLIB_ONNX_HPP
#define GUARD_MIGRAPHLIB_ONNX_HPP
#include <migraph/program.hpp>
namespace
migraph
{
program
parse_onnx
(
const
std
::
string
&
name
);
}
// namespace migraph
#endif
src/include/
rtg
/operation.hpp
→
src/include/
migraph
/operation.hpp
View file @
06fb0905
#ifndef
RTG_GUARD_RTG
LIB_OPERAND_HPP
#define
RTG_GUARD_RTG
LIB_OPERAND_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
LIB_OPERAND_HPP
#define
MIGRAPH_GUARD_MIGRAPH
LIB_OPERAND_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <rtg/shape.hpp>
#include <rtg/argument.hpp>
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
namespace
rtg
{
namespace
migraph
{
namespace
operation_stream
{
...
...
@@ -28,7 +29,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(shape output,std::vector<argument> input) const;
* argument compute(
context& ctx,
shape output,std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
*
...
...
@@ -83,6 +84,14 @@ struct operation
:
nullptr
;
}
const
std
::
type_info
&
type_id
()
const
{
if
(
private_detail_te_handle_empty
())
return
typeid
(
std
::
nullptr_t
);
else
return
private_detail_te_get_handle
().
type
();
}
std
::
string
name
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
...
@@ -95,10 +104,11 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
}
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
std
::
move
(
output
),
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
ctx
,
std
::
move
(
output
),
std
::
move
(
input
));
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
...
...
@@ -114,10 +124,10 @@ struct operation
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
};
template
<
typename
PrivateDetailTypeErasedT
>
...
...
@@ -156,15 +166,15 @@ struct operation
return
private_detail_te_value
.
compute_shape
(
std
::
move
(
input
));
}
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
override
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
override
{
return
private_detail_te_value
.
compute
(
std
::
move
(
output
),
std
::
move
(
input
));
return
private_detail_te_value
.
compute
(
ctx
,
std
::
move
(
output
),
std
::
move
(
input
));
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
using
rtg
::
operation_stream
::
operator
<<
;
using
migraph
::
operation_stream
::
operator
<<
;
return
os
<<
private_detail_te_value
;
}
...
...
@@ -181,13 +191,20 @@ struct operation
}
};
bool
private_detail_te_handle_empty
()
const
{
return
private_detail_te_handle_mem_var
==
nullptr
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
return
*
private_detail_te_handle_mem_var
;
}
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
...
...
@@ -226,6 +243,6 @@ inline const ValueType& any_cast(const operation& x)
return
*
y
;
}
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/operators.hpp
→
src/include/
migraph
/operators.hpp
View file @
06fb0905
#ifndef
RTG
_GUARD_OPERATORS_HPP
#define
RTG
_GUARD_OPERATORS_HPP
#ifndef
MIGRAPH
_GUARD_OPERATORS_HPP
#define
MIGRAPH
_GUARD_OPERATORS_HPP
#include <array>
#include <
rtg
/operation.hpp>
#include <
rtg
/stringutils.hpp>
#include <
rtg
/streamutils.hpp>
#include <
migraph
/operation.hpp>
#include <
migraph
/stringutils.hpp>
#include <
migraph
/streamutils.hpp>
#include <cmath>
namespace
rtg
{
namespace
migraph
{
struct
check_shapes
{
const
std
::
vector
<
shape
>*
shapes
;
const
std
::
string
name
;
check_shapes
(
const
std
::
vector
<
shape
>&
s
)
:
shapes
(
&
s
)
{}
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
)
:
shapes
(
&
s
),
name
(
op
.
name
())
{
}
std
::
string
prefix
()
const
{
if
(
name
.
empty
())
return
""
;
else
return
name
+
": "
;
}
const
check_shapes
&
has
(
std
::
size_t
n
)
const
{
assert
(
shapes
!=
nullptr
);
if
(
shapes
->
size
()
!=
n
)
RTG_THROW
(
"Wrong number of arguments: expected "
+
std
::
to_string
(
n
)
+
" but given "
+
std
::
to_string
(
shapes
->
size
()));
MIGRAPH_THROW
(
prefix
()
+
"Wrong number of arguments: expected "
+
std
::
to_string
(
n
)
+
" but given "
+
std
::
to_string
(
shapes
->
size
()));
return
*
this
;
}
...
...
@@ -30,7 +44,7 @@ struct check_shapes
if
(
!
shapes
->
empty
())
{
if
(
shapes
->
front
().
lens
().
size
()
!=
n
)
RTG_THROW
(
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
MIGRAPH_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
}
return
*
this
;
}
...
...
@@ -38,28 +52,28 @@ struct check_shapes
const
check_shapes
&
same_shape
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
RTG_THROW
(
"Shapes do not match"
);
MIGRAPH_THROW
(
prefix
()
+
"Shapes do not match"
);
return
*
this
;
}
const
check_shapes
&
same_type
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
RTG_THROW
(
"Types do not match"
);
MIGRAPH_THROW
(
prefix
()
+
"Types do not match"
);
return
*
this
;
}
const
check_shapes
&
same_dims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
lens
();
}))
RTG_THROW
(
"Dimensions do not match"
);
MIGRAPH_THROW
(
prefix
()
+
"Dimensions do not match"
);
return
*
this
;
}
const
check_shapes
&
same_ndims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
lens
().
size
();
}))
RTG_THROW
(
"Dimensions do not match"
);
MIGRAPH_THROW
(
prefix
()
+
"Dimensions do not match"
);
return
*
this
;
}
...
...
@@ -83,7 +97,10 @@ struct check_shapes
struct
not_computable
{
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
struct
convolution
...
...
@@ -101,7 +118,7 @@ struct convolution
std
::
string
name
()
const
{
return
"convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_ndims
().
only_dims
(
4
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_ndims
().
only_dims
(
4
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
...
...
@@ -149,11 +166,14 @@ struct convolution
}
else
{
RTG
_THROW
(
"Invalid padding mode"
);
MIGRAPH
_THROW
(
"Invalid padding mode"
);
}
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
{
...
...
@@ -175,7 +195,7 @@ struct pooling
std
::
string
name
()
const
{
return
"pooling"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
).
only_dims
(
4
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
only_dims
(
4
);
const
shape
&
input
=
inputs
.
at
(
0
);
auto
t
=
input
.
type
();
...
...
@@ -197,7 +217,10 @@ struct pooling
}};
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
{
...
...
@@ -216,11 +239,14 @@ struct activation
std
::
string
name
()
const
{
return
"activation"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
{
os
<<
op
.
name
()
<<
":"
<<
op
.
mode
;
...
...
@@ -234,20 +260,20 @@ struct transpose
std
::
string
name
()
const
{
return
"transpose"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
input
=
inputs
.
at
(
0
);
auto
input_lens
=
input
.
lens
();
auto
input_strides
=
input
.
strides
();
auto
t
=
input
.
type
();
if
(
dims
.
size
()
!=
input_lens
.
size
())
{
RTG
_THROW
(
"Permutation has wrong number of axes"
);
MIGRAPH
_THROW
(
"Permutation has wrong number of axes"
);
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
{
RTG
_THROW
(
"Invalid permutation"
);
MIGRAPH
_THROW
(
"Invalid permutation"
);
}
std
::
vector
<
size_t
>
output_lens
(
input_lens
.
size
());
std
::
vector
<
size_t
>
output_strides
(
input_lens
.
size
());
...
...
@@ -258,7 +284,10 @@ struct transpose
}
return
{
t
,
output_lens
,
output_strides
};
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
struct
contiguous
...
...
@@ -266,16 +295,19 @@ struct contiguous
std
::
string
name
()
const
{
return
"contiguous"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
lens
=
inputs
.
at
(
0
).
lens
();
auto
t
=
inputs
.
at
(
0
).
type
();
if
(
lens
.
size
()
<
2
)
{
RTG
_THROW
(
"Number of dimensions should exceed 1"
);
MIGRAPH
_THROW
(
"Number of dimensions should exceed 1"
);
}
return
{
t
,
lens
};
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
struct
reshape
...
...
@@ -284,8 +316,7 @@ struct reshape
std
::
string
name
()
const
{
return
"reshape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
if
(
inputs
.
empty
())
RTG_THROW
(
"Wrong number of arguments"
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
...
...
@@ -299,11 +330,15 @@ struct reshape
std
::
copy
(
idims
.
begin
()
+
rdims
.
size
(),
idims
.
end
(),
std
::
back_inserter
(
rdims
));
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
RTG_THROW
(
"Wrong number of elements"
);
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
MIGRAPH_THROW
(
"Wrong number of elements for reshape"
);
return
s
;
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
...
...
@@ -319,17 +354,20 @@ struct gemm
std
::
string
name
()
const
{
return
"gemm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
()
.
same_ndims
().
only_dims
(
2
)
;
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
();
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
a
.
lens
()[
1
]
!=
b
.
lens
()[
0
])
RTG
_THROW
(
"Inner dimensions do not match"
);
MIGRAPH
_THROW
(
"Inner dimensions do not match"
);
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
gemm
&
op
)
{
...
...
@@ -346,7 +384,10 @@ struct unary
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
at
(
0
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
struct
identity
:
unary
...
...
@@ -435,19 +476,19 @@ struct broadcast
result
.
lens
().
cbegin
(),
result
.
lens
().
cend
(),
[
&
](
auto
x
)
{
return
x
==
1
;
}))
{
if
(
axis
!=
0
)
RTG
_THROW
(
"when broadcasting tensor of size 1, axis should be 0"
);
MIGRAPH
_THROW
(
"when broadcasting tensor of size 1, axis should be 0"
);
return
{
t
,
result
.
lens
(),
std
::
move
(
bcast_strides
)};
}
else
{
assert
(
result
.
lens
().
size
()
-
axis
>=
input
.
lens
().
size
());
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
result
.
lens
().
begin
()
+
axis
))
RTG
_THROW
(
"when broadcasting success sizes must match"
);
MIGRAPH
_THROW
(
"when broadcasting success sizes must match"
);
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
result
.
lens
(),
std
::
move
(
bcast_strides
)};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
output_shape
,
std
::
move
(
args
.
at
(
1
).
data
)};
}
...
...
@@ -461,7 +502,10 @@ struct binary
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
return
inputs
.
at
(
0
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
struct
add
:
binary
...
...
@@ -490,12 +534,26 @@ struct outline
std
::
string
name
()
const
{
return
"outline"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
0
);
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
return
s
;
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
return
{
s
,
nullptr
};
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
return
{
s
,
nullptr
};
}
};
template
<
class
T
>
struct
check_context
{
std
::
string
name
()
const
{
return
"check_context"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
return
{};
}
argument
compute
(
context
&
ctx
,
shape
,
std
::
vector
<
argument
>
)
const
{
T
*
x
=
any_cast
<
T
>
(
&
ctx
);
if
(
x
==
nullptr
)
MIGRAPH_THROW
(
std
::
string
(
"Unexpected context type: "
)
+
ctx
.
type_id
().
name
());
return
{};
}
};
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/program.hpp
→
src/include/
migraph
/program.hpp
View file @
06fb0905
#ifndef
RTG_GUARD_RTG
LIB_PROGRAM_HPP
#define
RTG_GUARD_RTG
LIB_PROGRAM_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
LIB_PROGRAM_HPP
#define
MIGRAPH_GUARD_MIGRAPH
LIB_PROGRAM_HPP
#include <list>
#include <unordered_map>
#include <
rtg
/operation.hpp>
#include <
rtg
/literal.hpp>
#include <
rtg
/builtin.hpp>
#include <
rtg
/instruction_ref.hpp>
#include <
rtg
/target.hpp>
#include <
migraph
/operation.hpp>
#include <
migraph
/literal.hpp>
#include <
migraph
/builtin.hpp>
#include <
migraph
/instruction_ref.hpp>
#include <
migraph
/target.hpp>
#include <algorithm>
#include <iostream>
namespace
rtg
{
namespace
migraph
{
struct
program_impl
;
...
...
@@ -27,6 +27,8 @@ struct program
program
&
operator
=
(
program
&&
)
noexcept
;
~
program
()
noexcept
;
using
parameter_map
=
std
::
unordered_map
<
std
::
string
,
argument
>
;
template
<
class
...
Ts
>
instruction_ref
add_instruction
(
operation
op
,
Ts
...
args
)
{
...
...
@@ -64,7 +66,7 @@ struct program
shape
get_parameter_shape
(
std
::
string
name
);
argument
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
;
argument
eval
(
parameter_map
params
)
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
...
...
@@ -81,6 +83,6 @@ struct program
std
::
unique_ptr
<
program_impl
>
impl
;
};
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/ranges.hpp
→
src/include/
migraph
/ranges.hpp
View file @
06fb0905
#ifndef
RTG_GUARD_RTG
LIB_RANGES_HPP
#define
RTG_GUARD_RTG
LIB_RANGES_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
LIB_RANGES_HPP
#define
MIGRAPH_GUARD_MIGRAPH
LIB_RANGES_HPP
namespace
rtg
{
namespace
migraph
{
template
<
class
C
,
class
T
>
bool
contains
(
C
&&
c
,
T
&&
x
)
...
...
@@ -15,6 +15,6 @@ void copy(Range&& r, Iterator it)
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
}
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/raw_data.hpp
→
src/include/
migraph
/raw_data.hpp
View file @
06fb0905
#ifndef
RTG
_GUARD_RAW_DATA_HPP
#define
RTG
_GUARD_RAW_DATA_HPP
#ifndef
MIGRAPH
_GUARD_RAW_DATA_HPP
#define
MIGRAPH
_GUARD_RAW_DATA_HPP
#include <
rtg
/tensor_view.hpp>
#include <
rtg
/requires.hpp>
#include <
migraph
/tensor_view.hpp>
#include <
migraph
/requires.hpp>
namespace
rtg
{
namespace
migraph
{
struct
raw_data_base
{
...
...
@@ -89,13 +89,27 @@ struct raw_data : raw_data_base
assert
(
self
->
single
());
return
self
->
template
at
<
T
>();
}
template
<
class
T
>
using
is_data_ptr
=
bool_c
<
(
std
::
is_void
<
T
>
{}
or
std
::
is_same
<
char
,
std
::
remove_cv_t
<
T
>>
{}
or
std
::
is_same
<
unsigned
char
,
std
::
remove_cv_t
<
T
>>
{})
>
;
template
<
class
T
>
using
get_data_type
=
std
::
conditional_t
<
is_data_ptr
<
T
>
{},
float
,
T
>
;
template
<
class
T
>
bool
matches
()
const
{
return
is_data_ptr
<
T
>
{}
||
self
->
get_shape
().
type
()
==
migraph
::
shape
::
get_type
<
get_data_type
<
T
>>
{};
}
template
<
class
T
>
operator
T
*
()
{
using
type
=
std
::
remove_cv_t
<
T
>
;
assert
((
std
::
is_void
<
T
>
{}
or
std
::
is_same
<
char
,
type
>
{}
or
std
::
is_same
<
unsigned
char
,
type
>
{}
or
self
->
get_shape
().
type
()
==
rtg
::
shape
::
get_type
<
T
>
{}));
assert
(
matches
<
T
>
());
return
reinterpret_cast
<
type
*>
(
self
->
data
());
}
};
...
...
@@ -109,15 +123,26 @@ struct raw_data : raw_data_base
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
if
(
s
.
type
()
!=
rtg
::
shape
::
get_type
<
T
>
{})
RTG
_THROW
(
"Incorrect data type for raw data"
);
if
(
s
.
type
()
!=
migraph
::
shape
::
get_type
<
T
>
{})
MIGRAPH
_THROW
(
"Incorrect data type for raw data"
);
return
make_view
(
s
,
reinterpret_cast
<
T
*>
(
buffer
));
}
/// Cast the data pointer
template
<
class
T
>
T
*
cast
()
const
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
assert
(
s
.
type
()
==
migraph
::
shape
::
get_type
<
T
>
{});
return
reinterpret_cast
<
T
*>
(
buffer
);
}
};
template
<
class
T
,
class
U
,
RTG_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}
&&
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
MIGRAPH_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}
&&
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
==
(
const
T
&
x
,
const
U
&
y
)
{
auto
&&
xshape
=
x
.
get_shape
();
...
...
@@ -139,7 +164,8 @@ bool operator==(const T& x, const U& y)
template
<
class
T
,
class
U
,
RTG_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}
&&
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
MIGRAPH_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}
&&
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
{
return
!
(
x
==
y
);
...
...
@@ -170,13 +196,13 @@ auto visit_all(T&& x, Ts&&... xs)
auto
&&
s
=
x
.
get_shape
();
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
RTG
_THROW
(
"Types must be the same"
);
MIGRAPH
_THROW
(
"Types must be the same"
);
return
[
&
](
auto
v
)
{
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail
::
visit_all_impl
(
s
,
v
,
x
,
xs
...);
};
}
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/migraph/requires.hpp
0 → 100644
View file @
06fb0905
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#include <type_traits>
namespace
migraph
{
template
<
bool
...
Bs
>
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
||
true
)...
>>
// NOLINT
{
};
template
<
bool
B
>
using
bool_c
=
std
::
integral_constant
<
bool
,
B
>
;
#ifdef CPPCHECK
#define MIGRAPH_REQUIRES(...) class = void
#else
#define MIGRAPH_REQUIRES(...) \
bool PrivateRequires##__LINE__ = true, \
class = typename std::enable_if<and_<__VA_ARGS__, PrivateRequires##__LINE__>{}>::type
#endif
}
// namespace migraph
#endif
src/include/
rtg
/shape.hpp
→
src/include/
migraph
/shape.hpp
View file @
06fb0905
#ifndef
RTG_GUARD_RTG
LIB_SHAPE_HPP
#define
RTG_GUARD_RTG
LIB_SHAPE_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
LIB_SHAPE_HPP
#define
MIGRAPH_GUARD_MIGRAPH
LIB_SHAPE_HPP
#include <vector>
#include <cassert>
#include <ostream>
#include <numeric>
#include <
rtg
/errors.hpp>
#include <
migraph
/errors.hpp>
namespace
rtg
{
namespace
migraph
{
struct
shape
{
// Add new types here
// clang-format off
#define
RTG
_SHAPE_VISIT_TYPES(m) \
#define
MIGRAPH
_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \
m(double_type, double) \
m(uint8_type, uint8_t) \
...
...
@@ -28,25 +28,22 @@ struct shape
m(uint64_type, uint64_t)
// clang-format on
#define
RTG
_SHAPE_ENUM_TYPES(x, t) x,
#define
MIGRAPH
_SHAPE_ENUM_TYPES(x, t) x,
enum
type_t
{
any_type
,
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_ENUM_TYPES
)
MIGRAPH_SHAPE_VISIT_TYPES
(
MIGRAPH_SHAPE_ENUM_TYPES
)
};
#undef
RTG
_SHAPE_ENUM_TYPES
#undef
MIGRAPH
_SHAPE_ENUM_TYPES
template
<
class
T
,
class
=
void
>
struct
get_type
:
std
::
integral_constant
<
type_t
,
any_type
>
{
};
#define RTG_SHAPE_GET_TYPE(x, t) \
struct
get_type
;
#define MIGRAPH_SHAPE_GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \
};
RTG
_SHAPE_VISIT_TYPES
(
RTG
_SHAPE_GET_TYPE
)
#undef
RTG
_SHAPE_GET_TYPE
MIGRAPH
_SHAPE_VISIT_TYPES
(
MIGRAPH
_SHAPE_GET_TYPE
)
#undef
MIGRAPH
_SHAPE_GET_TYPE
shape
();
shape
(
type_t
t
);
...
...
@@ -74,6 +71,7 @@ struct shape
std
::
size_t
index
(
std
::
size_t
i
)
const
;
bool
packed
()
const
;
bool
broadcasted
()
const
;
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
...
...
@@ -124,13 +122,12 @@ struct shape
{
switch
(
this
->
m_type
)
{
case
any_type
:
RTG_THROW
(
"Cannot visit the any_type"
);
#define RTG_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
RTG
_SHAPE_VISIT_TYPES
(
RTG
_SHAPE_VISITOR_CASE
)
#undef
RTG
_SHAPE_VISITOR_CASE
MIGRAPH
_SHAPE_VISIT_TYPES
(
MIGRAPH
_SHAPE_VISITOR_CASE
)
#undef
MIGRAPH
_SHAPE_VISITOR_CASE
}
RTG
_THROW
(
"Unknown type"
);
MIGRAPH
_THROW
(
"Unknown type"
);
}
private:
...
...
@@ -144,6 +141,6 @@ struct shape
std
::
string
type_string
()
const
;
};
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/shape_for_each.hpp
→
src/include/
migraph
/shape_for_each.hpp
View file @
06fb0905
#ifndef
RTG_GUARD_RTG
LIB_SHAPE_FOR_EACH_HPP
#define
RTG_GUARD_RTG
LIB_SHAPE_FOR_EACH_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
LIB_SHAPE_FOR_EACH_HPP
#define
MIGRAPH_GUARD_MIGRAPH
LIB_SHAPE_FOR_EACH_HPP
#include <
rtg
/shape.hpp>
#include <
migraph
/shape.hpp>
#include <algorithm>
namespace
rtg
{
namespace
migraph
{
template
<
class
F
>
void
shape_for_each
(
const
rtg
::
shape
&
s
,
F
f
)
void
shape_for_each
(
const
migraph
::
shape
&
s
,
F
f
)
{
// Ensure calls to f use const ref to vector
auto
call
=
[
&
f
](
const
std
::
vector
<
std
::
size_t
>&
i
)
{
f
(
i
);
};
...
...
@@ -26,6 +26,6 @@ void shape_for_each(const rtg::shape& s, F f)
}
}
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/streamutils.hpp
→
src/include/
migraph
/streamutils.hpp
View file @
06fb0905
#ifndef
RTG
_GUARD_STREAMUTILS_HPP
#define
RTG
_GUARD_STREAMUTILS_HPP
#ifndef
MIGRAPH
_GUARD_STREAMUTILS_HPP
#define
MIGRAPH
_GUARD_STREAMUTILS_HPP
#include <ostream>
#include <algorithm>
namespace
rtg
{
namespace
migraph
{
template
<
class
T
>
struct
stream_range_container
...
...
@@ -31,6 +31,6 @@ inline stream_range_container<Range> stream_range(const Range& r)
return
{
r
};
}
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/stringutils.hpp
→
src/include/
migraph
/stringutils.hpp
View file @
06fb0905
#ifndef
RTG_GUARD_RTG
LIB_STRINGUTILS_HPP
#define
RTG_GUARD_RTG
LIB_STRINGUTILS_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
LIB_STRINGUTILS_HPP
#define
MIGRAPH_GUARD_MIGRAPH
LIB_STRINGUTILS_HPP
#include <algorithm>
#include <numeric>
#include <string>
#include <sstream>
namespace
rtg
{
namespace
migraph
{
inline
std
::
string
replace_string
(
std
::
string
subject
,
const
std
::
string
&
search
,
const
std
::
string
&
replace
)
...
...
@@ -77,6 +77,6 @@ inline std::string to_string(const Range& r)
return
ss
.
str
();
}
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/target.hpp
→
src/include/
migraph
/target.hpp
View file @
06fb0905
#ifndef
RTG_GUARD_RTG
LIB_TARGET_HPP
#define
RTG_GUARD_RTG
LIB_TARGET_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
LIB_TARGET_HPP
#define
MIGRAPH_GUARD_MIGRAPH
LIB_TARGET_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraph/context.hpp>
namespace
rtg
{
namespace
migraph
{
struct
program
;
...
...
@@ -18,6 +19,7 @@ struct program;
* {
* std::string name() const;
* void apply(program & p) const;
* context get_context() const;
* };
*
*/
...
...
@@ -71,6 +73,14 @@ struct target
:
nullptr
;
}
const
std
::
type_info
&
type_id
()
const
{
if
(
private_detail_te_handle_empty
())
return
typeid
(
std
::
nullptr_t
);
else
return
private_detail_te_get_handle
().
type
();
}
std
::
string
name
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
...
@@ -83,6 +93,12 @@ struct target
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
}
context
get_context
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
}
private:
struct
private_detail_te_handle_base_type
{
...
...
@@ -92,6 +108,7 @@ struct target
virtual
std
::
string
name
()
const
=
0
;
virtual
void
apply
(
program
&
p
)
const
=
0
;
virtual
context
get_context
()
const
=
0
;
};
template
<
typename
PrivateDetailTypeErasedT
>
...
...
@@ -126,6 +143,8 @@ struct target
void
apply
(
program
&
p
)
const
override
{
return
private_detail_te_value
.
apply
(
p
);
}
context
get_context
()
const
override
{
return
private_detail_te_value
.
get_context
();
}
PrivateDetailTypeErasedT
private_detail_te_value
;
};
...
...
@@ -139,13 +158,20 @@ struct target
}
};
bool
private_detail_te_handle_empty
()
const
{
return
private_detail_te_handle_mem_var
==
nullptr
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
return
*
private_detail_te_handle_mem_var
;
}
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
...
...
@@ -184,6 +210,6 @@ inline const ValueType& any_cast(const target& x)
return
*
y
;
}
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/tensor_view.hpp
→
src/include/
migraph
/tensor_view.hpp
View file @
06fb0905
#ifndef
RTG
_GUARD_TENSOR_VIEW_HPP
#define
RTG
_GUARD_TENSOR_VIEW_HPP
#ifndef
MIGRAPH
_GUARD_TENSOR_VIEW_HPP
#define
MIGRAPH
_GUARD_TENSOR_VIEW_HPP
#include <
rtg
/shape.hpp>
#include <
rtg
/float_equal.hpp>
#include <
rtg
/requires.hpp>
#include <
migraph
/shape.hpp>
#include <
migraph
/float_equal.hpp>
#include <
migraph
/requires.hpp>
#include <iostream>
namespace
rtg
{
namespace
migraph
{
template
<
class
T
>
struct
tensor_view
...
...
@@ -26,27 +26,27 @@ struct tensor_view
const
T
*
data
()
const
{
return
this
->
m_data
;
}
template
<
class
...
Ts
,
RTG
_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH
_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
const
T
&
operator
()(
Ts
...
xs
)
const
{
assert
(
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})
<
m_shape
.
bytes
()
/
sizeof
(
T
));
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
template
<
class
...
Ts
,
RTG
_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH
_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
T
&
operator
()(
Ts
...
xs
)
{
assert
(
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})
<
m_shape
.
bytes
()
/
sizeof
(
T
));
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
template
<
class
Iterator
,
RTG
_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
template
<
class
Iterator
,
MIGRAPH
_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
const
T
&
operator
()(
Iterator
start
,
Iterator
last
)
const
{
return
m_data
[
m_shape
.
index
(
start
,
last
)];
}
template
<
class
Iterator
,
RTG
_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
template
<
class
Iterator
,
MIGRAPH
_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
T
&
operator
()(
Iterator
start
,
Iterator
last
)
{
return
m_data
[
m_shape
.
index
(
start
,
last
)];
...
...
@@ -164,6 +164,6 @@ tensor_view<T> make_view(shape s, T* data)
return
{
s
,
data
};
}
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/rtg/fallthrough.hpp
deleted
100644 → 0
View file @
0a59f103
#ifndef RTG_GUARD_FALLTHROUGH_HPP
#define RTG_GUARD_FALLTHROUGH_HPP
namespace
rtg
{
#ifdef __clang__
#define RTG_FALLTHROUGH [[clang::fallthrough]]
#else
#define RTG_FALLTHROUGH
#endif
}
// namespace rtg
#endif
src/include/rtg/onnx.hpp
deleted
100644 → 0
View file @
0a59f103
#ifndef GUARD_RTGLIB_ONNX_HPP
#define GUARD_RTGLIB_ONNX_HPP
#include <rtg/program.hpp>
namespace
rtg
{
program
parse_onnx
(
const
std
::
string
&
name
);
}
// namespace rtg
#endif
src/include/rtg/requires.hpp
deleted
100644 → 0
View file @
0a59f103
#ifndef RTG_GUARD_RTGLIB_REQUIRES_HPP
#define RTG_GUARD_RTGLIB_REQUIRES_HPP
#include <type_traits>
namespace
rtg
{
template
<
bool
...
Bs
>
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
||
true
)...
>>
// NOLINT
{
};
#ifdef CPPCHECK
#define RTG_REQUIRES(...) class = void
#else
#define RTG_REQUIRES(...) class = typename std::enable_if<and_<__VA_ARGS__, true>{}>::type
#endif
}
// namespace rtg
#endif
src/onnx/CMakeLists.txt
View file @
06fb0905
...
...
@@ -5,15 +5,23 @@ add_library(onnx-proto STATIC ${PROTO_SRCS})
target_include_directories
(
onnx-proto SYSTEM PUBLIC
${
CMAKE_CURRENT_BINARY_DIR
}
${
PROTOBUF_INCLUDE_DIR
}
)
target_compile_options
(
onnx-proto PRIVATE -w
)
target_link_libraries
(
onnx-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
set_target_properties
(
onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
add_library
(
rtg_onnx onnx.cpp
)
rocm_clang_tidy_check
(
rtg_onnx
)
target_link_libraries
(
rtg_onnx onnx-proto rtg
)
add_library
(
migraph_onnx onnx.cpp
)
rocm_clang_tidy_check
(
migraph_onnx
)
target_link_libraries
(
migraph_onnx PRIVATE onnx-proto
)
target_link_libraries
(
migraph_onnx PUBLIC migraph
)
add_executable
(
read_onnx read_onnx.cpp
)
rocm_clang_tidy_check
(
read_onnx
)
target_link_libraries
(
read_onnx
rtg_onnx rtg_cpu
)
target_link_libraries
(
read_onnx
migraph_onnx
)
add_executable
(
mnist mnist.cpp
)
rocm_clang_tidy_check
(
mnist
)
target_link_libraries
(
mnist rtg_onnx rtg_cpu
)
target_link_libraries
(
mnist migraph_cpu migraph_onnx
)
if
(
MIGRAPH_ENABLE_MIOPEN
)
add_executable
(
verify_onnx verify_onnx.cpp
)
rocm_clang_tidy_check
(
verify_onnx
)
target_link_libraries
(
verify_onnx migraph_onnx migraph_cpu migraph_miopen
)
endif
()
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment