Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a1952a8d
Commit
a1952a8d
authored
Sep 15, 2018
by
Paul
Browse files
Make compute optional
parent
f550da30
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
55 additions
and
56 deletions
+55
-56
src/include/migraph/operation.hpp
src/include/migraph/operation.hpp
+19
-1
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+0
-44
src/include/migraph/ranges.hpp
src/include/migraph/ranges.hpp
+1
-10
src/include/migraph/rank.hpp
src/include/migraph/rank.hpp
+18
-0
tools/include/operation.hpp
tools/include/operation.hpp
+17
-1
No files found.
src/include/migraph/operation.hpp
View file @
a1952a8d
...
...
@@ -8,6 +8,7 @@
#include <type_traits>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
...
...
@@ -55,11 +56,28 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
MIGRAPH_THROW
(
"Not computable: "
+
x
.
name
());
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
x
.
compute
(
auto_any_cast
(
ctx
)
,
output_shape
,
input
);
return
compute
_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
}
/*
...
...
src/include/migraph/operators.hpp
View file @
a1952a8d
...
...
@@ -41,11 +41,6 @@ struct batch_norm_inference
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
return
inputs
.
front
();
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
struct
convolution
...
...
@@ -115,11 +110,6 @@ struct convolution
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
...
...
@@ -169,11 +159,6 @@ struct im2col
auto
channels_col
=
kernel_height
*
kernel_width
*
input_channels
;
return
{
input
.
type
(),
{
output_height
*
output_width
,
channels_col
}};
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
struct
pooling
...
...
@@ -211,11 +196,6 @@ struct pooling
}};
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
...
...
@@ -236,11 +216,6 @@ struct activation
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
{
os
<<
op
.
name
()
<<
":"
<<
op
.
mode
;
...
...
@@ -305,10 +280,6 @@ struct contiguous
}
return
{
t
,
lens
};
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
struct
reshape
...
...
@@ -349,12 +320,10 @@ struct reshape
MIGRAPH_THROW
(
"Wrong number of elements for reshape"
);
return
s
;
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
...
...
@@ -382,11 +351,6 @@ struct gemm
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
gemm
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
...
...
@@ -402,10 +366,6 @@ struct unary
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
struct
identity
:
unary
...
...
@@ -553,10 +513,6 @@ struct binary
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
struct
add
:
binary
...
...
src/include/migraph/ranges.hpp
View file @
a1952a8d
...
...
@@ -3,19 +3,10 @@
#include <algorithm>
#include <initializer_list>
#include <migraph/rank.hpp>
namespace
migraph
{
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
{
};
template
<
>
struct
rank
<
0
>
{
};
namespace
detail
{
template
<
class
String
,
class
T
>
...
...
src/include/migraph/rank.hpp
0 → 100644
View file @
a1952a8d
#ifndef MIGRAPH_GUARD_RTGLIB_RANK_HPP
#define MIGRAPH_GUARD_RTGLIB_RANK_HPP
namespace
migraph
{
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
{
};
template
<
>
struct
rank
<
0
>
{
};
}
// namespace migraph
#endif
tools/include/operation.hpp
View file @
a1952a8d
...
...
@@ -8,6 +8,7 @@
#include <type_traits>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
...
...
@@ -55,11 +56,26 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
MIGRAPH_THROW
(
"Not computable: "
+
x
.
name
());
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
x
.
compute
(
auto_any_cast
(
ctx
)
,
output_shape
,
input
);
return
compute
_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
}
<%
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment