Commit a4bf3a98 authored by Paul's avatar Paul
Browse files

Add half for cpu

parent ce3f2db7
...@@ -21,7 +21,7 @@ rocm_clang_tidy_check(migraph) ...@@ -21,7 +21,7 @@ rocm_clang_tidy_check(migraph)
target_include_directories(migraph PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraph PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
find_path(HALF_INCLUDE_DIR half.hpp) find_path(HALF_INCLUDE_DIR half.hpp)
target_include_directories(migraph PUBLIC ${HALF_INCLUDE_DIR}) target_include_directories(migraph SYSTEM PUBLIC ${HALF_INCLUDE_DIR})
add_subdirectory(onnx) add_subdirectory(onnx)
add_subdirectory(targets/cpu) add_subdirectory(targets/cpu)
......
...@@ -3,23 +3,24 @@ ...@@ -3,23 +3,24 @@
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/literal.hpp> #include <migraph/literal.hpp>
#include <migraph/type_traits.hpp>
#include <random> #include <random>
namespace migraph { namespace migraph {
template <class T, MIGRAPH_REQUIRES(std::is_floating_point<T>{})> template <class T, MIGRAPH_REQUIRES(is_floating_point<T>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
if(z == 0) if(z == 0)
return 0; return T(0);
const auto max = 32; const auto max = 32;
const double range = max / 2; // NOLINT const double range = max / 2; // NOLINT
double result = (z % max) / range; double result = (z % max) / range;
result -= 1; result -= 1;
return result; return T(result);
} }
template <class T, MIGRAPH_REQUIRES(std::is_signed<T>{} and not std::is_floating_point<T>{})> template <class T, MIGRAPH_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max(); const auto max = std::numeric_limits<T>::max();
...@@ -27,7 +28,7 @@ constexpr T normalize(unsigned long z) ...@@ -27,7 +28,7 @@ constexpr T normalize(unsigned long z)
return half_max - (z % max); return half_max - (z % max);
} }
template <class T, MIGRAPH_REQUIRES(not std::is_signed<T>{} and std::is_integral<T>{})> template <class T, MIGRAPH_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max(); const auto max = std::numeric_limits<T>::max();
......
/*=============================================================================
Copyright (c) 2017 Paul Fultz II
half.hpp
Distributed under the Boost Software License, Version 1.0. (See accompanying
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_HALF_HPP
#define MIGRAPH_GUARD_RTGLIB_HALF_HPP
#include <half.hpp>
namespace migraph {
using half = half_float::half;
namespace detail {
template<class T>
struct deduce
{
using type = T;
};
template<>
struct deduce<half_float::detail::expr>
{
using type = half;
};
} // namespace detail
template<class T>
using deduce = typename detail::deduce<T>::type;
} // namespace migraph
#endif
...@@ -20,10 +20,10 @@ struct literal : raw_data<literal> ...@@ -20,10 +20,10 @@ struct literal : raw_data<literal>
{ {
literal() {} literal() {}
template <class T> template <class U, class T=deduce<U>>
literal(T x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(shape::get_type<T>{}) literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(shape::get_type<T>{})
{ {
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivially_copyable<T>{}, "Literals can only be trivial types");
*(reinterpret_cast<T*>(buffer.get())) = x; *(reinterpret_cast<T*>(buffer.get())) = x;
} }
...@@ -31,7 +31,7 @@ struct literal : raw_data<literal> ...@@ -31,7 +31,7 @@ struct literal : raw_data<literal>
literal(const shape& s, const std::vector<T>& x) literal(const shape& s, const std::vector<T>& x)
: buffer(make_shared_array<char>(s.bytes())), m_shape(s) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
{ {
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivially_copyable<T>{}, "Literals can only be trivial types");
fill(x.begin(), x.end()); fill(x.begin(), x.end());
} }
...@@ -39,7 +39,7 @@ struct literal : raw_data<literal> ...@@ -39,7 +39,7 @@ struct literal : raw_data<literal>
literal(const shape& s, const std::initializer_list<T>& x) literal(const shape& s, const std::initializer_list<T>& x)
: buffer(make_shared_array<char>(s.bytes())), m_shape(s) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
{ {
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivially_copyable<T>{}, "Literals can only be trivial types");
fill(x.begin(), x.end()); fill(x.begin(), x.end());
} }
...@@ -101,7 +101,7 @@ literal transform(literal l, F f) ...@@ -101,7 +101,7 @@ literal transform(literal l, F f)
literal result; literal result;
l.visit([&](auto x) { l.visit([&](auto x) {
using type = std::remove_cv_t<typename decltype(x)::value_type>; using type = std::remove_cv_t<typename decltype(x)::value_type>;
std::vector<type> output(x.size(), 0.0); std::vector<type> output(x.size(), type(0));
std::transform(x.begin(), x.end(), output.begin(), f); std::transform(x.begin(), x.end(), output.begin(), f);
result = literal{l.get_shape(), output}; result = literal{l.get_shape(), output};
}); });
...@@ -115,7 +115,7 @@ literal transform(literal l1, literal l2, F f) ...@@ -115,7 +115,7 @@ literal transform(literal l1, literal l2, F f)
literal result; literal result;
visit_all(l1, l2)([&](auto x, auto y) { visit_all(l1, l2)([&](auto x, auto y) {
using type = std::remove_cv_t<typename decltype(x)::value_type>; using type = std::remove_cv_t<typename decltype(x)::value_type>;
std::vector<type> output(x.size(), 0.0); std::vector<type> output(x.size(), type(0));
std::transform(x.begin(), x.end(), y.begin(), output.begin(), f); std::transform(x.begin(), x.end(), y.begin(), output.begin(), f);
result = literal{l1.get_shape(), output}; result = literal{l1.get_shape(), output};
}); });
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <memory> #include <memory>
#include <migraph/errors.hpp> #include <migraph/errors.hpp>
#include <migraph/half.hpp>
namespace migraph { namespace migraph {
...@@ -19,6 +20,7 @@ struct shape ...@@ -19,6 +20,7 @@ struct shape
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \ #define MIGRAPH_SHAPE_VISIT_TYPES(m) \
m(half_type, half) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
m(uint8_type, uint8_t) \ m(uint8_type, uint8_t) \
......
/*=============================================================================
Copyright (c) 2017 Paul Fultz II
type_traits.hpp
Distributed under the Boost Software License, Version 1.0. (See accompanying
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <type_traits>
#include <migraph/half.hpp>
namespace migraph {
#define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template<class X> \
struct trait : std::trait<X> \
{}; \
\
template<> \
struct trait<T> \
: std::true_type \
{};
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
} // namespace migraph
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment