Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5a9f6308
Commit
5a9f6308
authored
Sep 15, 2021
by
Qianfeng Zhang
Browse files
Fix with regard to implementing GetZeroVal() in both kernel and host
parent
a18e6481
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
20 deletions
+33
-20
composable_kernel/include/utility/reduction_operator.hpp
composable_kernel/include/utility/reduction_operator.hpp
+33
-20
No files found.
composable_kernel/include/utility/reduction_operator.hpp
View file @
5a9f6308
...
@@ -82,7 +82,7 @@ struct Max
...
@@ -82,7 +82,7 @@ struct Max
{
{
using
dataType
=
T
;
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
std
::
numeric_limits
<
T
>::
min
();
};
__device__
static
T
GetZeroVal
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
};
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
{
...
@@ -127,16 +127,45 @@ struct Min
...
@@ -127,16 +127,45 @@ struct Min
static
constexpr
bool
indexable
=
true
;
static
constexpr
bool
indexable
=
true
;
};
};
template
<
class
T
>
struct
AMax
{
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
type_convert
<
T
>
{}(
0.0
f
);
};
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
if
(
a
<
b
)
a
=
b
;
}
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
if
(
a
<
b
)
{
a
=
b
;
changed
=
true
;
}
}
static
constexpr
bool
indexable
=
true
;
};
template
<
>
template
<
>
__device__
half_t
Max
<
half_t
>::
GetZeroVal
()
__device__
half_t
Max
<
half_t
>::
GetZeroVal
()
{
{
return
type_convert
<
half_t
>
{}(
std
::
numeric_limits
<
float
>::
min
());
const
unsigned
short
binary_lowest
=
0xFBFF
;
return
*
reinterpret_cast
<
const
half_t
*>
(
&
binary_lowest
);
};
};
template
<
>
template
<
>
__device__
half_t
Min
<
half_t
>::
GetZeroVal
()
__device__
half_t
Min
<
half_t
>::
GetZeroVal
()
{
{
return
type_convert
<
half_t
>
{}(
std
::
numeric_limits
<
float
>::
max
());
const
unsigned
short
binary_max
=
0x7BFF
;
return
*
reinterpret_cast
<
const
half_t
*>
(
&
binary_max
);
};
};
// Unary operators are usually called element-wisely before the reduction is executed on the
// Unary operators are usually called element-wisely before the reduction is executed on the
...
@@ -281,8 +310,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
...
@@ -281,8 +310,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
using
opType
=
reduce
::
Add
<
T
>
;
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
reduce
::
Add
<
T
>::
GetZeroVal
();
};
static
constexpr
bool
indexable
=
reduce
::
Add
<
T
>::
indexable
;
static
constexpr
bool
indexable
=
reduce
::
Add
<
T
>::
indexable
;
};
};
...
@@ -292,8 +319,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
...
@@ -292,8 +319,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
using
opType
=
reduce
::
Mul
<
T
>
;
using
opType
=
reduce
::
Mul
<
T
>
;
using
dataType
=
T
;
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
reduce
::
Mul
<
T
>::
GetZeroVal
();
};
static
constexpr
bool
indexable
=
reduce
::
Mul
<
T
>::
indexable
;
static
constexpr
bool
indexable
=
reduce
::
Mul
<
T
>::
indexable
;
};
};
...
@@ -303,8 +328,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
...
@@ -303,8 +328,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
using
opType
=
reduce
::
Min
<
T
>
;
using
opType
=
reduce
::
Min
<
T
>
;
using
dataType
=
T
;
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
reduce
::
Min
<
T
>::
GetZeroVal
();
};
static
constexpr
bool
indexable
=
reduce
::
Min
<
T
>::
indexable
;
static
constexpr
bool
indexable
=
reduce
::
Min
<
T
>::
indexable
;
};
};
...
@@ -314,19 +337,15 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
...
@@ -314,19 +337,15 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
using
opType
=
reduce
::
Max
<
T
>
;
using
opType
=
reduce
::
Max
<
T
>
;
using
dataType
=
T
;
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
reduce
::
Max
<
T
>::
GetZeroVal
();
};
static
constexpr
bool
indexable
=
reduce
::
Max
<
T
>::
indexable
;
static
constexpr
bool
indexable
=
reduce
::
Max
<
T
>::
indexable
;
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp_t
::
AMAX
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp_t
::
AMAX
>
{
{
using
opType
=
reduce
::
Max
<
T
>
;
using
opType
=
reduce
::
A
Max
<
T
>
;
using
dataType
=
T
;
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
reduce
::
Max
<
T
>::
GetZeroVal
();
};
static
constexpr
bool
indexable
=
reduce
::
Max
<
T
>::
indexable
;
static
constexpr
bool
indexable
=
reduce
::
Max
<
T
>::
indexable
;
};
};
...
@@ -336,8 +355,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
...
@@ -336,8 +355,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
using
opType
=
reduce
::
Add
<
T
>
;
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
reduce
::
Add
<
T
>::
GetZeroVal
();
};
static
constexpr
bool
indexable
=
reduce
::
Add
<
T
>::
indexable
;
static
constexpr
bool
indexable
=
reduce
::
Add
<
T
>::
indexable
;
};
};
...
@@ -347,8 +364,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
...
@@ -347,8 +364,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
using
opType
=
reduce
::
Add
<
T
>
;
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
reduce
::
Add
<
T
>::
GetZeroVal
();
};
static
constexpr
bool
indexable
=
reduce
::
Add
<
T
>::
indexable
;
static
constexpr
bool
indexable
=
reduce
::
Add
<
T
>::
indexable
;
};
};
...
@@ -358,8 +373,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
...
@@ -358,8 +373,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
using
opType
=
reduce
::
Add
<
T
>
;
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
reduce
::
Add
<
T
>::
GetZeroVal
();
};
static
constexpr
bool
indexable
=
reduce
::
Add
<
T
>::
indexable
;
static
constexpr
bool
indexable
=
reduce
::
Add
<
T
>::
indexable
;
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment