Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bff0223b
Commit
bff0223b
authored
Jun 22, 2018
by
Scott Thornton
Browse files
Added test for MNIST
parent
79fe7d41
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
160 additions
and
6 deletions
+160
-6
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+4
-0
src/onnx/mnist.cpp
src/onnx/mnist.cpp
+146
-0
src/program.cpp
src/program.cpp
+4
-0
src/targets/cpu/cpu_target.cpp
src/targets/cpu/cpu_target.cpp
+6
-6
No files found.
src/onnx/CMakeLists.txt
View file @
bff0223b
...
@@ -13,3 +13,7 @@ target_link_libraries(rtg_onnx onnx-proto rtg)
...
@@ -13,3 +13,7 @@ target_link_libraries(rtg_onnx onnx-proto rtg)
add_executable
(
read_onnx read_onnx.cpp
)
add_executable
(
read_onnx read_onnx.cpp
)
rocm_clang_tidy_check
(
read_onnx
)
rocm_clang_tidy_check
(
read_onnx
)
target_link_libraries
(
read_onnx rtg_onnx rtg_cpu
)
target_link_libraries
(
read_onnx rtg_onnx rtg_cpu
)
add_executable
(
mnist mnist.cpp
)
rocm_clang_tidy_check
(
mnist
)
target_link_libraries
(
mnist rtg_onnx rtg_cpu
)
src/onnx/mnist.cpp
0 → 100644
View file @
bff0223b
#include <cstdio>
#include <string>
#include <fstream>
#include <stdexcept>
#include <rtg/onnx.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include <rtg/generate.hpp>
std
::
vector
<
float
>
read_mnist_images
(
std
::
string
full_path
,
int
&
number_of_images
,
int
&
image_size
)
{
auto
reverseInt
=
[](
int
i
)
{
unsigned
char
c1
,
c2
,
c3
,
c4
;
c1
=
i
&
255
;
c2
=
(
i
>>
8
)
&
255
;
c3
=
(
i
>>
16
)
&
255
;
c4
=
(
i
>>
24
)
&
255
;
return
(
static_cast
<
int
>
(
c1
)
<<
24
)
+
(
static_cast
<
int
>
(
c2
)
<<
16
)
+
(
static_cast
<
int
>
(
c3
)
<<
8
)
+
c4
;
};
typedef
unsigned
char
uchar
;
std
::
ifstream
file
(
full_path
,
std
::
ios
::
binary
);
if
(
file
.
is_open
())
{
int
magic_number
=
0
,
n_rows
=
0
,
n_cols
=
0
;
file
.
read
((
char
*
)
&
magic_number
,
sizeof
(
magic_number
));
magic_number
=
reverseInt
(
magic_number
);
if
(
magic_number
!=
2051
)
throw
std
::
runtime_error
(
"Invalid MNIST image file!"
);
file
.
read
((
char
*
)
&
number_of_images
,
sizeof
(
number_of_images
)),
number_of_images
=
reverseInt
(
number_of_images
);
file
.
read
((
char
*
)
&
n_rows
,
sizeof
(
n_rows
)),
n_rows
=
reverseInt
(
n_rows
);
file
.
read
((
char
*
)
&
n_cols
,
sizeof
(
n_cols
)),
n_cols
=
reverseInt
(
n_cols
);
image_size
=
n_rows
*
n_cols
;
printf
(
"n_rows: %d n_cols: %d image_size: %d
\n\n
"
,
n_rows
,
n_cols
,
image_size
);
// uchar** _dataset = new uchar*[number_of_images];
// for(int i = 0; i < number_of_images; i++) {
// _dataset[i] = new uchar[image_size];
// file.read((char *)_dataset[i], image_size);
// }
std
::
vector
<
float
>
result
(
number_of_images
*
image_size
);
for
(
int
i
=
0
;
i
<
number_of_images
;
i
++
)
{
for
(
int
j
=
0
;
j
<
image_size
;
j
++
)
{
uchar
tmp
;
file
.
read
((
char
*
)
&
tmp
,
1
);
result
[
i
*
image_size
+
j
]
=
tmp
/
255.0
;
}
}
return
result
;
}
else
{
throw
std
::
runtime_error
(
"Cannot open file `"
+
full_path
+
"`!"
);
}
}
std
::
vector
<
int32_t
>
read_mnist_labels
(
std
::
string
full_path
,
int
&
number_of_labels
)
{
auto
reverseInt
=
[](
int
i
)
{
unsigned
char
c1
,
c2
,
c3
,
c4
;
c1
=
i
&
255
;
c2
=
(
i
>>
8
)
&
255
;
c3
=
(
i
>>
16
)
&
255
;
c4
=
(
i
>>
24
)
&
255
;
return
(
static_cast
<
int
>
(
c1
)
<<
24
)
+
(
static_cast
<
int
>
(
c2
)
<<
16
)
+
(
static_cast
<
int
>
(
c3
)
<<
8
)
+
c4
;
};
typedef
unsigned
char
uchar
;
std
::
ifstream
file
(
full_path
,
std
::
ios
::
binary
);
if
(
file
.
is_open
())
{
int
magic_number
=
0
;
file
.
read
((
char
*
)
&
magic_number
,
sizeof
(
magic_number
));
magic_number
=
reverseInt
(
magic_number
);
if
(
magic_number
!=
2049
)
throw
std
::
runtime_error
(
"Invalid MNIST label file!"
);
file
.
read
((
char
*
)
&
number_of_labels
,
sizeof
(
number_of_labels
)),
number_of_labels
=
reverseInt
(
number_of_labels
);
std
::
vector
<
int32_t
>
result
(
number_of_labels
);
for
(
int
i
=
0
;
i
<
number_of_labels
;
i
++
)
{
uchar
tmp
;
file
.
read
((
char
*
)
&
tmp
,
1
);
result
[
i
]
=
tmp
;
}
return
result
;
}
else
{
throw
std
::
runtime_error
(
"Unable to open file `"
+
full_path
+
"`!"
);
}
}
int
main
(
int
argc
,
char
const
*
argv
[])
{
if
(
argc
>
1
)
{
std
::
string
datafile
=
argv
[
2
];
std
::
string
labelfile
=
argv
[
3
];
int
nimages
=
-
1
;
int
image_size
=
-
1
;
int
nlabels
=
-
1
;
std
::
vector
<
float
>
input
=
read_mnist_images
(
datafile
,
nimages
,
image_size
);
std
::
vector
<
int32_t
>
labels
=
read_mnist_labels
(
labelfile
,
nlabels
);
printf
(
"label: %d
\n\n
"
,
labels
[
0
]);
for
(
int
i
=
7
;
i
<
9
;
i
++
)
{
for
(
int
j
=
0
;
j
<
28
;
j
++
)
{
printf
(
"%8.5f "
,
input
[
i
*
28
+
j
]);
}
printf
(
"
\n
"
);
}
std
::
string
file
=
argv
[
1
];
auto
prog
=
rtg
::
parse_onnx
(
file
);
prog
.
compile
(
rtg
::
cpu
::
cpu_target
{});
auto
s
=
prog
.
get_parameter_shape
(
"Input3"
);
std
::
cout
<<
s
<<
std
::
endl
;
auto
input3
=
rtg
::
argument
{
s
,
input
.
data
()};
auto
out
=
prog
.
eval
({{
"Input3"
,
input3
}});
std
::
cout
<<
out
<<
std
::
endl
;
std
::
cout
<<
prog
<<
std
::
endl
;
}
}
src/program.cpp
View file @
bff0223b
...
@@ -141,6 +141,10 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
...
@@ -141,6 +141,10 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
values
.
begin
(),
values
.
begin
(),
[
&
](
instruction_ref
i
)
{
return
results
.
at
(
std
::
addressof
(
*
i
));
});
[
&
](
instruction_ref
i
)
{
return
results
.
at
(
std
::
addressof
(
*
i
));
});
result
=
ins
.
op
.
compute
(
ins
.
result
,
values
);
result
=
ins
.
op
.
compute
(
ins
.
result
,
values
);
std
::
cout
<<
"Debug: "
<<
ins
.
op
.
name
()
<<
"
\n
"
;
if
(
result
.
get_shape
().
elements
()
>
0
and
result
.
get_shape
().
packed
()
and
std
::
isnan
(
result
.
at
<
float
>
()))
std
::
cout
<<
"Nan: "
<<
ins
.
op
.
name
()
<<
std
::
endl
;
}
}
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
}
}
...
...
src/targets/cpu/cpu_target.cpp
View file @
bff0223b
...
@@ -60,7 +60,11 @@ struct max_pool
...
@@ -60,7 +60,11 @@ struct max_pool
static
std
::
string
name
()
{
return
"max"
;
}
static
std
::
string
name
()
{
return
"max"
;
}
static
double
start
()
{
return
std
::
numeric_limits
<
double
>::
lowest
();
}
static
double
start
()
{
return
std
::
numeric_limits
<
double
>::
lowest
();
}
static
double
apply
(
double
x
,
double
y
)
{
return
x
+
y
;
}
static
double
apply
(
double
x
,
double
y
)
{
double
m
=
std
::
max
(
x
,
y
);
return
(
m
);
}
static
double
final
(
double
x
,
double
)
{
return
(
x
);
}
static
double
final
(
double
x
,
double
)
{
return
(
x
);
}
};
};
...
@@ -70,11 +74,7 @@ struct avg_pool
...
@@ -70,11 +74,7 @@ struct avg_pool
static
std
::
string
name
()
{
return
"average"
;
}
static
std
::
string
name
()
{
return
"average"
;
}
static
double
start
()
{
return
0.0
;
}
static
double
start
()
{
return
0.0
;
}
static
double
apply
(
double
x
,
double
y
)
static
double
apply
(
double
x
,
double
y
)
{
return
x
+
y
;
}
{
double
m
=
std
::
max
(
x
,
y
);
return
(
m
);
}
static
double
final
(
double
x
,
double
y
)
{
return
x
/
y
;
}
static
double
final
(
double
x
,
double
y
)
{
return
x
/
y
;
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment