"host/driver_offline/src/gemm_driver_offline.cpp" did not exist on "627d8ef35a6da8ad268b5197e3045ccdfb4ac684"
Commit cfbafecb authored by Scott Thornton's avatar Scott Thornton
Browse files

Merge branch 'lenet-test' of https://github.com/ROCmSoftwarePlatform/RTGLib into lenet-test

parents 92051ab8 7359bd4d
...@@ -8,4 +8,4 @@ target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) ...@@ -8,4 +8,4 @@ target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
add_executable(read_onnx read_onnx.cpp) add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx onnx-proto rtg) target_link_libraries(read_onnx onnx-proto rtg rtg_cpu)
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
#include <rtg/program.hpp> #include <rtg/program.hpp>
#include <rtg/operators.hpp> #include <rtg/operators.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include <random>
struct unknown struct unknown
{ {
std::string op; std::string op;
...@@ -227,6 +230,13 @@ struct onnx_parser ...@@ -227,6 +230,13 @@ struct onnx_parser
return result; return result;
} }
template <class T>
static rtg::literal from_repeated(rtg::shape::type_t t, const T& r)
{
std::size_t size = r.size();
return rtg::literal{{t, {size}}, r.begin(), r.end()};
}
static rtg::literal parse_value(const onnx::AttributeProto& attr) static rtg::literal parse_value(const onnx::AttributeProto& attr)
{ {
switch(attr.type()) switch(attr.type())
...@@ -238,10 +248,8 @@ struct onnx_parser ...@@ -238,10 +248,8 @@ struct onnx_parser
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t()); case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {}; case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS: case onnx::AttributeProto::FLOATS:
return rtg::literal{rtg::shape::float_type, attr.floats().begin(), attr.floats().end()}; return from_repeated(rtg::shape::float_type, attr.floats());
case onnx::AttributeProto::INTS: case onnx::AttributeProto::INTS: return from_repeated(rtg::shape::int64_type, attr.ints());
return rtg::literal{rtg::shape::int32_type, attr.ints().begin(), attr.ints().end()};
;
case onnx::AttributeProto::STRINGS: return {}; case onnx::AttributeProto::STRINGS: return {};
case onnx::AttributeProto::TENSORS: return {}; case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::GRAPHS: return {}; case onnx::AttributeProto::GRAPHS: return {};
...@@ -329,6 +337,22 @@ struct onnx_parser ...@@ -329,6 +337,22 @@ struct onnx_parser
} }
}; };
// TODO: Move this to a seperate header
std::vector<float> get_tensor_data(rtg::shape s)
{
std::vector<float> result(s.elements());
std::mt19937 engine{0};
std::uniform_real_distribution<> dist;
std::generate(result.begin(), result.end(), [&] { return dist(engine); });
return result;
}
rtg::argument get_tensor_argument(rtg::shape s)
{
auto v = get_tensor_data(s);
return {s, [v]() mutable { return reinterpret_cast<char*>(v.data()); }};
}
int main(int argc, char const* argv[]) int main(int argc, char const* argv[])
{ {
if(argc > 1) if(argc > 1)
...@@ -339,6 +363,11 @@ int main(int argc, char const* argv[]) ...@@ -339,6 +363,11 @@ int main(int argc, char const* argv[])
try try
{ {
parser.parse_from(input); parser.parse_from(input);
parser.prog.compile(rtg::cpu::cpu_target{});
auto s = parser.prog.get_parameter_shape("Input3");
auto input3 = get_tensor_argument(s);
auto out = parser.prog.eval({{"Input3", input3}});
(void)out;
} }
catch(...) catch(...)
{ {
......
...@@ -586,6 +586,10 @@ struct cpu_apply ...@@ -586,6 +586,10 @@ struct cpu_apply
{ {
apply_activation(it); apply_activation(it);
} }
else if(it->op.name() == "pooling")
{
apply_pooling(it);
}
else if(apply_map.count(it->op.name()) > 0) else if(apply_map.count(it->op.name()) > 0)
{ {
apply_map.at(it->op.name())(it); apply_map.at(it->op.name())(it);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment