Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2ea9fc44
Commit
2ea9fc44
authored
Oct 22, 2018
by
wsttiger
Browse files
Merge branch 'master' into concat
parents
c7d783dd
7d76401e
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
314 additions
and
84 deletions
+314
-84
src/include/migraph/functional.hpp
src/include/migraph/functional.hpp
+5
-0
src/include/migraph/operation.hpp
src/include/migraph/operation.hpp
+42
-2
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+106
-57
src/include/migraph/reflect.hpp
src/include/migraph/reflect.hpp
+50
-0
src/include/migraph/streamutils.hpp
src/include/migraph/streamutils.hpp
+23
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+2
-1
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+1
-1
src/targets/gpu/include/migraph/gpu/convolution.hpp
src/targets/gpu/include/migraph/gpu/convolution.hpp
+7
-8
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+2
-2
test/operation.cpp
test/operation.cpp
+24
-4
tools/include/operation.hpp
tools/include/operation.hpp
+52
-9
No files found.
src/include/migraph/functional.hpp
View file @
2ea9fc44
...
@@ -87,6 +87,11 @@ constexpr void each_args(F f, Ts&&... xs)
...
@@ -87,6 +87,11 @@ constexpr void each_args(F f, Ts&&... xs)
swallow
{(
f
(
std
::
forward
<
Ts
>
(
xs
)),
0
)...};
swallow
{(
f
(
std
::
forward
<
Ts
>
(
xs
)),
0
)...};
}
}
template
<
class
F
>
constexpr
void
each_args
(
F
)
{
}
/// Implements a fix-point combinator
/// Implements a fix-point combinator
template
<
class
R
,
class
F
>
template
<
class
R
,
class
F
>
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
...
...
src/include/migraph/operation.hpp
View file @
2ea9fc44
...
@@ -8,7 +8,8 @@
...
@@ -8,7 +8,8 @@
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/reflect.hpp>
#include <migraph/streamutils.hpp>
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
#include <migraph/auto_any_cast.hpp>
...
@@ -54,11 +55,34 @@ namespace operation_stream {
...
@@ -54,11 +55,34 @@ namespace operation_stream {
template
<
class
T
>
template
<
class
T
>
auto
operator
<<
(
std
::
ostream
&
os
,
const
T
&
x
)
->
decltype
(
os
<<
x
.
name
())
auto
operator
<<
(
std
::
ostream
&
os
,
const
T
&
x
)
->
decltype
(
os
<<
x
.
name
())
{
{
return
os
<<
x
.
name
();
os
<<
x
.
name
();
char
delim
=
'['
;
reflect_each
(
x
,
[
&
](
auto
&
y
,
auto
name
)
{
os
<<
delim
;
os
<<
name
<<
"="
;
stream_write_value
(
os
,
y
);
delim
=
','
;
});
if
(
delim
==
','
)
os
<<
"]"
;
return
os
;
}
}
}
// namespace operation_stream
}
// namespace operation_stream
namespace
operation_equal
{
template
<
class
T
,
class
U
>
auto
operator
==
(
const
T
&
x
,
const
U
&
y
)
->
decltype
(
x
.
name
()
==
y
.
name
())
{
if
(
x
.
name
()
!=
y
.
name
())
return
false
;
const
auto
&
yy
=
any_cast
<
T
>
(
y
);
return
reflect_tie
(
x
)
==
reflect_tie
(
yy
);
}
}
// namespace operation_equal
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
T
&
x
,
...
@@ -93,6 +117,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
...
@@ -93,6 +117,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
* shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
* };
*
*
*/
*/
...
@@ -178,6 +203,12 @@ struct operation
...
@@ -178,6 +203,12 @@ struct operation
return
op
.
private_detail_te_get_handle
().
operator_shift_left
(
os
);
return
op
.
private_detail_te_get_handle
().
operator_shift_left
(
os
);
}
}
friend
bool
operator
==
(
const
operation
&
x
,
const
operation
&
y
)
{
assert
(
x
.
private_detail_te_handle_mem_var
);
return
x
.
private_detail_te_get_handle
().
operator
==
(
y
);
}
private:
private:
struct
private_detail_te_handle_base_type
struct
private_detail_te_handle_base_type
{
{
...
@@ -190,6 +221,7 @@ struct operation
...
@@ -190,6 +221,7 @@ struct operation
virtual
argument
virtual
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
};
};
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
...
@@ -242,6 +274,12 @@ struct operation
...
@@ -242,6 +274,12 @@ struct operation
return
os
<<
private_detail_te_value
;
return
os
<<
private_detail_te_value
;
}
}
bool
operator
==
(
const
operation
&
y
)
const
override
{
using
migraph
::
operation_equal
::
operator
==
;
return
private_detail_te_value
==
y
;
}
PrivateDetailTypeErasedT
private_detail_te_value
;
PrivateDetailTypeErasedT
private_detail_te_value
;
};
};
...
@@ -307,6 +345,8 @@ inline const ValueType& any_cast(const operation& x)
...
@@ -307,6 +345,8 @@ inline const ValueType& any_cast(const operation& x)
return
*
y
;
return
*
y
;
}
}
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
#endif
#endif
}
// namespace migraph
}
// namespace migraph
...
...
src/include/migraph/operators.hpp
View file @
2ea9fc44
...
@@ -35,7 +35,12 @@ struct batch_norm_inference
...
@@ -35,7 +35,12 @@ struct batch_norm_inference
bn_infer_mode_t
bn_mode
=
spatial
;
bn_infer_mode_t
bn_mode
=
spatial
;
bool
is_test
=
false
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
epsilon
,
"epsilon"
),
f
(
self
.
momentum
,
"momentum"
),
f
(
self
.
bn_mode
,
"bn_mode"
));
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
@@ -56,6 +61,16 @@ struct convolution
...
@@ -56,6 +61,16 @@ struct convolution
valid
valid
};
};
padding_mode_t
padding_mode
=
default_
;
padding_mode_t
padding_mode
=
default_
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
padding
,
"padding"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
padding_mode
,
"padding_mode"
));
}
std
::
string
name
()
const
{
return
"convolution"
;
}
std
::
string
name
()
const
{
return
"convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
@@ -110,16 +125,6 @@ struct convolution
...
@@ -110,16 +125,6 @@ struct convolution
MIGRAPH_THROW
(
"Invalid padding mode"
);
MIGRAPH_THROW
(
"Invalid padding mode"
);
}
}
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"padding={"
<<
stream_range
(
op
.
padding
)
<<
"}, "
;
os
<<
"stride={"
<<
stream_range
(
op
.
stride
)
<<
"}, "
;
os
<<
"dilation={"
<<
stream_range
(
op
.
dilation
)
<<
"}"
;
os
<<
"]"
;
return
os
;
}
};
};
struct
im2col
struct
im2col
...
@@ -133,6 +138,16 @@ struct im2col
...
@@ -133,6 +138,16 @@ struct im2col
same
,
same
,
valid
valid
};
};
padding_mode_t
padding_mode
=
default_
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
padding
,
"padding"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
padding_mode
,
"padding_mode"
));
}
std
::
string
name
()
const
{
return
"im2col"
;
}
std
::
string
name
()
const
{
return
"im2col"
;
}
...
@@ -168,6 +183,16 @@ struct pooling
...
@@ -168,6 +183,16 @@ struct pooling
std
::
array
<
std
::
size_t
,
2
>
padding
=
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
lengths
=
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
lengths
=
{{
1
,
1
}};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
mode
,
"mode"
),
f
(
self
.
padding
,
"padding"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
lengths
,
"lengths"
));
}
std
::
string
name
()
const
{
return
"pooling"
;
}
std
::
string
name
()
const
{
return
"pooling"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
@@ -196,16 +221,6 @@ struct pooling
...
@@ -196,16 +221,6 @@ struct pooling
1
)),
1
)),
}};
}};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"padding={"
<<
stream_range
(
op
.
padding
)
<<
"}, "
;
os
<<
"stride={"
<<
stream_range
(
op
.
stride
)
<<
"}, "
;
os
<<
"lengths={"
<<
stream_range
(
op
.
lengths
)
<<
"}"
;
os
<<
"]"
;
return
os
;
}
};
};
struct
activation
struct
activation
...
@@ -227,6 +242,13 @@ struct activation
...
@@ -227,6 +242,13 @@ struct activation
struct
transpose
struct
transpose
{
{
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
dims
,
"dims"
));
}
std
::
string
name
()
const
{
return
"transpose"
;
}
std
::
string
name
()
const
{
return
"transpose"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
@@ -258,13 +280,6 @@ struct transpose
...
@@ -258,13 +280,6 @@ struct transpose
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
transpose
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"dims={"
<<
stream_range
(
op
.
dims
)
<<
"}"
;
os
<<
"]"
;
return
os
;
}
};
};
struct
contiguous
struct
contiguous
...
@@ -326,6 +341,13 @@ struct slice
...
@@ -326,6 +341,13 @@ struct slice
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
starts
;
std
::
vector
<
int64_t
>
starts
;
std
::
vector
<
int64_t
>
ends
;
std
::
vector
<
int64_t
>
ends
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
),
f
(
self
.
starts
,
"starts"
),
f
(
self
.
ends
,
"ends"
));
}
std
::
string
name
()
const
{
return
"slice"
;
}
std
::
string
name
()
const
{
return
"slice"
;
}
auto
fix_index
(
const
std
::
vector
<
std
::
size_t
>&
lens
,
std
::
size_t
axis
,
int64_t
index
)
const
auto
fix_index
(
const
std
::
vector
<
std
::
size_t
>&
lens
,
std
::
size_t
axis
,
int64_t
index
)
const
...
@@ -398,6 +420,13 @@ struct slice
...
@@ -398,6 +420,13 @@ struct slice
struct
squeeze
struct
squeeze
{
{
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
std
::
string
name
()
const
{
return
"squeeze"
;
}
std
::
string
name
()
const
{
return
"squeeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
@@ -438,6 +467,13 @@ struct squeeze
...
@@ -438,6 +467,13 @@ struct squeeze
struct
unsqueeze
struct
unsqueeze
{
{
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
@@ -469,6 +505,13 @@ struct unsqueeze
...
@@ -469,6 +505,13 @@ struct unsqueeze
struct
reshape
struct
reshape
{
{
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
dims
,
"dims"
));
}
std
::
string
name
()
const
{
return
"reshape"
;
}
std
::
string
name
()
const
{
return
"reshape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
@@ -508,19 +551,19 @@ struct reshape
...
@@ -508,19 +551,19 @@ struct reshape
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"dims={"
<<
stream_range
(
op
.
dims
)
<<
"}"
;
os
<<
"]"
;
return
os
;
}
};
};
struct
gemm
struct
gemm
{
{
float
alpha
=
1.0
;
float
alpha
=
1.0
;
float
beta
=
0.0
;
float
beta
=
0.0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
));
}
std
::
string
name
()
const
{
return
"gemm"
;
}
std
::
string
name
()
const
{
return
"gemm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
@@ -534,13 +577,6 @@ struct gemm
...
@@ -534,13 +577,6 @@ struct gemm
to_string_range
(
b
.
lens
())
+
"}"
);
to_string_range
(
b
.
lens
())
+
"}"
);
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
gemm
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"]"
;
return
os
;
}
};
};
struct
unary
struct
unary
...
@@ -625,6 +661,13 @@ struct softmax
...
@@ -625,6 +661,13 @@ struct softmax
struct
flatten
struct
flatten
{
{
uint64_t
axis
=
0
;
uint64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
std
::
string
name
()
const
{
return
"flatten"
;
}
std
::
string
name
()
const
{
return
"flatten"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
...
@@ -645,17 +688,17 @@ struct flatten
...
@@ -645,17 +688,17 @@ struct flatten
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
flatten
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"axis="
<<
op
.
axis
;
os
<<
"]"
;
return
os
;
}
};
};
struct
broadcast
struct
broadcast
{
{
uint64_t
axis
=
0
;
uint64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
shape
broadcast_shape
;
shape
broadcast_shape
;
std
::
string
name
()
const
{
return
"broadcast"
;
}
std
::
string
name
()
const
{
return
"broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
@@ -687,18 +730,10 @@ struct broadcast
...
@@ -687,18 +730,10 @@ struct broadcast
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
broadcast
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"axis="
<<
op
.
axis
;
os
<<
"]"
;
return
os
;
}
};
};
struct
binary
struct
binary
{
{
uint64_t
broadcast
=
0
;
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
...
@@ -730,6 +765,13 @@ struct load
...
@@ -730,6 +765,13 @@ struct load
{
{
shape
s
;
shape
s
;
std
::
size_t
offset
=
0
;
std
::
size_t
offset
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
s
,
"shape"
),
f
(
self
.
offset
,
"offset"
));
}
std
::
string
name
()
const
{
return
"load"
;
}
std
::
string
name
()
const
{
return
"load"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
...
@@ -745,6 +787,13 @@ struct load
...
@@ -745,6 +787,13 @@ struct load
struct
outline
struct
outline
{
{
shape
s
;
shape
s
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
s
,
"shape"
));
}
std
::
string
name
()
const
{
return
"outline"
;
}
std
::
string
name
()
const
{
return
"outline"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
...
...
src/include/migraph/reflect.hpp
0 → 100644
View file @
2ea9fc44
#ifndef MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#define MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#include <migraph/functional.hpp>
#include <migraph/rank.hpp>
#include <functional>
namespace
migraph
{
namespace
detail
{
template
<
class
T
,
class
Selector
>
auto
reflect_impl
(
rank
<
1
>
,
T
&
x
,
Selector
f
)
->
decltype
(
T
::
reflect
(
x
,
f
))
{
return
T
::
reflect
(
x
,
std
::
move
(
f
));
}
template
<
class
T
,
class
Selector
>
auto
reflect_impl
(
rank
<
0
>
,
T
&
,
Selector
)
{
return
pack
();
}
}
// namespace detail
template
<
class
T
,
class
Selector
>
auto
reflect
(
T
&
x
,
Selector
f
)
{
return
detail
::
reflect_impl
(
rank
<
1
>
{},
x
,
std
::
move
(
f
));
}
template
<
class
T
>
auto
reflect_tie
(
T
&
x
)
{
return
reflect
(
x
,
[](
auto
&&
y
,
auto
&&
...)
{
return
std
::
ref
(
y
);
})(
[](
auto
&&
...
xs
)
{
return
std
::
tie
(
xs
.
get
()...);
});
}
template
<
class
T
,
class
F
>
void
reflect_each
(
T
&
x
,
F
f
)
{
return
reflect
(
x
,
[](
auto
&&
y
,
auto
...
ys
)
{
return
pack
(
std
::
ref
(
y
),
ys
...);
})(
[
&
](
auto
&&
...
xs
)
{
each_args
([
&
](
auto
p
)
{
p
([
&
](
auto
&&
y
,
auto
...
ys
)
{
f
(
y
.
get
(),
ys
...);
});
},
xs
...);
});
}
}
// namespace migraph
#endif
src/include/migraph/streamutils.hpp
View file @
2ea9fc44
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <ostream>
#include <ostream>
#include <algorithm>
#include <algorithm>
#include <migraph/rank.hpp>
namespace
migraph
{
namespace
migraph
{
...
@@ -31,6 +32,28 @@ inline stream_range_container<Range> stream_range(const Range& r)
...
@@ -31,6 +32,28 @@ inline stream_range_container<Range> stream_range(const Range& r)
return
{
r
};
return
{
r
};
}
}
namespace
detail
{
template
<
class
Range
>
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
->
decltype
(
r
.
begin
(),
r
.
end
(),
void
())
{
os
<<
stream_range
(
r
);
}
template
<
class
T
>
void
stream_write_value_impl
(
rank
<
0
>
,
std
::
ostream
&
os
,
const
T
&
x
)
{
os
<<
x
;
}
}
// namespace detail
template
<
class
T
>
void
stream_write_value
(
std
::
ostream
&
os
,
const
T
&
x
)
{
detail
::
stream_write_value_impl
(
rank
<
1
>
{},
os
,
x
);
}
}
// namespace migraph
}
// namespace migraph
#endif
#endif
src/onnx/onnx.cpp
View file @
2ea9fc44
...
@@ -255,7 +255,8 @@ struct onnx_parser
...
@@ -255,7 +255,8 @@ struct onnx_parser
?
op
::
batch_norm_inference
::
spatial
?
op
::
batch_norm_inference
::
spatial
:
op
::
batch_norm_inference
::
per_activation
;
:
op
::
batch_norm_inference
::
per_activation
;
}
}
op
::
batch_norm_inference
op
{
epsilon
,
momentum
,
bn_mode
,
is_test
};
(
void
)
is_test
;
op
::
batch_norm_inference
op
{
epsilon
,
momentum
,
bn_mode
};
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
}
...
...
src/opt/memory_coloring_impl.cpp
View file @
2ea9fc44
...
@@ -192,7 +192,7 @@ void memory_coloring_impl::register_operand_alias()
...
@@ -192,7 +192,7 @@ void memory_coloring_impl::register_operand_alias()
operand_alias
[
"@param"
]
=
-
1
;
operand_alias
[
"@param"
]
=
-
1
;
operand_alias
[
"transpose"
]
=
0
;
operand_alias
[
"transpose"
]
=
0
;
operand_alias
[
"flatten"
]
=
0
;
operand_alias
[
"flatten"
]
=
0
;
operand_alias
[
"broadcast"
]
=
1
;
operand_alias
[
"broadcast"
]
=
0
;
operand_alias
[
"reshape"
]
=
0
;
operand_alias
[
"reshape"
]
=
0
;
operand_alias
[
"pass"
]
=
0
;
operand_alias
[
"pass"
]
=
0
;
}
}
...
...
src/targets/gpu/include/migraph/gpu/convolution.hpp
View file @
2ea9fc44
...
@@ -26,19 +26,18 @@ struct miopen_convolution
...
@@ -26,19 +26,18 @@ struct miopen_convolution
shared
<
convolution_descriptor
>
cd
;
shared
<
convolution_descriptor
>
cd
;
miopenConvFwdAlgorithm_t
algo
{};
miopenConvFwdAlgorithm_t
algo
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
// TODO: Add algo
return
op
::
convolution
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"gpu::convolution"
;
}
std
::
string
name
()
const
{
return
"gpu::convolution"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
argument
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
shape
compile
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
instruction_ref
>
inputs
);
shape
compile
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
instruction_ref
>
inputs
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
miopen_convolution
&
self
)
{
os
<<
self
.
name
()
<<
"["
;
os
<<
self
.
op
<<
", "
;
os
<<
"algo="
<<
self
.
algo
;
os
<<
"]"
;
return
os
;
}
};
};
}
// namespace gpu
}
// namespace gpu
...
...
test/onnx/onnx_test.cpp
View file @
2ea9fc44
...
@@ -54,8 +54,8 @@ void pytorch_conv_bn_relu_maxpool()
...
@@ -54,8 +54,8 @@ void pytorch_conv_bn_relu_maxpool()
auto
l3
=
p
.
add_instruction
(
migraph
::
op
::
convolution
{},
l0
,
l1
);
auto
l3
=
p
.
add_instruction
(
migraph
::
op
::
convolution
{},
l0
,
l1
);
auto
l4
=
p
.
add_instruction
(
migraph
::
op
::
broadcast
{
axis
,
l3
->
get_shape
()},
l2
);
auto
l4
=
p
.
add_instruction
(
migraph
::
op
::
broadcast
{
axis
,
l3
->
get_shape
()},
l2
);
auto
l5
=
p
.
add_instruction
(
migraph
::
op
::
add
{},
l3
,
l4
);
auto
l5
=
p
.
add_instruction
(
migraph
::
op
::
add
{},
l3
,
l4
);
auto
l6
=
p
.
add_instruction
(
migraph
::
op
::
batch_norm_inference
{},
l5
,
p3
,
p4
,
p5
,
p6
);
auto
l6
=
p
.
add_instruction
(
migraph
::
op
::
batch_norm_inference
{
1.0e-5
f
},
l5
,
p3
,
p4
,
p5
,
p6
);
auto
l7
=
p
.
add_instruction
(
migraph
::
op
::
activation
{
"relu"
},
l6
);
auto
l7
=
p
.
add_instruction
(
migraph
::
op
::
activation
{
"relu"
},
l6
);
p
.
add_instruction
(
migraph
::
op
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
2
,
2
}}},
l7
);
p
.
add_instruction
(
migraph
::
op
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
2
,
2
}}},
l7
);
auto
prog
=
migraph
::
parse_onnx
(
"conv_bn_relu_maxpool.onnx"
);
auto
prog
=
migraph
::
parse_onnx
(
"conv_bn_relu_maxpool.onnx"
);
...
...
test/operation.cpp
View file @
2ea9fc44
...
@@ -6,6 +6,11 @@
...
@@ -6,6 +6,11 @@
struct
simple_operation
struct
simple_operation
{
{
template
<
class
T
,
class
F
>
static
auto
reflect
(
T
&
x
,
F
f
)
{
return
migraph
::
pack
(
f
(
x
.
data
,
"data"
));
}
int
data
=
1
;
int
data
=
1
;
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
migraph
::
shape
compute_shape
(
const
std
::
vector
<
migraph
::
shape
>&
)
const
migraph
::
shape
compute_shape
(
const
std
::
vector
<
migraph
::
shape
>&
)
const
...
@@ -19,7 +24,7 @@ struct simple_operation
...
@@ -19,7 +24,7 @@ struct simple_operation
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
simple_operation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
simple_operation
&
op
)
{
{
os
<<
"["
<<
op
.
name
()
<<
"]"
;
os
<<
op
.
name
()
<<
"["
<<
op
.
data
<<
"]"
;
return
os
;
return
os
;
}
}
};
};
...
@@ -44,9 +49,23 @@ void operation_copy_test()
...
@@ -44,9 +49,23 @@ void operation_copy_test()
migraph
::
operation
op1
=
s
;
// NOLINT
migraph
::
operation
op1
=
s
;
// NOLINT
migraph
::
operation
op2
=
op1
;
// NOLINT
migraph
::
operation
op2
=
op1
;
// NOLINT
// cppcheck-suppress duplicateExpression
// cppcheck-suppress duplicateExpression
EXPECT
(
s
.
name
()
==
op1
.
name
()
);
EXPECT
(
s
==
op1
);
// cppcheck-suppress duplicateExpression
// cppcheck-suppress duplicateExpression
EXPECT
(
op2
.
name
()
==
op1
.
name
());
EXPECT
(
op2
==
op1
);
}
void
operation_equal_test
()
{
simple_operation
s
{};
migraph
::
operation
op1
=
s
;
s
.
data
=
2
;
migraph
::
operation
op2
=
op1
;
// NOLINT
migraph
::
operation
op3
=
s
;
// NOLINT
EXPECT
(
s
!=
op1
);
EXPECT
(
op2
==
op1
);
EXPECT
(
op3
!=
op2
);
EXPECT
(
op3
!=
op1
);
}
}
struct
not_operation
struct
not_operation
...
@@ -70,7 +89,7 @@ void operation_print()
...
@@ -70,7 +89,7 @@ void operation_print()
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
op
;
ss
<<
op
;
std
::
string
s
=
ss
.
str
();
std
::
string
s
=
ss
.
str
();
EXPECT
(
s
==
"
[
simple]"
);
EXPECT
(
s
==
"simple
[1
]"
);
}
}
void
operation_default_print
()
void
operation_default_print
()
...
@@ -85,6 +104,7 @@ void operation_default_print()
...
@@ -85,6 +104,7 @@ void operation_default_print()
int
main
()
int
main
()
{
{
operation_copy_test
();
operation_copy_test
();
operation_equal_test
();
operation_any_cast
();
operation_any_cast
();
operation_print
();
operation_print
();
operation_default_print
();
operation_default_print
();
...
...
tools/include/operation.hpp
View file @
2ea9fc44
...
@@ -8,7 +8,8 @@
...
@@ -8,7 +8,8 @@
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/reflect.hpp>
#include <migraph/streamutils.hpp>
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
#include <migraph/auto_any_cast.hpp>
...
@@ -54,11 +55,34 @@ namespace operation_stream {
...
@@ -54,11 +55,34 @@ namespace operation_stream {
template
<
class
T
>
template
<
class
T
>
auto
operator
<<
(
std
::
ostream
&
os
,
const
T
&
x
)
->
decltype
(
os
<<
x
.
name
())
auto
operator
<<
(
std
::
ostream
&
os
,
const
T
&
x
)
->
decltype
(
os
<<
x
.
name
())
{
{
return
os
<<
x
.
name
();
os
<<
x
.
name
();
char
delim
=
'['
;
reflect_each
(
x
,
[
&
](
auto
&
y
,
auto
name
)
{
os
<<
delim
;
os
<<
name
<<
"="
;
stream_write_value
(
os
,
y
);
delim
=
','
;
});
if
(
delim
==
','
)
os
<<
"]"
;
return
os
;
}
}
}
// namespace operation_stream
}
// namespace operation_stream
namespace
operation_equal
{
template
<
class
T
,
class
U
>
auto
operator
==
(
const
T
&
x
,
const
U
&
y
)
->
decltype
(
x
.
name
()
==
y
.
name
())
{
if
(
x
.
name
()
!=
y
.
name
())
return
false
;
const
auto
&
yy
=
any_cast
<
T
>
(
y
);
return
reflect_tie
(
x
)
==
reflect_tie
(
yy
);
}
}
// namespace operation_equal
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
T
&
x
,
...
@@ -85,13 +109,32 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
...
@@ -85,13 +109,32 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
}
}
<%
<%
interface
(
'
operation
'
,
interface
(
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
'
operation
'
,
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
const
std
::
vector
<
shape
>&
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
const
shape
&
'
,
input
=
'
const
std
::
vector
<
argument
>&
'
,
const
=
True
,
default
=
'
compute_op
'
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
const
std
::
vector
<
shape
>&
'
,
const
=
True
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
virtual
(
'
compute
'
,
)
returns
=
'
argument
'
,
%>
ctx
=
'
context
&
'
,
output
=
'
const
shape
&
'
,
input
=
'
const
std
::
vector
<
argument
>&
'
,
const
=
True
,
default
=
'
compute_op
'
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
),
friend
(
'
operator
==
'
,
returns
=
'
bool
'
,
x
=
'
const
operation
&
'
,
y
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_equal
::
operator
==
'
))
%>
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
#endif
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment