Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5890e300
Commit
5890e300
authored
Oct 25, 2021
by
Jun Liu
Committed by
GitHub
Oct 25, 2021
Browse files
[Composable Kernel] update develop branch code to ck_upstream
Merge pull request #1236 from ROCmSoftwarePlatform/develop
parents
8557901d
dfb80c4e
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
469 additions
and
64 deletions
+469
-64
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp
...eric_reduction_second_call_threadwise_reduce_all_dims.cpp
+222
-0
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp
..._reduction_second_call_threadwise_reduce_partial_dims.cpp
+13
-32
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp
...eneric_reduction_second_call_warpwise_reduce_all_dims.cpp
+221
-0
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp
...ic_reduction_second_call_warpwise_reduce_partial_dims.cpp
+13
-32
No files found.
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp
0 → 100644
View file @
5890e300
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_threadwise.hpp"
using
namespace
ck
;
using
srcDataType
=
typename
get_datatype_from_enum
<
static_cast
<
DataTypeEnum_t
>
(
CK_PARAM_SRC_DATATYPE
)
>::
type
;
using
dstDataType
=
typename
get_datatype_from_enum
<
static_cast
<
DataTypeEnum_t
>
(
CK_PARAM_DST_DATATYPE
)
>::
type
;
using
compType
=
typename
get_datatype_from_enum
<
static_cast
<
DataTypeEnum_t
>
(
CK_PARAM_REDUCE_COMPTYPE
)
>::
type
;
constexpr
index_t
BlockSize
=
CK_PARAM_BLOCKSIZE
;
// tunable
using
toReduceDims
=
Sequence
<
CK_PARAM_TOREDUCE_DIMS
>
;
using
invariantDims
=
Sequence
<
CK_PARAM_INVARIANT_DIMS
>
;
// this could be empty
constexpr
ReduceTensorOp_t
op
=
static_cast
<
ReduceTensorOp_t
>
(
CK_PARAM_REDUCE_OP
);
constexpr
NanPropagation_t
nanPropaOpt
=
CK_PARAM_NAN_PROPAGATE
==
0
?
NanPropagation_t
::
NOT_PROPAGATE_NAN
:
NanPropagation_t
::
PROPAGATE_NAN
;
constexpr
ReduceTensorIndices_t
reduceIndicesOpt
=
CK_PARAM_REDUCE_INDICES
==
0
?
ReduceTensorIndices_t
::
NO_INDICES
:
ReduceTensorIndices_t
::
FLATTENED_INDICES
;
constexpr
bool
src2d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_SRC2D_PADDING
);
constexpr
bool
dst1d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_DST1D_PADDING
);
constexpr
bool
indexable
=
reduce_binary_operator
<
compType
,
op
>::
indexable
;
constexpr
bool
need_indices
=
indexable
&&
(
reduceIndicesOpt
!=
ReduceTensorIndices_t
::
NO_INDICES
);
constexpr
index_t
GredThreadBufferLength
=
CK_PARAM_THREAD_BUFFER_LENGTH
;
// tunable
extern
"C"
__global__
void
gridwise_generic_reduce_2_prepare
(
int
GridSize
,
int
BlkGroupSize
,
void
*
__restrict__
ws_global
)
{
(
void
)
BlkGroupSize
;
void
*
p_src2dDesc
=
ws_global
;
void
*
p_dst1dDesc
=
static_cast
<
char
*>
(
ws_global
)
+
2048
;
const
auto
tupleDstLengths
=
make_tuple
(
1
);
const
auto
tupleDstStrides
=
make_tuple
(
1
);
auto
dstDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
const
index_t
invariantLen
=
dstDesc
.
GetLength
(
Number
<
0
>
{});
const
index_t
toReduceLen
=
BlkGroupSize
;
auto
src2dDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
invariantLen
,
toReduceLen
));
constexpr
auto
copySliceLen
=
GredThreadBufferLength
;
if
constexpr
(
src2d_need_padding
)
{
const
auto
srcPad1
=
GridSize
*
BlockSize
-
invariantLen
;
const
auto
srcPad2
=
((
toReduceLen
+
copySliceLen
-
1
)
/
copySliceLen
)
*
copySliceLen
-
toReduceLen
;
auto
src2dDesc_2
=
transform_tensor_descriptor
(
src2dDesc
,
make_tuple
(
make_pad_transform
(
invariantLen
,
0
,
srcPad1
),
make_pad_transform
(
toReduceLen
,
0
,
srcPad2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
src2dDesc_2
)
*>
(
p_src2dDesc
)
=
src2dDesc_2
;
}
else
{
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
src2dDesc
)
*>
(
p_src2dDesc
)
=
src2dDesc
;
}
if
constexpr
(
dst1d_need_padding
)
{
const
auto
dstPad
=
GridSize
*
BlockSize
-
invariantLen
;
auto
dst1dDesc_2
=
transform_tensor_descriptor
(
dstDesc
,
make_tuple
(
make_pad_transform
(
invariantLen
,
0
,
dstPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
dst1dDesc_2
)
*>
(
p_dst1dDesc
)
=
dst1dDesc_2
;
}
else
{
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
dstDesc
)
*>
(
p_dst1dDesc
)
=
dstDesc
;
}
};
struct
get_ref_desc_types
{
static
constexpr
auto
ref_tupleDstLengths
=
make_tuple
(
8
);
static
constexpr
auto
ref_dstDesc
=
make_naive_tensor_descriptor
(
ref_tupleDstLengths
,
ref_tupleDstLengths
);
static
constexpr
index_t
ref_invariantLen
=
ref_dstDesc
.
GetLength
(
Number
<
0
>
{});
static
constexpr
index_t
ref_toReduceLen
=
8
;
static
constexpr
auto
ref_src2dDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
ref_invariantLen
,
ref_toReduceLen
));
using
refType_src2dDesc
=
decltype
(
ref_src2dDesc
);
using
refType_dst1dDesc
=
decltype
(
ref_dstDesc
);
// used by the DirectThreadWise and DirectWarpWise method
using
refType_src2dDesc_padded_12
=
decltype
(
transform_tensor_descriptor
(
ref_src2dDesc
,
make_tuple
(
make_pad_transform
(
ref_invariantLen
,
0
,
2
),
make_pad_transform
(
ref_toReduceLen
,
0
,
2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{})));
using
refType_dst1dDesc_padded
=
decltype
(
transform_tensor_descriptor
(
ref_dstDesc
,
make_tuple
(
make_pad_transform
(
ref_invariantLen
,
0
,
2
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{})));
};
using
refType_src2dDesc
=
typename
get_ref_desc_types
::
refType_src2dDesc
;
using
refType_dst1dDesc
=
typename
get_ref_desc_types
::
refType_dst1dDesc
;
using
refType_src2dDesc_padded_12
=
typename
get_ref_desc_types
::
refType_src2dDesc_padded_12
;
using
refType_dst1dDesc_padded
=
typename
get_ref_desc_types
::
refType_dst1dDesc_padded
;
template
<
bool
need_padding
>
static
__device__
auto
get_reduction_src2d_descriptor
(
const
void
*
p_src2dDesc
)
{
if
constexpr
(
need_padding
)
return
(
*
reinterpret_cast
<
const
refType_src2dDesc_padded_12
*>
(
p_src2dDesc
));
else
return
(
*
reinterpret_cast
<
const
refType_src2dDesc
*>
(
p_src2dDesc
));
};
template
<
bool
need_padding
>
static
__device__
auto
get_reduction_dst1d_descriptor
(
const
void
*
p_dst1dDesc
)
{
if
constexpr
(
need_padding
)
return
(
*
reinterpret_cast
<
const
refType_dst1dDesc_padded
*>
(
p_dst1dDesc
));
else
return
(
*
reinterpret_cast
<
const
refType_dst1dDesc
*>
(
p_dst1dDesc
));
};
extern
"C"
__global__
void
gridwise_generic_reduce_2
(
int
origReduceLen
,
float
alpha
,
const
void
*
__restrict__
p_src_global
,
float
beta
,
void
*
__restrict__
p_dst_global
,
const
void
CONSTANT
*
ws_global
,
long
ws_buf2_bytes_offset
,
void
*
__restrict__
indices_global
)
{
(
void
)
p_src_global
;
const
void
*
p_src2dDesc
=
cast_pointer_to_generic_address_space
(
ws_global
);
const
void
*
p_dst1dDesc
=
static_cast
<
const
char
*>
(
p_src2dDesc
)
+
2048
;
void
*
ws_buf1_global
=
const_cast
<
char
*>
(
static_cast
<
const
char
*>
(
p_src2dDesc
)
+
4096
);
const
auto
src2dDesc
=
get_reduction_src2d_descriptor
<
src2d_need_padding
>
(
p_src2dDesc
);
const
auto
dst1dDesc
=
get_reduction_dst1d_descriptor
<
dst1d_need_padding
>
(
p_dst1dDesc
);
using
gridwise_2d_reduce
=
GridwiseReduction_xy_to_x_direct_threadwise
<
BlockSize
,
srcDataType
,
dstDataType
,
compType
,
decltype
(
src2dDesc
),
decltype
(
dst1dDesc
),
op
,
nanPropaOpt
,
reduceIndicesOpt
,
false
,
true
,
GredThreadBufferLength
>
;
void
*
const
ws_buf2_global
=
ws_buf2_bytes_offset
>
0
?
static_cast
<
void
*>
(
static_cast
<
char
*>
(
ws_buf1_global
)
+
ws_buf2_bytes_offset
)
:
nullptr
;
constexpr
int
RunId
=
need_indices
?
3
:
1
;
gridwise_2d_reduce
::
template
Run
<
RunId
>(
src2dDesc
,
dst1dDesc
,
origReduceLen
,
alpha
,
static_cast
<
const
srcDataType
*
const
__restrict__
>
(
ws_buf1_global
),
beta
,
static_cast
<
dstDataType
*
const
__restrict__
>
(
p_dst_global
),
static_cast
<
const
int
*
const
__restrict__
>
(
ws_buf2_global
),
static_cast
<
int
*
const
__restrict__
>
(
indices_global
));
};
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise.cpp
→
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise
_reduce_partial_dims
.cpp
View file @
5890e300
...
@@ -42,12 +42,8 @@ using compType =
...
@@ -42,12 +42,8 @@ using compType =
constexpr
index_t
BlockSize
=
CK_PARAM_BLOCKSIZE
;
// tunable
constexpr
index_t
BlockSize
=
CK_PARAM_BLOCKSIZE
;
// tunable
constexpr
index_t
srcDims
=
CK_PARAM_IN_DIMS
;
constexpr
index_t
dstDims
=
CK_PARAM_OUT_DIMS
;
constexpr
index_t
dstDims
=
CK_PARAM_OUT_DIMS
;
using
toReduceDims
=
Sequence
<
CK_PARAM_TOREDUCE_DIMS
>
;
using
invariantDims
=
Sequence
<
CK_PARAM_INVARIANT_DIMS
>
;
// this could be empty
constexpr
ReduceTensorOp_t
op
=
static_cast
<
ReduceTensorOp_t
>
(
CK_PARAM_REDUCE_OP
);
constexpr
ReduceTensorOp_t
op
=
static_cast
<
ReduceTensorOp_t
>
(
CK_PARAM_REDUCE_OP
);
constexpr
NanPropagation_t
nanPropaOpt
=
CK_PARAM_NAN_PROPAGATE
==
0
constexpr
NanPropagation_t
nanPropaOpt
=
CK_PARAM_NAN_PROPAGATE
==
0
?
NanPropagation_t
::
NOT_PROPAGATE_NAN
?
NanPropagation_t
::
NOT_PROPAGATE_NAN
...
@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
...
@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr
bool
src2d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_SRC2D_PADDING
);
constexpr
bool
src2d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_SRC2D_PADDING
);
constexpr
bool
dst1d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_DST1D_PADDING
);
constexpr
bool
dst1d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_DST1D_PADDING
);
////////////////////////////////////////////////////////////////////////////////////////
using
specDims
=
typename
sequence_merge
<
invariantDims
,
toReduceDims
>::
type
;
static_assert
(
is_valid_sequence_map
<
specDims
>::
value
&&
specDims
::
Size
()
==
srcDims
,
"Wrong invariant and/or toReduce dimensions!"
);
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert
(
invariantDims
::
Size
()
>
0
||
dstDims
==
1
,
"If all source dimensions are reduced, the dest should have only one dimension !!"
);
constexpr
bool
indexable
=
reduce_binary_operator
<
compType
,
op
>::
indexable
;
constexpr
bool
indexable
=
reduce_binary_operator
<
compType
,
op
>::
indexable
;
constexpr
bool
need_indices
=
indexable
&&
(
reduceIndicesOpt
!=
ReduceTensorIndices_t
::
NO_INDICES
);
constexpr
bool
need_indices
=
indexable
&&
(
reduceIndicesOpt
!=
ReduceTensorIndices_t
::
NO_INDICES
);
...
@@ -152,12 +138,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
...
@@ -152,12 +138,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
make_pad_transform
(
toReduceLen
,
0
,
srcPad2
)),
make_pad_transform
(
toReduceLen
,
0
,
srcPad2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
(
hipThreadIdx_x
==
0
)
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
src2dDesc_2
)
*>
(
p_src2dDesc
)
=
src2dDesc_2
;
*
static_cast
<
decltype
(
src2dDesc_2
)
*>
(
p_src2dDesc
)
=
src2dDesc_2
;
}
}
else
else
{
{
if
(
hipThreadIdx_x
==
0
)
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
src2dDesc
)
*>
(
p_src2dDesc
)
=
src2dDesc
;
*
static_cast
<
decltype
(
src2dDesc
)
*>
(
p_src2dDesc
)
=
src2dDesc
;
}
}
...
@@ -169,17 +155,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
...
@@ -169,17 +155,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
make_tuple
(
make_pad_transform
(
invariantLen
,
0
,
dstPad
)),
make_tuple
(
make_pad_transform
(
invariantLen
,
0
,
dstPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
if
(
hipThreadIdx_x
==
0
)
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
dst1dDesc_2
)
*>
(
p_dst1dDesc
)
=
dst1dDesc_2
;
*
static_cast
<
decltype
(
dst1dDesc_2
)
*>
(
p_dst1dDesc
)
=
dst1dDesc_2
;
}
}
else
else
{
{
if
(
hipThreadIdx_x
==
0
)
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
dst1dDesc
)
*>
(
p_dst1dDesc
)
=
dst1dDesc
;
*
static_cast
<
decltype
(
dst1dDesc
)
*>
(
p_dst1dDesc
)
=
dst1dDesc
;
}
}
};
};
template
<
index_t
srcDims
,
index_t
dstDims
,
typename
invariantDims
,
typename
toReduce
Dims
>
template
<
index_t
dst
Dims
>
struct
get_ref_desc_types
struct
get_ref_desc_types
{
{
static
constexpr
auto
ref_tupleDstLengths
=
static
constexpr
auto
ref_tupleDstLengths
=
...
@@ -217,16 +203,11 @@ struct get_ref_desc_types
...
@@ -217,16 +203,11 @@ struct get_ref_desc_types
make_tuple
(
Sequence
<
0
>
{})));
make_tuple
(
Sequence
<
0
>
{})));
};
};
using
refType_src2dDesc
=
using
refType_src2dDesc
=
typename
get_ref_desc_types
<
dstDims
>::
refType_src2dDesc
;
typename
get_ref_desc_types
<
srcDims
,
dstDims
,
invariantDims
,
toReduceDims
>::
refType_src2dDesc
;
using
refType_dst1dDesc
=
typename
get_ref_desc_types
<
dstDims
>::
refType_dst1dDesc
;
using
refType_dst1dDesc
=
typename
get_ref_desc_types
<
srcDims
,
dstDims
,
invariantDims
,
toReduceDims
>::
refType_dst1dDesc
;
using
refType_src2dDesc_padded_12
=
using
refType_src2dDesc_padded_12
=
typename
get_ref_desc_types
<
srcDims
,
dstDims
,
invariantDims
,
toReduceDims
>::
typename
get_ref_desc_types
<
dstDims
>::
refType_src2dDesc_padded_12
;
refType_src2dDesc_padded_12
;
using
refType_dst1dDesc_padded
=
typename
get_ref_desc_types
<
dstDims
>::
refType_dst1dDesc_padded
;
using
refType_dst1dDesc_padded
=
typename
get_ref_desc_types
<
srcDims
,
dstDims
,
invariantDims
,
toReduceDims
>::
refType_dst1dDesc_padded
;
template
<
bool
need_padding
>
template
<
bool
need_padding
>
static
__device__
auto
get_reduction_src2d_descriptor
(
const
void
*
p_src2dDesc
)
static
__device__
auto
get_reduction_src2d_descriptor
(
const
void
*
p_src2dDesc
)
...
@@ -251,15 +232,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
...
@@ -251,15 +232,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
const
void
*
__restrict__
p_src_global
,
const
void
*
__restrict__
p_src_global
,
float
beta
,
float
beta
,
void
*
__restrict__
p_dst_global
,
void
*
__restrict__
p_dst_global
,
void
*
__restrict__
ws_global
,
const
void
CONSTANT
*
ws_global
,
long
ws_buf2_bytes_offset
,
long
ws_buf2_bytes_offset
,
void
*
__restrict__
indices_global
)
void
*
__restrict__
indices_global
)
{
{
(
void
)
p_src_global
;
(
void
)
p_src_global
;
const
void
*
p_src2dDesc
=
ws_global
;
const
void
*
p_src2dDesc
=
cast_pointer_to_generic_address_space
(
ws_global
)
;
const
void
*
p_dst1dDesc
=
static_cast
<
c
har
*>
(
ws_global
)
+
2048
;
const
void
*
p_dst1dDesc
=
static_cast
<
c
onst
char
*>
(
p_src2dDesc
)
+
2048
;
void
*
ws_buf1_global
=
static_cast
<
char
*>
(
ws_global
)
+
4096
;
void
*
ws_buf1_global
=
const_cast
<
char
*>
(
static_cast
<
const
char
*>
(
p_src2dDesc
)
+
4096
)
;
const
auto
src2dDesc
=
get_reduction_src2d_descriptor
<
src2d_need_padding
>
(
p_src2dDesc
);
const
auto
src2dDesc
=
get_reduction_src2d_descriptor
<
src2d_need_padding
>
(
p_src2dDesc
);
const
auto
dst1dDesc
=
get_reduction_dst1d_descriptor
<
dst1d_need_padding
>
(
p_dst1dDesc
);
const
auto
dst1dDesc
=
get_reduction_dst1d_descriptor
<
dst1d_need_padding
>
(
p_dst1dDesc
);
...
...
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp
0 → 100644
View file @
5890e300
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
using
namespace
ck
;
using
srcDataType
=
typename
get_datatype_from_enum
<
static_cast
<
DataTypeEnum_t
>
(
CK_PARAM_SRC_DATATYPE
)
>::
type
;
using
dstDataType
=
typename
get_datatype_from_enum
<
static_cast
<
DataTypeEnum_t
>
(
CK_PARAM_DST_DATATYPE
)
>::
type
;
using
compType
=
typename
get_datatype_from_enum
<
static_cast
<
DataTypeEnum_t
>
(
CK_PARAM_REDUCE_COMPTYPE
)
>::
type
;
constexpr
index_t
BlockSize
=
CK_PARAM_BLOCKSIZE
;
// tunable
constexpr
ReduceTensorOp_t
op
=
static_cast
<
ReduceTensorOp_t
>
(
CK_PARAM_REDUCE_OP
);
constexpr
NanPropagation_t
nanPropaOpt
=
CK_PARAM_NAN_PROPAGATE
==
0
?
NanPropagation_t
::
NOT_PROPAGATE_NAN
:
NanPropagation_t
::
PROPAGATE_NAN
;
constexpr
ReduceTensorIndices_t
reduceIndicesOpt
=
CK_PARAM_REDUCE_INDICES
==
0
?
ReduceTensorIndices_t
::
NO_INDICES
:
ReduceTensorIndices_t
::
FLATTENED_INDICES
;
constexpr
bool
src2d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_SRC2D_PADDING
);
constexpr
bool
dst1d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_DST1D_PADDING
);
constexpr
bool
indexable
=
reduce_binary_operator
<
compType
,
op
>::
indexable
;
constexpr
bool
need_indices
=
indexable
&&
(
reduceIndicesOpt
!=
ReduceTensorIndices_t
::
NO_INDICES
);
constexpr
index_t
GredAccessesPerThreadInWarp
=
CK_PARAM_ACCESSES_PER_THREAD_INWARP
;
// tunable
extern
"C"
__global__
void
gridwise_generic_reduce_2_prepare
(
int
GridSize
,
int
BlkGroupSize
,
void
*
__restrict__
ws_global
)
{
(
void
)
BlkGroupSize
;
void
*
p_src2dDesc
=
ws_global
;
void
*
p_dst1dDesc
=
static_cast
<
char
*>
(
ws_global
)
+
2048
;
const
auto
tupleDstLengths
=
make_tuple
(
1
);
const
auto
tupleDstStrides
=
make_tuple
(
1
);
auto
dstDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
const
index_t
invariantLen
=
dstDesc
.
GetLength
(
Number
<
0
>
{});
const
index_t
toReduceLen
=
BlkGroupSize
;
auto
src2dDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
invariantLen
,
toReduceLen
));
constexpr
auto
copySliceLen
=
warpSize
*
GredAccessesPerThreadInWarp
;
if
constexpr
(
src2d_need_padding
)
{
const
auto
srcPad1
=
GridSize
*
BlockSize
/
warpSize
-
invariantLen
;
const
auto
srcPad2
=
((
toReduceLen
+
copySliceLen
-
1
)
/
copySliceLen
)
*
copySliceLen
-
toReduceLen
;
auto
src2dDesc_2
=
transform_tensor_descriptor
(
src2dDesc
,
make_tuple
(
make_pad_transform
(
invariantLen
,
0
,
srcPad1
),
make_pad_transform
(
toReduceLen
,
0
,
srcPad2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
src2dDesc_2
)
*>
(
p_src2dDesc
)
=
src2dDesc_2
;
}
else
{
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
src2dDesc
)
*>
(
p_src2dDesc
)
=
src2dDesc
;
}
if
constexpr
(
dst1d_need_padding
)
{
const
auto
dstPad
=
GridSize
*
BlockSize
/
warpSize
-
invariantLen
;
auto
dst1dDesc_2
=
transform_tensor_descriptor
(
dstDesc
,
make_tuple
(
make_pad_transform
(
invariantLen
,
0
,
dstPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
dst1dDesc_2
)
*>
(
p_dst1dDesc
)
=
dst1dDesc_2
;
}
else
{
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
dstDesc
)
*>
(
p_dst1dDesc
)
=
dstDesc
;
}
};
struct
get_ref_desc_types
{
static
constexpr
auto
ref_tupleDstLengths
=
make_tuple
(
8
);
static
constexpr
auto
ref_dstDesc
=
make_naive_tensor_descriptor
(
ref_tupleDstLengths
,
ref_tupleDstLengths
);
static
constexpr
index_t
ref_invariantLen
=
ref_dstDesc
.
GetLength
(
Number
<
0
>
{});
static
constexpr
index_t
ref_toReduceLen
=
8
;
static
constexpr
auto
ref_src2dDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
ref_invariantLen
,
ref_toReduceLen
));
using
refType_src2dDesc
=
decltype
(
ref_src2dDesc
);
using
refType_dst1dDesc
=
decltype
(
ref_dstDesc
);
// used by the DirectThreadWise and DirectWarpWise method
using
refType_src2dDesc_padded_12
=
decltype
(
transform_tensor_descriptor
(
ref_src2dDesc
,
make_tuple
(
make_pad_transform
(
ref_invariantLen
,
0
,
2
),
make_pad_transform
(
ref_toReduceLen
,
0
,
2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{})));
using
refType_dst1dDesc_padded
=
decltype
(
transform_tensor_descriptor
(
ref_dstDesc
,
make_tuple
(
make_pad_transform
(
ref_invariantLen
,
0
,
2
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{})));
};
using
refType_src2dDesc
=
typename
get_ref_desc_types
::
refType_src2dDesc
;
using
refType_dst1dDesc
=
typename
get_ref_desc_types
::
refType_dst1dDesc
;
using
refType_src2dDesc_padded_12
=
typename
get_ref_desc_types
::
refType_src2dDesc_padded_12
;
using
refType_dst1dDesc_padded
=
typename
get_ref_desc_types
::
refType_dst1dDesc_padded
;
template
<
bool
need_padding
>
static
__device__
auto
get_reduction_src2d_descriptor
(
const
void
*
p_src2dDesc
)
{
if
constexpr
(
need_padding
)
return
(
*
reinterpret_cast
<
const
refType_src2dDesc_padded_12
*>
(
p_src2dDesc
));
else
return
(
*
reinterpret_cast
<
const
refType_src2dDesc
*>
(
p_src2dDesc
));
};
template
<
bool
need_padding
>
static
__device__
auto
get_reduction_dst1d_descriptor
(
const
void
*
p_dst1dDesc
)
{
if
constexpr
(
need_padding
)
return
(
*
reinterpret_cast
<
const
refType_dst1dDesc_padded
*>
(
p_dst1dDesc
));
else
return
(
*
reinterpret_cast
<
const
refType_dst1dDesc
*>
(
p_dst1dDesc
));
};
extern
"C"
__global__
void
gridwise_generic_reduce_2
(
int
origReduceLen
,
float
alpha
,
const
void
*
__restrict__
p_src_global
,
float
beta
,
void
*
__restrict__
p_dst_global
,
const
void
CONSTANT
*
ws_global
,
long
ws_buf2_bytes_offset
,
void
*
__restrict__
indices_global
)
{
(
void
)
p_src_global
;
const
void
*
p_src2dDesc
=
cast_pointer_to_generic_address_space
(
ws_global
);
const
void
*
p_dst1dDesc
=
static_cast
<
const
char
*>
(
p_src2dDesc
)
+
2048
;
void
*
ws_buf1_global
=
const_cast
<
char
*>
(
static_cast
<
const
char
*>
(
p_src2dDesc
)
+
4096
);
const
auto
src2dDesc
=
get_reduction_src2d_descriptor
<
src2d_need_padding
>
(
p_src2dDesc
);
const
auto
dst1dDesc
=
get_reduction_dst1d_descriptor
<
dst1d_need_padding
>
(
p_dst1dDesc
);
using
gridwise_2d_reduce
=
GridwiseReduction_xy_to_x_direct_warpwise
<
BlockSize
,
srcDataType
,
dstDataType
,
compType
,
decltype
(
src2dDesc
),
decltype
(
dst1dDesc
),
op
,
nanPropaOpt
,
reduceIndicesOpt
,
false
,
true
,
GredAccessesPerThreadInWarp
>
;
void
*
const
ws_buf2_global
=
ws_buf2_bytes_offset
>
0
?
static_cast
<
void
*>
(
static_cast
<
char
*>
(
ws_buf1_global
)
+
ws_buf2_bytes_offset
)
:
nullptr
;
constexpr
int
RunId
=
need_indices
?
3
:
1
;
gridwise_2d_reduce
::
template
Run
<
RunId
>(
src2dDesc
,
dst1dDesc
,
origReduceLen
,
alpha
,
static_cast
<
const
srcDataType
*
const
__restrict__
>
(
ws_buf1_global
),
beta
,
static_cast
<
dstDataType
*
const
__restrict__
>
(
p_dst_global
),
static_cast
<
const
int
*
const
__restrict__
>
(
ws_buf2_global
),
static_cast
<
int
*
const
__restrict__
>
(
indices_global
));
};
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise.cpp
→
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise
_reduce_partial_dims
.cpp
View file @
5890e300
...
@@ -42,12 +42,8 @@ using compType =
...
@@ -42,12 +42,8 @@ using compType =
constexpr
index_t
BlockSize
=
CK_PARAM_BLOCKSIZE
;
// tunable
constexpr
index_t
BlockSize
=
CK_PARAM_BLOCKSIZE
;
// tunable
constexpr
index_t
srcDims
=
CK_PARAM_IN_DIMS
;
constexpr
index_t
dstDims
=
CK_PARAM_OUT_DIMS
;
constexpr
index_t
dstDims
=
CK_PARAM_OUT_DIMS
;
using
toReduceDims
=
Sequence
<
CK_PARAM_TOREDUCE_DIMS
>
;
using
invariantDims
=
Sequence
<
CK_PARAM_INVARIANT_DIMS
>
;
// this could be empty
constexpr
ReduceTensorOp_t
op
=
static_cast
<
ReduceTensorOp_t
>
(
CK_PARAM_REDUCE_OP
);
constexpr
ReduceTensorOp_t
op
=
static_cast
<
ReduceTensorOp_t
>
(
CK_PARAM_REDUCE_OP
);
constexpr
NanPropagation_t
nanPropaOpt
=
CK_PARAM_NAN_PROPAGATE
==
0
constexpr
NanPropagation_t
nanPropaOpt
=
CK_PARAM_NAN_PROPAGATE
==
0
?
NanPropagation_t
::
NOT_PROPAGATE_NAN
?
NanPropagation_t
::
NOT_PROPAGATE_NAN
...
@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
...
@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr
bool
src2d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_SRC2D_PADDING
);
constexpr
bool
src2d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_SRC2D_PADDING
);
constexpr
bool
dst1d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_DST1D_PADDING
);
constexpr
bool
dst1d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_DST1D_PADDING
);
////////////////////////////////////////////////////////////////////////////////////////
using
specDims
=
typename
sequence_merge
<
invariantDims
,
toReduceDims
>::
type
;
static_assert
(
is_valid_sequence_map
<
specDims
>::
value
&&
specDims
::
Size
()
==
srcDims
,
"Wrong invariant and/or toReduce dimensions!"
);
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert
(
invariantDims
::
Size
()
>
0
||
dstDims
==
1
,
"If all source dimensions are reduced, the dest should have only one dimension !!"
);
constexpr
bool
indexable
=
reduce_binary_operator
<
compType
,
op
>::
indexable
;
constexpr
bool
indexable
=
reduce_binary_operator
<
compType
,
op
>::
indexable
;
constexpr
bool
need_indices
=
indexable
&&
(
reduceIndicesOpt
!=
ReduceTensorIndices_t
::
NO_INDICES
);
constexpr
bool
need_indices
=
indexable
&&
(
reduceIndicesOpt
!=
ReduceTensorIndices_t
::
NO_INDICES
);
...
@@ -153,12 +139,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
...
@@ -153,12 +139,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
make_pad_transform
(
toReduceLen
,
0
,
srcPad2
)),
make_pad_transform
(
toReduceLen
,
0
,
srcPad2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
(
hipThreadIdx_x
==
0
)
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
src2dDesc_2
)
*>
(
p_src2dDesc
)
=
src2dDesc_2
;
*
static_cast
<
decltype
(
src2dDesc_2
)
*>
(
p_src2dDesc
)
=
src2dDesc_2
;
}
}
else
else
{
{
if
(
hipThreadIdx_x
==
0
)
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
src2dDesc
)
*>
(
p_src2dDesc
)
=
src2dDesc
;
*
static_cast
<
decltype
(
src2dDesc
)
*>
(
p_src2dDesc
)
=
src2dDesc
;
}
}
...
@@ -170,17 +156,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
...
@@ -170,17 +156,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
make_tuple
(
make_pad_transform
(
invariantLen
,
0
,
dstPad
)),
make_tuple
(
make_pad_transform
(
invariantLen
,
0
,
dstPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
if
(
hipThreadIdx_x
==
0
)
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
dst1dDesc_2
)
*>
(
p_dst1dDesc
)
=
dst1dDesc_2
;
*
static_cast
<
decltype
(
dst1dDesc_2
)
*>
(
p_dst1dDesc
)
=
dst1dDesc_2
;
}
}
else
else
{
{
if
(
hipThreadIdx_x
==
0
)
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
dst1dDesc
)
*>
(
p_dst1dDesc
)
=
dst1dDesc
;
*
static_cast
<
decltype
(
dst1dDesc
)
*>
(
p_dst1dDesc
)
=
dst1dDesc
;
}
}
};
};
template
<
index_t
srcDims
,
index_t
dstDims
,
typename
invariantDims
,
typename
toReduce
Dims
>
template
<
index_t
dst
Dims
>
struct
get_ref_desc_types
struct
get_ref_desc_types
{
{
static
constexpr
auto
ref_tupleDstLengths
=
static
constexpr
auto
ref_tupleDstLengths
=
...
@@ -218,16 +204,11 @@ struct get_ref_desc_types
...
@@ -218,16 +204,11 @@ struct get_ref_desc_types
make_tuple
(
Sequence
<
0
>
{})));
make_tuple
(
Sequence
<
0
>
{})));
};
};
using
refType_src2dDesc
=
using
refType_src2dDesc
=
typename
get_ref_desc_types
<
dstDims
>::
refType_src2dDesc
;
typename
get_ref_desc_types
<
srcDims
,
dstDims
,
invariantDims
,
toReduceDims
>::
refType_src2dDesc
;
using
refType_dst1dDesc
=
typename
get_ref_desc_types
<
dstDims
>::
refType_dst1dDesc
;
using
refType_dst1dDesc
=
typename
get_ref_desc_types
<
srcDims
,
dstDims
,
invariantDims
,
toReduceDims
>::
refType_dst1dDesc
;
using
refType_src2dDesc_padded_12
=
using
refType_src2dDesc_padded_12
=
typename
get_ref_desc_types
<
srcDims
,
dstDims
,
invariantDims
,
toReduceDims
>::
typename
get_ref_desc_types
<
dstDims
>::
refType_src2dDesc_padded_12
;
refType_src2dDesc_padded_12
;
using
refType_dst1dDesc_padded
=
typename
get_ref_desc_types
<
dstDims
>::
refType_dst1dDesc_padded
;
using
refType_dst1dDesc_padded
=
typename
get_ref_desc_types
<
srcDims
,
dstDims
,
invariantDims
,
toReduceDims
>::
refType_dst1dDesc_padded
;
template
<
bool
need_padding
>
template
<
bool
need_padding
>
static
__device__
auto
get_reduction_src2d_descriptor
(
const
void
*
p_src2dDesc
)
static
__device__
auto
get_reduction_src2d_descriptor
(
const
void
*
p_src2dDesc
)
...
@@ -252,15 +233,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
...
@@ -252,15 +233,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
const
void
*
__restrict__
p_src_global
,
const
void
*
__restrict__
p_src_global
,
float
beta
,
float
beta
,
void
*
__restrict__
p_dst_global
,
void
*
__restrict__
p_dst_global
,
void
*
__restrict__
ws_global
,
const
void
CONSTANT
*
ws_global
,
long
ws_buf2_bytes_offset
,
long
ws_buf2_bytes_offset
,
void
*
__restrict__
indices_global
)
void
*
__restrict__
indices_global
)
{
{
(
void
)
p_src_global
;
(
void
)
p_src_global
;
const
void
*
p_src2dDesc
=
ws_global
;
const
void
*
p_src2dDesc
=
cast_pointer_to_generic_address_space
(
ws_global
)
;
const
void
*
p_dst1dDesc
=
static_cast
<
c
har
*>
(
ws_global
)
+
2048
;
const
void
*
p_dst1dDesc
=
static_cast
<
c
onst
char
*>
(
p_src2dDesc
)
+
2048
;
void
*
ws_buf1_global
=
static_cast
<
char
*>
(
ws_global
)
+
4096
;
void
*
ws_buf1_global
=
const_cast
<
char
*>
(
static_cast
<
const
char
*>
(
p_src2dDesc
)
+
4096
)
;
const
auto
src2dDesc
=
get_reduction_src2d_descriptor
<
src2d_need_padding
>
(
p_src2dDesc
);
const
auto
src2dDesc
=
get_reduction_src2d_descriptor
<
src2d_need_padding
>
(
p_src2dDesc
);
const
auto
dst1dDesc
=
get_reduction_dst1d_descriptor
<
dst1d_need_padding
>
(
p_dst1dDesc
);
const
auto
dst1dDesc
=
get_reduction_dst1d_descriptor
<
dst1d_need_padding
>
(
p_dst1dDesc
);
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment