Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5b21a77f
Commit
5b21a77f
authored
May 21, 2018
by
Paul
Browse files
Add check_shapes helper class
parent
e15f5d2a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
16 deletions
+72
-16
src/include/rtg/operators.hpp
src/include/rtg/operators.hpp
+72
-16
No files found.
src/include/rtg/operators.hpp
View file @
5b21a77f
...
@@ -8,6 +8,73 @@
...
@@ -8,6 +8,73 @@
namespace
rtg
{
namespace
rtg
{
struct
check_shapes
{
const
std
::
vector
<
shape
>*
shapes
;
check_shapes
(
const
std
::
vector
<
shape
>&
s
)
:
shapes
(
&
s
)
{}
const
check_shapes
&
has
(
std
::
size_t
n
)
const
{
assert
(
shapes
!=
nullptr
);
if
(
shapes
->
size
()
!=
n
)
RTG_THROW
(
"Wrong number of arguments: expected "
+
std
::
to_string
(
n
)
+
" but given "
+
std
::
to_string
(
shapes
->
size
()));
return
*
this
;
}
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
{
assert
(
shapes
!=
nullptr
);
if
(
!
shapes
->
empty
())
{
if
(
shapes
->
front
().
lens
().
size
()
!=
n
)
RTG_THROW
(
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
}
return
*
this
;
}
const
check_shapes
&
same_shape
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
RTG_THROW
(
"Shapes do not match"
);
return
*
this
;
}
const
check_shapes
&
same_type
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
RTG_THROW
(
"Types do not match"
);
return
*
this
;
}
const
check_shapes
&
same_dims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
lens
();
}))
RTG_THROW
(
"Dimensions do not match"
);
return
*
this
;
}
template
<
class
F
>
bool
same
(
F
f
)
const
{
assert
(
shapes
!=
nullptr
);
if
(
shapes
->
empty
())
return
true
;
auto
&&
key
=
f
(
shapes
->
front
());
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
}
template
<
class
Predicate
>
bool
all_of
(
Predicate
p
)
const
{
assert
(
shapes
!=
nullptr
);
return
std
::
all_of
(
shapes
->
begin
(),
shapes
->
end
(),
p
);
}
};
struct
not_computable
struct
not_computable
{
{
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
...
@@ -21,17 +88,10 @@ struct convolution
...
@@ -21,17 +88,10 @@ struct convolution
std
::
string
name
()
const
{
return
"convolution"
;
}
std
::
string
name
()
const
{
return
"convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
size
()
!=
2
)
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
().
only_dims
(
4
);
RTG_THROW
(
"Wrong number of arguments"
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
const
shape
&
weights
=
inputs
.
at
(
1
);
if
(
input
.
type
()
!=
weights
.
type
())
RTG_THROW
(
"Type doesn't match"
);
if
(
input
.
lens
().
size
()
!=
weights
.
lens
().
size
())
RTG_THROW
(
"Dimensions don't match"
);
if
(
input
.
lens
().
size
()
!=
4
)
RTG_THROW
(
"Only 4d convolution supported"
);
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
return
{
t
,
return
{
t
,
{
{
...
@@ -74,12 +134,9 @@ struct pooling
...
@@ -74,12 +134,9 @@ struct pooling
std
::
string
name
()
const
{
return
"pooling"
;
}
std
::
string
name
()
const
{
return
"pooling"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
empty
())
check_shapes
{
inputs
}.
has
(
1
).
only_dims
(
4
);
RTG_THROW
(
"Wrong number of arguments"
);
const
shape
&
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
size
()
!=
4
)
RTG_THROW
(
"Only 4d pooling supported"
);
const
shape
&
input
=
inputs
.
at
(
0
);
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
return
{
t
,
return
{
t
,
{
{
...
@@ -117,8 +174,7 @@ struct activation
...
@@ -117,8 +174,7 @@ struct activation
std
::
string
name
()
const
{
return
"activation"
;
}
std
::
string
name
()
const
{
return
"activation"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
empty
())
check_shapes
{
inputs
}.
has
(
1
);
RTG_THROW
(
"Wrong number of arguments"
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment