Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b28bd72d
Commit
b28bd72d
authored
Aug 16, 2018
by
Paul
Browse files
Formatting
parent
42a952cb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
38 deletions
+40
-38
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+5
-5
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+1
-1
test/op_shape_test.cpp
test/op_shape_test.cpp
+34
-32
No files found.
src/include/migraph/operators.hpp
View file @
b28bd72d
...
@@ -429,14 +429,14 @@ struct flatten
...
@@ -429,14 +429,14 @@ struct flatten
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
auto
&&
lens
=
inputs
.
front
().
lens
();
auto
&&
lens
=
inputs
.
front
().
lens
();
if
(
axis
>
lens
.
size
())
if
(
axis
>
lens
.
size
())
{
{
MIGRAPH_THROW
(
"axis for flatten must be less than tensor rank"
);
MIGRAPH_THROW
(
"axis for flatten must be less than tensor rank"
);
}
}
auto
x
=
std
::
accumulate
(
auto
x
=
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
auto
y
=
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
...
...
test/cpu_ops_test.cpp
View file @
b28bd72d
...
@@ -604,7 +604,7 @@ void transpose_test()
...
@@ -604,7 +604,7 @@ void transpose_test()
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
std
::
vector
<
size_t
>
new_lens
=
{
1
,
3
,
2
,
2
};
std
::
vector
<
size_t
>
new_lens
=
{
1
,
3
,
2
,
2
};
EXPECT
(
bool
{
output
.
get_shape
().
lens
()
==
new_lens
});
EXPECT
(
bool
{
output
.
get_shape
().
lens
()
==
new_lens
});
});
});
}
}
...
...
test/op_shape_test.cpp
View file @
b28bd72d
...
@@ -5,51 +5,54 @@
...
@@ -5,51 +5,54 @@
#include <sstream>
#include <sstream>
#include "test.hpp"
#include "test.hpp"
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
expect_shape
(
migraph
::
shape
expected
,
migraph
::
operation
op
,
Ts
...
xs
)
void
expect_shape
(
migraph
::
shape
expected
,
migraph
::
operation
op
,
Ts
...
xs
)
{
{
migraph
::
program
p
;
migraph
::
program
p
;
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
std
::
vector
<
migraph
::
instruction_ref
>
args
;
std
::
vector
<
migraph
::
instruction_ref
>
args
;
for
(
auto
&&
s
:
shapes
)
for
(
auto
&&
s
:
shapes
)
args
.
push_back
(
p
.
add_outline
(
s
));
args
.
push_back
(
p
.
add_outline
(
s
));
p
.
add_instruction
(
op
,
args
);
p
.
add_instruction
(
op
,
args
);
if
(
p
.
get_shape
()
!=
expected
)
{
if
(
p
.
get_shape
()
!=
expected
)
{
std
::
cout
<<
"FAILED: Incorrect shape for "
<<
op
.
name
()
<<
": "
;
std
::
cout
<<
"FAILED: Incorrect shape for "
<<
op
.
name
()
<<
": "
;
std
::
cout
<<
expected
<<
" != "
<<
p
.
get_shape
()
<<
std
::
endl
;
std
::
cout
<<
expected
<<
" != "
<<
p
.
get_shape
()
<<
std
::
endl
;
for
(
auto
&&
s
:
shapes
)
for
(
auto
&&
s
:
shapes
)
std
::
cout
<<
" "
<<
s
<<
std
::
endl
;
std
::
cout
<<
" "
<<
s
<<
std
::
endl
;
}
}
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
throws_shape
(
migraph
::
operation
op
,
Ts
...
xs
)
void
throws_shape
(
migraph
::
operation
op
,
Ts
...
xs
)
{
{
migraph
::
program
p
;
migraph
::
program
p
;
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
std
::
vector
<
migraph
::
instruction_ref
>
args
;
std
::
vector
<
migraph
::
instruction_ref
>
args
;
for
(
auto
&&
s
:
shapes
)
for
(
auto
&&
s
:
shapes
)
args
.
push_back
(
p
.
add_outline
(
s
));
args
.
push_back
(
p
.
add_outline
(
s
));
bool
thrown
=
test
::
throws
([
&
]
{
p
.
add_instruction
(
op
,
args
);
});
bool
thrown
=
test
::
throws
([
&
]
{
p
.
add_instruction
(
op
,
args
);
});
if
(
not
thrown
)
{
if
(
not
thrown
)
{
std
::
cout
<<
"FAILED: No error found for "
<<
op
.
name
()
<<
": "
;
std
::
cout
<<
"FAILED: No error found for "
<<
op
.
name
()
<<
": "
;
for
(
auto
&&
s
:
shapes
)
for
(
auto
&&
s
:
shapes
)
std
::
cout
<<
" "
<<
s
<<
std
::
endl
;
std
::
cout
<<
" "
<<
s
<<
std
::
endl
;
}
}
}
}
template
<
class
...
>
template
<
class
...
>
struct
always_false
struct
always_false
:
std
::
false_type
:
std
::
false_type
{
{
};
};
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
throws_shape
(
migraph
::
shape
,
Ts
...)
void
throws_shape
(
migraph
::
shape
,
Ts
...)
{
{
static_assert
(
always_false
<
Ts
...
>
{},
"An expected shape should not be passed to throws_shape function"
);
static_assert
(
always_false
<
Ts
...
>
{},
"An expected shape should not be passed to throws_shape function"
);
}
}
void
batch_norm_inference_shape
()
void
batch_norm_inference_shape
()
{
{
const
size_t
channels
=
3
;
const
size_t
channels
=
3
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
4
,
channels
,
3
,
3
}};
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
4
,
channels
,
3
,
3
}};
...
@@ -59,7 +62,7 @@ void batch_norm_inference_shape()
...
@@ -59,7 +62,7 @@ void batch_norm_inference_shape()
throws_shape
(
migraph
::
batch_norm_inference
{},
s
,
vars
,
vars
,
vars
,
vars
,
vars
);
throws_shape
(
migraph
::
batch_norm_inference
{},
s
,
vars
,
vars
,
vars
,
vars
,
vars
);
}
}
void
convolution_shape
()
void
convolution_shape
()
{
{
migraph
::
shape
output
{
migraph
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
migraph
::
shape
output
{
migraph
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
...
@@ -88,7 +91,7 @@ void contiguous_shape()
...
@@ -88,7 +91,7 @@ void contiguous_shape()
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}};
expect_shape
(
output
,
migraph
::
contiguous
{},
input
);
expect_shape
(
output
,
migraph
::
contiguous
{},
input
);
throws_shape
(
migraph
::
contiguous
{},
input
,
input
);
throws_shape
(
migraph
::
contiguous
{},
input
,
input
);
migraph
::
shape
single
{
migraph
::
shape
::
float_type
,
{
2
}};
migraph
::
shape
single
{
migraph
::
shape
::
float_type
,
{
2
}};
throws_shape
(
migraph
::
contiguous
{},
single
);
throws_shape
(
migraph
::
contiguous
{},
single
);
}
}
...
@@ -96,11 +99,8 @@ void contiguous_shape()
...
@@ -96,11 +99,8 @@ void contiguous_shape()
void
reshape_shape
()
void
reshape_shape
()
{
{
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
24
,
1
,
1
,
1
}};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
24
,
1
,
1
,
1
}};
for
(
auto
&&
new_shape
:
std
::
vector
<
std
::
vector
<
int64_t
>>
{
for
(
auto
&&
new_shape
:
{
8
,
3
,
1
,
1
},
std
::
vector
<
std
::
vector
<
int64_t
>>
{{
8
,
3
,
1
,
1
},
{
1
,
3
,
4
,
2
},
{
1
,
3
,
4
,
2
}})
{
1
,
3
,
4
,
2
},
{
1
,
3
,
4
,
2
}
})
{
{
std
::
vector
<
std
::
size_t
>
lens
(
new_shape
.
size
());
std
::
vector
<
std
::
size_t
>
lens
(
new_shape
.
size
());
std
::
copy
(
new_shape
.
begin
(),
new_shape
.
end
(),
lens
.
begin
());
std
::
copy
(
new_shape
.
begin
(),
new_shape
.
end
(),
lens
.
begin
());
...
@@ -108,10 +108,7 @@ void reshape_shape()
...
@@ -108,10 +108,7 @@ void reshape_shape()
expect_shape
(
output
,
migraph
::
reshape
{
new_shape
},
input
);
expect_shape
(
output
,
migraph
::
reshape
{
new_shape
},
input
);
}
}
for
(
auto
&&
new_shape
:
std
::
vector
<
std
::
vector
<
int64_t
>>
{
for
(
auto
&&
new_shape
:
std
::
vector
<
std
::
vector
<
int64_t
>>
{{
8
,
3
,
2
,
2
},
{
1
,
3
,
-
1
,
-
1
}})
{
8
,
3
,
2
,
2
},
{
1
,
3
,
-
1
,
-
1
}
})
{
{
throws_shape
(
migraph
::
reshape
{
new_shape
},
input
);
throws_shape
(
migraph
::
reshape
{
new_shape
},
input
);
}
}
...
@@ -120,15 +117,20 @@ void reshape_shape()
...
@@ -120,15 +117,20 @@ void reshape_shape()
void
flatten_shape
()
void
flatten_shape
()
{
{
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
2
,
4
,
6
,
8
}};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
2
,
4
,
6
,
8
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
2
*
4
*
6
*
8
}},
migraph
::
flatten
{
0
},
input
);
expect_shape
(
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
4
*
6
*
8
}},
migraph
::
flatten
{
1
},
input
);
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
2
*
4
*
6
*
8
}},
migraph
::
flatten
{
0
},
input
);
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
*
4
,
6
*
8
}},
migraph
::
flatten
{
2
},
input
);
expect_shape
(
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
*
4
*
6
,
8
}},
migraph
::
flatten
{
3
},
input
);
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
4
*
6
*
8
}},
migraph
::
flatten
{
1
},
input
);
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
*
4
*
6
*
8
,
1
}},
migraph
::
flatten
{
4
},
input
);
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
*
4
,
6
*
8
}},
migraph
::
flatten
{
2
},
input
);
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
*
4
*
6
,
8
}},
migraph
::
flatten
{
3
},
input
);
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
*
4
*
6
*
8
,
1
}},
migraph
::
flatten
{
4
},
input
);
throws_shape
(
migraph
::
flatten
{
5
},
input
);
throws_shape
(
migraph
::
flatten
{
5
},
input
);
}
}
int
main
()
int
main
()
{
{
batch_norm_inference_shape
();
batch_norm_inference_shape
();
convolution_shape
();
convolution_shape
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment