Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
22ee67a9
Commit
22ee67a9
authored
Apr 11, 2024
by
root
Browse files
add reduce_threadwise_multi_d
parent
b17ce193
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
212 additions
and
153 deletions
+212
-153
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
example/12_reduce/reduce_threadwise.cpp
example/12_reduce/reduce_threadwise.cpp
+15
-15
example/12_reduce/reduce_threadwise_impl.hpp
example/12_reduce/reduce_threadwise_impl.hpp
+37
-45
include/ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp
.../ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp
+69
-0
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
...tion/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
+87
-86
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp
...ion/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp
+3
-6
No files found.
cmake/EnableCompilerWarnings.cmake
View file @
22ee67a9
...
@@ -66,7 +66,7 @@ else()
...
@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunreachable-code
-Wunused
-Wunused
-Wno-reserved-identifier
-Wno-reserved-identifier
-Werror
#
-Werror
-Wno-option-ignored
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-extra-semi-stmt
...
...
example/12_reduce/reduce_threadwise.cpp
View file @
22ee67a9
...
@@ -235,8 +235,8 @@ int main(int argc, char* argv[])
...
@@ -235,8 +235,8 @@ int main(int argc, char* argv[])
else
else
{
{
// for testing half_t
// for testing half_t
pass
=
pass
=
pass
&&
pass
&&
reduce_threadwise_test
<
ck
::
half_t
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
reduce_threadwise_test
<
ck
::
half_t
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
},
1.0
f
,
0.0
f
);
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
},
1.0
f
,
0.0
f
);
// for testing float
// for testing float
...
...
example/12_reduce/reduce_threadwise_impl.hpp
View file @
22ee67a9
...
@@ -89,27 +89,25 @@ int reduce_threadwise_impl(bool do_verification,
...
@@ -89,27 +89,25 @@ int reduce_threadwise_impl(bool do_verification,
return
(
-
1
);
return
(
-
1
);
};
};
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
// using Add = tensor_operation::element_wise::Add;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
using
InElementwiseOperation
=
PassThrough
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
OutElementwiseOperation
=
PassThrough
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
using
InOutDataTypeInDevice
=
InOutDataType
;
using
InOutDataTypeInDevice
=
InOutDataType
;
using
DeviceReduceInstance
=
using
DeviceReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceReduceThreadWiseMultiD
<
InOutDataTypeInDevice
,
ck
::
tensor_operation
::
device
::
DeviceReduceThreadWiseMultiD
<
InOutDataTypeInDevice
,
ck
::
Tuple
<>
,
AccDataType
,
AccDataType
,
InOutDataTypeInDevice
,
InOutDataTypeInDevice
,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
ReduceOperation
,
ReduceOperation
,
InElementwiseOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
OutElementwiseOperation
,
PropagateNan
,
OutputIndex
,
false
,
false
,
// HaveIndexInputIfOutputIndex
256
,
// BlockSize
256
,
// BlockSize
4
,
// MThreadSliceSize
4
,
// MThreadSliceSize
1
,
// KThreadSliceSize
1
,
// KThreadSliceSize
...
@@ -173,7 +171,6 @@ int reduce_threadwise_impl(bool do_verification,
...
@@ -173,7 +171,6 @@ int reduce_threadwise_impl(bool do_verification,
DeviceMem
in_dev
(
sizeof
(
InOutDataTypeInDevice
)
*
in
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
in_dev
(
sizeof
(
InOutDataTypeInDevice
)
*
in
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
out_dev
(
sizeof
(
InOutDataTypeInDevice
)
*
out
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
out_dev
(
sizeof
(
InOutDataTypeInDevice
)
*
out
.
mDesc
.
GetElementSpaceSize
());
in_dev
.
ToDevice
(
in
.
mData
.
data
());
in_dev
.
ToDevice
(
in
.
mData
.
data
());
if
(
beta
!=
0.0
f
)
if
(
beta
!=
0.0
f
)
...
@@ -187,11 +184,7 @@ int reduce_threadwise_impl(bool do_verification,
...
@@ -187,11 +184,7 @@ int reduce_threadwise_impl(bool do_verification,
DeviceMem
out_index_dev
(
indicesSizeInBytes
);
DeviceMem
out_index_dev
(
indicesSizeInBytes
);
InElementwiseOperation
in_elementwise_op
;
InElementwiseOperation
in_elementwise_op
;
AccElementwiseOperation
acc_elementwise_op
;
OutElementwiseOperation
out_elementwise_op
;
std
::
tie
(
in_elementwise_op
,
acc_elementwise_op
)
=
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
static_cast
<
int32_t
>
(
reduce_total_length
));
std
::
array
<
index_t
,
Rank
>
arrInLengths
;
std
::
array
<
index_t
,
Rank
>
arrInLengths
;
std
::
array
<
index_t
,
Rank
>
arrInStrides
;
std
::
array
<
index_t
,
Rank
>
arrInStrides
;
...
@@ -213,7 +206,7 @@ int reduce_threadwise_impl(bool do_verification,
...
@@ -213,7 +206,7 @@ int reduce_threadwise_impl(bool do_verification,
NumReduceDim
,
NumReduceDim
,
ReduceOperation
,
ReduceOperation
,
InElementwiseOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PassThrough
,
PropagateNan
,
PropagateNan
,
OutputIndex
>
;
OutputIndex
>
;
...
@@ -231,7 +224,7 @@ int reduce_threadwise_impl(bool do_verification,
...
@@ -231,7 +224,7 @@ int reduce_threadwise_impl(bool do_verification,
out_ref
.
mData
.
data
(),
out_ref
.
mData
.
data
(),
out_indices_ref
.
mData
.
data
(),
out_indices_ref
.
mData
.
data
(),
in_elementwise_op
,
in_elementwise_op
,
acc_elementwise_op
);
PassThrough
{}
);
if
(
!
reduce_ref
.
IsSupportedArgument
(
argument_ptr_ref
.
get
()))
if
(
!
reduce_ref
.
IsSupportedArgument
(
argument_ptr_ref
.
get
()))
{
{
...
@@ -249,17 +242,16 @@ int reduce_threadwise_impl(bool do_verification,
...
@@ -249,17 +242,16 @@ int reduce_threadwise_impl(bool do_verification,
auto
argument_ptr
=
reduce
.
MakeArgumentPointer
(
arrInLengths
,
auto
argument_ptr
=
reduce
.
MakeArgumentPointer
(
arrInLengths
,
arrInStrides
,
arrInStrides
,
{},
{},
arrOutLengths
,
arrOutLengths
,
arrOutStrides
,
arrOutStrides
,
reduceDims
,
reduceDims
,
static_cast
<
double
>
(
alpha
),
static_cast
<
double
>
(
beta
),
in_dev
.
GetDeviceBuffer
(),
in_dev
.
GetDeviceBuffer
(),
nullptr
,
{}
,
out_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
(),
out_index_dev
.
GetDeviceBuffer
(),
in_elementwise_op
,
in_elementwise_op
,
acc
_elementwise_op
);
out
_elementwise_op
);
if
(
!
reduce
.
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
reduce
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
...
...
include/ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp
0 → 100644
View file @
22ee67a9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataType
,
typename
DsDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
OutElementwiseOperation
>
struct
DeviceReduceMultiD
:
public
BaseOperator
{
static
constexpr
index_t
NumOutDim
=
(
Rank
-
NumReduceDim
==
0
)
?
1
:
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutDim
>
,
NumDTensor
>
DsLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutDim
>
,
NumDTensor
>
DsStrides
,
const
std
::
array
<
index_t
,
NumOutDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumOutDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
void
*
in_dev
,
const
std
::
array
<
const
void
*
,
NumDTensor
>
ds_dev
,
void
*
out_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
OutElementwiseOperation
out_elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
InDataType
,
typename
DsDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
OutElementwiseOperation
>
using
DeviceReduceMultiDPtr
=
std
::
unique_ptr
<
DeviceReduceMultiD
<
InDataType
,
DsDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
OutElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
View file @
22ee67a9
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce
_multi_d
.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp"
...
@@ -19,33 +19,29 @@ namespace tensor_operation {
...
@@ -19,33 +19,29 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
DsDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
OutElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
,
bool
TransformIndexKtoGlobal
,
bool
HaveIndexInputIfOutputIndex
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadSliceSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
index_t
OutDstVectorSize
>
struct
DeviceReduceThreadWiseMultiD
:
public
DeviceReduce
<
InDataType
,
struct
DeviceReduceThreadWiseMultiD
:
public
DeviceReduceMultiD
<
InDataType
,
DsDataType
,
AccDataType
,
AccDataType
,
OutDataType
,
OutDataType
,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
ReduceOperation
,
ReduceOperation
,
InElementwiseOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
OutElementwiseOperation
>
PropagateNan
,
OutputIndex
>
{
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
...
@@ -57,10 +53,10 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
...
@@ -57,10 +53,10 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
using
IndexDataType
=
int32_t
;
using
IndexDataType
=
int32_t
;
static
constexpr
bool
HaveIndexInput
=
OutputIndex
&&
HaveIndexInputIfOutputIndex
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumSrcDim
=
Rank
;
static
constexpr
index_t
NumSrcDim
=
Rank
;
static
constexpr
index_t
NumDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
index_t
NumDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
...
@@ -159,34 +155,69 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
...
@@ -159,34 +155,69 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
return
(
out_grid_desc_m_padded
);
return
(
out_grid_desc_m_padded
);
};
};
static
auto
MakeDsDescriptor
(
const
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsLengths
,
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsStrides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
DeviceReduceThreadWiseMultiD
::
MakeDst1dDescriptor
(
DsLengths
[
i
],
DsStrides
[
i
]);
},
Number
<
NumDTensor
>
{});
}
using
InGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({},
{}));
using
OutGridDesc_M
=
decltype
(
MakeDst1dDescriptor
({},
{}));
using
DsGridDesc_M
=
decltype
(
MakeDsDescriptor
({},
{}));
using
GridwiseReduce
=
GridwiseReduction_mk_to_m_threadwise_multi_d
<
InDataType
,
DsDataType
,
OutDataType
,
AccDataType
,
InGridDesc_M_K
,
DsGridDesc_M
,
OutGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
OutElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
BlockSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
using
DsGridPointer
=
typename
GridwiseReduce
::
DsGridPointer
;
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
Argument
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
double
alpha
,
double
beta
,
const
InDataType
*
in_dev
,
const
InDataType
*
in_dev
,
const
std
::
array
<
const
void
*
,
NumDTensor
>
ds_dev
,
OutDataType
*
out_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_index_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
const
OutElementwiseOperation
out_elementwise_op
)
:
outLengths_
{
outLengths
},
:
DsLengths_
{
DsLengths
},
DsStrides_
{
DsStrides
},
outLengths_
{
outLengths
},
outStrides_
{
outStrides
},
outStrides_
{
outStrides
},
in_dev_
{
in_dev
},
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_dev_
{
out_dev
},
out_index_dev_
{
out_index_dev
},
in_elementwise_op_
{
in_elementwise_op
},
in_elementwise_op_
{
in_elementwise_op
},
acc
_elementwise_op_
{
acc
_elementwise_op
}
out
_elementwise_op_
{
out
_elementwise_op
}
{
{
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
alpha_
=
type_convert
<
AccDataType
>
(
alpha
);
beta_
=
type_convert
<
AccDataType
>
(
beta
);
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
inLengths_
);
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
inLengths_
);
...
@@ -201,22 +232,33 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
...
@@ -201,22 +232,33 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
;
M_BlockTileSize
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
ds_dev
[
i
]);
});
ds_grid_desc_m_
=
MakeDsDescriptor
(
DsLengths
,
DsStrides
);
}
}
std
::
array
<
index_t
,
Rank
>
inLengths_
;
std
::
array
<
index_t
,
Rank
>
inLengths_
;
std
::
array
<
index_t
,
Rank
>
inStrides_
;
std
::
array
<
index_t
,
Rank
>
inStrides_
;
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsLengths_
;
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsStrides_
;
std
::
array
<
index_t
,
NumDstDim
>
outLengths_
;
std
::
array
<
index_t
,
NumDstDim
>
outLengths_
;
std
::
array
<
index_t
,
NumDstDim
>
outStrides_
;
std
::
array
<
index_t
,
NumDstDim
>
outStrides_
;
AccDataType
alpha_
;
AccDataType
beta_
;
const
InDataType
*
in_dev_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
OutDataType
*
out_dev_
;
IndexDataType
*
out_index_dev_
;
DsGridPointer
p_ds_grid_
;
InElementwiseOperation
in_elementwise_op_
;
InElementwiseOperation
in_elementwise_op_
;
AccElementwiseOperation
acc_elementwise_op_
;
OutElementwiseOperation
out_elementwise_op_
;
DsGridDesc_M
ds_grid_desc_m_
;
index_t
invariant_lowest_length
;
index_t
invariant_lowest_length
;
index_t
reduce_lowest_length
;
index_t
reduce_lowest_length
;
...
@@ -236,44 +278,8 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
...
@@ -236,44 +278,8 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
const
auto
out_grid_desc_m
=
const
auto
out_grid_desc_m
=
DeviceReduceThreadWiseMultiD
::
MakeDst1dDescriptor
(
arg
.
outLengths_
,
arg
.
outStrides_
);
DeviceReduceThreadWiseMultiD
::
MakeDst1dDescriptor
(
arg
.
outLengths_
,
arg
.
outStrides_
);
const
auto
ds_grid_desc_m
=
generate_tuple
(
[
&
](
auto
i
)
{
ignore
=
i
;
return
DeviceReduceThreadWiseMultiD
::
MakeDst1dDescriptor
(
arg
.
outLengths_
,
arg
.
outStrides_
);
},
Number
<
1
>
{});
using
InGridDesc_M_K
=
decltype
(
in_grid_desc_m_k
);
using
OutGridDesc_M
=
decltype
(
out_grid_desc_m
);
using
DsGridDesc_M
=
decltype
(
ds_grid_desc_m
);
float
avg_time
=
0
;
float
avg_time
=
0
;
using
Add
=
tensor_operation
::
element_wise
::
Add
;
using
GridwiseReduce
=
GridwiseReduction_mk_to_m_threadwise_multi_d
<
InDataType
,
Tuple
<
OutDataType
>
,
OutDataType
,
AccDataType
,
InGridDesc_M_K
,
DsGridDesc_M
,
OutGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
Add
,
InMemoryDataOperationEnum
::
Set
,
PropagateNan
,
BlockSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
using
DsGridPointer
=
typename
GridwiseReduce
::
DsGridPointer
;
const
auto
kernel
=
kernel_reduce_threadwise_multi_d
<
GridwiseReduce
,
const
auto
kernel
=
kernel_reduce_threadwise_multi_d
<
GridwiseReduce
,
InDataType
,
InDataType
,
OutDataType
,
OutDataType
,
...
@@ -282,23 +288,21 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
...
@@ -282,23 +288,21 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
DsGridDesc_M
,
DsGridDesc_M
,
OutGridDesc_M
,
OutGridDesc_M
,
InElementwiseOperation
,
InElementwiseOperation
,
Add
,
OutElementwiseOperation
,
DsGridPointer
>
;
DsGridPointer
>
;
DsGridPointer
p_ds_grid_
;
avg_time
=
launch_and_time_kernel
(
stream_config
,
avg_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
arg
.
gridSize
),
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
in_grid_desc_m_k
,
in_grid_desc_m_k
,
ds_grid_desc_m
,
arg
.
ds_grid_desc_m
_
,
out_grid_desc_m
,
out_grid_desc_m
,
arg
.
in_elementwise_op_
,
arg
.
in_elementwise_op_
,
Add
{}
,
arg
.
out_elementwise_op_
,
arg
.
in_dev_
,
arg
.
in_dev_
,
p_ds_grid_
,
arg
.
p_ds_grid_
,
arg
.
out_dev_
);
arg
.
out_dev_
);
return
(
avg_time
);
return
(
avg_time
);
...
@@ -356,32 +360,29 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
...
@@ -356,32 +360,29 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDstDim
>
,
NumDTensor
>
DsStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
const
void
*
in_dev
,
const
void
*
in_index
_dev
,
const
std
::
array
<
const
void
*
,
NumDTensor
>
ds
_dev
,
void
*
out_dev
,
void
*
out_dev
,
void
*
out_index_dev
,
const
InElementwiseOperation
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
Acc
ElementwiseOperation
acc
_elementwise_op
)
override
const
Out
ElementwiseOperation
out
_elementwise_op
)
override
{
{
(
void
)
in_index_dev
;
return
std
::
make_unique
<
Argument
>
(
inLengths
,
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
inStrides
,
DsLengths
,
DsStrides
,
outLengths
,
outLengths
,
outStrides
,
outStrides
,
reduceDims
,
reduceDims
,
alpha
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
const
InDataType
*>
(
in_dev
),
ds_dev
,
static_cast
<
OutDataType
*>
(
out_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
static_cast
<
IndexDataType
*>
(
out_index_dev
),
in_elementwise_op
,
in_elementwise_op
,
acc
_elementwise_op
);
out
_elementwise_op
);
};
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp
View file @
22ee67a9
...
@@ -55,7 +55,6 @@ template <typename InDataType,
...
@@ -55,7 +55,6 @@ template <typename InDataType,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
InMemoryDataOperationEnum
OutMemoryDataOperation
,
InMemoryDataOperationEnum
OutMemoryDataOperation
,
bool
PropagateNan
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadSliceSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
...
@@ -110,7 +109,7 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d
...
@@ -110,7 +109,7 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d
ThreadReduceSrcDesc_M_K
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
ReduceOperation
,
PropagateNan
>
;
false
>
;
const
auto
identityVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>();
const
auto
identityVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>();
...
@@ -189,8 +188,6 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d
...
@@ -189,8 +188,6 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d
auto
ds_global_buf
=
generate_tuple
(
auto
ds_global_buf
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
I
)
{
// static_assert(ds_grid_desc_m[I].GetNumOfDimension() == 1, "");
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
I
],
ds_grid_desc_m
[
I
].
GetElementSpaceSize
());
p_ds_grid
[
I
],
ds_grid_desc_m
[
I
].
GetElementSpaceSize
());
},
},
...
@@ -208,9 +205,9 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d
...
@@ -208,9 +205,9 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d
Sequence
<
MThreadSliceSize
>
,
// SliceLengths
Sequence
<
MThreadSliceSize
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
0
,
// SrcVectorDim
OutDstVectorSize
,
1
,
1
,
// SrcScalarStrideInVector
1
,
// SrcScalarStrideInVector
fals
e
>
{
tru
e
>
{
ds_grid_desc_m
[
I
],
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
)};
ds_grid_desc_m
[
I
],
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
)};
},
},
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment