Commit 95e70f20 authored by Paul's avatar Paul
Browse files

Add pooling and run on the cpu

parent 9e19f2d7
...@@ -8,4 +8,4 @@ target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) ...@@ -8,4 +8,4 @@ target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
add_executable(read_onnx read_onnx.cpp) add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx onnx-proto rtg) target_link_libraries(read_onnx onnx-proto rtg rtg_cpu)
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
#include <rtg/program.hpp> #include <rtg/program.hpp>
#include <rtg/operators.hpp> #include <rtg/operators.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include <random>
struct unknown struct unknown
{ {
std::string op; std::string op;
...@@ -334,6 +337,22 @@ struct onnx_parser ...@@ -334,6 +337,22 @@ struct onnx_parser
} }
}; };
// TODO: Move this to a seperate header
std::vector<float> get_tensor_data(rtg::shape s)
{
std::vector<float> result(s.elements());
std::mt19937 engine{0};
std::uniform_real_distribution<> dist;
std::generate(result.begin(), result.end(), [&] { return dist(engine); });
return result;
}
rtg::argument get_tensor_argument(rtg::shape s)
{
auto v = get_tensor_data(s);
return {s, [v]() mutable { return reinterpret_cast<char*>(v.data()); }};
}
int main(int argc, char const* argv[]) int main(int argc, char const* argv[])
{ {
if(argc > 1) if(argc > 1)
...@@ -344,6 +363,11 @@ int main(int argc, char const* argv[]) ...@@ -344,6 +363,11 @@ int main(int argc, char const* argv[])
try try
{ {
parser.parse_from(input); parser.parse_from(input);
parser.prog.compile(rtg::cpu::cpu_target{});
auto s = parser.prog.get_parameter_shape("Input3");
auto input3 = get_tensor_argument(s);
auto out = parser.prog.eval({{"Input3", input3}});
(void)out;
} }
catch(...) catch(...)
{ {
......
...@@ -586,6 +586,10 @@ struct cpu_apply ...@@ -586,6 +586,10 @@ struct cpu_apply
{ {
apply_activation(it); apply_activation(it);
} }
else if(it->op.name() == "pooling")
{
apply_pooling(it);
}
else if(apply_map.count(it->op.name()) > 0) else if(apply_map.count(it->op.name()) > 0)
{ {
apply_map.at(it->op.name())(it); apply_map.at(it->op.name())(it);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment