Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a4bf3a98
Commit
a4bf3a98
authored
Oct 31, 2018
by
Paul
Browse files
Add half for cpu
parent
ce3f2db7
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
84 additions
and
13 deletions
+84
-13
src/CMakeLists.txt
src/CMakeLists.txt
+1
-1
src/include/migraph/generate.hpp
src/include/migraph/generate.hpp
+6
-5
src/include/migraph/half.hpp
src/include/migraph/half.hpp
+36
-0
src/include/migraph/literal.hpp
src/include/migraph/literal.hpp
+7
-7
src/include/migraph/shape.hpp
src/include/migraph/shape.hpp
+2
-0
src/include/migraph/type_traits.hpp
src/include/migraph/type_traits.hpp
+32
-0
No files found.
src/CMakeLists.txt
View file @
a4bf3a98
...
...
@@ -21,7 +21,7 @@ rocm_clang_tidy_check(migraph)
target_include_directories
(
migraph PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
find_path
(
HALF_INCLUDE_DIR half.hpp
)
target_include_directories
(
migraph PUBLIC
${
HALF_INCLUDE_DIR
}
)
target_include_directories
(
migraph
SYSTEM
PUBLIC
${
HALF_INCLUDE_DIR
}
)
add_subdirectory
(
onnx
)
add_subdirectory
(
targets/cpu
)
...
...
src/include/migraph/generate.hpp
View file @
a4bf3a98
...
...
@@ -3,23 +3,24 @@
#include <migraph/argument.hpp>
#include <migraph/literal.hpp>
#include <migraph/type_traits.hpp>
#include <random>
namespace
migraph
{
template
<
class
T
,
MIGRAPH_REQUIRES
(
std
::
is_floating_point
<
T
>{})
>
template
<
class
T
,
MIGRAPH_REQUIRES
(
is_floating_point
<
T
>{})
>
constexpr
T
normalize
(
unsigned
long
z
)
{
if
(
z
==
0
)
return
0
;
return
T
(
0
)
;
const
auto
max
=
32
;
const
double
range
=
max
/
2
;
// NOLINT
double
result
=
(
z
%
max
)
/
range
;
result
-=
1
;
return
result
;
return
T
(
result
)
;
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
std
::
is_signed
<
T
>{}
and
not
std
::
is_floating_point
<
T
>
{})
>
template
<
class
T
,
MIGRAPH_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
...
...
@@ -27,7 +28,7 @@ constexpr T normalize(unsigned long z)
return
half_max
-
(
z
%
max
);
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
not
std
::
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
template
<
class
T
,
MIGRAPH_REQUIRES
(
not
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
...
...
src/include/migraph/half.hpp
0 → 100644
View file @
a4bf3a98
/*=============================================================================
Copyright (c) 2017 Paul Fultz II
half.hpp
Distributed under the Boost Software License, Version 1.0. (See accompanying
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_HALF_HPP
#define MIGRAPH_GUARD_RTGLIB_HALF_HPP
#include <half.hpp>
namespace
migraph
{
using
half
=
half_float
::
half
;
namespace
detail
{
template
<
class
T
>
struct
deduce
{
using
type
=
T
;
};
template
<
>
struct
deduce
<
half_float
::
detail
::
expr
>
{
using
type
=
half
;
};
}
// namespace detail
template
<
class
T
>
using
deduce
=
typename
detail
::
deduce
<
T
>::
type
;
}
// namespace migraph
#endif
src/include/migraph/literal.hpp
View file @
a4bf3a98
...
...
@@ -20,10 +20,10 @@ struct literal : raw_data<literal>
{
literal
()
{}
template
<
class
T
>
literal
(
T
x
)
:
buffer
(
make_shared_array
<
char
>
(
sizeof
(
T
))),
m_shape
(
shape
::
get_type
<
T
>
{})
template
<
class
U
,
class
T
=
deduce
<
U
>
>
literal
(
U
x
)
:
buffer
(
make_shared_array
<
char
>
(
sizeof
(
T
))),
m_shape
(
shape
::
get_type
<
T
>
{})
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
ly_copyable
<
T
>
{},
"Literals can only be trivial types"
);
*
(
reinterpret_cast
<
T
*>
(
buffer
.
get
()))
=
x
;
}
...
...
@@ -31,7 +31,7 @@ struct literal : raw_data<literal>
literal
(
const
shape
&
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
make_shared_array
<
char
>
(
s
.
bytes
())),
m_shape
(
s
)
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
ly_copyable
<
T
>
{},
"Literals can only be trivial types"
);
fill
(
x
.
begin
(),
x
.
end
());
}
...
...
@@ -39,7 +39,7 @@ struct literal : raw_data<literal>
literal
(
const
shape
&
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
make_shared_array
<
char
>
(
s
.
bytes
())),
m_shape
(
s
)
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
ly_copyable
<
T
>
{},
"Literals can only be trivial types"
);
fill
(
x
.
begin
(),
x
.
end
());
}
...
...
@@ -101,7 +101,7 @@ literal transform(literal l, F f)
literal
result
;
l
.
visit
([
&
](
auto
x
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
x
)
::
value_type
>
;
std
::
vector
<
type
>
output
(
x
.
size
(),
0.0
);
std
::
vector
<
type
>
output
(
x
.
size
(),
type
(
0
)
);
std
::
transform
(
x
.
begin
(),
x
.
end
(),
output
.
begin
(),
f
);
result
=
literal
{
l
.
get_shape
(),
output
};
});
...
...
@@ -115,7 +115,7 @@ literal transform(literal l1, literal l2, F f)
literal
result
;
visit_all
(
l1
,
l2
)([
&
](
auto
x
,
auto
y
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
x
)
::
value_type
>
;
std
::
vector
<
type
>
output
(
x
.
size
(),
0.0
);
std
::
vector
<
type
>
output
(
x
.
size
(),
type
(
0
)
);
std
::
transform
(
x
.
begin
(),
x
.
end
(),
y
.
begin
(),
output
.
begin
(),
f
);
result
=
literal
{
l1
.
get_shape
(),
output
};
});
...
...
src/include/migraph/shape.hpp
View file @
a4bf3a98
...
...
@@ -8,6 +8,7 @@
#include <memory>
#include <migraph/errors.hpp>
#include <migraph/half.hpp>
namespace
migraph
{
...
...
@@ -19,6 +20,7 @@ struct shape
// Add new types here
// clang-format off
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
m(uint8_type, uint8_t) \
...
...
src/include/migraph/type_traits.hpp
0 → 100644
View file @
a4bf3a98
/*=============================================================================
Copyright (c) 2017 Paul Fultz II
type_traits.hpp
Distributed under the Boost Software License, Version 1.0. (See accompanying
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <type_traits>
#include <migraph/half.hpp>
namespace
migraph
{
#define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template<class X> \
struct trait : std::trait<X> \
{}; \
\
template<> \
struct trait<T> \
: std::true_type \
{};
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
}
// namespace migraph
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment