Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7a7497f9
Commit
7a7497f9
authored
Sep 17, 2021
by
Qianfeng Zhang
Browse files
Remove constexpr from initialized zeroVal and tiny fix in reduction_operator.hpp
parent
4fea4251
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
12 additions
and
12 deletions
+12
-12
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
...sor_operation/gridwise_generic_2d_reduction_blockwise.hpp
+3
-3
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
...ation/gridwise_generic_2d_reduction_direct_threadwise.hpp
+3
-3
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
...eration/gridwise_generic_2d_reduction_direct_warpwise.hpp
+3
-3
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp
...or_operation/gridwise_generic_2d_reduction_multiblock.hpp
+2
-2
composable_kernel/include/utility/reduction_operator.hpp
composable_kernel/include/utility/reduction_operator.hpp
+1
-1
No files found.
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
View file @
7a7497f9
...
@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise
// LDS
// LDS
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
const
expr
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
@@ -243,7 +243,7 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -243,7 +243,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
__shared__
int
block_indices_buffer
[
BlockBufferSize
];
__shared__
int
block_indices_buffer
[
BlockBufferSize
];
const
expr
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
@@ -431,7 +431,7 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -431,7 +431,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
__shared__
int
block_indices_buffer
[
BlockBufferSize
];
__shared__
int
block_indices_buffer
[
BlockBufferSize
];
const
expr
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_val_buf
=
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
...
...
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
View file @
7a7497f9
...
@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
(
void
)
ws_indices_global
;
(
void
)
ws_indices_global
;
(
void
)
indices_global
;
(
void
)
indices_global
;
const
expr
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
@@ -204,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -204,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{
{
(
void
)
ws_indices_global
;
(
void
)
ws_indices_global
;
const
expr
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
@@ -348,7 +348,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -348,7 +348,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{
{
(
void
)
origReduceLen
;
(
void
)
origReduceLen
;
const
expr
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_val_buf
=
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
...
...
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
View file @
7a7497f9
...
@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
(
void
)
ws_indices_global
;
(
void
)
ws_indices_global
;
(
void
)
indices_global
;
(
void
)
indices_global
;
const
expr
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
@@ -215,7 +215,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -215,7 +215,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{
{
(
void
)
ws_indices_global
;
(
void
)
ws_indices_global
;
const
expr
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
...
@@ -373,7 +373,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -373,7 +373,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{
{
(
void
)
origReduceLen
;
(
void
)
origReduceLen
;
const
expr
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_val_buf
=
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
...
...
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp
View file @
7a7497f9
...
@@ -86,7 +86,7 @@ struct GridwiseReduction_xy_to_x_multiblock
...
@@ -86,7 +86,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(
void
)
alpha
;
// unused
(
void
)
alpha
;
// unused
(
void
)
beta
;
// unused
(
void
)
beta
;
// unused
const
expr
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
// LDS
// LDS
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
...
@@ -216,7 +216,7 @@ struct GridwiseReduction_xy_to_x_multiblock
...
@@ -216,7 +216,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(
void
)
alpha
;
// unused
(
void
)
alpha
;
// unused
(
void
)
beta
;
// unused
(
void
)
beta
;
// unused
const
expr
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
// LDS
// LDS
__shared__
compType
p_in_block_values_buffer
[
BlockBufferSize
];
__shared__
compType
p_in_block_values_buffer
[
BlockBufferSize
];
...
...
composable_kernel/include/utility/reduction_operator.hpp
View file @
7a7497f9
...
@@ -82,7 +82,7 @@ struct Max
...
@@ -82,7 +82,7 @@ struct Max
{
{
using
dataType
=
T
;
using
dataType
=
T
;
__device__
static
constexpr
T
GetReductionZeroVal
()
{
return
NumericLimits
<
T
>::
l
owest
();
};
__device__
static
constexpr
T
GetReductionZeroVal
()
{
return
NumericLimits
<
T
>::
L
owest
();
};
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment