Commit 84f88fcc authored by Paul's avatar Paul
Browse files

Refactor header files

parent 1fca08da
add_library(rtg
generate.cpp
program.cpp
shape.cpp
)
......
#include <rtg/generate.hpp>
namespace rtg {
rtg::argument generate_argument(rtg::shape s, std::mt19937::result_type seed)
{
rtg::argument result;
s.visit_type([&](auto as)
{
using type = typename decltype(as)::type;
auto v = generate_tensor_data<type>(s, seed);
result = {s, [v]() mutable { return reinterpret_cast<char*>(v.data()); }};
});
return result;
}
} // namespace rtg
#ifndef RTG_GUARD_RTGLIB_GENERATE_HPP
#define RTG_GUARD_RTGLIB_GENERATE_HPP
#include <rtg/argument.hpp>
#include <random>
namespace rtg {
template<class T>
std::vector<T> generate_tensor_data(rtg::shape s, std::mt19937::result_type seed = 0)
{
std::vector<T> result(s.elements());
std::mt19937 engine{seed};
std::uniform_real_distribution<> dist;
std::generate(result.begin(), result.end(), [&] { return dist(engine); });
return result;
}
rtg::argument generate_argument(rtg::shape s, std::mt19937::result_type seed = 0);
} // namespace rtg
#endif
#ifndef RTG_GUARD_RTGLIB_RANGES_HPP
#define RTG_GUARD_RTGLIB_RANGES_HPP
namespace rtg {
template <class C, class T>
bool contains(C&& c, T&& x)
{
return c.find(x) != c.end();
}
template <class Range, class Iterator>
void copy(Range&& r, Iterator it)
{
std::copy(r.begin(), r.end(), it);
}
} // namespace rtg
#endif
......@@ -10,6 +10,7 @@
#include <rtg/fallthrough.hpp>
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
#include <rtg/ranges.hpp>
namespace rtg {
......@@ -32,18 +33,6 @@ struct unknown
}
};
template <class C, class T>
bool contains(C&& c, T&& x)
{
return c.find(x) != c.end();
}
template <class Range, class Iterator>
void copy(Range&& r, Iterator it)
{
std::copy(r.begin(), r.end(), it);
}
struct onnx_parser
{
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
......
......@@ -2,23 +2,7 @@
#include <rtg/onnx.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include <random>
// TODO: Move this to a seperate header
std::vector<float> get_tensor_data(rtg::shape s)
{
std::vector<float> result(s.elements());
std::mt19937 engine{0};
std::uniform_real_distribution<> dist;
std::generate(result.begin(), result.end(), [&] { return dist(engine); });
return result;
}
rtg::argument get_tensor_argument(rtg::shape s)
{
auto v = get_tensor_data(s);
return {s, [v]() mutable { return reinterpret_cast<char*>(v.data()); }};
}
#include <rtg/generate.hpp>
int main(int argc, char const* argv[])
{
......@@ -28,7 +12,7 @@ int main(int argc, char const* argv[])
auto prog = rtg::parse_onnx(file);
prog.compile(rtg::cpu::cpu_target{});
auto s = prog.get_parameter_shape("Input3");
auto input3 = get_tensor_argument(s);
auto input3 = generate_argument(s);
auto out = prog.eval({{"Input3", input3}});
(void)out;
std::cout << prog << std::endl;
......
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
#include <rtg/generate.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include <rtg/miopen/miopen_target.hpp>
#include <rtg/manage_ptr.hpp>
#include <miopen/miopen.h>
#include <random>
#include "test.hpp"
#include "verify.hpp"
......@@ -64,24 +63,10 @@ rtg::program create_program()
return p;
}
std::vector<float> get_tensor_data(rtg::shape s)
{
std::vector<float> result(s.elements());
std::mt19937 engine{0};
std::uniform_real_distribution<> dist;
std::generate(result.begin(), result.end(), [&] { return dist(engine); });
return result;
}
rtg::argument get_tensor_argument_cpu(rtg::shape s)
{
auto v = get_tensor_data(s);
return {s, [v]() mutable { return reinterpret_cast<char*>(v.data()); }};
}
// TODO: Move to header
rtg::argument get_tensor_argument_gpu(rtg::shape s)
{
auto v = get_tensor_data(s);
auto v = rtg::generate_tensor_data<float>(s);
auto p = rtg::share(write(v));
return {s, [p]() mutable { return reinterpret_cast<char*>(p.get()); }};
}
......@@ -90,8 +75,8 @@ std::vector<float> cpu()
{
std::vector<float> result;
auto p = create_program();
auto x = get_tensor_argument_cpu({rtg::shape::float_type, {4, 3, 3, 3}});
auto w = get_tensor_argument_cpu({rtg::shape::float_type, {4, 3, 3, 3}});
auto x = rtg::generate_argument({rtg::shape::float_type, {4, 3, 3, 3}});
auto w = rtg::generate_argument({rtg::shape::float_type, {4, 3, 3, 3}});
p.compile(rtg::cpu::cpu_target{});
auto r = p.eval({{"x", x}, {"w", w}});
auto output = r.get<float>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment