Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f0019df3
Commit
f0019df3
authored
Sep 17, 2021
by
Qianfeng Zhang
Browse files
Add half_t support to NumericLimits and make constexpr GetZeroVal() of binary operator
parent
eac1753d
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
33 additions
and
42 deletions
+33
-42
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
...sor_operation/gridwise_generic_2d_reduction_blockwise.hpp
+3
-3
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
...ation/gridwise_generic_2d_reduction_direct_threadwise.hpp
+3
-3
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
...eration/gridwise_generic_2d_reduction_direct_warpwise.hpp
+3
-3
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp
...or_operation/gridwise_generic_2d_reduction_multiblock.hpp
+2
-2
composable_kernel/include/utility/data_type.hpp
composable_kernel/include/utility/data_type.hpp
+17
-10
composable_kernel/include/utility/reduction_operator.hpp
composable_kernel/include/utility/reduction_operator.hpp
+5
-21
No files found.
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
View file @
f0019df3
...
...
@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise
// LDS
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
auto
zeroVal
=
opReduce
::
GetZeroVal
();
constexpr
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
...
@@ -243,7 +243,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
__shared__
int
block_indices_buffer
[
BlockBufferSize
];
auto
zeroVal
=
opReduce
::
GetZeroVal
();
constexpr
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
...
@@ -431,7 +431,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
__shared__
int
block_indices_buffer
[
BlockBufferSize
];
auto
zeroVal
=
opReduce
::
GetZeroVal
();
constexpr
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
...
...
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
View file @
f0019df3
...
...
@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
(
void
)
ws_indices_global
;
(
void
)
indices_global
;
const
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
expr
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
...
@@ -204,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{
(
void
)
ws_indices_global
;
const
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
expr
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
...
@@ -348,7 +348,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{
(
void
)
origReduceLen
;
const
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
expr
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
...
...
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
View file @
f0019df3
...
...
@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
(
void
)
ws_indices_global
;
(
void
)
indices_global
;
auto
zeroVal
=
opReduce
::
GetZeroVal
();
constexpr
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
...
@@ -215,7 +215,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{
(
void
)
ws_indices_global
;
auto
zeroVal
=
opReduce
::
GetZeroVal
();
constexpr
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
...
@@ -373,7 +373,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{
(
void
)
origReduceLen
;
auto
zeroVal
=
opReduce
::
GetZeroVal
();
constexpr
auto
zeroVal
=
opReduce
::
GetZeroVal
();
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
...
...
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp
View file @
f0019df3
...
...
@@ -86,7 +86,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(
void
)
alpha
;
// unused
(
void
)
beta
;
// unused
auto
zeroVal
=
opReduce
::
GetZeroVal
();
constexpr
auto
zeroVal
=
opReduce
::
GetZeroVal
();
// LDS
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
...
...
@@ -216,7 +216,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(
void
)
alpha
;
// unused
(
void
)
beta
;
// unused
auto
zeroVal
=
opReduce
::
GetZeroVal
();
constexpr
auto
zeroVal
=
opReduce
::
GetZeroVal
();
// LDS
__shared__
compType
p_in_block_values_buffer
[
BlockBufferSize
];
...
...
composable_kernel/include/utility/data_type.hpp
View file @
f0019df3
...
...
@@ -1008,20 +1008,27 @@ struct inner_product_with_conversion
};
template
<
typename
T
>
struct
NumericLimits
;
struct
NumericLimits
{
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
};
template
<
>
struct
NumericLimits
<
int32
_t
>
struct
NumericLimits
<
half
_t
>
{
__host__
__device__
static
constexpr
int32_t
Min
()
{
return
std
::
numeric_limits
<
int32_t
>::
min
();
}
static
constexpr
unsigned
short
binary_min
=
0x0400
;
static
constexpr
unsigned
short
binary_max
=
0x7BFF
;
static
constexpr
unsigned
short
binary_lowest
=
0xFBFF
;
__host__
__device__
static
constexpr
int32_t
Max
()
{
return
std
::
numeric_limits
<
int32_t
>::
max
();
}
__host__
__device__
static
constexpr
half_t
Min
()
{
return
as_type
<
half_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
half_t
Max
()
{
return
as_type
<
half_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
as_type
<
half_t
>
(
binary_lowest
);
}
};
}
// namespace ck
...
...
composable_kernel/include/utility/reduction_operator.hpp
View file @
f0019df3
...
...
@@ -58,7 +58,7 @@ struct Add
{
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
type_conver
t
<
T
>
{}
(
0.0
f
);
};
__device__
static
constexpr
T
GetZeroVal
()
{
return
static_cas
t
<
T
>
(
0.0
f
);
};
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
a
=
a
+
b
;
}
...
...
@@ -70,7 +70,7 @@ struct Mul
{
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
type_conver
t
<
T
>
{}
(
1.0
f
);
};
__device__
static
constexpr
T
GetZeroVal
()
{
return
static_cas
t
<
T
>
(
1.0
f
);
};
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
a
=
a
*
b
;
}
...
...
@@ -82,7 +82,7 @@ struct Max
{
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
std
::
n
umeric
_l
imits
<
T
>::
lowest
();
};
__device__
static
constexpr
T
GetZeroVal
()
{
return
N
umeric
L
imits
<
T
>::
lowest
();
};
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
...
...
@@ -107,7 +107,7 @@ struct Min
{
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
std
::
n
umeric
_l
imits
<
T
>::
m
ax
();
};
__device__
static
constexpr
T
GetZeroVal
()
{
return
N
umeric
L
imits
<
T
>::
M
ax
();
};
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
...
...
@@ -132,7 +132,7 @@ struct AMax
{
using
dataType
=
T
;
__device__
static
T
GetZeroVal
()
{
return
type_conver
t
<
T
>
{}
(
0.0
f
);
};
__device__
static
constexpr
T
GetZeroVal
()
{
return
static_cas
t
<
T
>
(
0.0
f
);
};
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
...
...
@@ -152,22 +152,6 @@ struct AMax
static
constexpr
bool
indexable
=
true
;
};
template
<
>
__device__
half_t
Max
<
half_t
>::
GetZeroVal
()
{
const
unsigned
short
binary_lowest
=
0xFBFF
;
return
*
reinterpret_cast
<
const
half_t
*>
(
&
binary_lowest
);
};
template
<
>
__device__
half_t
Min
<
half_t
>::
GetZeroVal
()
{
const
unsigned
short
binary_max
=
0x7BFF
;
return
*
reinterpret_cast
<
const
half_t
*>
(
&
binary_max
);
};
// Unary operators are usually called element-wisely before the reduction is executed on the
// elements.
// They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment