Commit cfbafecb authored by Scott Thornton's avatar Scott Thornton
Browse files

Merge branch 'lenet-test' of https://github.com/ROCmSoftwarePlatform/RTGLib into lenet-test

parents 92051ab8 7359bd4d
......@@ -8,4 +8,4 @@ target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx onnx-proto rtg)
target_link_libraries(read_onnx onnx-proto rtg rtg_cpu)
......@@ -12,6 +12,9 @@
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include <random>
struct unknown
{
std::string op;
......@@ -227,6 +230,13 @@ struct onnx_parser
return result;
}
template <class T>
static rtg::literal from_repeated(rtg::shape::type_t t, const T& r)
{
std::size_t size = r.size();
return rtg::literal{{t, {size}}, r.begin(), r.end()};
}
static rtg::literal parse_value(const onnx::AttributeProto& attr)
{
switch(attr.type())
......@@ -238,10 +248,8 @@ struct onnx_parser
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS:
return rtg::literal{rtg::shape::float_type, attr.floats().begin(), attr.floats().end()};
case onnx::AttributeProto::INTS:
return rtg::literal{rtg::shape::int32_type, attr.ints().begin(), attr.ints().end()};
;
return from_repeated(rtg::shape::float_type, attr.floats());
case onnx::AttributeProto::INTS: return from_repeated(rtg::shape::int64_type, attr.ints());
case onnx::AttributeProto::STRINGS: return {};
case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::GRAPHS: return {};
......@@ -329,6 +337,22 @@ struct onnx_parser
}
};
// TODO: Move this to a seperate header
std::vector<float> get_tensor_data(rtg::shape s)
{
std::vector<float> result(s.elements());
std::mt19937 engine{0};
std::uniform_real_distribution<> dist;
std::generate(result.begin(), result.end(), [&] { return dist(engine); });
return result;
}
rtg::argument get_tensor_argument(rtg::shape s)
{
auto v = get_tensor_data(s);
return {s, [v]() mutable { return reinterpret_cast<char*>(v.data()); }};
}
int main(int argc, char const* argv[])
{
if(argc > 1)
......@@ -339,6 +363,11 @@ int main(int argc, char const* argv[])
try
{
parser.parse_from(input);
parser.prog.compile(rtg::cpu::cpu_target{});
auto s = parser.prog.get_parameter_shape("Input3");
auto input3 = get_tensor_argument(s);
auto out = parser.prog.eval({{"Input3", input3}});
(void)out;
}
catch(...)
{
......
......@@ -586,6 +586,10 @@ struct cpu_apply
{
apply_activation(it);
}
else if(it->op.name() == "pooling")
{
apply_pooling(it);
}
else if(apply_map.count(it->op.name()) > 0)
{
apply_map.at(it->op.name())(it);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment