Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bff0223b
Commit
bff0223b
authored
Jun 22, 2018
by
Scott Thornton
Browse files
Added test for MNIST
parent
79fe7d41
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
160 additions
and
6 deletions
+160
-6
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+4
-0
src/onnx/mnist.cpp
src/onnx/mnist.cpp
+146
-0
src/program.cpp
src/program.cpp
+4
-0
src/targets/cpu/cpu_target.cpp
src/targets/cpu/cpu_target.cpp
+6
-6
No files found.
src/onnx/CMakeLists.txt
View file @
bff0223b
...
@@ -13,3 +13,7 @@ target_link_libraries(rtg_onnx onnx-proto rtg)
...
@@ -13,3 +13,7 @@ target_link_libraries(rtg_onnx onnx-proto rtg)
add_executable
(
read_onnx read_onnx.cpp
)
add_executable
(
read_onnx read_onnx.cpp
)
rocm_clang_tidy_check
(
read_onnx
)
rocm_clang_tidy_check
(
read_onnx
)
target_link_libraries
(
read_onnx rtg_onnx rtg_cpu
)
target_link_libraries
(
read_onnx rtg_onnx rtg_cpu
)
add_executable
(
mnist mnist.cpp
)
rocm_clang_tidy_check
(
mnist
)
target_link_libraries
(
mnist rtg_onnx rtg_cpu
)
src/onnx/mnist.cpp
0 → 100644
View file @
bff0223b
#include <cstdio>
#include <string>
#include <fstream>
#include <stdexcept>
#include <rtg/onnx.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include <rtg/generate.hpp>
std
::
vector
<
float
>
read_mnist_images
(
std
::
string
full_path
,
int
&
number_of_images
,
int
&
image_size
)
{
auto
reverseInt
=
[](
int
i
)
{
unsigned
char
c1
,
c2
,
c3
,
c4
;
c1
=
i
&
255
;
c2
=
(
i
>>
8
)
&
255
;
c3
=
(
i
>>
16
)
&
255
;
c4
=
(
i
>>
24
)
&
255
;
return
(
static_cast
<
int
>
(
c1
)
<<
24
)
+
(
static_cast
<
int
>
(
c2
)
<<
16
)
+
(
static_cast
<
int
>
(
c3
)
<<
8
)
+
c4
;
};
typedef
unsigned
char
uchar
;
std
::
ifstream
file
(
full_path
,
std
::
ios
::
binary
);
if
(
file
.
is_open
())
{
int
magic_number
=
0
,
n_rows
=
0
,
n_cols
=
0
;
file
.
read
((
char
*
)
&
magic_number
,
sizeof
(
magic_number
));
magic_number
=
reverseInt
(
magic_number
);
if
(
magic_number
!=
2051
)
throw
std
::
runtime_error
(
"Invalid MNIST image file!"
);
file
.
read
((
char
*
)
&
number_of_images
,
sizeof
(
number_of_images
)),
number_of_images
=
reverseInt
(
number_of_images
);
file
.
read
((
char
*
)
&
n_rows
,
sizeof
(
n_rows
)),
n_rows
=
reverseInt
(
n_rows
);
file
.
read
((
char
*
)
&
n_cols
,
sizeof
(
n_cols
)),
n_cols
=
reverseInt
(
n_cols
);
image_size
=
n_rows
*
n_cols
;
printf
(
"n_rows: %d n_cols: %d image_size: %d
\n\n
"
,
n_rows
,
n_cols
,
image_size
);
// uchar** _dataset = new uchar*[number_of_images];
// for(int i = 0; i < number_of_images; i++) {
// _dataset[i] = new uchar[image_size];
// file.read((char *)_dataset[i], image_size);
// }
std
::
vector
<
float
>
result
(
number_of_images
*
image_size
);
for
(
int
i
=
0
;
i
<
number_of_images
;
i
++
)
{
for
(
int
j
=
0
;
j
<
image_size
;
j
++
)
{
uchar
tmp
;
file
.
read
((
char
*
)
&
tmp
,
1
);
result
[
i
*
image_size
+
j
]
=
tmp
/
255.0
;
}
}
return
result
;
}
else
{
throw
std
::
runtime_error
(
"Cannot open file `"
+
full_path
+
"`!"
);
}
}
std
::
vector
<
int32_t
>
read_mnist_labels
(
std
::
string
full_path
,
int
&
number_of_labels
)
{
auto
reverseInt
=
[](
int
i
)
{
unsigned
char
c1
,
c2
,
c3
,
c4
;
c1
=
i
&
255
;
c2
=
(
i
>>
8
)
&
255
;
c3
=
(
i
>>
16
)
&
255
;
c4
=
(
i
>>
24
)
&
255
;
return
(
static_cast
<
int
>
(
c1
)
<<
24
)
+
(
static_cast
<
int
>
(
c2
)
<<
16
)
+
(
static_cast
<
int
>
(
c3
)
<<
8
)
+
c4
;
};
typedef
unsigned
char
uchar
;
std
::
ifstream
file
(
full_path
,
std
::
ios
::
binary
);
if
(
file
.
is_open
())
{
int
magic_number
=
0
;
file
.
read
((
char
*
)
&
magic_number
,
sizeof
(
magic_number
));
magic_number
=
reverseInt
(
magic_number
);
if
(
magic_number
!=
2049
)
throw
std
::
runtime_error
(
"Invalid MNIST label file!"
);
file
.
read
((
char
*
)
&
number_of_labels
,
sizeof
(
number_of_labels
)),
number_of_labels
=
reverseInt
(
number_of_labels
);
std
::
vector
<
int32_t
>
result
(
number_of_labels
);
for
(
int
i
=
0
;
i
<
number_of_labels
;
i
++
)
{
uchar
tmp
;
file
.
read
((
char
*
)
&
tmp
,
1
);
result
[
i
]
=
tmp
;
}
return
result
;
}
else
{
throw
std
::
runtime_error
(
"Unable to open file `"
+
full_path
+
"`!"
);
}
}
int
main
(
int
argc
,
char
const
*
argv
[])
{
if
(
argc
>
1
)
{
std
::
string
datafile
=
argv
[
2
];
std
::
string
labelfile
=
argv
[
3
];
int
nimages
=
-
1
;
int
image_size
=
-
1
;
int
nlabels
=
-
1
;
std
::
vector
<
float
>
input
=
read_mnist_images
(
datafile
,
nimages
,
image_size
);
std
::
vector
<
int32_t
>
labels
=
read_mnist_labels
(
labelfile
,
nlabels
);
printf
(
"label: %d
\n\n
"
,
labels
[
0
]);
for
(
int
i
=
7
;
i
<
9
;
i
++
)
{
for
(
int
j
=
0
;
j
<
28
;
j
++
)
{
printf
(
"%8.5f "
,
input
[
i
*
28
+
j
]);
}
printf
(
"
\n
"
);
}
std
::
string
file
=
argv
[
1
];
auto
prog
=
rtg
::
parse_onnx
(
file
);
prog
.
compile
(
rtg
::
cpu
::
cpu_target
{});
auto
s
=
prog
.
get_parameter_shape
(
"Input3"
);
std
::
cout
<<
s
<<
std
::
endl
;
auto
input3
=
rtg
::
argument
{
s
,
input
.
data
()};
auto
out
=
prog
.
eval
({{
"Input3"
,
input3
}});
std
::
cout
<<
out
<<
std
::
endl
;
std
::
cout
<<
prog
<<
std
::
endl
;
}
}
src/program.cpp
View file @
bff0223b
...
@@ -141,6 +141,10 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
...
@@ -141,6 +141,10 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
values
.
begin
(),
values
.
begin
(),
[
&
](
instruction_ref
i
)
{
return
results
.
at
(
std
::
addressof
(
*
i
));
});
[
&
](
instruction_ref
i
)
{
return
results
.
at
(
std
::
addressof
(
*
i
));
});
result
=
ins
.
op
.
compute
(
ins
.
result
,
values
);
result
=
ins
.
op
.
compute
(
ins
.
result
,
values
);
std
::
cout
<<
"Debug: "
<<
ins
.
op
.
name
()
<<
"
\n
"
;
if
(
result
.
get_shape
().
elements
()
>
0
and
result
.
get_shape
().
packed
()
and
std
::
isnan
(
result
.
at
<
float
>
()))
std
::
cout
<<
"Nan: "
<<
ins
.
op
.
name
()
<<
std
::
endl
;
}
}
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
}
}
...
...
src/targets/cpu/cpu_target.cpp
View file @
bff0223b
...
@@ -60,7 +60,11 @@ struct max_pool
...
@@ -60,7 +60,11 @@ struct max_pool
static
std
::
string
name
()
{
return
"max"
;
}
static
std
::
string
name
()
{
return
"max"
;
}
static
double
start
()
{
return
std
::
numeric_limits
<
double
>::
lowest
();
}
static
double
start
()
{
return
std
::
numeric_limits
<
double
>::
lowest
();
}
static
double
apply
(
double
x
,
double
y
)
{
return
x
+
y
;
}
static
double
apply
(
double
x
,
double
y
)
{
double
m
=
std
::
max
(
x
,
y
);
return
(
m
);
}
static
double
final
(
double
x
,
double
)
{
return
(
x
);
}
static
double
final
(
double
x
,
double
)
{
return
(
x
);
}
};
};
...
@@ -70,11 +74,7 @@ struct avg_pool
...
@@ -70,11 +74,7 @@ struct avg_pool
static
std
::
string
name
()
{
return
"average"
;
}
static
std
::
string
name
()
{
return
"average"
;
}
static
double
start
()
{
return
0.0
;
}
static
double
start
()
{
return
0.0
;
}
static
double
apply
(
double
x
,
double
y
)
static
double
apply
(
double
x
,
double
y
)
{
return
x
+
y
;
}
{
double
m
=
std
::
max
(
x
,
y
);
return
(
m
);
}
static
double
final
(
double
x
,
double
y
)
{
return
x
/
y
;
}
static
double
final
(
double
x
,
double
y
)
{
return
x
/
y
;
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment