Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cab8f2e5
"include/vscode:/vscode.git/clone" did not exist on "b8635a25b2ce87f70433b32e00858c6ae0f39fde"
Commit
cab8f2e5
authored
Mar 14, 2022
by
Jing Zhang
Browse files
clean
parents
c20aabc3
9a17e7fb
Changes
86
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
492 additions
and
310 deletions
+492
-310
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp
...ice_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp
+6
-6
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce.hpp
...duce/device_reduce_instance_multiblock_partial_reduce.hpp
+34
-35
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp
...reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp
+19
-19
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp
...reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp
+9
-9
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp
...reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp
+22
-22
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp
...reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp
...reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp
+29
-29
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp
...instance/gpu/reduce/device_reduce_instance_threadwise.hpp
+34
-31
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp
.../reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp
+19
-19
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp
.../reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp
+9
-9
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp
.../reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp
+27
-27
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp
.../reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp
+9
-9
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp
.../reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp
+27
-27
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
+53
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
+53
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
+74
-0
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp
...u/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp
+19
-19
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp
...u/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp
+9
-9
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp
...u/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp
+27
-27
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp
...u/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp
+9
-9
No files found.
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp
View file @
cab8f2e5
...
@@ -11,13 +11,13 @@ namespace device {
...
@@ -11,13 +11,13 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
0
,
1
,
2
);
// for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
0
);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce.hpp
View file @
cab8f2e5
...
@@ -55,7 +55,7 @@ template <typename InDataType,
...
@@ -55,7 +55,7 @@ template <typename InDataType,
typename
AccDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
OutDataType
,
int
Rank
,
int
Rank
,
typename
ReduceDim
s
,
int
Num
ReduceDim
,
ReduceTensorOp_t
ReduceOpId
,
ReduceTensorOp_t
ReduceOpId
,
NanPropagation_t
NanOpt
,
NanPropagation_t
NanOpt
,
ReduceTensorIndices_t
IndicesOpt
>
ReduceTensorIndices_t
IndicesOpt
>
...
@@ -93,7 +93,7 @@ void add_device_reduce_instance_multiblock_partial_reduce(
...
@@ -93,7 +93,7 @@ void add_device_reduce_instance_multiblock_partial_reduce(
AccDataType
,
AccDataType
,
OutDataType
,
OutDataType
,
Rank
,
Rank
,
ReduceDim
s
,
Num
ReduceDim
,
ReduceOperation
,
ReduceOperation
,
InElementwiseOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
...
@@ -113,21 +113,21 @@ void add_device_reduce_instance_multiblock_partial_reduce(
...
@@ -113,21 +113,21 @@ void add_device_reduce_instance_multiblock_partial_reduce(
});
});
};
};
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE(
\
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank,
...)
\
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank,
NumReduceDim)
\
template void add_device_reduce_instance_multiblock_partial_reduce<inT,
\
template void add_device_reduce_instance_multiblock_partial_reduce<inT, \
compT,
\
compT, \
outT,
\
outT, \
Rank,
\
Rank, \
Sequence<__VA_ARGS__>
, \
NumReduceDim
, \
ReduceOpId,
\
ReduceOpId, \
NanOpt,
\
NanOpt, \
IndicesOpt>(
\
IndicesOpt>( \
std::vector<deviceReduceMultiBlockPartialReducePtrType<compT, ReduceOpId>> &
\
std::vector<deviceReduceMultiBlockPartialReducePtrType<compT, ReduceOpId>> & \
device_op_instances)
device_op_instances)
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID( \
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank,
...)
\
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank,
NumReduceDim)
\
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE(inT, \
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE(inT, \
compT, \
compT, \
outT, \
outT, \
...
@@ -135,28 +135,27 @@ void add_device_reduce_instance_multiblock_partial_reduce(
...
@@ -135,28 +135,27 @@ void add_device_reduce_instance_multiblock_partial_reduce(
static_cast<NanPropagation_t>(NanOpt), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
Rank, \
__VA_ARGS__)
NumReduceDim)
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE( \
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
extern template void \
extern template void add_device_reduce_instance_multiblock_partial_reduce<inT, \
add_device_reduce_instance_multiblock_partial_reduce<inT, \
compT, \
compT, \
outT, \
outT, \
Rank, \
Rank, \
NumReduceDim, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
ReduceOpId, \
NanOpt, \
NanOpt, \
IndicesOpt>( \
IndicesOpt>( \
std::vector< \
std::vector< \
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
InElementwiseOperation, \
InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
AccElementwiseOperation>> & \
AccElementwiseOperation>> & \
device_op_instances)
device_op_instances)
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID( \
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank,
...)
\
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank,
NumReduceDim)
\
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE(inT, \
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE(inT, \
compT, \
compT, \
outT, \
outT, \
...
@@ -164,7 +163,7 @@ void add_device_reduce_instance_multiblock_partial_reduce(
...
@@ -164,7 +163,7 @@ void add_device_reduce_instance_multiblock_partial_reduce(
static_cast<NanPropagation_t>(NanOpt), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
Rank, \
__VA_ARGS__
)
NumReduceDim
)
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
}
// namespace device
}
// namespace device
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp
View file @
cab8f2e5
...
@@ -11,25 +11,25 @@ namespace device {
...
@@ -11,25 +11,25 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp
View file @
cab8f2e5
...
@@ -11,16 +11,16 @@ namespace device {
...
@@ -11,16 +11,16 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
0
,
1
,
2
);
// for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
0
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp
View file @
cab8f2e5
...
@@ -11,29 +11,29 @@ namespace device {
...
@@ -11,29 +11,29 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp
View file @
cab8f2e5
...
@@ -11,10 +11,10 @@ namespace device {
...
@@ -11,10 +11,10 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp
View file @
cab8f2e5
...
@@ -11,37 +11,37 @@ namespace device {
...
@@ -11,37 +11,37 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
2
,
1
);
// Will be moved to use MultiBlockAtomicAdd
// Will be moved to use MultiBlockAtomicAdd
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
4
,
0
,
1
,
2
);
// for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
2
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
4
,
0
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
4
,
1
);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
2
,
1
);
//
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp
View file @
cab8f2e5
...
@@ -57,7 +57,7 @@ template <typename InDataType,
...
@@ -57,7 +57,7 @@ template <typename InDataType,
typename
AccDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
OutDataType
,
int
Rank
,
int
Rank
,
typename
ReduceDim
s
,
int
Num
ReduceDim
,
ReduceTensorOp_t
ReduceOpId
,
ReduceTensorOp_t
ReduceOpId
,
NanPropagation_t
NanOpt
,
NanPropagation_t
NanOpt
,
ReduceTensorIndices_t
IndicesOpt
>
ReduceTensorIndices_t
IndicesOpt
>
...
@@ -89,7 +89,7 @@ void add_device_reduce_instance_threadwise(
...
@@ -89,7 +89,7 @@ void add_device_reduce_instance_threadwise(
AccDataType
,
AccDataType
,
OutDataType
,
OutDataType
,
Rank
,
Rank
,
ReduceDim
s
,
Num
ReduceDim
,
ReduceOperation
,
ReduceOperation
,
InElementwiseOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
...
@@ -108,34 +108,36 @@ void add_device_reduce_instance_threadwise(
...
@@ -108,34 +108,36 @@ void add_device_reduce_instance_threadwise(
});
});
};
};
#define ADD_THREADWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
#define ADD_THREADWISE_INST_BY_TYPE( \
template void add_device_reduce_instance_threadwise<inT, \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \
template void add_device_reduce_instance_threadwise<inT, \
outT, \
compT, \
Rank, \
outT, \
Sequence<__VA_ARGS__>, \
Rank, \
ReduceOpId, \
NumReduceDim, \
NanOpt, \
ReduceOpId, \
IndicesOpt>( \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances)
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
#define ADD_THREADWISE_INST_BY_ID( \
ADD_THREADWISE_INST_BY_TYPE(inT, \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \
ADD_THREADWISE_INST_BY_TYPE(inT, \
outT, \
compT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
outT, \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
static_cast<NanPropagation_t>(NanOpt), \
Rank, \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
__VA_ARGS__)
Rank, \
NumReduceDim)
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank,
...)
\
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank,
NumReduceDim)
\
extern template void add_device_reduce_instance_threadwise<inT, \
extern template void add_device_reduce_instance_threadwise<inT, \
compT, \
compT, \
outT, \
outT, \
Rank, \
Rank, \
Sequence<__VA_ARGS__>,
\
NumReduceDim,
\
ReduceOpId, \
ReduceOpId, \
NanOpt, \
NanOpt, \
IndicesOpt>( \
IndicesOpt>( \
...
@@ -145,15 +147,16 @@ void add_device_reduce_instance_threadwise(
...
@@ -145,15 +147,16 @@ void add_device_reduce_instance_threadwise(
AccElementwiseOperation>> & \
AccElementwiseOperation>> & \
device_op_instances)
device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
#define ADD_THREADWISE_INST_REF_BY_ID( \
ADD_THREADWISE_INST_REF_BY_TYPE(inT, \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \
ADD_THREADWISE_INST_REF_BY_TYPE(inT, \
outT, \
compT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
outT, \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
static_cast<NanPropagation_t>(NanOpt), \
Rank, \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
__VA_ARGS__)
Rank, \
NumReduceDim)
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
}
// namespace device
}
// namespace device
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp
View file @
cab8f2e5
...
@@ -11,25 +11,25 @@ namespace device {
...
@@ -11,25 +11,25 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp
View file @
cab8f2e5
...
@@ -11,16 +11,16 @@ namespace device {
...
@@ -11,16 +11,16 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
0
,
1
,
2
);
// for ADD
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
0
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AVG
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp
View file @
cab8f2e5
...
@@ -11,34 +11,34 @@ namespace device {
...
@@ -11,34 +11,34 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
4
,
0
,
1
,
2
);
// for ADD
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
4
,
0
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AVG
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp
View file @
cab8f2e5
...
@@ -11,16 +11,16 @@ namespace device {
...
@@ -11,16 +11,16 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
0
,
1
,
2
);
// for ADD
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
0
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AVG
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp
View file @
cab8f2e5
...
@@ -11,34 +11,34 @@ namespace device {
...
@@ -11,34 +11,34 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
4
,
0
,
1
,
2
);
// for ADD
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
4
,
0
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
0
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AVG
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
5
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
7
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
2
,
0
,
1
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
3
,
0
,
1
,
2
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
4
,
0
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
2
,
1
);
//
ADD_THREADWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
0 → 100644
View file @
cab8f2e5
#include <stdlib.h>
#include "config.hpp"
#include "device_grouped_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_grouped_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization_t
::
Default
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
{});
}
}
// namespace device_grouped_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
0 → 100644
View file @
cab8f2e5
#include <stdlib.h>
#include "config.hpp"
#include "device_grouped_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_grouped_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization_t
::
Default
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
{});
}
}
// namespace device_grouped_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
0 → 100644
View file @
cab8f2e5
#include <stdlib.h>
#include "config.hpp"
#include "device_grouped_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_grouped_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization_t
::
Default
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization_t
::
MNPadding
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
// irregular tile size
using
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
144
,
8
,
8
,
16
,
16
,
2
,
9
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
8
,
4
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
144
,
4
,
8
,
16
,
16
,
2
,
9
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
{});
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances
{});
}
}
// namespace device_grouped_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp
View file @
cab8f2e5
...
@@ -6,25 +6,25 @@ namespace device {
...
@@ -6,25 +6,25 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MIN
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MAX
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MIN
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MAX
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp
View file @
cab8f2e5
...
@@ -6,16 +6,16 @@ namespace device {
...
@@ -6,16 +6,16 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
0
,
1
,
2
);
// for ADD
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
0
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AVG
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp
View file @
cab8f2e5
...
@@ -6,34 +6,34 @@ namespace device {
...
@@ -6,34 +6,34 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
4
,
0
,
1
,
2
);
// for ADD
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
4
,
0
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
0
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AVG
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
5
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
7
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MIN
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MAX
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MIN
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
2
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MAX
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
3
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp
View file @
cab8f2e5
...
@@ -6,16 +6,16 @@ namespace device {
...
@@ -6,16 +6,16 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDim
s
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank |
Num
ReduceDim
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
0
,
1
,
2
);
// for ADD
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
0
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
0
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AVG
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
5
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment