Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
#ifndef MIGRAPHX_GUARD_RTGLIB_FILESYSTEM_HPP
#define MIGRAPHX_GUARD_RTGLIB_FILESYSTEM_HPP
#include <migraphx/config.hpp>
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1
#else
#define MIGRAPHX_HAS_FILESYSTEM 0
#endif
#if __has_include(<experimental/filesystem>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#else
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
#endif
#else
#define MIGRAPHX_HAS_FILESYSTEM 0
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
#endif
#if MIGRAPHX_HAS_FILESYSTEM
#include <filesystem>
#elif MIGRAPHX_HAS_FILESYSTEM_TS
#include <experimental/filesystem>
#else
#error "No filesystem include available"
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_FILESYSTEM
namespace fs = ::std::filesystem;
#elif MIGRAPHX_HAS_FILESYSTEM_TS
namespace fs = ::std::experimental::filesystem;
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/type_traits.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -19,7 +20,7 @@ using common_type = typename std::common_type<Ts...>::type; ...@@ -19,7 +20,7 @@ using common_type = typename std::common_type<Ts...>::type;
struct float_equal_fn struct float_equal_fn
{ {
template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(is_floating_point<T>{})>
static bool apply(T x, T y) static bool apply(T x, T y)
{ {
return std::isfinite(x) and std::isfinite(y) and return std::isfinite(x) and std::isfinite(y) and
...@@ -27,7 +28,7 @@ struct float_equal_fn ...@@ -27,7 +28,7 @@ struct float_equal_fn
std::nextafter(x, std::numeric_limits<T>::max()) >= y; std::nextafter(x, std::numeric_limits<T>::max()) >= y;
} }
template <class T, MIGRAPHX_REQUIRES(not std::is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(not is_floating_point<T>{})>
static bool apply(T x, T y) static bool apply(T x, T y)
{ {
return x == y; return x == y;
......
...@@ -125,16 +125,10 @@ auto fix(F f) ...@@ -125,16 +125,10 @@ auto fix(F f)
return fix<void>(f); return fix<void>(f);
} }
template <class... Ts>
auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
template <class F, class T> template <class F, class T>
auto fold_impl(F&&, T&& x) auto fold_impl(F&&, T&& x)
{ {
return x; return std::forward<T>(x);
} }
template <class F, class T, class U, class... Ts> template <class F, class T, class U, class... Ts>
...@@ -149,6 +143,22 @@ auto fold(F f) ...@@ -149,6 +143,22 @@ auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); }; return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
} }
template <class... Ts>
auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
inline auto pack_join() { return pack(); }
template <class... Ps>
auto pack_join(Ps... ps)
{
return fold([](auto p1, auto p2) {
return p1([=](auto... xs) { return p2([=](auto... ys) { return pack(xs..., ys...); }); });
})(ps...);
}
template <class F, class Proj> template <class F, class Proj>
auto by(F f, Proj proj) auto by(F f, Proj proj)
{ {
...@@ -216,6 +226,11 @@ struct id ...@@ -216,6 +226,11 @@ struct id
} }
}; };
template <class... Ts>
void nop(Ts&&...)
{
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
struct fuse_pointwise
{
std::string name() const { return "fuse_pointwise"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/tensor_view.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T, class F>
void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();
par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -25,18 +25,26 @@ constexpr T normalize(unsigned long z) ...@@ -25,18 +25,26 @@ constexpr T normalize(unsigned long z)
template <class T, MIGRAPHX_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max() / 64; const auto max = 1UL << (sizeof(T) * 5);
const auto half_max = max / 2; const auto half_max = max / 2;
return half_max - (z % max); return half_max - (z % max);
} }
template <class T, MIGRAPHX_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})> template <class T,
MIGRAPHX_REQUIRES(not is_signed<T>{} and std::is_integral<T>{} and
not std::is_same<T, bool>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max() / 64; const auto max = 1UL << (sizeof(T) * 5);
return z % max; return z % max;
} }
template <class T, MIGRAPHX_REQUIRES(std::is_same<T, bool>{})>
constexpr bool normalize(unsigned long z)
{
return static_cast<bool>(z % 2);
}
template <class T> template <class T>
struct xorshf96_generator struct xorshf96_generator
{ {
...@@ -80,16 +88,16 @@ struct xorshift_generator ...@@ -80,16 +88,16 @@ struct xorshift_generator
template <class T> template <class T>
auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0) auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0)
{ {
auto result = make_shared_array<T>(s.elements()); auto result = make_shared_array<T>(s.element_space());
std::generate(result.get(), result.get() + s.elements(), xorshf96_generator<T>{seed}); std::generate(result.get(), result.get() + s.element_space(), xorshf96_generator<T>{seed});
return result; return result;
} }
template <class T> template <class T>
auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0) auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0)
{ {
auto result = make_shared_array<T>(s.elements()); auto result = make_shared_array<T>(s.element_space());
std::generate(result.get(), result.get() + s.elements(), [=] { return value; }); std::generate(result.get(), result.get() + s.element_space(), [=] { return value; });
return result; return result;
} }
......
...@@ -23,11 +23,13 @@ struct deduce ...@@ -23,11 +23,13 @@ struct deduce
using type = T; using type = T;
}; };
#ifdef HAS_HALF_V1
template <> template <>
struct deduce<half_float::detail::expr> struct deduce<half_float::detail::expr>
{ {
using type = half; using type = half;
}; };
#endif
} // namespace detail } // namespace detail
template <class T> template <class T>
...@@ -36,4 +38,24 @@ using deduce = typename detail::deduce<T>::type; ...@@ -36,4 +38,24 @@ using deduce = typename detail::deduce<T>::type;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <class T>
struct common_type<migraphx::half, T> : std::common_type<float, T>
{
};
template <class T>
struct common_type<T, migraphx::half> : std::common_type<float, T>
{
};
template <>
struct common_type<migraphx::half, migraphx::half>
{
using type = migraphx::half;
};
} // namespace std
#endif #endif
#ifndef MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct inline_module
{
std::string name() const { return "inline_module"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP
#include <string>
#include <vector>
#include <array>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* insert pads if attribute of padding is asymmetrical
*/
struct insert_pad
{
std::string name() const { return "insert_pad"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/erase.hpp> #include <migraphx/erase.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -14,7 +15,11 @@ namespace migraphx { ...@@ -14,7 +15,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args); shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
shape compute_shape(const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods);
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args); std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
std::vector<shape> try_compute_shape(const operation& op, const std::vector<shape>& inputs);
struct instruction struct instruction
{ {
...@@ -22,6 +27,11 @@ struct instruction ...@@ -22,6 +27,11 @@ struct instruction
instruction(operation o, shape r, std::vector<instruction_ref> args); instruction(operation o, shape r, std::vector<instruction_ref> args);
instruction(operation o,
shape r,
std::vector<instruction_ref> args,
std::vector<module_ref> modules);
instruction(literal l); instruction(literal l);
void replace(operation o); void replace(operation o);
...@@ -32,7 +42,7 @@ struct instruction ...@@ -32,7 +42,7 @@ struct instruction
friend bool operator==(const instruction& i, instruction_ref ref); friend bool operator==(const instruction& i, instruction_ref ref);
bool valid(instruction_ref start) const; bool valid(instruction_ref start, bool check_order = false) const;
bool valid() const; bool valid() const;
...@@ -45,6 +55,8 @@ struct instruction ...@@ -45,6 +55,8 @@ struct instruction
const std::vector<instruction_ref>& inputs() const; const std::vector<instruction_ref>& inputs() const;
const std::vector<module_ref>& module_inputs() const;
const std::vector<instruction_ref>& outputs() const; const std::vector<instruction_ref>& outputs() const;
friend bool operator==(const instruction& x, const instruction& y); friend bool operator==(const instruction& x, const instruction& y);
...@@ -65,13 +77,25 @@ struct instruction ...@@ -65,13 +77,25 @@ struct instruction
migraphx::erase(output, ins); migraphx::erase(output, ins);
} }
static void replace_refs(instruction_ref ins,
const std::unordered_map<instruction_ref, instruction_ref>& map_insts,
const std::unordered_map<module_ref, module_ref>& map_mods);
static void backreference(instruction_ref ref); static void backreference(instruction_ref ref);
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins); static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins);
static void replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod);
static void static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args); replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
static void replace(instruction_ref ins,
operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args);
bool can_eval() const; bool can_eval() const;
argument eval(bool check_eval = true) const; argument eval(bool check_eval = true) const;
...@@ -80,39 +104,52 @@ struct instruction ...@@ -80,39 +104,52 @@ struct instruction
static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false); static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false);
void set_normalized(bool value = true);
bool is_normalized() const;
bool need_normalization() const;
operation normalized_operator() const;
void debug_print() const;
static void print(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names);
private: private:
// internal // internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args); void replace(operation o, const shape& r, std::vector<instruction_ref> args);
// internal
void replace(operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> mdl_args);
// internal // internal
void replace(std::vector<instruction_ref> args); void replace(std::vector<instruction_ref> args);
// internal
void replace(std::vector<instruction_ref> args, std::vector<module_ref> mdl_args);
// internal // internal
void replace_argument(instruction_ref old, instruction_ref new_ins); void replace_argument(instruction_ref old, instruction_ref new_ins);
// internal
void replace_mod_argument(module_ref old, module_ref new_mod);
void replace(const shape& r); void replace(const shape& r);
operation op; operation op;
shape result; shape result{};
std::vector<instruction_ref> output; std::vector<instruction_ref> output;
std::vector<instruction_ref> arguments; std::vector<instruction_ref> arguments;
std::vector<module_ref> module_args;
literal lit; literal lit;
bool normalized = false;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <>
struct hash<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = std::size_t;
result_type operator()(const argument_type& x) const noexcept
{
return std::hash<migraphx::instruction*>{}(&*x);
}
};
} // namespace std
#endif #endif
...@@ -11,7 +11,35 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,7 +11,35 @@ inline namespace MIGRAPHX_INLINE_NS {
struct instruction; struct instruction;
using instruction_ref = std::list<instruction>::iterator; using instruction_ref = std::list<instruction>::iterator;
migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <>
struct hash<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = std::size_t;
result_type operator()(const migraphx::instruction_ref& x) const noexcept
{
return std::hash<migraphx::instruction*>{}(migraphx::as_address(x));
}
};
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return migraphx::as_address(x) == migraphx::as_address(y);
}
};
} // namespace std
#endif #endif
#ifndef MIGRAPHX_GUARD_RTGLIB_IOTA_ITERATOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_IOTA_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <iterator>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class F, class Iterator = std::ptrdiff_t>
struct basic_iota_iterator
{
Iterator index;
F f;
using difference_type = std::ptrdiff_t;
using reference = decltype(f(std::declval<Iterator>()));
using value_type = typename std::remove_reference<reference>::type;
using pointer = typename std::add_pointer<value_type>::type;
using iterator_category = std::random_access_iterator_tag;
basic_iota_iterator& operator+=(int n)
{
index += n;
return *this;
}
basic_iota_iterator& operator-=(int n)
{
index -= n;
return *this;
}
basic_iota_iterator& operator++()
{
index++;
return *this;
}
basic_iota_iterator& operator--()
{
index--;
return *this;
}
basic_iota_iterator operator++(int) // NOLINT
{
basic_iota_iterator it = *this;
index++;
return it;
}
basic_iota_iterator operator--(int) // NOLINT
{
basic_iota_iterator it = *this;
index--;
return it;
}
// TODO: operator->
reference operator*() const { return f(index); }
};
template <class T, class F>
inline basic_iota_iterator<F, T> make_basic_iota_iterator(T x, F f)
{
return basic_iota_iterator<F, T>{x, f};
}
template <class F, class Iterator>
inline basic_iota_iterator<F, Iterator> operator+(basic_iota_iterator<F, Iterator> x,
std::ptrdiff_t y)
{
return x += y;
}
template <class F, class Iterator>
inline basic_iota_iterator<F, Iterator> operator+(std::ptrdiff_t x,
basic_iota_iterator<F, Iterator> y)
{
return y + x;
}
template <class F, class Iterator>
inline std::ptrdiff_t operator-(basic_iota_iterator<F, Iterator> x,
basic_iota_iterator<F, Iterator> y)
{
return x.index - y.index;
}
template <class F, class Iterator>
inline basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x,
std::ptrdiff_t y)
{
return x -= y;
}
template <class F, class Iterator>
inline bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator>
inline bool operator!=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index != y.index;
}
template <class F, class Iterator>
inline bool operator<(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index < y.index;
}
template <class F, class Iterator>
inline bool operator>(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index > y.index;
}
template <class F, class Iterator>
inline bool operator>=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index >= y.index;
}
template <class F, class Iterator>
inline bool operator<=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index <= y.index;
}
using iota_iterator = basic_iota_iterator<id>;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator, class EndIterator>
auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(!it._M_dereferenceable())
{
return !it._M_dereferenceable();
}
template <class Iterator, class EndIterator>
auto is_end(rank<1>, Iterator it, EndIterator last)
{
return it == last;
}
template <class Iterator, class EndIterator>
bool is_end(Iterator it, EndIterator last)
{
return is_end(rank<2>{}, it, last);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP
...@@ -59,18 +59,24 @@ struct iterator_for_range ...@@ -59,18 +59,24 @@ struct iterator_for_range
struct iterator struct iterator
{ {
using difference_type = std::ptrdiff_t;
using reference = decltype(std::declval<base_iterator>());
using value_type = std::remove_reference_t<reference>;
using pointer = std::add_pointer_t<value_type>;
using iterator_category = std::input_iterator_tag;
base_iterator i; base_iterator i;
auto operator*() const { return Selector::deref(i); } auto operator*() const { return Selector::deref(i); }
base_iterator operator++() { return ++i; } base_iterator operator++() { return ++i; }
bool operator==(const iterator& rhs) const { return i == rhs.i; }
bool operator!=(const iterator& rhs) const { return i != rhs.i; } bool operator!=(const iterator& rhs) const { return i != rhs.i; }
}; };
iterator begin() iterator begin() const
{ {
assert(base != nullptr); assert(base != nullptr);
return {Selector::begin(base)}; return {Selector::begin(base)};
} }
iterator end() iterator end() const
{ {
assert(base != nullptr); assert(base != nullptr);
return {Selector::end(base)}; return {Selector::end(base)};
......
#ifndef MIGRAPHX_GUARD_RTGLIB_JSON_HPP
#define MIGRAPHX_GUARD_RTGLIB_JSON_HPP
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::string to_pretty_json_string(const value& val, std::size_t indent = 4);
std::string to_json_string(const value& val);
value from_json_string(const std::string& str);
value from_json_string(const char* str, std::size_t size);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
enum class lifetime
{
local,
global,
borrow
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
...@@ -52,7 +52,8 @@ struct literal : raw_data<literal> ...@@ -52,7 +52,8 @@ struct literal : raw_data<literal>
fill(start, end); fill(start, end);
} }
literal(const shape& s, const char* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s) template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)>
literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
{ {
std::copy(x, x + s.bytes(), buffer.get()); std::copy(x, x + s.bytes(), buffer.get());
} }
...@@ -65,11 +66,13 @@ struct literal : raw_data<literal> ...@@ -65,11 +66,13 @@ struct literal : raw_data<literal>
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const { return this->m_shape; }
std::vector<literal> get_sub_objects() const { return {}; }
/// Convert the data to an argument /// Convert the data to an argument
argument get_argument() const argument get_argument() const
{ {
std::vector<char> b(buffer.get(), buffer.get() + m_shape.bytes()); auto b = make_shared_array<char>(buffer.get(), buffer.get() + m_shape.bytes());
return {m_shape, [b]() mutable { return b.data(); }}; return {m_shape, [b]() { return b.get(); }};
} }
private: private:
...@@ -90,7 +93,7 @@ struct literal : raw_data<literal> ...@@ -90,7 +93,7 @@ struct literal : raw_data<literal>
m_shape.visit_type([&](auto as) { m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get())); auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = *it; output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++; it++;
}); });
}); });
...@@ -125,6 +128,9 @@ literal transform(literal l1, literal l2, F f) ...@@ -125,6 +128,9 @@ literal transform(literal l1, literal l2, F f)
return result; return result;
} }
void migraphx_to_value(value& v, const literal& l);
void migraphx_from_value(const value& v, literal& l);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_LOAD_SAVE_HPP
#define MIGRAPHX_GUARD_RTGLIB_LOAD_SAVE_HPP
#include <migraphx/program.hpp>
#include <string>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct file_options
{
std::string format = "msgpack";
};
program load(const std::string& filename, const file_options& options = file_options{});
program load_buffer(const std::vector<char>& buffer, const file_options& options = file_options{});
program
load_buffer(const char* buffer, std::size_t size, const file_options& options = file_options{});
void save(const program& p,
const std::string& filename,
const file_options& options = file_options{});
std::vector<char> save_buffer(const program& p, const file_options& options = file_options{});
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_MAKE_OP_HPP
#define MIGRAPHX_GUARD_RTGLIB_MAKE_OP_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name);
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v);
operation make_op_from_value(const std::string& name, const value& v);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template <class Value>
operation make_op(const std::string& name, const Value& v)
{
return make_op_from_value(name, v);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -10,7 +10,15 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,15 @@ inline namespace MIGRAPHX_INLINE_NS {
template <typename T> template <typename T>
std::shared_ptr<T> make_shared_array(size_t size) std::shared_ptr<T> make_shared_array(size_t size)
{ {
return std::shared_ptr<T>(new T[size], std::default_delete<T[]>()); // NOLINT return std::shared_ptr<T>(new T[size](), std::default_delete<T[]>()); // NOLINT
}
template <class T, class Iterator>
std::shared_ptr<T> make_shared_array(Iterator start, Iterator last)
{
auto result = make_shared_array<T>(std::distance(start, last));
std::copy(start, last, result.get());
return result;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment