Commit 09c946b2 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added MNIST driver

parent 003e2eb9
...@@ -29,12 +29,14 @@ struct tensor_view ...@@ -29,12 +29,14 @@ struct tensor_view
template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const const T& operator()(Ts... xs) const
{ {
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs) T& operator()(Ts... xs)
{ {
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
......
...@@ -136,7 +136,8 @@ int main(int argc, char const* argv[]) ...@@ -136,7 +136,8 @@ int main(int argc, char const* argv[])
std::string file = argv[1]; std::string file = argv[1];
auto prog = rtg::parse_onnx(file); auto prog = rtg::parse_onnx(file);
prog.compile(rtg::cpu::cpu_target{}); prog.compile(rtg::cpu::cpu_target{});
auto s = prog.get_parameter_shape("Input3"); // auto s = prog.get_parameter_shape("Input3");
auto s = rtg::shape{rtg::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl; std::cout << s << std::endl;
auto input3 = rtg::argument{s, input.data()}; auto input3 = rtg::argument{s, input.data()};
auto out = prog.eval({{"Input3", input3}}); auto out = prog.eval({{"Input3", input3}});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment