Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a1952a8d
Commit
a1952a8d
authored
Sep 15, 2018
by
Paul
Browse files
Make compute optional
parent
f550da30
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
55 additions
and
56 deletions
+55
-56
src/include/migraph/operation.hpp
src/include/migraph/operation.hpp
+19
-1
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+0
-44
src/include/migraph/ranges.hpp
src/include/migraph/ranges.hpp
+1
-10
src/include/migraph/rank.hpp
src/include/migraph/rank.hpp
+18
-0
tools/include/operation.hpp
tools/include/operation.hpp
+17
-1
No files found.
src/include/migraph/operation.hpp
View file @
a1952a8d
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
#include <migraph/auto_any_cast.hpp>
...
@@ -55,11 +56,28 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -55,11 +56,28 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
}
// namespace operation_stream
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
MIGRAPH_THROW
(
"Not computable: "
+
x
.
name
());
}
template
<
class
T
>
template
<
class
T
>
argument
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
)
,
output_shape
,
input
);
return
compute
_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
}
}
/*
/*
...
...
src/include/migraph/operators.hpp
View file @
a1952a8d
...
@@ -41,11 +41,6 @@ struct batch_norm_inference
...
@@ -41,11 +41,6 @@ struct batch_norm_inference
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
};
struct
convolution
struct
convolution
...
@@ -115,11 +110,6 @@ struct convolution
...
@@ -115,11 +110,6 @@ struct convolution
}
}
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
{
{
os
<<
op
.
name
()
<<
"["
;
os
<<
op
.
name
()
<<
"["
;
...
@@ -169,11 +159,6 @@ struct im2col
...
@@ -169,11 +159,6 @@ struct im2col
auto
channels_col
=
kernel_height
*
kernel_width
*
input_channels
;
auto
channels_col
=
kernel_height
*
kernel_width
*
input_channels
;
return
{
input
.
type
(),
{
output_height
*
output_width
,
channels_col
}};
return
{
input
.
type
(),
{
output_height
*
output_width
,
channels_col
}};
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
};
struct
pooling
struct
pooling
...
@@ -211,11 +196,6 @@ struct pooling
...
@@ -211,11 +196,6 @@ struct pooling
}};
}};
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
{
{
os
<<
op
.
name
()
<<
"["
;
os
<<
op
.
name
()
<<
"["
;
...
@@ -236,11 +216,6 @@ struct activation
...
@@ -236,11 +216,6 @@ struct activation
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
{
{
os
<<
op
.
name
()
<<
":"
<<
op
.
mode
;
os
<<
op
.
name
()
<<
":"
<<
op
.
mode
;
...
@@ -305,10 +280,6 @@ struct contiguous
...
@@ -305,10 +280,6 @@ struct contiguous
}
}
return
{
t
,
lens
};
return
{
t
,
lens
};
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
};
struct
reshape
struct
reshape
...
@@ -349,12 +320,10 @@ struct reshape
...
@@ -349,12 +320,10 @@ struct reshape
MIGRAPH_THROW
(
"Wrong number of elements for reshape"
);
MIGRAPH_THROW
(
"Wrong number of elements for reshape"
);
return
s
;
return
s
;
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
{
os
<<
op
.
name
()
<<
"["
;
os
<<
op
.
name
()
<<
"["
;
...
@@ -382,11 +351,6 @@ struct gemm
...
@@ -382,11 +351,6 @@ struct gemm
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
gemm
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
gemm
&
op
)
{
{
os
<<
op
.
name
()
<<
"["
;
os
<<
op
.
name
()
<<
"["
;
...
@@ -402,10 +366,6 @@ struct unary
...
@@ -402,10 +366,6 @@ struct unary
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
};
struct
identity
:
unary
struct
identity
:
unary
...
@@ -553,10 +513,6 @@ struct binary
...
@@ -553,10 +513,6 @@ struct binary
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
};
};
struct
add
:
binary
struct
add
:
binary
...
...
src/include/migraph/ranges.hpp
View file @
a1952a8d
...
@@ -3,19 +3,10 @@
...
@@ -3,19 +3,10 @@
#include <algorithm>
#include <algorithm>
#include <initializer_list>
#include <initializer_list>
#include <migraph/rank.hpp>
namespace
migraph
{
namespace
migraph
{
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
{
};
template
<
>
struct
rank
<
0
>
{
};
namespace
detail
{
namespace
detail
{
template
<
class
String
,
class
T
>
template
<
class
String
,
class
T
>
...
...
src/include/migraph/rank.hpp
0 → 100644
View file @
a1952a8d
#ifndef MIGRAPH_GUARD_RTGLIB_RANK_HPP
#define MIGRAPH_GUARD_RTGLIB_RANK_HPP
namespace
migraph
{
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
{
};
template
<
>
struct
rank
<
0
>
{
};
}
// namespace migraph
#endif
tools/include/operation.hpp
View file @
a1952a8d
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
#include <migraph/auto_any_cast.hpp>
...
@@ -55,11 +56,26 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -55,11 +56,26 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
}
// namespace operation_stream
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
MIGRAPH_THROW
(
"Not computable: "
+
x
.
name
());
}
template
<
class
T
>
template
<
class
T
>
argument
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
)
,
output_shape
,
input
);
return
compute
_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
}
}
<%
<%
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment